Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import warnings
from collections.abc import Callable
from functools import singledispatch, wraps
Expand All @@ -13,7 +14,10 @@
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.type import Type
from pytensor.link.numba.cache import compile_numba_function_src, hash_from_pickle_dump
from pytensor.link.numba.cache import (
compile_numba_function_src,
hash_from_pickle_dump,
)
from pytensor.link.utils import (
fgraph_to_python,
)
Expand All @@ -33,7 +37,7 @@ def _filter_numba_warnings():
"ignore",
message=(
"(\x1b\\[1m)*" # ansi escape code for bold text
'Cannot cache compiled function "numba_funcified_fgraph" as it uses dynamic globals'
'Cannot cache compiled function "numba_funcified_fgraph.*" as it uses dynamic globals'
),
category=NumbaWarning,
)
Expand Down Expand Up @@ -446,12 +450,20 @@ def numba_funcify_ensure_cache(op, *args, **kwargs) -> tuple[Callable, str | Non
print(f"{op} of type {type(op)} will not be cached by PyTensor.\n") # noqa: T201
return jitable_func, None
else:
op_name = jitable_func.__name__
full_cache_key = f"{cache_key}_fastmath{int(config.numba__fastmath)}"
# Include cache_key in the wrapper name to ensure unique LLVM symbol names.
# Without this, functions with the same __name__ but different behavior
# (e.g. all DimShuffle ops produce "dimshuffle")
# get identical mangled names when numba's UID counter overlaps after os.fork().
# This could cause compilation errors or silent bugs.
# See https://github.com/numba/numba/issues/10486
safe_key = re.sub(r"[^a-zA-Z0-9_]", "_", full_cache_key)
op_name = f"{jitable_func.__name__}_{safe_key}"
cached_func = compile_numba_function_src(
src=f"def {op_name}(*args): return jitable_func(*args)",
function_name=op_name,
global_env=globals() | {"jitable_func": jitable_func},
cache_key=f"{cache_key}_fastmath{int(config.numba__fastmath)}",
cache_key=full_cache_key,
)
return numba_njit(cached_func, cache=True), cache_key

Expand Down
2 changes: 1 addition & 1 deletion tests/compile/function/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
pytestmark = pytest.mark.filterwarnings(
"error",
r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning",
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph\".*:numba.NumbaWarning",
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph.*:numba.NumbaWarning",
)


Expand Down
2 changes: 1 addition & 1 deletion tests/link/numba/signal/test_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
pytestmark = pytest.mark.filterwarnings(
"error",
r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning",
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph\".*:numba.NumbaWarning",
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph.*:numba.NumbaWarning",
)


Expand Down
105 changes: 82 additions & 23 deletions tests/link/numba/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import copy
import os
from collections.abc import Callable, Iterable
from typing import TYPE_CHECKING, Any
from unittest import mock
Expand Down Expand Up @@ -29,6 +30,7 @@
_filter_numba_warnings,
cache_key_for_constant,
numba_funcify_and_cache_key,
numba_njit,
)
from pytensor.link.numba.linker import NumbaLinker
from pytensor.scalar.basic import Composite, ScalarOp, as_scalar
Expand Down Expand Up @@ -426,14 +428,13 @@ def test_shared_updates():


def test_config_options_fastmath():
x = pt.dvector()

with config.change_flags(numba__fastmath=True):
pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode)
numba_sum_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__[
"jitable_func"
].py_func.__globals__["impl_sum"]
assert numba_sum_fn.targetoptions["fastmath"] == {

@numba_njit
def fn_fast(x):
return x + 1

assert fn_fast.targetoptions["fastmath"] == {
"afn",
"arcp",
"contract",
Expand All @@ -442,28 +443,30 @@ def test_config_options_fastmath():
}

with config.change_flags(numba__fastmath=False):
pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode)
numba_sum_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__[
"jitable_func"
].py_func.__globals__["impl_sum"]
assert numba_sum_fn.targetoptions["fastmath"] is False

@numba_njit
def fn_nofast(x):
return x + 1

