Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Unreleased
### Added
### Fixed
- `Expr.__array_ufunc__` can't handle 0-dim array
### Changed
### Removed

Expand Down
16 changes: 12 additions & 4 deletions src/pyscipopt/expr.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -203,16 +203,24 @@ cdef class ExprLike:
)

if method == "__call__":
if arrays := [a for a in args if isinstance(a, np.ndarray)]:
if arrays := [a for a in args if isinstance(a, np.ndarray) and a.ndim >= 1]:
if any(a.dtype.kind not in "fiub" for a in arrays):
return NotImplemented
# If the np.ndarray is of numeric type, all arguments are converted to
# MatrixExpr or MatrixGenExpr and then the ufunc is applied.
return ufunc(*[_ensure_matrix(a) for a in args], **kwargs)

# Convert `np.generic` to native Python types to stop __array_ufunc__
# recursion from `np.generic + MatrixExpr`.
args = [a.item() if isinstance(a, np.generic) else a for a in args]
# Convert `np.generic` and 0-dim `np.ndarray` to native Python types to stop
# __array_ufunc__ recursion from `np.generic + MatrixExpr/Expr`.
args = [
a.item()
if (
isinstance(a, np.generic)
or (isinstance(a, np.ndarray) and a.ndim == 0)
)
else a
for a in args
]

if ufunc is np.add:
return args[0] + args[1]
Expand Down
12 changes: 12 additions & 0 deletions tests/test_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,18 @@ def test_binary_ufunc(model):
assert str(np.greater_equal(a, x)) == "[ExprCons(Expr({Term(x): 1.0}), None, 2.0)]"


def test_np_generic_cmp_with_expr():
# test #1218
m = Model()
x = m.addVar(name="x")
value = np.float64(5.0)

assert str(x <= -value) == "ExprCons(Expr({Term(x): 1.0}), None, -5.0)"
assert str(x <= value ) == "ExprCons(Expr({Term(x): 1.0}), None, 5.0)"
assert str(-value <= x) == "ExprCons(Expr({Term(x): 1.0}), -5.0, None)"
assert str(value <= x) == "ExprCons(Expr({Term(x): 1.0}), 5.0, None)"


def test_mul():
m = Model()
x = m.addVar(name="x")
Expand Down
Loading