Skip to content

Commit f808419

Browse files
committed
Fix numba cache mangling in forked processes
1 parent 4582fc9 commit f808419

7 files changed

Lines changed: 78 additions & 9 deletions

File tree

pytensor/link/numba/dispatch/basic.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
import warnings
23
from collections.abc import Callable
34
from functools import singledispatch, wraps
@@ -13,7 +14,10 @@
1314
from pytensor.graph.basic import Apply, Constant, Variable
1415
from pytensor.graph.fg import FunctionGraph
1516
from pytensor.graph.type import Type
16-
from pytensor.link.numba.cache import compile_numba_function_src, hash_from_pickle_dump
17+
from pytensor.link.numba.cache import (
18+
compile_numba_function_src,
19+
hash_from_pickle_dump,
20+
)
1721
from pytensor.link.utils import (
1822
fgraph_to_python,
1923
)
@@ -33,7 +37,7 @@ def _filter_numba_warnings():
3337
"ignore",
3438
message=(
3539
"(\x1b\\[1m)*" # ansi escape code for bold text
36-
'Cannot cache compiled function "numba_funcified_fgraph" as it uses dynamic globals'
40+
'Cannot cache compiled function "numba_funcified_fgraph.*" as it uses dynamic globals'
3741
),
3842
category=NumbaWarning,
3943
)
@@ -446,12 +450,20 @@ def numba_funcify_ensure_cache(op, *args, **kwargs) -> tuple[Callable, str | Non
446450
print(f"{op} of type {type(op)} will not be cached by PyTensor.\n") # noqa: T201
447451
return jitable_func, None
448452
else:
449-
op_name = jitable_func.__name__
453+
full_cache_key = f"{cache_key}_fastmath{int(config.numba__fastmath)}"
454+
# Include cache_key in the wrapper name to ensure unique LLVM symbol names.
455+
# Without this, functions with the same __name__ but different behavior
456+
# (e.g. all DimShuffle ops produce "dimshuffle")
457+
# get identical mangled names when numba's UID counter overlaps after os.fork().
458+
# This could cause compilation errors or silent bugs.
459+
# See https://github.com/numba/numba/issues/10486
460+
safe_key = re.sub(r"[^a-zA-Z0-9_]", "_", full_cache_key)
461+
op_name = f"{jitable_func.__name__}_{safe_key}"
450462
cached_func = compile_numba_function_src(
451463
src=f"def {op_name}(*args): return jitable_func(*args)",
452464
function_name=op_name,
453465
global_env=globals() | {"jitable_func": jitable_func},
454-
cache_key=f"{cache_key}_fastmath{int(config.numba__fastmath)}",
466+
cache_key=full_cache_key,
455467
)
456468
return numba_njit(cached_func, cache=True), cache_key
457469

tests/compile/function/test_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
pytestmark = pytest.mark.filterwarnings(
4141
"error",
4242
r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning",
43-
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph\".*:numba.NumbaWarning",
43+
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph.*:numba.NumbaWarning",
4444
)
4545

4646

tests/link/numba/signal/test_conv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
pytestmark = pytest.mark.filterwarnings(
1414
"error",
1515
r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning",
16-
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph\".*:numba.NumbaWarning",
16+
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph.*:numba.NumbaWarning",
1717
)
1818

1919

tests/link/numba/test_basic.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import contextlib
22
import copy
3+
import os
34
from collections.abc import Callable, Iterable
45
from typing import TYPE_CHECKING, Any
56
from unittest import mock
@@ -787,3 +788,59 @@ def test_constant_output(self):
787788
assert self.generate_and_validate_key(fg_pi) != self.generate_and_validate_key(
788789
fg_e
789790
)
791+
792+
793+
@pytest.mark.skipif(not hasattr(os, "fork"), reason="Test requires os.fork (Unix only)")
794+
def test_fork_cache_no_type_mismatch(tmp_path, monkeypatch):
795+
"""Regression test for fork-safety of the numba disk cache.
796+
797+
After os.fork(), numba's internal UID counter (FunctionIdentity._unique_ids)
798+
is shared between parent and child. If two exec()-created wrapper functions
799+
with the same qualname get the same UID in different processes, their LLVM
800+
mangled names collide. When they have different return types (e.g. 3D vs 4D
801+
array), this causes a ValueError during LLVM lowering.
802+
803+
PyTensor prevents this by including the cache key in the wrapper function
804+
name, ensuring unique LLVM symbols even when UIDs overlap after fork.
805+
806+
See: https://github.com/numba/numba/issues/10486
807+
"""
808+
import pytensor.link.numba.cache as cache_mod
809+
810+
# Use a temporary cache for this test
811+
monkeypatch.setattr(cache_mod, "NUMBA_CACHE_PATH", tmp_path)
812+
813+
def run_in_fork(func):
814+
pid = os.fork()
815+
if pid == 0:
816+
try:
817+
func()
818+
os._exit(0)
819+
except BaseException:
820+
os._exit(1)
821+
else:
822+
_, status = os.waitpid(pid, 0)
823+
return os.WEXITSTATUS(status)
824+
825+
def graph_a():
826+
x = pt.tensor3("x")
827+
fn = function([x], x.transpose(2, 0, 1), mode="NUMBA")
828+
assert fn(np.zeros((2, 3, 4))).shape == (4, 2, 3)
829+
830+
def graph_b():
831+
x = pt.tensor3("x")
832+
fn = function([x], [x.transpose(2, 0, 1), x[None]], mode="NUMBA")
833+
r1, r2 = fn(np.zeros((2, 3, 4)))
834+
assert r1.shape == (4, 2, 3)
835+
assert r2.shape == (1, 2, 3, 4)
836+
837+
# Fork child compiles graph_a (transpose only)
838+
assert run_in_fork(graph_a) == 0, "Fork child failed"
839+
840+
# Parent compiles graph_b (transpose + expand dims)
841+
# This loads fork's cache and also compiles fresh ops
842+
graph_b()
843+
844+
# Running in another fork is also fine
845+
assert run_in_fork(graph_a) == 0, "Fork child 1 failed"
846+
assert run_in_fork(graph_b) == 0, "Fork child 2 failed"

tests/tensor/test_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@
155155
pytestmark = pytest.mark.filterwarnings(
156156
"error",
157157
r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning",
158-
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph\".*:numba.NumbaWarning",
158+
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph.*:numba.NumbaWarning",
159159
r"ignore::numba.NumbaPerformanceWarning",
160160
)
161161

tests/tensor/test_einsum.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
pytestmark = pytest.mark.filterwarnings(
2121
"error",
2222
r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning",
23-
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph\".*:numba.NumbaWarning",
23+
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph.*:numba.NumbaWarning",
2424
r"ignore::numba.NumbaPerformanceWarning",
2525
)
2626

tests/tensor/test_variable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
pytestmark = pytest.mark.filterwarnings(
4949
"error",
5050
r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning",
51-
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph\".*:numba.NumbaWarning",
51+
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph.*:numba.NumbaWarning",
5252
)
5353

5454

0 commit comments

Comments
 (0)