From dede2bf0ba2f2be24cca79e1e05b21d22f32b46f Mon Sep 17 00:00:00 2001 From: Taksh Date: Mon, 6 Apr 2026 05:51:09 +0530 Subject: [PATCH] Fix docstring in nucleus sampling - Line 22: "is less that $p$" -> "is at least $p$" (fixes both typo and semantics) - Line 64: "less that" -> "less than" (fixes typo) Co-Authored-By: Claude Opus 4.6 (1M context) --- labml_nn/sampling/nucleus.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/labml_nn/sampling/nucleus.py b/labml_nn/sampling/nucleus.py index 6de9c719e..cab29560d 100644 --- a/labml_nn/sampling/nucleus.py +++ b/labml_nn/sampling/nucleus.py @@ -19,7 +19,7 @@ $$\sum_{x_i \in V^{(p)}} P(x_i | x_{1:i-1}) \ge p$$ -That is, we pick the highest probable tokens until the sum of their probabilities is less that $p$. +That is, we pick the highest probable tokens until the sum of their probabilities is at least $p$. Then we sample from the selected tokens. @@ -61,7 +61,7 @@ def __call__(self, logits: torch.Tensor): # Find the cumulative sums less than $p$. nucleus = cum_sum_probs < self.p # Prepend ones so that we add one token after the minimum number - # of tokens with cumulative probability less that $p$. + # of tokens with cumulative probability less than $p$. nucleus = torch.cat([nucleus.new_ones(nucleus.shape[:-1] + (1,)), nucleus[..., :-1]], dim=-1) # Get log probabilities and mask out the non-nucleus