Fixing issue Samples are outside the support for DiscreteUniform dist…#1835
Fixing issue Samples are outside the support for DiscreteUniform dist…#1835Deathn0t wants to merge 8 commits intopyro-ppl:masterfrom
Conversation
numpyro/infer/mixed_hmc.py
Outdated
| lambda idx, support: support[idx], | ||
| z_discrete, | ||
| self._support_enumerates, | ||
| ) |
There was a problem hiding this comment.
Doing this might return in-support values but I worry that the algorithms are wrong. To compute potential energy correctly in the algorithm, we need to work with in-support values. I think you can pass support_enumerates into self._discrete_proposal_fn and change the proposal logic there.
proposal = random.randint(rng_proposal, (), minval=0, maxval=support_size)
# z_new_flat = z_discrete_flat.at[idx].set(proposal)
z_new_flat = z_discrete_flat.at[idx].set(support_enumerate[proposal])
or for modified rw proposal
i = random.randint(rng_proposal, (), minval=0, maxval=support_size - 1)
# proposal = jnp.where(i >= z_discrete_flat[idx], i + 1, i)
# proposal = jnp.where(random.bernoulli(rng_stay, stay_prob), idx, proposal)
proposal_index = jnp.where(support_size[i] == z_discrete_flat[idx], support_size - 1, i)
proposal = jnp.where(random.bernoulli(rng_stay, stay_prob), idx, support_size[proposal_index])
z_new_flat = z_discrete_flat.at[idx].set(proposal)
or at discrete gibbs proposal
proposal_index = jnp.where(support_enumerate[i] == z_init_flat[idx], support_size - 1, i)
z_new_flat = z_init_flat.at[idx].set(support_enumerate[proposal_index])
There was a problem hiding this comment.
Ok, thank you for the feedback. I will try this.
There was a problem hiding this comment.
@fehiepsi how do you debug in numpyro? I tried jax.debug. but nothing happens.
There was a problem hiding this comment.
I use print most of the time. When actual values are needed, I sometimes use jax.disable_jit()
There was a problem hiding this comment.
@fehiepsi I have issues with passing enumerate supports and traced values as the support arrays can have different sizes. I was thinking maybe to just pass the "lower bound of the support" as offset and combined with support_sizes it should make the trick. Are there discrete variables where the support is not a simple discrete range with step 1 between values?
There was a problem hiding this comment.
for modified_rw_proposal I think you used support_size in place of support_enumerate, shouldn't it be:
i = random.randint(rng_proposal, (), minval=0, maxval=support_size - 1)
# proposal = jnp.where(i >= z_discrete_flat[idx], i + 1, i)
# proposal = jnp.where(random.bernoulli(rng_stay, stay_prob), idx, proposal)
proposal_index = jnp.where(support_enumerate[i] == z_discrete_flat[idx], support_size - 1, i)
proposal = jnp.where(random.bernoulli(rng_stay, stay_prob), idx, support_enumerate[proposal_index])
z_new_flat = z_discrete_flat.at[idx].set(proposal)There was a problem hiding this comment.
thanks! your solutions are super cool! I haven't thought of different support sizes previously.
numpyro/infer/hmc_gibbs.py
Outdated
| self._support_enumerates = np.zeros( | ||
| (len(self._support_sizes), max_length_support_enumerates), dtype=int | ||
| ) | ||
| for i, (name, site) in enumerate(self._prototype_trace.items()): |
There was a problem hiding this comment.
great solution! I just have a couple of comments:
- it might be better to loop over names in support_sizes and get site via
site = self._prototype_trace[name] - we use
ravel_pytreeto flatten support_sizes. so we might want to keep the same behavior here. I don't have a great solution for this, maybe
support_enumerates = {}
for name, support_size in self._support_sizes.items():
site = self._prototype_trace[name]
enumerate_support = site["fn"].enumerate_support(False)
padded_enumerate_support = np.pad(enumerate_support, (0, max_length_support_enumerates - enumerate_support.shape[0]))
padded_enumerate_support = np.broadcast_to(padded_enumerate_support, support_size.shape + (max_length_support_enumerates,))
support_enumerates[name] = padded_enumerate_support
self._support_enumerates = jax.vmap(lambda x: ravel_pytree(x)[0], in_axes=1, out_axes=1)(support_enumerates)|
@fehiepsi it worked fine with |
|
I think we need to ravel along the first axis. The second axis (corresponds to |
numpyro/infer/hmc_gibbs.py
Outdated
| for site in self._prototype_trace.values() | ||
| if site["type"] == "sample" | ||
| and site["fn"].has_enumerate_support | ||
| and not site["is_observed"] |
There was a problem hiding this comment.
nit: it is better to loop over support_sizes: for name, site in self._prototype_trace.items() if name in support_sizes
the first axis is |
|
we vmap over the batch axis, which is the second axis, i.e. in_axes=1 |
|
Could you also add a simple test (as in the issue) for this? you can run |
|
I applied the lint/format and I added a test.
ok, but the So the following line: support_size.shape + (max_length_support_enumerates,),is just equivalent to Maybe you have an example where |
|
That is a good point. I thought support sizes contain flatten arrays. Sorry for the confusion. I guess we need to move the enumerate dimension to the first axis before vmapping like you did |
|
I tried the following direction: max_length_support_enumerates = np.max(
[size for size in self._support_sizes.values()]
)
support_enumerates = {}
for name, support_size in self._support_sizes.items():
site = self._prototype_trace[name]
enumerate_support = site["fn"].enumerate_support(True).T
# Only the last dimension that corresponds to support size is padded
pad_width = [(0, 0) for _ in range(len(enumerate_support.shape) - 1)] + [
(0, max_length_support_enumerates - enumerate_support.shape[-1])
]
padded_enumerate_support = np.pad(enumerate_support, pad_width)
support_enumerates[name] = padded_enumerate_support
self._support_enumerates = jax.vmap(
lambda x: ravel_pytree(x)[0], in_axes=len(support_size.shape), out_axes=1
)(support_enumerates)which work with the following cases: def model_1():
numpyro.sample("x0", dist.DiscreteUniform(10, 12))
numpyro.sample("x1", dist.Categorical(np.asarray([0.25, 0.25, 0.25, 0.25])))
def model_2():
numpyro.sample("x0", dist.Categorical(0.25 * jnp.ones((4,))))
numpyro.sample("x1", dist.Categorical(0.1 * jnp.ones((10,))))
def model_3():
numpyro.sample("x0", dist.Categorical(0.25 * jnp.ones((3, 4))))
numpyro.sample("x1", dist.Categorical(0.1 * jnp.ones((3, 10))))But fails when I try to batch def model_4():
numpyro.sample("x1", dist.DiscreteUniform(10 * jnp.ones((3,)), 19 * jnp.ones((3,))))with the following exception which comes before the code I added (when the |
|
The By the way, maybe we need to use |
|
Hmm, there seems to have a bug at DiscreteUniform.enumerate_support. |
|
@fehiepsi sorry for the delay... other things happened I couldn't follow up. Yes, let me test this now! |
…tests are passing when using changes from PR pyro-ppl#1859
This fixes issue #1834 for
MixedHMCsampling withDiscreteUniformdistribution sampling outside the support without using theenumerate_support.