Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions labml_nn/sampling/nucleus.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

$$\sum_{x_i \in V^{(p)}} P(x_i | x_{1:i-1}) \ge p$$

That is, we pick the highest probable tokens until the sum of their probabilities is less that $p$.
That is, we pick the highest probable tokens until the sum of their probabilities is at least $p$.

Then we sample from the selected tokens.

Expand Down Expand Up @@ -61,7 +61,7 @@ def __call__(self, logits: torch.Tensor):
# Find the cumulative sums less than $p$.
nucleus = cum_sum_probs < self.p
# Prepend ones so that we add one token after the minimum number
# of tokens with cumulative probability less that $p$.
# of tokens with cumulative probability less than $p$.
nucleus = torch.cat([nucleus.new_ones(nucleus.shape[:-1] + (1,)), nucleus[..., :-1]], dim=-1)

# Get log probabilities and mask out the non-nucleus
Expand Down