Skip to content

normalized logprobs when using keep sampling mask#6966

Open
DesmonDay wants to merge 1 commit intoPaddlePaddle:release/2.4from
DesmonDay:release/2.4
Open

normalized logprobs when using keep sampling mask#6966
DesmonDay wants to merge 1 commit intoPaddlePaddle:release/2.4from
DesmonDay:release/2.4

Conversation

@DesmonDay
Copy link

Motivation

💡 If this PR is a Cherry Pick, the PR title needs to follow the format by adding the [Cherry-Pick] label at the very beginning and appending the original PR ID at the end. For example, [Cherry-Pick][CI] Add check trigger and logic(#5191)

💡 如若此PR是Cherry Pick,PR标题需遵循格式,在最开始加上[Cherry-Pick]标签,以及最后面加上原PR ID,例如[Cherry-Pick][CI] Add check trigger and logic(#5191)

Modifications

Usage or Command

Accuracy Tests

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

Copilot AI review requested due to automatic review settings March 22, 2026 15:47
@paddle-bot
Copy link

paddle-bot bot commented Mar 22, 2026

Thanks for your contribution!

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 旨在在启用 keep_sampling_mask(top-k/top-p 截断候选集)时,对返回的 token-level logprobs 做重归一化,使其与“截断后的采样分布”一致,并通过 SamplerOutput 透传每条样本的 logZ_K 信息。

Changes:

  • _compute_sampling_mask() 中额外计算并返回每条样本的 logZ_K(候选集概率和的 log),并在 SamplerOutput 中新增 logz_per_batch 字段透传。
  • post_process_normal() / post_process_specualate() 中基于 logZ_Klogprobs_tensors.logprobs 做重归一化处理。
  • 调整 _compute_sampling_mask() 返回类型与调用方,以同时返回 sparse mask 与 logz_per_batch

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 7 comments.

File Description
fastdeploy/worker/output.py SamplerOutput 新增 logz_per_batch 字段,用于后处理阶段重归一化 logprobs。
fastdeploy/model_executor/pre_and_post_process.py 在 normal/speculative 的 post-process 中对 logprobs 做减去 logZ_K 的重归一化。
fastdeploy/model_executor/layers/sample/sampler.py 计算 sampling mask 时新增 logZ_K 计算并回传,同时更新调用链。

Comment on lines +187 to +193
# Z_K = sum(probs[i] * final_mask[i]) for each request i
# logZ_K = log(Z_K), with small constant to avoid log(0)
# ------------------------------------------------------------------
candidate_probs = paddle.where(
final_mask,
probs,
paddle.zeros_like(probs)
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

final_mask 是在 sorted_probs/renorm_sorted_probs 的排序空间里构造出来的(与 sorted_indices 对齐),但这里用它去 mask 原始 probs(未排序)会导致选中的 token 概率错位,从而计算出错误的 Z_K/logZ_K,后续 logprobs 重归一化会整体偏差。建议在排序空间里计算 Z_K(例如对 sorted_probsfinal_mask 求和),或把 final_mask scatter 回 vocab 原顺序后再与 probs 对齐。

Suggested change
# Z_K = sum(probs[i] * final_mask[i]) for each request i
# logZ_K = log(Z_K), with small constant to avoid log(0)
# ------------------------------------------------------------------
candidate_probs = paddle.where(
final_mask,
probs,
paddle.zeros_like(probs)
# Z_K = sum(sorted_probs[i] * final_mask[i]) for each request i
# logZ_K = log(Z_K), with small constant to avoid log(0)
# ------------------------------------------------------------------
candidate_probs = paddle.where(
final_mask,
sorted_probs,
paddle.zeros_like(sorted_probs)

Copilot uses AI. Check for mistakes.
Copy link
Author

@DesmonDay DesmonDay Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里确实有一个坑点:

  1. 采样 token:它在候选集 K 内,所以 logZ_K 对它是正确的,归一化后的值是正确的 ✓
  2. top-k 候选:在候选集K内,则 logZ_K 对它是正确的,归一化后的值是正确的 ✓
  3. top-k 集以外的,即使被减去 logZ_K,也没有关系,因为最终并不会返回这些值。

由于现在 keep sampling mask只需要返回一个logprobs,所以可以保证在候选集内的 logprobs 是被正确归一化的,那么其他的值就不必考虑了。

Comment on lines +379 to +388
logprobs = sampler_output.logprobs_tensors.logprobs
# logz_per_batch: [B], log(sum(probs in candidate set K)) for each request
logz = paddle.to_tensor(sampler_output.logz_per_batch, dtype=logprobs.dtype)
# Renormalize: log π_masked = log π_full - log Z_K
# Only normalize valid candidates; padding positions use -inf
valid_mask = paddle.isfinite(logprobs)
normalized_logprobs = paddle.where(
valid_mask,
logprobs - logz.unsqueeze(1), # broadcast subtraction
paddle.full_like(logprobs, float('-inf'))
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里 logz = paddle.to_tensor(...) 没有指定 place。当前实现里非 speculative 路径的 gather_logprobs() 会把 logprobs 搬到 CPU(.cpu()),而默认 device 往往是 GPU;这会导致 logprobs - logz 出现 place mismatch(或隐式拷贝带来额外开销)。建议用 place=logprobs.place(或显式 logz = logz.astype(logprobs.dtype).to(logprobs.place) 等等)确保与 logprobs 同设备。

Copilot uses AI. Check for mistakes.
Comment on lines +383 to +388
# Only normalize valid candidates; padding positions use -inf
valid_mask = paddle.isfinite(logprobs)
normalized_logprobs = paddle.where(
valid_mask,
logprobs - logz.unsqueeze(1), # broadcast subtraction
paddle.full_like(logprobs, float('-inf'))
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当前用 paddle.isfinite(logprobs) 作为“有效候选”的判断不准确:gather_logprobs() 取的是全 vocab 的 top-k,所以当请求的 num_logprobs 大于 top_p/top_k 截断后的候选集大小时,会把候选集外 token 的 logprob 也一起减去 logZ_K,得到一个有限值;但在截断采样分布下这些 token 概率应为 0(logprob=-inf)。建议结合 sampler_output.sampling_mask(候选 token ids)把候选集外的 token 对应位置置为 -inf,或在采样阶段直接按候选集计算/截断 top-logprobs。

Suggested change
# Only normalize valid candidates; padding positions use -inf
valid_mask = paddle.isfinite(logprobs)
normalized_logprobs = paddle.where(
valid_mask,
logprobs - logz.unsqueeze(1), # broadcast subtraction
paddle.full_like(logprobs, float('-inf'))
# Start from finite positions (exclude paddings / invalid values).
valid_mask = paddle.isfinite(logprobs)
# If sampling_mask is available, further restrict to candidate tokens.
if getattr(sampler_output, "sampling_mask", None) is not None:
# logprob_token_ids: [B, max_num_logprobs + 1]
logprob_token_ids = sampler_output.logprobs_tensors.logprob_token_ids
# Initialize candidate mask as all False.
candidate_mask = paddle.zeros_like(logprobs, dtype="bool")
# sampling_mask is List[np.ndarray], one per batch/request.
for b_idx, mask_ids in enumerate(sampler_output.sampling_mask):
if mask_ids is None or len(mask_ids) == 0:
continue
# Convert candidate ids for this batch to tensor.
candidate_ids = paddle.to_tensor(mask_ids, dtype=logprob_token_ids.dtype)
# token_row: [1, N], candidate_ids.unsqueeze(1): [K, 1]
# After broadcasting, eq: [K, N]; any over K gives [N].
token_row = logprob_token_ids[b_idx : b_idx + 1, :]
eq = token_row == candidate_ids.unsqueeze(1)
row_mask = paddle.any(eq, axis=0)
candidate_mask[b_idx] = row_mask
# Only positions that are both finite and in the candidate set remain valid.
valid_mask = paddle.logical_and(valid_mask, candidate_mask)
normalized_logprobs = paddle.where(
valid_mask,
logprobs - logz.unsqueeze(1), # broadcast subtraction
paddle.full_like(logprobs, float("-inf")),

Copilot uses AI. Check for mistakes.
# Renormalize logprobs to match truncated sampling distribution (when enabled).
if (
sampler_output.logprobs_tensors is not None
and sampler_output.logz_per_batch is not None
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该重归一化逻辑目前只判断 logprobs_tensors 非空即可执行,但在 gpu_model_runner 中存在 enable_logprob=True 且本批次没有任何请求需要 logprobs 时,会创建一个 CPU 上的空 LogprobsTensorspaddle.empty_like,内容未初始化)。若同时启用了 keep_sampling_mask,这里会对未初始化数据做减法,导致输出出现随机值/NaN。建议额外判断 sampling_metadata.max_num_logprobs is not None(或其它明确的“本批次确实要返回 logprobs”的标记)再进行重归一化。

Suggested change
and sampler_output.logz_per_batch is not None
and sampler_output.logz_per_batch is not None
# Only renormalize when this batch actually requests logprobs.
and getattr(getattr(sampler_output, "sampling_metadata", None), "max_num_logprobs", None)
is not None

Copilot uses AI. Check for mistakes.
Comment on lines +515 to +522
logprobs = sampler_output.logprobs_tensors.logprobs
logz = paddle.to_tensor(sampler_output.logz_per_batch, dtype=logprobs.dtype)
valid_mask = paddle.isfinite(logprobs)
normalized_logprobs = paddle.where(
valid_mask,
logprobs - logz.unsqueeze(1),
paddle.full_like(logprobs, float('-inf'))
)
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

speculative 路径下 sampler_output.logprobs_tensors.logprobs 很可能仍在 GPU(该分支的 gather_logprobs() 没有 .cpu()),这里同样用 paddle.to_tensor(sampler_output.logz_per_batch) 默认创建的张量可能在 CPU 或默认 device,容易造成 place mismatch 或额外 D2H/H2D 拷贝。建议显式把 logz 创建/搬到 logprobs.place

Copilot uses AI. Check for mistakes.
Comment on lines +517 to +523
valid_mask = paddle.isfinite(logprobs)
normalized_logprobs = paddle.where(
valid_mask,
logprobs - logz.unsqueeze(1),
paddle.full_like(logprobs, float('-inf'))
)
sampler_output.logprobs_tensors = LogprobsTensors(
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

与 normal 分支相同,这里用 isfinite(logprobs) 不能区分“候选集内/外”的 token;当返回的 top-logprobs 包含候选集外 token 时,会被归一化成有限值而不是 -inf。如果目标是“匹配截断采样分布”,建议结合对应 token 的 sampling_mask 把候选集外的 logprob 置为 -inf

Copilot uses AI. Check for mistakes.
sampling_mask: Optional[List[np.ndarray]] = None
# logZ_K for each request: log(sum(probs in candidate set K))
# Used for renormalizing logprobs to match the truncated sampling distribution.
# Shape: [num_reqs]
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logz_per_batch 的注释写的是每个 request 一条(Shape: [num_reqs]),但在 speculative decoding 路径里 _compute_sampling_mask() 是按每个 accepted token 计算/返回的(长度为 total_accepted),并且在 post_process 中也是按 logprobs 的第一维对齐做 broadcast。建议把这里的注释补充为“non-spec: [num_reqs];spec: [total_accepted_tokens]”以避免误用。

Suggested change
# Shape: [num_reqs]
# Shape:
# - Non-speculative decoding: [num_reqs]
# - Speculative decoding: [total_accepted_tokens], aligned with the first
# dimension of logprobs / sampling_mask and regrouped by request in
# post-processing.

Copilot uses AI. Check for mistakes.
@codecov-commenter
Copy link

Codecov Report

❌ Patch coverage is 40.90909% with 13 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (release/2.4@7ec091f). Learn more about missing BASE report.

Files with missing lines Patch % Lines
fastdeploy/model_executor/pre_and_post_process.py 0.00% 10 Missing and 2 partials ⚠️
fastdeploy/model_executor/layers/sample/sampler.py 88.88% 1 Missing ⚠️
Additional details and impacted files
@@              Coverage Diff               @@
##             release/2.4    #6966   +/-   ##
==============================================
  Coverage               ?   56.34%           
==============================================
  Files                  ?      333           
  Lines                  ?    42524           
  Branches               ?     6475           
==============================================
  Hits                   ?    23959           
  Misses                 ?    16685           
  Partials               ?     1880           
Flag Coverage Δ
GPU 56.34% <40.90%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants