-
Notifications
You must be signed in to change notification settings - Fork 739
[KSM][Optimization] renormalized logprobs when using keep sampling mask #6966
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -370,6 +370,27 @@ def post_process_normal( | |||||||||||||||||
| ) | ||||||||||||||||||
| # 3. Transmit the model's output and stop generation signal via message queue. | ||||||||||||||||||
| # In the future, we will abandon this approach. | ||||||||||||||||||
| # Renormalize logprobs to match truncated sampling distribution (when enabled). | ||||||||||||||||||
| if sampler_output.logprobs_tensors is not None and sampler_output.logz_per_batch is not None: | ||||||||||||||||||
| # logprobs_tensors.logprobs: [B, max_num_logprobs + 1] | ||||||||||||||||||
| logprobs = sampler_output.logprobs_tensors.logprobs | ||||||||||||||||||
| # logz_per_batch: [B], log(sum(probs in candidate set K)) for each request | ||||||||||||||||||
| logz = paddle.to_tensor(sampler_output.logz_per_batch, dtype=logprobs.dtype) | ||||||||||||||||||
| # Renormalize: log π_masked = log π_full - log Z_K | ||||||||||||||||||
| # Only normalize valid candidates; padding positions use -inf | ||||||||||||||||||
| valid_mask = paddle.isfinite(logprobs) | ||||||||||||||||||
| normalized_logprobs = paddle.where( | ||||||||||||||||||
| valid_mask, | ||||||||||||||||||
| logprobs - logz.unsqueeze(1), # broadcast subtraction | ||||||||||||||||||
| paddle.full_like(logprobs, float("-inf")), | ||||||||||||||||||
| ) | ||||||||||||||||||
| # Update logprobs_tensors with normalized values | ||||||||||||||||||
| sampler_output.logprobs_tensors = LogprobsTensors( | ||||||||||||||||||
| logprob_token_ids=sampler_output.logprobs_tensors.logprob_token_ids, | ||||||||||||||||||
| logprobs=normalized_logprobs, | ||||||||||||||||||
| selected_token_ranks=sampler_output.logprobs_tensors.selected_token_ranks, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
Comment on lines
+387
to
+392
|
||||||||||||||||||
| # Update logprobs_tensors with normalized values | |
| sampler_output.logprobs_tensors = LogprobsTensors( | |
| logprob_token_ids=sampler_output.logprobs_tensors.logprob_token_ids, | |
| logprobs=normalized_logprobs, | |
| selected_token_ranks=sampler_output.logprobs_tensors.selected_token_ranks, | |
| ) | |
| # Update existing logprobs_tensors with normalized values in place | |
| sampler_output.logprobs_tensors.logprobs = normalized_logprobs |
Copilot
AI
Mar 22, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
speculative 路径下 sampler_output.logprobs_tensors.logprobs 很可能仍在 GPU(该分支的 gather_logprobs() 没有 .cpu()),这里同样用 paddle.to_tensor(sampler_output.logz_per_batch) 默认创建的张量可能在 CPU 或默认 device,容易造成 place mismatch 或额外 D2H/H2D 拷贝。建议显式把 logz 创建/搬到 logprobs.place。
Copilot
AI
Mar 22, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
与 normal 分支相同,这里用 isfinite(logprobs) 不能区分“候选集内/外”的 token;当返回的 top-logprobs 包含候选集外 token 时,会被归一化成有限值而不是 -inf。如果目标是“匹配截断采样分布”,建议结合对应 token 的 sampling_mask 把候选集外的 logprob 置为 -inf。
Copilot
AI
Mar 23, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Speculative 路径同样存在:gather_logprobs() 返回的 topk 可能包含候选集 K 之外的 token,当前仅做 - logZ_K 会让这些 token 获得“非零”的截断分布概率。建议利用每个 accepted token 对应的 sampling_mask(flatten 列表)把不在 K 的 token logprob 置为 -inf,或在 sampler 侧直接对截断分布计算 logprobs。
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -193,6 +193,10 @@ class SamplerOutput: | |||||||||||||
| # check whether the current path is speculative or non-speculative when | ||||||||||||||
| # interpreting the dimension. | ||||||||||||||
| sampling_mask: Optional[List[np.ndarray]] = None | ||||||||||||||
| # logZ_K for each request: log(sum(probs in candidate set K)) | ||||||||||||||
| # Used for renormalizing logprobs to match the truncated sampling distribution. | ||||||||||||||
| # Shape: [num_reqs] | ||||||||||||||
|
||||||||||||||
| # Shape: [num_reqs] | |
| # Shape: | |
| # - Non-speculative decoding: [num_reqs] | |
| # - Speculative decoding: [total_accepted_tokens], aligned with the first | |
| # dimension of logprobs / sampling_mask and regrouped by request in | |
| # post-processing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的重归一化逻辑可能与现有的
top_p_normalized_logprobs选项叠加:当compute_logprobs()已经对 top_p 做过归一化(prob sum=1)时,再减去logZ_K会产生二次归一化,导致返回 logprobs 不再对应任何有效分布。建议为每个 request 传递/判断是否已做 top_p 归一化,并对这类 request 跳过该步骤,或统一在同一位置完成“top_p + top_k”截断分布的 logprobs 计算。