From 4582fc9712e8b690c8814531b83f1f4f63e0bd05 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 9 Mar 2026 16:14:03 +0100 Subject: [PATCH 1/2] Test numba config flags directly --- tests/link/numba/test_basic.py | 48 ++++++++++++++++++---------------- 1 file changed, 25 insertions(+), 23 deletions(-) diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 5dad9eb8e3..fb29eda9c3 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -29,6 +29,7 @@ _filter_numba_warnings, cache_key_for_constant, numba_funcify_and_cache_key, + numba_njit, ) from pytensor.link.numba.linker import NumbaLinker from pytensor.scalar.basic import Composite, ScalarOp, as_scalar @@ -426,14 +427,13 @@ def test_shared_updates(): def test_config_options_fastmath(): - x = pt.dvector() - with config.change_flags(numba__fastmath=True): - pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode) - numba_sum_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__[ - "jitable_func" - ].py_func.__globals__["impl_sum"] - assert numba_sum_fn.targetoptions["fastmath"] == { + + @numba_njit + def fn_fast(x): + return x + 1 + + assert fn_fast.targetoptions["fastmath"] == { "afn", "arcp", "contract", @@ -442,28 +442,30 @@ def test_config_options_fastmath(): } with config.change_flags(numba__fastmath=False): - pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode) - numba_sum_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__[ - "jitable_func" - ].py_func.__globals__["impl_sum"] - assert numba_sum_fn.targetoptions["fastmath"] is False + @numba_njit + def fn_nofast(x): + return x + 1 + + assert fn_nofast.targetoptions["fastmath"] is False -def test_config_options_cached(): - x = pt.dvector() +def test_config_options_cached(): with config.change_flags(numba__cache=True): - pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode) - numba_sum_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__[ - "jitable_func" - ].py_func.__globals__["impl_sum"] - assert not isinstance(numba_sum_fn._cache, numba.core.caching.NullCache) + + @numba_njit(cache=True) + def fn_cached(x): + return x + 1 + + assert not isinstance(fn_cached._cache, numba.core.caching.NullCache) with config.change_flags(numba__cache=False): - pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode) - # Without caching we don't wrap the function in jitable_func - numba_sum_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"] - assert isinstance(numba_sum_fn._cache, numba.core.caching.NullCache) + + @numba_njit + def fn_uncached(x): + return x + 1 + + assert isinstance(fn_uncached._cache, numba.core.caching.NullCache) def test_scalar_return_value_conversion(): From f808419a4b9de8d5def25b1fefdaa6889cd06abe Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 18 Mar 2026 23:56:00 +0100 Subject: [PATCH 2/2] Fix numba cache mangling in forked processes --- pytensor/link/numba/dispatch/basic.py | 20 ++++++++-- tests/compile/function/test_types.py | 2 +- tests/link/numba/signal/test_conv.py | 2 +- tests/link/numba/test_basic.py | 57 +++++++++++++++++++++++++++ tests/tensor/test_basic.py | 2 +- tests/tensor/test_einsum.py | 2 +- tests/tensor/test_variable.py | 2 +- 7 files changed, 78 insertions(+), 9 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 0f1a148110..145fc72ca7 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -1,3 +1,4 @@ +import re import warnings from collections.abc import Callable from functools import singledispatch, wraps @@ -13,7 +14,10 @@ from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.fg import FunctionGraph from pytensor.graph.type import Type -from pytensor.link.numba.cache import compile_numba_function_src, hash_from_pickle_dump +from pytensor.link.numba.cache import ( + compile_numba_function_src, + hash_from_pickle_dump, +) from pytensor.link.utils import ( fgraph_to_python, ) @@ -33,7 +37,7 @@ def _filter_numba_warnings(): "ignore", message=( "(\x1b\\[1m)*" # ansi escape code for bold text - 'Cannot cache compiled function "numba_funcified_fgraph" as it uses dynamic globals' + 'Cannot cache compiled function "numba_funcified_fgraph.*" as it uses dynamic globals' ), category=NumbaWarning, ) @@ -446,12 +450,20 @@ def numba_funcify_ensure_cache(op, *args, **kwargs) -> tuple[Callable, str | Non print(f"{op} of type {type(op)} will not be cached by PyTensor.\n") # noqa: T201 return jitable_func, None else: - op_name = jitable_func.__name__ + full_cache_key = f"{cache_key}_fastmath{int(config.numba__fastmath)}" + # Include cache_key in the wrapper name to ensure unique LLVM symbol names. + # Without this, functions with the same __name__ but different behavior + # (e.g. all DimShuffle ops produce "dimshuffle") + # get identical mangled names when numba's UID counter overlaps after os.fork(). + # This could cause compilation errors or silent bugs. + # See https://github.com/numba/numba/issues/10486 + safe_key = re.sub(r"[^a-zA-Z0-9_]", "_", full_cache_key) + op_name = f"{jitable_func.__name__}_{safe_key}" cached_func = compile_numba_function_src( src=f"def {op_name}(*args): return jitable_func(*args)", function_name=op_name, global_env=globals() | {"jitable_func": jitable_func}, - cache_key=f"{cache_key}_fastmath{int(config.numba__fastmath)}", + cache_key=full_cache_key, ) return numba_njit(cached_func, cache=True), cache_key diff --git a/tests/compile/function/test_types.py b/tests/compile/function/test_types.py index 59588548f7..a7a3089881 100644 --- a/tests/compile/function/test_types.py +++ b/tests/compile/function/test_types.py @@ -40,7 +40,7 @@ pytestmark = pytest.mark.filterwarnings( "error", r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning", - r"ignore:Cannot cache compiled function \"numba_funcified_fgraph\".*:numba.NumbaWarning", + r"ignore:Cannot cache compiled function \"numba_funcified_fgraph.*:numba.NumbaWarning", ) diff --git a/tests/link/numba/signal/test_conv.py b/tests/link/numba/signal/test_conv.py index 1234ec9831..a7d5c52e5a 100644 --- a/tests/link/numba/signal/test_conv.py +++ b/tests/link/numba/signal/test_conv.py @@ -13,7 +13,7 @@ pytestmark = pytest.mark.filterwarnings( "error", r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning", - r"ignore:Cannot cache compiled function \"numba_funcified_fgraph\".*:numba.NumbaWarning", + r"ignore:Cannot cache compiled function \"numba_funcified_fgraph.*:numba.NumbaWarning", ) diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index fb29eda9c3..6ec60f935c 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -1,5 +1,6 @@ import contextlib import copy +import os from collections.abc import Callable, Iterable from typing import TYPE_CHECKING, Any from unittest import mock @@ -787,3 +788,59 @@ def test_constant_output(self): assert self.generate_and_validate_key(fg_pi) != self.generate_and_validate_key( fg_e ) + + +@pytest.mark.skipif(not hasattr(os, "fork"), reason="Test requires os.fork (Unix only)") +def test_fork_cache_no_type_mismatch(tmp_path, monkeypatch): + """Regression test for fork-safety of the numba disk cache. + + After os.fork(), numba's internal UID counter (FunctionIdentity._unique_ids) + is shared between parent and child. If two exec()-created wrapper functions + with the same qualname get the same UID in different processes, their LLVM + mangled names collide. When they have different return types (e.g. 3D vs 4D + array), this causes a ValueError during LLVM lowering. + + PyTensor prevents this by including the cache key in the wrapper function + name, ensuring unique LLVM symbols even when UIDs overlap after fork. + + See: https://github.com/numba/numba/issues/10486 + """ + import pytensor.link.numba.cache as cache_mod + + # Use a temporary cache for this test + monkeypatch.setattr(cache_mod, "NUMBA_CACHE_PATH", tmp_path) + + def run_in_fork(func): + pid = os.fork() + if pid == 0: + try: + func() + os._exit(0) + except BaseException: + os._exit(1) + else: + _, status = os.waitpid(pid, 0) + return os.WEXITSTATUS(status) + + def graph_a(): + x = pt.tensor3("x") + fn = function([x], x.transpose(2, 0, 1), mode="NUMBA") + assert fn(np.zeros((2, 3, 4))).shape == (4, 2, 3) + + def graph_b(): + x = pt.tensor3("x") + fn = function([x], [x.transpose(2, 0, 1), x[None]], mode="NUMBA") + r1, r2 = fn(np.zeros((2, 3, 4))) + assert r1.shape == (4, 2, 3) + assert r2.shape == (1, 2, 3, 4) + + # Fork child compiles graph_a (transpose only) + assert run_in_fork(graph_a) == 0, "Fork child failed" + + # Parent compiles graph_b (transpose + expand dims) + # This loads fork's cache and also compiles fresh ops + graph_b() + + # Running in another fork is also fine + assert run_in_fork(graph_a) == 0, "Fork child 1 failed" + assert run_in_fork(graph_b) == 0, "Fork child 2 failed" diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index d4c2ff24d9..54495b8b98 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -155,7 +155,7 @@ pytestmark = pytest.mark.filterwarnings( "error", r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning", - r"ignore:Cannot cache compiled function \"numba_funcified_fgraph\".*:numba.NumbaWarning", + r"ignore:Cannot cache compiled function \"numba_funcified_fgraph.*:numba.NumbaWarning", r"ignore::numba.NumbaPerformanceWarning", ) diff --git a/tests/tensor/test_einsum.py b/tests/tensor/test_einsum.py index 1a3620e46b..1943d7a394 100644 --- a/tests/tensor/test_einsum.py +++ b/tests/tensor/test_einsum.py @@ -20,7 +20,7 @@ pytestmark = pytest.mark.filterwarnings( "error", r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning", - r"ignore:Cannot cache compiled function \"numba_funcified_fgraph\".*:numba.NumbaWarning", + r"ignore:Cannot cache compiled function \"numba_funcified_fgraph.*:numba.NumbaWarning", r"ignore::numba.NumbaPerformanceWarning", ) diff --git a/tests/tensor/test_variable.py b/tests/tensor/test_variable.py index 130b104746..f157f23878 100644 --- a/tests/tensor/test_variable.py +++ b/tests/tensor/test_variable.py @@ -48,7 +48,7 @@ pytestmark = pytest.mark.filterwarnings( "error", r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning", - r"ignore:Cannot cache compiled function \"numba_funcified_fgraph\".*:numba.NumbaWarning", + r"ignore:Cannot cache compiled function \"numba_funcified_fgraph.*:numba.NumbaWarning", )