diff --git a/labml_nn/sampling/nucleus.py b/labml_nn/sampling/nucleus.py index 6de9c719e..cab29560d 100644 --- a/labml_nn/sampling/nucleus.py +++ b/labml_nn/sampling/nucleus.py @@ -19,7 +19,7 @@ $$\sum_{x_i \in V^{(p)}} P(x_i | x_{1:i-1}) \ge p$$ -That is, we pick the highest probable tokens until the sum of their probabilities is less that $p$. +That is, we pick the highest probable tokens until the sum of their probabilities is at least $p$. Then we sample from the selected tokens. @@ -61,7 +61,7 @@ def __call__(self, logits: torch.Tensor): # Find the cumulative sums less than $p$. nucleus = cum_sum_probs < self.p # Prepend ones so that we add one token after the minimum number - # of tokens with cumulative probability less that $p$. + # of tokens with cumulative probability less than $p$. nucleus = torch.cat([nucleus.new_ones(nucleus.shape[:-1] + (1,)), nucleus[..., :-1]], dim=-1) # Get log probabilities and mask out the non-nucleus