Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 32 additions & 15 deletions pytensor/link/numba/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,14 +216,16 @@ def random_fn(rng, mean, cov):

@numba_core_rv_funcify.register(ptr.DirichletRV)
def core_DirichletRV(op, node):
dtype = op.dtype

@numba_basic.numba_njit
def random_fn(rng, alpha):
y = np.empty_like(alpha)
y = np.empty_like(alpha, dtype=dtype)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this was created as integers the line y[i] = ... would implicitly cast the gamma draws back to integer, but those must be floats for the output to make sense

for i in range(len(alpha)):
y[i] = rng.gamma(alpha[i], 1.0)
return y / y.sum()

return random_fn
return random_fn, 1


@numba_core_rv_funcify.register(ptr.GumbelRV)
Expand Down Expand Up @@ -410,7 +412,7 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs
rv_op: RandomVariable = rv_node.op

try:
core_rv_fn = numba_core_rv_funcify(rv_op, rv_node)
core_rv_fn_and_cache_key = numba_core_rv_funcify(rv_op, rv_node)
except NotImplementedError:
py_impl = generate_fallback_impl(rv_op, node=rv_node, **kwargs)

Expand All @@ -420,6 +422,16 @@ def fallback_rv(_core_shape, *args):

return fallback_rv, None

match core_rv_fn_and_cache_key:
case (core_rv_fn, (int() | None) as core_cache_key):
pass
case (_core_rv_fn, invalid_core_cache_key):
raise ValueError(
f"Invalid core_cache_key returned from numba_core_rv_funcify: {invalid_core_cache_key}. Must be int or None."
)
case core_rv_fn:
core_cache_key = "__None__"

size = rv_op.size_param(rv_node)
dist_params = rv_op.dist_params(rv_node)
size_len = None if isinstance(size.type, NoneTypeT) else get_vector_length(size)
Expand Down Expand Up @@ -469,16 +481,21 @@ def impl(core_shape, rng, size, *dist_params):

return impl

rv_op_props_dict = rv_op.props_dict() if hasattr(rv_op, "props_dict") else {}
random_rv_key_contents = (
type(op),
type(rv_op),
rv_op,
tuple(rv_op_props_dict.items()),
size_len,
core_shape_len,
inplace,
input_bc_patterns,
)
random_rv_key = sha256(str(random_rv_key_contents).encode()).hexdigest()
if core_cache_key is None:
# If the core RV can't be cached, then the whole RV can't be cached
random_rv_key = None # type: ignore[unreachable]
else:
rv_op_props_dict = rv_op.props_dict() if hasattr(rv_op, "props_dict") else {}
random_rv_key_contents = (
type(op),
type(rv_op),
rv_op,
tuple(rv_op_props_dict.items()),
size_len,
core_shape_len,
inplace,
input_bc_patterns,
core_cache_key,
)
random_rv_key = sha256(str(random_rv_key_contents).encode()).hexdigest()
return random, random_rv_key
10 changes: 10 additions & 0 deletions tests/link/numba/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,16 @@ def test_DirichletRV(a, size, cm):
assert np.allclose(res, exp_res, atol=1e-4)


def test_dirichlet_discrete_alpha():
alpha = pt.lvector()
g = ptr.dirichlet(alpha, size=100)
fn = function([alpha], g, mode=numba_mode)
res = fn(np.array([1, 1, 1], dtype=np.int64))
assert res.dtype == np.float64
np.testing.assert_allclose(res.sum(-1), 1.0)
assert np.unique(res).size > 2 # Make sure we have more than just 0s and 1s


def test_rv_inside_ofg():
rng_np = np.random.default_rng(562)
rng = shared(rng_np)
Expand Down
Loading