Skip to content

Commit 7ced77d

Browse files
committed
Fix numba cache mangling in forked processes
1 parent 21d7545 commit 7ced77d

2 files changed

Lines changed: 72 additions & 4 deletions

File tree

pytensor/link/numba/dispatch/basic.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
import warnings
23
from collections.abc import Callable
34
from functools import singledispatch, wraps
@@ -13,7 +14,10 @@
1314
from pytensor.graph.basic import Apply, Constant, Variable
1415
from pytensor.graph.fg import FunctionGraph
1516
from pytensor.graph.type import Type
16-
from pytensor.link.numba.cache import compile_numba_function_src, hash_from_pickle_dump
17+
from pytensor.link.numba.cache import (
18+
compile_numba_function_src,
19+
hash_from_pickle_dump,
20+
)
1721
from pytensor.link.utils import (
1822
fgraph_to_python,
1923
)
@@ -33,7 +37,7 @@ def _filter_numba_warnings():
3337
"ignore",
3438
message=(
3539
"(\x1b\\[1m)*" # ansi escape code for bold text
36-
'Cannot cache compiled function "numba_funcified_fgraph" as it uses dynamic globals'
40+
'Cannot cache compiled function "numba_funcified_fgraph.*" as it uses dynamic globals'
3741
),
3842
category=NumbaWarning,
3943
)
@@ -446,12 +450,19 @@ def numba_funcify_ensure_cache(op, *args, **kwargs) -> tuple[Callable, str | Non
446450
print(f"{op} of type {type(op)} will not be cached by PyTensor.\n") # noqa: T201
447451
return jitable_func, None
448452
else:
449-
op_name = jitable_func.__name__
453+
full_cache_key = f"{cache_key}_fastmath{int(config.numba__fastmath)}"
454+
# Include cache_key in the wrapper name to ensure unique LLVM symbol
455+
# names. Without this, functions with the same __name__ (e.g. all
456+
# DimShuffle ops produce "dimshuffle") but different return types get
457+
# identical mangled names when numba's UID counter overlaps after
458+
# os.fork(), causing LLVM type mismatch errors.
459+
safe_key = re.sub(r"[^a-zA-Z0-9_]", "_", full_cache_key)
460+
op_name = f"{jitable_func.__name__}_{safe_key}"
450461
cached_func = compile_numba_function_src(
451462
src=f"def {op_name}(*args): return jitable_func(*args)",
452463
function_name=op_name,
453464
global_env=globals() | {"jitable_func": jitable_func},
454-
cache_key=f"{cache_key}_fastmath{int(config.numba__fastmath)}",
465+
cache_key=full_cache_key,
455466
)
456467
return numba_njit(cached_func, cache=True), cache_key
457468

tests/link/numba/test_basic.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import contextlib
22
import copy
3+
import os
34
from collections.abc import Callable, Iterable
45
from typing import TYPE_CHECKING, Any
56
from unittest import mock
@@ -787,3 +788,59 @@ def test_constant_output(self):
787788
assert self.generate_and_validate_key(fg_pi) != self.generate_and_validate_key(
788789
fg_e
789790
)
791+
792+
793+
@pytest.mark.skipif(not hasattr(os, "fork"), reason="Test requires os.fork (Unix only)")
794+
def test_fork_cache_no_type_mismatch(tmp_path, monkeypatch):
795+
"""Regression test for fork-safety of the numba disk cache.
796+
797+
After os.fork(), numba's internal UID counter (FunctionIdentity._unique_ids)
798+
is shared between parent and child. If two exec()-created wrapper functions
799+
with the same qualname get the same UID in different processes, their LLVM
800+
mangled names collide. When they have different return types (e.g. 3D vs 4D
801+
array), this causes a ValueError during LLVM lowering.
802+
803+
PyTensor prevents this by including the cache key in the wrapper function
804+
name, ensuring unique LLVM symbols even when UIDs overlap after fork.
805+
806+
See: https://github.com/numba/numba/issues/10486
807+
"""
808+
import pytensor.link.numba.cache as cache_mod
809+
810+
# Use a temporary cache for this test
811+
monkeypatch.setattr(cache_mod, "NUMBA_CACHE_PATH", tmp_path)
812+
813+
def run_in_fork(func):
814+
pid = os.fork()
815+
if pid == 0:
816+
try:
817+
func()
818+
os._exit(0)
819+
except BaseException:
820+
os._exit(1)
821+
else:
822+
_, status = os.waitpid(pid, 0)
823+
return os.WEXITSTATUS(status)
824+
825+
def graph_a():
826+
x = pt.tensor3("x")
827+
fn = function([x], x.transpose(2, 0, 1), mode="NUMBA")
828+
assert fn(np.zeros((2, 3, 4))).shape == (4, 2, 3)
829+
830+
def graph_b():
831+
x = pt.tensor3("x")
832+
fn = function([x], [x.transpose(2, 0, 1), x[None]], mode="NUMBA")
833+
r1, r2 = fn(np.zeros((2, 3, 4)))
834+
assert r1.shape == (4, 2, 3)
835+
assert r2.shape == (1, 2, 3, 4)
836+
837+
# Fork child compiles graph_a (transpose only)
838+
assert run_in_fork(graph_a) == 0, "Fork child failed"
839+
840+
# Parent compiles graph_b (transpose + expand dims)
841+
# This loads fork's cache and also compiles fresh ops
842+
graph_b()
843+
844+
# Running in another fork is also fine
845+
assert run_in_fork(graph_a) == 0, "Fork child 1 failed"
846+
assert run_in_fork(graph_b) == 0, "Fork child 2 failed"

0 commit comments

Comments
 (0)