Skip to content

Commit 4b2ace5

Browse files
committed
Fix unrolled scan with untraced_sit_sot
1 parent 4b7a512 commit 4b2ace5

2 files changed

Lines changed: 32 additions & 3 deletions

File tree

pytensor/scan/basic.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -783,9 +783,14 @@ def wrap_into_list(x):
783783
_ordered_args[sit_sot_rightOrder[idx]] = [sit_sot_inner_inputs[idx]]
784784

785785
for idx in range(n_untraced_sit_sot):
786-
_ordered_args[untraced_sit_sot_rightOrder[idx]] = [
787-
untraced_sit_sot_inner_inputs[idx]
788-
]
786+
if single_step_requested:
787+
_ordered_args[untraced_sit_sot_rightOrder[idx]] = [
788+
untraced_sit_sot_scan_inputs[idx]
789+
]
790+
else:
791+
_ordered_args[untraced_sit_sot_rightOrder[idx]] = [
792+
untraced_sit_sot_inner_inputs[idx]
793+
]
789794

790795
ordered_args = list(chain.from_iterable(_ordered_args))
791796
if single_step_requested:

tests/scan/test_basic.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4166,6 +4166,30 @@ def rng_between_xs(rng):
41664166
assert isinstance(final_rng.type, RandomGeneratorType)
41674167

41684168

4169+
def test_single_step_untraced_sit_sot():
4170+
"""Single-step scan unrolling must connect untraced sit_sot to actual inputs."""
4171+
rng = random_generator_type("rng")
4172+
4173+
def step(rng):
4174+
next_rng, x = rng.normal()
4175+
return next_rng, x
4176+
4177+
unrolled_next_rng, unrolled_draw = scan(
4178+
fn=step,
4179+
outputs_info=[rng, None],
4180+
n_steps=1,
4181+
return_updates=False,
4182+
)
4183+
4184+
del rng.tag.used # Avoid reuse warning
4185+
expected_next_rng, draw = step(rng)
4186+
expected_draw = draw[None] # unrolling adds an expand_dims
4187+
assert_equal_computations(
4188+
[unrolled_next_rng, unrolled_draw],
4189+
[expected_next_rng, expected_draw],
4190+
)
4191+
4192+
41694193
@pytest.mark.parametrize("mode", [Mode(linker="py"), Mode(linker="cvm"), "numba"])
41704194
def test_zero_steps_untraced_sit_sot(mode):
41714195
"""Regression test: 0-step scan with untraced sit_sot must return initial state, not None.

0 commit comments

Comments
 (0)