Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions pytensor/scan/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,7 @@ def wrap_into_list(x):
n_untraced_sit_sot = 0
untraced_sit_sot_scan_inputs = []
untraced_sit_sot_inner_inputs = []
untraced_sit_sot_inner_slices = []
untraced_sit_sot_inner_outputs = []
untraced_sit_sot_rightOrder = []

Expand Down Expand Up @@ -717,6 +718,7 @@ def wrap_into_list(x):
UserWarning,
)
untraced_sit_sot_scan_inputs.append(actual_arg)
untraced_sit_sot_inner_slices.append(actual_arg)
untraced_sit_sot_inner_inputs.append(arg)
n_untraced_sit_sot += 1
untraced_sit_sot_rightOrder.append(i)
Expand Down Expand Up @@ -783,9 +785,14 @@ def wrap_into_list(x):
_ordered_args[sit_sot_rightOrder[idx]] = [sit_sot_inner_inputs[idx]]

for idx in range(n_untraced_sit_sot):
_ordered_args[untraced_sit_sot_rightOrder[idx]] = [
untraced_sit_sot_inner_inputs[idx]
]
if single_step_requested:
_ordered_args[untraced_sit_sot_rightOrder[idx]] = [
untraced_sit_sot_inner_slices[idx]
]
else:
_ordered_args[untraced_sit_sot_rightOrder[idx]] = [
untraced_sit_sot_inner_inputs[idx]
]

ordered_args = list(chain.from_iterable(_ordered_args))
if single_step_requested:
Expand Down
7 changes: 7 additions & 0 deletions pytensor/tensor/random/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,5 +376,12 @@ def sidestep_unused_rng_consumer(fgraph, node):
if any(fgraph.clients[out] for out in non_rng_outputs):
return None

# Don't sidestep if any input RNG has other clients.
# The graph likely has duplicate nodes that will be merged later,
# and sidestepping now would destroy the RNG update prematurely.
rng_inputs = set(update_map.keys())
if any(len(fgraph.clients[rng_in]) > 1 for rng_in in rng_inputs):
return None

# Bypass: map each RNG output back to its corresponding RNG input
return {rng_out: rng_in for rng_in, rng_out in update_map.items()}
24 changes: 24 additions & 0 deletions tests/scan/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4166,6 +4166,30 @@ def rng_between_xs(rng):
assert isinstance(final_rng.type, RandomGeneratorType)


def test_single_step_untraced_sit_sot():
"""Single-step scan unrolling must connect untraced sit_sot to actual inputs."""
rng = random_generator_type("rng")

def step(rng):
next_rng, x = rng.normal()
return next_rng, x

unrolled_next_rng, unrolled_draw = scan(
fn=step,
outputs_info=[rng, None],
n_steps=1,
return_updates=False,
)

del rng.tag.used # Avoid reuse warning
expected_next_rng, draw = step(rng)
expected_draw = draw[None] # unrolling adds an expand_dims
assert_equal_computations(
[unrolled_next_rng, unrolled_draw],
[expected_next_rng, expected_draw],
)


@pytest.mark.parametrize("mode", [Mode(linker="py"), Mode(linker="cvm"), "numba"])
def test_zero_steps_untraced_sit_sot(mode):
"""Regression test: 0-step scan with untraced sit_sot must return initial state, not None.
Expand Down
35 changes: 35 additions & 0 deletions tests/tensor/random/rewriting/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import (
EquilibriumGraphRewriter,
MergeOptimizer,
check_stack_trace,
dfs_rewriter,
)
Expand All @@ -35,6 +36,7 @@
local_rv_size_lift,
local_subtensor_rv_lift,
)
from pytensor.tensor.random.rewriting.basic import sidestep_unused_rng_consumer
from pytensor.tensor.random.rewriting.numba import introduce_explicit_core_shape_rv
from pytensor.tensor.random.type import random_generator_type
from pytensor.tensor.rewriting.shape import ShapeFeature, ShapeOptimizer
Expand Down Expand Up @@ -979,6 +981,39 @@ def test_Dimshuffle_lift_rename(ds_order, lifted, dist_op, dist_params, size, rt
assert new_out.name == "test_name_lifted"


def test_sidestep_unused_rng_consumer_with_duplicate_node():
"""Sidestep rewrite must not fire when input RNG has another client.

Reproduces a bug where graph duplication (e.g., inside scan) creates two
copies of the same RV node: one whose draw is used and one whose RNG
update is used. The rewrite would see the update-only copy's draw as dead
and replace its RNG output with the input — destroying the update.
"""
rng = random_generator_type("rng")
next_rng_a, draw_a = rng.normal(0, 1, size=3)

# Simulate scan-like duplication: use draw from clone A, RNG update from clone B
next_rng_b, draw_b = draw_a.owner.clone().outputs
assert next_rng_a is not next_rng_b
assert draw_a is not draw_b
fg = FunctionGraph([rng], [draw_a, next_rng_b], clone=False)

# Clone B's draw is unused, but sidestep should NOT fire because rng has
# another client whose draw IS used.
dfs_rewriter(sidestep_unused_rng_consumer).rewrite(fg)
assert sum(isinstance(n.op, NormalRV) for n in fg.apply_nodes) == 2

# When neither clone's draw is used, sidestep alone won't fire (rng has multiple clients)
fg2 = FunctionGraph([rng], [next_rng_a, next_rng_b], clone=False)
dfs_rewriter(sidestep_unused_rng_consumer).rewrite(fg2)
assert sum(isinstance(n.op, NormalRV) for n in fg2.apply_nodes) == 2

# But after merge collapses them into one node it does.
MergeOptimizer().rewrite(fg2)
dfs_rewriter(sidestep_unused_rng_consumer).rewrite(fg2)
assert sum(isinstance(n.op, NormalRV) for n in fg2.apply_nodes) == 0


def test_unused_rng():
rng = random_generator_type("rng")
next_rng, x = rng.normal([0], [1], size=3)
Expand Down
Loading