|
1 | 1 | import contextlib |
2 | 2 | import copy |
| 3 | +import os |
3 | 4 | from collections.abc import Callable, Iterable |
4 | 5 | from typing import TYPE_CHECKING, Any |
5 | 6 | from unittest import mock |
@@ -787,3 +788,59 @@ def test_constant_output(self): |
787 | 788 | assert self.generate_and_validate_key(fg_pi) != self.generate_and_validate_key( |
788 | 789 | fg_e |
789 | 790 | ) |
| 791 | + |
| 792 | + |
| 793 | +@pytest.mark.skipif(not hasattr(os, "fork"), reason="Test requires os.fork (Unix only)") |
| 794 | +def test_fork_cache_no_type_mismatch(tmp_path, monkeypatch): |
| 795 | + """Regression test for fork-safety of the numba disk cache. |
| 796 | +
|
| 797 | + After os.fork(), numba's internal UID counter (FunctionIdentity._unique_ids) |
| 798 | + is shared between parent and child. If two exec()-created wrapper functions |
| 799 | + with the same qualname get the same UID in different processes, their LLVM |
| 800 | + mangled names collide. When they have different return types (e.g. 3D vs 4D |
| 801 | + array), this causes a ValueError during LLVM lowering. |
| 802 | +
|
| 803 | + PyTensor prevents this by including the cache key in the wrapper function |
| 804 | + name, ensuring unique LLVM symbols even when UIDs overlap after fork. |
| 805 | +
|
| 806 | + See: https://github.com/numba/numba/issues/10486 |
| 807 | + """ |
| 808 | + import pytensor.link.numba.cache as cache_mod |
| 809 | + |
| 810 | + # Use a temporary cache for this test |
| 811 | + monkeypatch.setattr(cache_mod, "NUMBA_CACHE_PATH", tmp_path) |
| 812 | + |
| 813 | + def run_in_fork(func): |
| 814 | + pid = os.fork() |
| 815 | + if pid == 0: |
| 816 | + try: |
| 817 | + func() |
| 818 | + os._exit(0) |
| 819 | + except BaseException: |
| 820 | + os._exit(1) |
| 821 | + else: |
| 822 | + _, status = os.waitpid(pid, 0) |
| 823 | + return os.WEXITSTATUS(status) |
| 824 | + |
| 825 | + def graph_a(): |
| 826 | + x = pt.tensor3("x") |
| 827 | + fn = function([x], x.transpose(2, 0, 1), mode="NUMBA") |
| 828 | + assert fn(np.zeros((2, 3, 4))).shape == (4, 2, 3) |
| 829 | + |
| 830 | + def graph_b(): |
| 831 | + x = pt.tensor3("x") |
| 832 | + fn = function([x], [x.transpose(2, 0, 1), x[None]], mode="NUMBA") |
| 833 | + r1, r2 = fn(np.zeros((2, 3, 4))) |
| 834 | + assert r1.shape == (4, 2, 3) |
| 835 | + assert r2.shape == (1, 2, 3, 4) |
| 836 | + |
| 837 | + # Fork child compiles graph_a (transpose only) |
| 838 | + assert run_in_fork(graph_a) == 0, "Fork child failed" |
| 839 | + |
| 840 | + # Parent compiles graph_b (transpose + expand dims) |
| 841 | + # This loads fork's cache and also compiles fresh ops |
| 842 | + graph_b() |
| 843 | + |
| 844 | + # Running in another fork is also fine |
| 845 | + assert run_in_fork(graph_a) == 0, "Fork child 1 failed" |
| 846 | + assert run_in_fork(graph_b) == 0, "Fork child 2 failed" |
0 commit comments