Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

## Unreleased
### Added
- `Expr` and `GenExpr` support NumPy unary functions (`np.sin`, `np.cos`, `np.sqrt`, `np.exp`, `np.log`, `np.absolute`)
- `Expr` and `GenExpr` support NumPy unary functions (`np.sin`, `np.cos`, `np.sqrt`, `np.exp`, `np.log`, `np.absolute`, `np.negative`)
- `Expr` and `GenExpr` support NumPy binary functions (`np.add`, `np.subtract`, `np.multiply`, `np.divide`, `np.true_divide`, `np.power`, `np.less_equal`, `np.greater_equal`, `np.equal`)
- Added `getBase()` and `setBase()` methods to `LP` class for getting/setting basis status
- Added `getMemUsed()`, `getMemTotal()`, and `getMemExternEstim()` methods
### Fixed
Expand Down
33 changes: 32 additions & 1 deletion src/pyscipopt/expr.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,32 @@ cdef class ExprLike:
)

if method == "__call__":
if ufunc is np.absolute:
if arrays := [a for a in args if type(a) is np.ndarray]:
if any(a.dtype.kind not in "fiub" for a in arrays):
return NotImplemented
# If the np.ndarray is of numeric type, all arguments are converted to
# MatrixExpr or MatrixGenExpr and then the ufunc is applied.
return ufunc(*[_ensure_matrix(a) for a in args], **kwargs)

if ufunc is np.add:
return args[0] + args[1]
elif ufunc is np.subtract:
return args[0] - args[1]
elif ufunc is np.multiply:
return args[0] * args[1]
elif ufunc in {np.divide, np.true_divide}:
return args[0] / args[1]
elif ufunc is np.power:
return args[0] ** args[1]
elif ufunc is np.negative:
return -args[0]
elif ufunc is np.less_equal:
return args[0] <= args[1]
elif ufunc is np.greater_equal:
return args[0] >= args[1]
elif ufunc is np.equal:
return args[0] == args[1]
elif ufunc is np.absolute:
return args[0].__abs__()
elif ufunc is np.exp:
return args[0].exp()
Expand Down Expand Up @@ -1031,6 +1056,12 @@ cdef inline object _wrap_ufunc(object x, object ufunc):
return res.view(MatrixGenExpr) if isinstance(res, np.ndarray) else res
return ufunc(_to_const(x))

cdef inline object _ensure_matrix(object arg):
if type(arg) is np.ndarray:
return arg.view(MatrixExpr)
matrix = MatrixExpr if isinstance(arg, Expr) else MatrixGenExpr
return np.array(arg, dtype=object).view(matrix)


def expr_to_nodes(expr):
'''transforms tree to an array of nodes. each node is an operator and the position of the
Expand Down
41 changes: 40 additions & 1 deletion tests/test_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def test_getVal_with_GenExpr():
m.getVal(1 / z)


def test_unary(model):
def test_unary_ufunc(model):
m, x, y, z = model

res = "abs(sum(0.0,prod(1.0,x)))"
Expand Down Expand Up @@ -276,6 +276,45 @@ def test_unary(model):
# forbid modifying Variable/Expr/GenExpr in-place via out parameter
np.sin(x, out=np.array([0]))

# test np.negative
assert str(np.negative(x)) == "Expr({Term(x): -1.0})"


def test_binary_ufunc(model):
m, x, y, z = model

# test np.add
assert str(np.add(x, 1)) == "Expr({Term(x): 1.0, Term(): 1.0})"
assert str(np.add(1, x)) == "Expr({Term(x): 1.0, Term(): 1.0})"
a = np.array([1])
assert str(np.add(x, a)) == "[Expr({Term(x): 1.0, Term(): 1.0})]"
assert str(np.add(a, x)) == "[Expr({Term(x): 1.0, Term(): 1.0})]"

# test np.subtract
assert str(np.subtract(x, 1)) == "Expr({Term(x): 1.0, Term(): -1.0})"
assert str(np.subtract(1, x)) == "Expr({Term(x): -1.0, Term(): 1.0})"
assert str(np.subtract(x, a)) == "[Expr({Term(x): 1.0, Term(): -1.0})]"
assert str(np.subtract(a, x)) == "[Expr({Term(x): -1.0, Term(): 1.0})]"

# test np.multiply
a = np.array([2])
assert str(np.multiply(x, 2)) == "Expr({Term(x): 2.0})"
assert str(np.multiply(2, x)) == "Expr({Term(x): 2.0})"
assert str(np.multiply(x, a)) == "[Expr({Term(x): 2.0})]"
assert str(np.multiply(a, x)) == "[Expr({Term(x): 2.0})]"

# test np.divide
assert str(np.divide(x, 2)) == "Expr({Term(x): 0.5})"
assert str(np.divide(2, x)) == "prod(2.0,**(sum(0.0,prod(1.0,x)),-1))"
assert str(np.divide(x, a)) == "[Expr({Term(x): 0.5})]"
assert str(np.divide(a, x)) == "[prod(2.0,**(sum(0.0,prod(1.0,x)),-1))]"

# test np.power
assert str(np.power(x, 2)) == "Expr({Term(x, x): 1.0})"
assert str(np.power(2, x)) == "exp(prod(1.0,sum(0.0,prod(1.0,x)),log(2.0)))"
assert str(np.power(x, a)) == "[Expr({Term(x, x): 1.0})]"
assert str(np.power(a, x)) == "[exp(prod(1.0,sum(0.0,prod(1.0,x)),log(2.0)))]"


def test_mul():
m = Model()
Expand Down
Loading