From 78b54916a353afa4796fcdad38293b531378c2b8 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 10 Apr 2026 13:32:07 +0200 Subject: [PATCH 1/2] Dont eagerly sidestep RNGs that have multiple clients Could destroy update information in duplicated graphs (like Scan likes to produce) --- pytensor/tensor/random/rewriting/basic.py | 7 +++++ tests/tensor/random/rewriting/test_basic.py | 35 +++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/pytensor/tensor/random/rewriting/basic.py b/pytensor/tensor/random/rewriting/basic.py index b3e1d219c5..28b462f7e1 100644 --- a/pytensor/tensor/random/rewriting/basic.py +++ b/pytensor/tensor/random/rewriting/basic.py @@ -376,5 +376,12 @@ def sidestep_unused_rng_consumer(fgraph, node): if any(fgraph.clients[out] for out in non_rng_outputs): return None + # Don't sidestep if any input RNG has other clients. + # The graph likely has duplicate nodes that will be merged later, + # and sidestepping now would destroy the RNG update prematurely. + rng_inputs = set(update_map.keys()) + if any(len(fgraph.clients[rng_in]) > 1 for rng_in in rng_inputs): + return None + # Bypass: map each RNG output back to its corresponding RNG input return {rng_out: rng_in for rng_in, rng_out in update_map.items()} diff --git a/tests/tensor/random/rewriting/test_basic.py b/tests/tensor/random/rewriting/test_basic.py index 8d50b2c63f..68315219a8 100644 --- a/tests/tensor/random/rewriting/test_basic.py +++ b/tests/tensor/random/rewriting/test_basic.py @@ -12,6 +12,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import ( EquilibriumGraphRewriter, + MergeOptimizer, check_stack_trace, dfs_rewriter, ) @@ -35,6 +36,7 @@ local_rv_size_lift, local_subtensor_rv_lift, ) +from pytensor.tensor.random.rewriting.basic import sidestep_unused_rng_consumer from pytensor.tensor.random.rewriting.numba import introduce_explicit_core_shape_rv from pytensor.tensor.random.type import random_generator_type from pytensor.tensor.rewriting.shape import ShapeFeature, ShapeOptimizer @@ -979,6 +981,39 @@ def test_Dimshuffle_lift_rename(ds_order, lifted, dist_op, dist_params, size, rt assert new_out.name == "test_name_lifted" +def test_sidestep_unused_rng_consumer_with_duplicate_node(): + """Sidestep rewrite must not fire when input RNG has another client. + + Reproduces a bug where graph duplication (e.g., inside scan) creates two + copies of the same RV node: one whose draw is used and one whose RNG + update is used. The rewrite would see the update-only copy's draw as dead + and replace its RNG output with the input — destroying the update. + """ + rng = random_generator_type("rng") + next_rng_a, draw_a = rng.normal(0, 1, size=3) + + # Simulate scan-like duplication: use draw from clone A, RNG update from clone B + next_rng_b, draw_b = draw_a.owner.clone().outputs + assert next_rng_a is not next_rng_b + assert draw_a is not draw_b + fg = FunctionGraph([rng], [draw_a, next_rng_b], clone=False) + + # Clone B's draw is unused, but sidestep should NOT fire because rng has + # another client whose draw IS used. + dfs_rewriter(sidestep_unused_rng_consumer).rewrite(fg) + assert sum(isinstance(n.op, NormalRV) for n in fg.apply_nodes) == 2 + + # When neither clone's draw is used, sidestep alone won't fire (rng has multiple clients) + fg2 = FunctionGraph([rng], [next_rng_a, next_rng_b], clone=False) + dfs_rewriter(sidestep_unused_rng_consumer).rewrite(fg2) + assert sum(isinstance(n.op, NormalRV) for n in fg2.apply_nodes) == 2 + + # But after merge collapses them into one node it does. + MergeOptimizer().rewrite(fg2) + dfs_rewriter(sidestep_unused_rng_consumer).rewrite(fg2) + assert sum(isinstance(n.op, NormalRV) for n in fg2.apply_nodes) == 0 + + def test_unused_rng(): rng = random_generator_type("rng") next_rng, x = rng.normal([0], [1], size=3) From b453b1974016c66acbf0da008822295dbfd17066 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 10 Apr 2026 13:39:42 +0200 Subject: [PATCH 2/2] Fix unrolled scan with untraced_sit_sot --- pytensor/scan/basic.py | 13 ++++++++++--- tests/scan/test_basic.py | 24 ++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/pytensor/scan/basic.py b/pytensor/scan/basic.py index a1cc51c5db..2e111f2ab1 100644 --- a/pytensor/scan/basic.py +++ b/pytensor/scan/basic.py @@ -664,6 +664,7 @@ def wrap_into_list(x): n_untraced_sit_sot = 0 untraced_sit_sot_scan_inputs = [] untraced_sit_sot_inner_inputs = [] + untraced_sit_sot_inner_slices = [] untraced_sit_sot_inner_outputs = [] untraced_sit_sot_rightOrder = [] @@ -717,6 +718,7 @@ def wrap_into_list(x): UserWarning, ) untraced_sit_sot_scan_inputs.append(actual_arg) + untraced_sit_sot_inner_slices.append(actual_arg) untraced_sit_sot_inner_inputs.append(arg) n_untraced_sit_sot += 1 untraced_sit_sot_rightOrder.append(i) @@ -783,9 +785,14 @@ def wrap_into_list(x): _ordered_args[sit_sot_rightOrder[idx]] = [sit_sot_inner_inputs[idx]] for idx in range(n_untraced_sit_sot): - _ordered_args[untraced_sit_sot_rightOrder[idx]] = [ - untraced_sit_sot_inner_inputs[idx] - ] + if single_step_requested: + _ordered_args[untraced_sit_sot_rightOrder[idx]] = [ + untraced_sit_sot_inner_slices[idx] + ] + else: + _ordered_args[untraced_sit_sot_rightOrder[idx]] = [ + untraced_sit_sot_inner_inputs[idx] + ] ordered_args = list(chain.from_iterable(_ordered_args)) if single_step_requested: diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py index 7d57473965..b669a7c56a 100644 --- a/tests/scan/test_basic.py +++ b/tests/scan/test_basic.py @@ -4166,6 +4166,30 @@ def rng_between_xs(rng): assert isinstance(final_rng.type, RandomGeneratorType) +def test_single_step_untraced_sit_sot(): + """Single-step scan unrolling must connect untraced sit_sot to actual inputs.""" + rng = random_generator_type("rng") + + def step(rng): + next_rng, x = rng.normal() + return next_rng, x + + unrolled_next_rng, unrolled_draw = scan( + fn=step, + outputs_info=[rng, None], + n_steps=1, + return_updates=False, + ) + + del rng.tag.used # Avoid reuse warning + expected_next_rng, draw = step(rng) + expected_draw = draw[None] # unrolling adds an expand_dims + assert_equal_computations( + [unrolled_next_rng, unrolled_draw], + [expected_next_rng, expected_draw], + ) + + @pytest.mark.parametrize("mode", [Mode(linker="py"), Mode(linker="cvm"), "numba"]) def test_zero_steps_untraced_sit_sot(mode): """Regression test: 0-step scan with untraced sit_sot must return initial state, not None.