diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py index 38d90f848e..f8ffa4adf9 100644 --- a/pytensor/link/numba/dispatch/random.py +++ b/pytensor/link/numba/dispatch/random.py @@ -216,14 +216,16 @@ def random_fn(rng, mean, cov): @numba_core_rv_funcify.register(ptr.DirichletRV) def core_DirichletRV(op, node): + dtype = op.dtype + @numba_basic.numba_njit def random_fn(rng, alpha): - y = np.empty_like(alpha) + y = np.empty_like(alpha, dtype=dtype) for i in range(len(alpha)): y[i] = rng.gamma(alpha[i], 1.0) return y / y.sum() - return random_fn + return random_fn, 1 @numba_core_rv_funcify.register(ptr.GumbelRV) @@ -410,7 +412,7 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs rv_op: RandomVariable = rv_node.op try: - core_rv_fn = numba_core_rv_funcify(rv_op, rv_node) + core_rv_fn_and_cache_key = numba_core_rv_funcify(rv_op, rv_node) except NotImplementedError: py_impl = generate_fallback_impl(rv_op, node=rv_node, **kwargs) @@ -420,6 +422,16 @@ def fallback_rv(_core_shape, *args): return fallback_rv, None + match core_rv_fn_and_cache_key: + case (core_rv_fn, (int() | None) as core_cache_key): + pass + case (_core_rv_fn, invalid_core_cache_key): + raise ValueError( + f"Invalid core_cache_key returned from numba_core_rv_funcify: {invalid_core_cache_key}. Must be int or None." + ) + case core_rv_fn: + core_cache_key = "__None__" + size = rv_op.size_param(rv_node) dist_params = rv_op.dist_params(rv_node) size_len = None if isinstance(size.type, NoneTypeT) else get_vector_length(size) @@ -469,16 +481,21 @@ def impl(core_shape, rng, size, *dist_params): return impl - rv_op_props_dict = rv_op.props_dict() if hasattr(rv_op, "props_dict") else {} - random_rv_key_contents = ( - type(op), - type(rv_op), - rv_op, - tuple(rv_op_props_dict.items()), - size_len, - core_shape_len, - inplace, - input_bc_patterns, - ) - random_rv_key = sha256(str(random_rv_key_contents).encode()).hexdigest() + if core_cache_key is None: + # If the core RV can't be cached, then the whole RV can't be cached + random_rv_key = None # type: ignore[unreachable] + else: + rv_op_props_dict = rv_op.props_dict() if hasattr(rv_op, "props_dict") else {} + random_rv_key_contents = ( + type(op), + type(rv_op), + rv_op, + tuple(rv_op_props_dict.items()), + size_len, + core_shape_len, + inplace, + input_bc_patterns, + core_cache_key, + ) + random_rv_key = sha256(str(random_rv_key_contents).encode()).hexdigest() return random, random_rv_key diff --git a/tests/link/numba/test_random.py b/tests/link/numba/test_random.py index 20b1026e07..fd7d61e232 100644 --- a/tests/link/numba/test_random.py +++ b/tests/link/numba/test_random.py @@ -679,6 +679,16 @@ def test_DirichletRV(a, size, cm): assert np.allclose(res, exp_res, atol=1e-4) +def test_dirichlet_discrete_alpha(): + alpha = pt.lvector() + g = ptr.dirichlet(alpha, size=100) + fn = function([alpha], g, mode=numba_mode) + res = fn(np.array([1, 1, 1], dtype=np.int64)) + assert res.dtype == np.float64 + np.testing.assert_allclose(res.sum(-1), 1.0) + assert np.unique(res).size > 2 # Make sure we have more than just 0s and 1s + + def test_rv_inside_ofg(): rng_np = np.random.default_rng(562) rng = shared(rng_np)