Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 26 additions & 19 deletions pytensor/link/numba/dispatch/compile_ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from copy import deepcopy
from hashlib import sha256
from textwrap import dedent

import numba
import numpy as np
Expand All @@ -10,6 +11,7 @@
from pytensor.compile.mode import NUMBA
from pytensor.compile.ops import DeepCopyOp, TypeCastingOp
from pytensor.ifelse import IfElse
from pytensor.link.numba.cache import compile_numba_function_src
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
numba_funcify_and_cache_key,
Expand Down Expand Up @@ -106,30 +108,35 @@ def deepcopy(x):
@register_funcify_default_op_cache_key(IfElse)
def numba_funcify_IfElse(op, **kwargs):
n_outs = op.n_outs
as_view = op.as_view

if n_outs > 1:

@numba_basic.numba_njit
def ifelse(cond, *args):
if cond:
res = args[:n_outs]
else:
res = args[n_outs:]

return res
true_names = [f"t{i}" for i in range(n_outs)]
false_names = [f"f{i}" for i in range(n_outs)]
arg_list = ", ".join((*true_names, *false_names))

if as_view:
true_returns = ", ".join(true_names)
else:
true_returns = ", ".join(f"{name}.copy()" for name in true_names)
# We only ever view (alias) variables from the true branch. False branch variables must always be copied.
false_returns = ", ".join(f"{name}.copy()" for name in false_names)

func_src = dedent(
f"""
def ifelse(cond, {arg_list}):
if cond:
return {true_returns}
else:
return {false_returns}
"""
)

@numba_basic.numba_njit
def ifelse(cond, *args):
if cond:
res = args[:n_outs]
else:
res = args[n_outs:]

return res[0]
ifelse_func = numba_basic.numba_njit(
compile_numba_function_src(func_src, "ifelse", globals())
)

return ifelse
cache_version = 1
return ifelse_func, cache_version


@register_funcify_and_cache_key(CheckAndRaise)
Expand Down
43 changes: 42 additions & 1 deletion tests/link/numba/test_compile_ops.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import numpy as np
import pytest

from pytensor import Mode, OpFromGraph, config, function, ifelse, scan
from pytensor import In, Mode, OpFromGraph, Out, config, function, ifelse, scan
from pytensor import tensor as pt
from pytensor.compile import ViewOp
from pytensor.graph import vectorize_graph
from pytensor.ifelse import IfElse
from pytensor.raise_op import assert_op
from pytensor.scalar import Add
from pytensor.scan.op import Scan
Expand Down Expand Up @@ -231,3 +232,43 @@ def test_ofg_with_inner_scan_rewrite():
cholesky_op = scan_op.fgraph.outputs[0].owner.op
assert isinstance(cholesky_op, Blockwise)
assert isinstance(cholesky_op.core_op, Cholesky)


@pytest.mark.parametrize("as_view", [True, False])
def test_ifelse_single_output(as_view, single_out=True):
x = pt.vector("x")
y = pt.vector("y")
if single_out:
outs = [x]
else:
outs = [x, y]

op = IfElse(as_view=as_view, n_outs=len(outs))
outs = op(x.sum() > 0, *outs, *outs, return_list=True)

fn = function(
[In(x, borrow=True), In(y, borrow=True)],
[Out(out, borrow=True) for out in outs],
mode=Mode("numba", optimizer=None),
accept_inplace=True,
on_unused_input="ignore",
)

# FALSE branch
test_x = np.zeros(3)
test_y = np.ones(5)
res_false = fn(test_x, test_y)
for test_inp, res_out in zip([test_x, test_y], res_false, strict=False):
np.testing.assert_array_equal(test_inp, res_out)
# IfElse only views on the true branch variates
assert res_out is not test_inp

# TRUE branch
test_x = np.ones(3)
res_true = fn(test_x, test_y)
for test_inp, res_out in zip([test_x, test_y], res_true, strict=False):
np.testing.assert_array_equal(test_inp, res_out)
if as_view:
assert res_out is test_inp
else:
assert res_out is not test_inp
Loading