From 0d1c361db71c585dc90e22e8acf0d370f73fd277 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 26 May 2026 13:46:31 -0400 Subject: [PATCH 1/6] more robust digonal constant check --- pytensor/assumptions/diagonal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/assumptions/diagonal.py b/pytensor/assumptions/diagonal.py index 36f160e287..a114df17c7 100644 --- a/pytensor/assumptions/diagonal.py +++ b/pytensor/assumptions/diagonal.py @@ -50,7 +50,7 @@ def _diagonal_from_constant(var: TensorConstant) -> FactState: result = FactState.FALSE else: eye_mask = np.eye(n, dtype=bool) - result = FactState.FALSE if np.any(data * ~eye_mask) else FactState.TRUE + result = FactState.FALSE if np.any(data[..., ~eye_mask]) else FactState.TRUE var.tag.is_diagonal = result return result From e769fdab8162f37a4c9ed347554e5e3d74a4385b Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 26 May 2026 13:50:00 -0400 Subject: [PATCH 2/6] use `np.any(x, where=mask)` for perf --- pytensor/assumptions/diagonal.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytensor/assumptions/diagonal.py b/pytensor/assumptions/diagonal.py index a114df17c7..4c3c2175ab 100644 --- a/pytensor/assumptions/diagonal.py +++ b/pytensor/assumptions/diagonal.py @@ -50,7 +50,9 @@ def _diagonal_from_constant(var: TensorConstant) -> FactState: result = FactState.FALSE else: eye_mask = np.eye(n, dtype=bool) - result = FactState.FALSE if np.any(data[..., ~eye_mask]) else FactState.TRUE + result = ( + FactState.FALSE if np.any(data, where=~eye_mask) else FactState.TRUE + ) var.tag.is_diagonal = result return result From ee99138dcb110e2c9f33a9b6926d78d1f3c6878d Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 26 May 2026 13:55:22 -0400 Subject: [PATCH 3/6] use `np.count_nonzero` for more perf --- pytensor/assumptions/diagonal.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/pytensor/assumptions/diagonal.py b/pytensor/assumptions/diagonal.py index 4c3c2175ab..ae8b4ded19 100644 --- a/pytensor/assumptions/diagonal.py +++ b/pytensor/assumptions/diagonal.py @@ -49,10 +49,13 @@ def _diagonal_from_constant(var: TensorConstant) -> FactState: if m != n: result = FactState.FALSE else: - eye_mask = np.eye(n, dtype=bool) - result = ( - FactState.FALSE if np.any(data, where=~eye_mask) else FactState.TRUE - ) + # Off-diagonal has a nonzero iff there are more nonzeros in ``data`` + # than on its diagonal. + diag = np.diagonal(data, axis1=-2, axis2=-1) + if np.count_nonzero(data) == np.count_nonzero(diag): + result = FactState.TRUE + else: + result = FactState.FALSE var.tag.is_diagonal = result return result From 89156a66ffe6793dda0e9fef77c97eb5108adef8 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 26 May 2026 15:09:33 -0400 Subject: [PATCH 4/6] Add speed up selection/permutation worst-case constant checks using count_nonzero --- pytensor/assumptions/permutation.py | 19 ++++++++++++------- pytensor/assumptions/selection.py | 7 ++++++- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/pytensor/assumptions/permutation.py b/pytensor/assumptions/permutation.py index bed9013a89..4bc9e4fc4f 100644 --- a/pytensor/assumptions/permutation.py +++ b/pytensor/assumptions/permutation.py @@ -37,14 +37,19 @@ def _permutation_from_constant(var: TensorConstant) -> FactState: data = np.asarray(var.data) if data.ndim < 2 or data.shape[-1] != data.shape[-2]: result = FactState.FALSE + elif not bool(np.all(data.sum(axis=-2) == 1)): + result = FactState.FALSE + elif not bool(np.all(data.sum(axis=-1) == 1)): + result = FactState.FALSE else: - # The row/column-sum reductions are cheaper and short-circuit the full-size binary - # scan; a doubly-stochastic matrix shows the binary check is still required. - is_permutation = ( - bool(np.all(data.sum(axis=-2) == 1)) - and bool(np.all(data.sum(axis=-1) == 1)) - and bool(np.all((data == 0) | (data == 1))) - ) + if data.dtype.kind in "uib": + # Fast check, only valid for integer/bool + is_permutation = data.max(initial=1) <= 1 and data.min(initial=0) >= 0 + else: + # Otherwise a matrix is permutation iff there is exactly 1 nonzero entry per row + # and column. That non-zero value can only be 1 due to previous checks. + n = data.shape[-1] + is_permutation = np.count_nonzero(data) == (data.size // n if n else 0) result = FactState.TRUE if is_permutation else FactState.FALSE var.tag.is_permutation = result diff --git a/pytensor/assumptions/selection.py b/pytensor/assumptions/selection.py index 6698b3b29b..7f804bc76d 100644 --- a/pytensor/assumptions/selection.py +++ b/pytensor/assumptions/selection.py @@ -33,7 +33,12 @@ def _selection_from_constant(var: TensorConstant) -> FactState: elif data.dtype.kind in "uib": is_selection = data.max(initial=1) <= 1 and data.min(initial=0) >= 0 else: - is_selection = bool(((data == 0) | (data == 1)).all()) + # Column sums all == 1 plus ``count_nonzero == k`` (per slice) force one + # nonzero per column equal to 1 + n_rows = data.shape[-2] + is_selection = np.count_nonzero(data) == ( + data.size // n_rows if n_rows else 0 + ) result = FactState.TRUE if is_selection else FactState.FALSE var.tag.is_selection = result From 1cd347d50f3927e33cb8d615db8c92f88d580297 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Thu, 28 May 2026 11:04:17 +0200 Subject: [PATCH 5/6] Faster permutation/selection from constants --- pytensor/assumptions/permutation.py | 23 +++++++++++------------ pytensor/assumptions/selection.py | 20 +++++++++----------- 2 files changed, 20 insertions(+), 23 deletions(-) diff --git a/pytensor/assumptions/permutation.py b/pytensor/assumptions/permutation.py index 4bc9e4fc4f..c750f4e290 100644 --- a/pytensor/assumptions/permutation.py +++ b/pytensor/assumptions/permutation.py @@ -37,19 +37,18 @@ def _permutation_from_constant(var: TensorConstant) -> FactState: data = np.asarray(var.data) if data.ndim < 2 or data.shape[-1] != data.shape[-2]: result = FactState.FALSE - elif not bool(np.all(data.sum(axis=-2) == 1)): - result = FactState.FALSE - elif not bool(np.all(data.sum(axis=-1) == 1)): - result = FactState.FALSE else: - if data.dtype.kind in "uib": - # Fast check, only valid for integer/bool - is_permutation = data.max(initial=1) <= 1 and data.min(initial=0) >= 0 - else: - # Otherwise a matrix is permutation iff there is exactly 1 nonzero entry per row - # and column. That non-zero value can only be 1 due to previous checks. - n = data.shape[-1] - is_permutation = np.count_nonzero(data) == (data.size // n if n else 0) + with np.errstate(invalid="ignore"): + if not (data.sum(axis=-2) == 1).all(): + is_permutation = False + elif not (data.sum(axis=-1) == 1).all(): + is_permutation = False + elif data.dtype.kind in "ub": + is_permutation = True + elif data.dtype.kind == "i": + is_permutation = data.min(initial=0) >= 0 + else: + is_permutation = bool(((data == 0) | (data == 1)).all()) result = FactState.TRUE if is_permutation else FactState.FALSE var.tag.is_permutation = result diff --git a/pytensor/assumptions/selection.py b/pytensor/assumptions/selection.py index 7f804bc76d..b8849af469 100644 --- a/pytensor/assumptions/selection.py +++ b/pytensor/assumptions/selection.py @@ -28,17 +28,15 @@ def _selection_from_constant(var: TensorConstant) -> FactState: if data.ndim < 2: result = FactState.FALSE else: - if not (data.sum(axis=-2) == 1).all(): - is_selection = False - elif data.dtype.kind in "uib": - is_selection = data.max(initial=1) <= 1 and data.min(initial=0) >= 0 - else: - # Column sums all == 1 plus ``count_nonzero == k`` (per slice) force one - # nonzero per column equal to 1 - n_rows = data.shape[-2] - is_selection = np.count_nonzero(data) == ( - data.size // n_rows if n_rows else 0 - ) + with np.errstate(invalid="ignore"): + if not (data.sum(axis=-2) == 1).all(): + is_selection = False + elif data.dtype.kind in "ub": + is_selection = True + elif data.dtype.kind == "i": + is_selection = data.min(initial=0) >= 0 + else: + is_selection = bool(((data == 0) | (data == 1)).all()) result = FactState.TRUE if is_selection else FactState.FALSE var.tag.is_selection = result From 61dbf7aa7028f43d8305bad9c7bb95df2e23dcb2 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Thu, 28 May 2026 18:00:12 +0200 Subject: [PATCH 6/6] nonzero is back --- pytensor/assumptions/permutation.py | 3 ++- pytensor/assumptions/selection.py | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pytensor/assumptions/permutation.py b/pytensor/assumptions/permutation.py index c750f4e290..890f80fc89 100644 --- a/pytensor/assumptions/permutation.py +++ b/pytensor/assumptions/permutation.py @@ -48,7 +48,8 @@ def _permutation_from_constant(var: TensorConstant) -> FactState: elif data.dtype.kind == "i": is_permutation = data.min(initial=0) >= 0 else: - is_permutation = bool(((data == 0) | (data == 1)).all()) + n = data.shape[-1] + is_permutation = np.count_nonzero(data) == (data.size // n if n else 0) result = FactState.TRUE if is_permutation else FactState.FALSE var.tag.is_permutation = result diff --git a/pytensor/assumptions/selection.py b/pytensor/assumptions/selection.py index b8849af469..ee45479644 100644 --- a/pytensor/assumptions/selection.py +++ b/pytensor/assumptions/selection.py @@ -36,7 +36,10 @@ def _selection_from_constant(var: TensorConstant) -> FactState: elif data.dtype.kind == "i": is_selection = data.min(initial=0) >= 0 else: - is_selection = bool(((data == 0) | (data == 1)).all()) + n_rows = data.shape[-2] + is_selection = np.count_nonzero(data) == ( + data.size // n_rows if n_rows else 0 + ) result = FactState.TRUE if is_selection else FactState.FALSE var.tag.is_selection = result