def test_config_options_cached():
x = pt.dvector()
assert fn_nofast.targetoptions["fastmath"] is False


def test_config_options_cached():
with config.change_flags(numba__cache=True):
pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode)
numba_sum_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__[
"jitable_func"
].py_func.__globals__["impl_sum"]
assert not isinstance(numba_sum_fn._cache, numba.core.caching.NullCache)

@numba_njit(cache=True)
def fn_cached(x):
return x + 1

assert not isinstance(fn_cached._cache, numba.core.caching.NullCache)

with config.change_flags(numba__cache=False):
pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode)
# Without caching we don't wrap the function in jitable_func
numba_sum_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"]
assert isinstance(numba_sum_fn._cache, numba.core.caching.NullCache)

@numba_njit
def fn_uncached(x):
return x + 1

assert isinstance(fn_uncached._cache, numba.core.caching.NullCache)


def test_scalar_return_value_conversion():
Expand Down Expand Up @@ -785,3 +788,59 @@ def test_constant_output(self):
assert self.generate_and_validate_key(fg_pi) != self.generate_and_validate_key(
fg_e
)


@pytest.mark.skipif(not hasattr(os, "fork"), reason="Test requires os.fork (Unix only)")
def test_fork_cache_no_type_mismatch(tmp_path, monkeypatch):
"""Regression test for fork-safety of the numba disk cache.

After os.fork(), numba's internal UID counter (FunctionIdentity._unique_ids)
is shared between parent and child. If two exec()-created wrapper functions
with the same qualname get the same UID in different processes, their LLVM
mangled names collide. When they have different return types (e.g. 3D vs 4D
array), this causes a ValueError during LLVM lowering.

PyTensor prevents this by including the cache key in the wrapper function
name, ensuring unique LLVM symbols even when UIDs overlap after fork.

See: https://github.com/numba/numba/issues/10486
"""
import pytensor.link.numba.cache as cache_mod

# Use a temporary cache for this test
monkeypatch.setattr(cache_mod, "NUMBA_CACHE_PATH", tmp_path)

def run_in_fork(func):
pid = os.fork()
if pid == 0:
try:
func()
os._exit(0)
except BaseException:
os._exit(1)
else:
_, status = os.waitpid(pid, 0)
return os.WEXITSTATUS(status)

def graph_a():
x = pt.tensor3("x")
fn = function([x], x.transpose(2, 0, 1), mode="NUMBA")
assert fn(np.zeros((2, 3, 4))).shape == (4, 2, 3)

def graph_b():
x = pt.tensor3("x")
fn = function([x], [x.transpose(2, 0, 1), x[None]], mode="NUMBA")
r1, r2 = fn(np.zeros((2, 3, 4)))
assert r1.shape == (4, 2, 3)
assert r2.shape == (1, 2, 3, 4)

# Fork child compiles graph_a (transpose only)
assert run_in_fork(graph_a) == 0, "Fork child failed"

# Parent compiles graph_b (transpose + expand dims)
# This loads fork's cache and also compiles fresh ops
graph_b()

# Running in another fork is also fine
assert run_in_fork(graph_a) == 0, "Fork child 1 failed"
assert run_in_fork(graph_b) == 0, "Fork child 2 failed"
2 changes: 1 addition & 1 deletion tests/tensor/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@
pytestmark = pytest.mark.filterwarnings(
"error",
r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning",
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph\".*:numba.NumbaWarning",
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph.*:numba.NumbaWarning",
r"ignore::numba.NumbaPerformanceWarning",
)

Expand Down
2 changes: 1 addition & 1 deletion tests/tensor/test_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
pytestmark = pytest.mark.filterwarnings(
"error",
r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning",
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph\".*:numba.NumbaWarning",
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph.*:numba.NumbaWarning",
r"ignore::numba.NumbaPerformanceWarning",
)

Expand Down
2 changes: 1 addition & 1 deletion tests/tensor/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
pytestmark = pytest.mark.filterwarnings(
"error",
r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning",
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph\".*:numba.NumbaWarning",
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph.*:numba.NumbaWarning",
)


Expand Down
Loading