From 4d98f004d495965894a30548e6e605bf8d9a62bb Mon Sep 17 00:00:00 2001 From: emekaokoli19 Date: Wed, 3 Dec 2025 00:06:51 +0100 Subject: [PATCH] Numba IfElse: respect view flag Co-authored-by: Ricardo Vieira <28983449+ricardov94@users.noreply.github.com> --- pytensor/link/numba/dispatch/compile_ops.py | 45 ++++++++++++--------- tests/link/numba/test_compile_ops.py | 43 +++++++++++++++++++- 2 files changed, 68 insertions(+), 20 deletions(-) diff --git a/pytensor/link/numba/dispatch/compile_ops.py b/pytensor/link/numba/dispatch/compile_ops.py index 8cba4a2e05..b91e4aa95a 100644 --- a/pytensor/link/numba/dispatch/compile_ops.py +++ b/pytensor/link/numba/dispatch/compile_ops.py @@ -1,5 +1,6 @@ from copy import deepcopy from hashlib import sha256 +from textwrap import dedent import numba import numpy as np @@ -10,6 +11,7 @@ from pytensor.compile.mode import NUMBA from pytensor.compile.ops import DeepCopyOp, TypeCastingOp from pytensor.ifelse import IfElse +from pytensor.link.numba.cache import compile_numba_function_src from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import ( numba_funcify_and_cache_key, @@ -106,30 +108,35 @@ def deepcopy(x): @register_funcify_default_op_cache_key(IfElse) def numba_funcify_IfElse(op, **kwargs): n_outs = op.n_outs + as_view = op.as_view - if n_outs > 1: - - @numba_basic.numba_njit - def ifelse(cond, *args): - if cond: - res = args[:n_outs] - else: - res = args[n_outs:] - - return res + true_names = [f"t{i}" for i in range(n_outs)] + false_names = [f"f{i}" for i in range(n_outs)] + arg_list = ", ".join((*true_names, *false_names)) + if as_view: + true_returns = ", ".join(true_names) else: + true_returns = ", ".join(f"{name}.copy()" for name in true_names) + # We only ever view (alias) variables from the true branch. False branch variables must always be copied. + false_returns = ", ".join(f"{name}.copy()" for name in false_names) + + func_src = dedent( + f""" + def ifelse(cond, {arg_list}): + if cond: + return {true_returns} + else: + return {false_returns} + """ + ) - @numba_basic.numba_njit - def ifelse(cond, *args): - if cond: - res = args[:n_outs] - else: - res = args[n_outs:] - - return res[0] + ifelse_func = numba_basic.numba_njit( + compile_numba_function_src(func_src, "ifelse", globals()) + ) - return ifelse + cache_version = 1 + return ifelse_func, cache_version @register_funcify_and_cache_key(CheckAndRaise) diff --git a/tests/link/numba/test_compile_ops.py b/tests/link/numba/test_compile_ops.py index f6af836029..2adece0b0f 100644 --- a/tests/link/numba/test_compile_ops.py +++ b/tests/link/numba/test_compile_ops.py @@ -1,10 +1,11 @@ import numpy as np import pytest -from pytensor import Mode, OpFromGraph, config, function, ifelse, scan +from pytensor import In, Mode, OpFromGraph, Out, config, function, ifelse, scan from pytensor import tensor as pt from pytensor.compile import ViewOp from pytensor.graph import vectorize_graph +from pytensor.ifelse import IfElse from pytensor.raise_op import assert_op from pytensor.scalar import Add from pytensor.scan.op import Scan @@ -231,3 +232,43 @@ def test_ofg_with_inner_scan_rewrite(): cholesky_op = scan_op.fgraph.outputs[0].owner.op assert isinstance(cholesky_op, Blockwise) assert isinstance(cholesky_op.core_op, Cholesky) + + +@pytest.mark.parametrize("as_view", [True, False]) +def test_ifelse_single_output(as_view, single_out=True): + x = pt.vector("x") + y = pt.vector("y") + if single_out: + outs = [x] + else: + outs = [x, y] + + op = IfElse(as_view=as_view, n_outs=len(outs)) + outs = op(x.sum() > 0, *outs, *outs, return_list=True) + + fn = function( + [In(x, borrow=True), In(y, borrow=True)], + [Out(out, borrow=True) for out in outs], + mode=Mode("numba", optimizer=None), + accept_inplace=True, + on_unused_input="ignore", + ) + + # FALSE branch + test_x = np.zeros(3) + test_y = np.ones(5) + res_false = fn(test_x, test_y) + for test_inp, res_out in zip([test_x, test_y], res_false, strict=False): + np.testing.assert_array_equal(test_inp, res_out) + # IfElse only views on the true branch variates + assert res_out is not test_inp + + # TRUE branch + test_x = np.ones(3) + res_true = fn(test_x, test_y) + for test_inp, res_out in zip([test_x, test_y], res_true, strict=False): + np.testing.assert_array_equal(test_inp, res_out) + if as_view: + assert res_out is test_inp + else: + assert res_out is not test_inp