Skip to content

Commit a2eae29

Browse files
Numba IfElse: respect view flag
Co-authored-by: Ricardo Vieira <28983449+ricardov94@users.noreply.github.com>
1 parent 79a4bc1 commit a2eae29

File tree

2 files changed

+68
-20
lines changed

2 files changed

+68
-20
lines changed

pytensor/link/numba/dispatch/compile_ops.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from copy import deepcopy
22
from hashlib import sha256
3+
from textwrap import dedent
34

45
import numba
56
import numpy as np
@@ -10,6 +11,7 @@
1011
from pytensor.compile.mode import NUMBA
1112
from pytensor.compile.ops import DeepCopyOp, TypeCastingOp
1213
from pytensor.ifelse import IfElse
14+
from pytensor.link.numba.cache import compile_numba_function_src
1315
from pytensor.link.numba.dispatch import basic as numba_basic
1416
from pytensor.link.numba.dispatch.basic import (
1517
numba_funcify_and_cache_key,
@@ -106,30 +108,35 @@ def deepcopy(x):
106108
@register_funcify_default_op_cache_key(IfElse)
107109
def numba_funcify_IfElse(op, **kwargs):
108110
n_outs = op.n_outs
111+
as_view = op.as_view
109112

110-
if n_outs > 1:
111-
112-
@numba_basic.numba_njit
113-
def ifelse(cond, *args):
114-
if cond:
115-
res = args[:n_outs]
116-
else:
117-
res = args[n_outs:]
118-
119-
return res
113+
true_names = [f"t{i}" for i in range(n_outs)]
114+
false_names = [f"f{i}" for i in range(n_outs)]
115+
arg_list = ", ".join((*true_names, *false_names))
120116

117+
if as_view:
118+
true_returns = ", ".join(true_names)
121119
else:
120+
true_returns = ", ".join(f"{name}.copy()" for name in true_names)
121+
# We only ever view (alias) variables from the true branch. False branch variables must always be copied.
122+
false_returns = ", ".join(f"{name}.copy()" for name in false_names)
123+
124+
func_src = dedent(
125+
f"""
126+
def ifelse(cond, {arg_list}):
127+
if cond:
128+
return {true_returns}
129+
else:
130+
return {false_returns}
131+
"""
132+
)
122133

123-
@numba_basic.numba_njit
124-
def ifelse(cond, *args):
125-
if cond:
126-
res = args[:n_outs]
127-
else:
128-
res = args[n_outs:]
129-
130-
return res[0]
134+
ifelse_func = numba_basic.numba_njit(
135+
compile_numba_function_src(func_src, "ifelse", globals())
136+
)
131137

132-
return ifelse
138+
cache_version = 1
139+
return ifelse_func, cache_version
133140

134141

135142
@register_funcify_and_cache_key(CheckAndRaise)

tests/link/numba/test_compile_ops.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import numpy as np
22
import pytest
33

4-
from pytensor import Mode, OpFromGraph, config, function, ifelse, scan
4+
from pytensor import In, Mode, OpFromGraph, Out, config, function, ifelse, scan
55
from pytensor import tensor as pt
66
from pytensor.compile import ViewOp
77
from pytensor.graph import vectorize_graph
8+
from pytensor.ifelse import IfElse
89
from pytensor.raise_op import assert_op
910
from pytensor.scalar import Add
1011
from pytensor.scan.op import Scan
@@ -231,3 +232,43 @@ def test_ofg_with_inner_scan_rewrite():
231232
cholesky_op = scan_op.fgraph.outputs[0].owner.op
232233
assert isinstance(cholesky_op, Blockwise)
233234
assert isinstance(cholesky_op.core_op, Cholesky)
235+
236+
237+
@pytest.mark.parametrize("as_view", [True, False])
238+
def test_ifelse_single_output(as_view, single_out=True):
239+
x = pt.vector("x")
240+
y = pt.vector("y")
241+
if single_out:
242+
outs = [x]
243+
else:
244+
outs = [x, y]
245+
246+
op = IfElse(as_view=as_view, n_outs=len(outs))
247+
outs = op(x.sum() > 0, *outs, *outs, return_list=True)
248+
249+
fn = function(
250+
[In(x, borrow=True), In(y, borrow=True)],
251+
[Out(out, borrow=True) for out in outs],
252+
mode=Mode("numba", optimizer=None),
253+
accept_inplace=True,
254+
on_unused_input="ignore",
255+
)
256+
257+
# FALSE branch
258+
test_x = np.zeros(3)
259+
test_y = np.ones(5)
260+
res_false = fn(test_x, test_y)
261+
for test_inp, res_out in zip([test_x, test_y], res_false, strict=False):
262+
np.testing.assert_array_equal(test_inp, res_out)
263+
# IfElse only views on the true branch variates
264+
assert res_out is not test_inp
265+
266+
# TRUE branch
267+
test_x = np.ones(3)
268+
res_true = fn(test_x, test_y)
269+
for test_inp, res_out in zip([test_x, test_y], res_true, strict=False):
270+
np.testing.assert_array_equal(test_inp, res_out)
271+
if as_view:
272+
assert res_out is test_inp
273+
else:
274+
assert res_out is not test_inp

0 commit comments

Comments
 (0)