diff --git a/pytensor/assumptions/diagonal.py b/pytensor/assumptions/diagonal.py index 36f160e287..ae8b4ded19 100644 --- a/pytensor/assumptions/diagonal.py +++ b/pytensor/assumptions/diagonal.py @@ -49,8 +49,13 @@ def _diagonal_from_constant(var: TensorConstant) -> FactState: if m != n: result = FactState.FALSE else: - eye_mask = np.eye(n, dtype=bool) - result = FactState.FALSE if np.any(data * ~eye_mask) else FactState.TRUE + # Off-diagonal has a nonzero iff there are more nonzeros in ``data`` + # than on its diagonal. + diag = np.diagonal(data, axis1=-2, axis2=-1) + if np.count_nonzero(data) == np.count_nonzero(diag): + result = FactState.TRUE + else: + result = FactState.FALSE var.tag.is_diagonal = result return result diff --git a/pytensor/assumptions/permutation.py b/pytensor/assumptions/permutation.py index bed9013a89..890f80fc89 100644 --- a/pytensor/assumptions/permutation.py +++ b/pytensor/assumptions/permutation.py @@ -38,13 +38,18 @@ def _permutation_from_constant(var: TensorConstant) -> FactState: if data.ndim < 2 or data.shape[-1] != data.shape[-2]: result = FactState.FALSE else: - # The row/column-sum reductions are cheaper and short-circuit the full-size binary - # scan; a doubly-stochastic matrix shows the binary check is still required. - is_permutation = ( - bool(np.all(data.sum(axis=-2) == 1)) - and bool(np.all(data.sum(axis=-1) == 1)) - and bool(np.all((data == 0) | (data == 1))) - ) + with np.errstate(invalid="ignore"): + if not (data.sum(axis=-2) == 1).all(): + is_permutation = False + elif not (data.sum(axis=-1) == 1).all(): + is_permutation = False + elif data.dtype.kind in "ub": + is_permutation = True + elif data.dtype.kind == "i": + is_permutation = data.min(initial=0) >= 0 + else: + n = data.shape[-1] + is_permutation = np.count_nonzero(data) == (data.size // n if n else 0) result = FactState.TRUE if is_permutation else FactState.FALSE var.tag.is_permutation = result diff --git a/pytensor/assumptions/selection.py b/pytensor/assumptions/selection.py index 6698b3b29b..ee45479644 100644 --- a/pytensor/assumptions/selection.py +++ b/pytensor/assumptions/selection.py @@ -28,12 +28,18 @@ def _selection_from_constant(var: TensorConstant) -> FactState: if data.ndim < 2: result = FactState.FALSE else: - if not (data.sum(axis=-2) == 1).all(): - is_selection = False - elif data.dtype.kind in "uib": - is_selection = data.max(initial=1) <= 1 and data.min(initial=0) >= 0 - else: - is_selection = bool(((data == 0) | (data == 1)).all()) + with np.errstate(invalid="ignore"): + if not (data.sum(axis=-2) == 1).all(): + is_selection = False + elif data.dtype.kind in "ub": + is_selection = True + elif data.dtype.kind == "i": + is_selection = data.min(initial=0) >= 0 + else: + n_rows = data.shape[-2] + is_selection = np.count_nonzero(data) == ( + data.size // n_rows if n_rows else 0 + ) result = FactState.TRUE if is_selection else FactState.FALSE var.tag.is_selection = result