normalized logprobs when using keep sampling mask#6966
normalized logprobs when using keep sampling mask#6966DesmonDay wants to merge 1 commit intoPaddlePaddle:release/2.4from
Conversation
|
Thanks for your contribution! |
There was a problem hiding this comment.
Pull request overview
该 PR 旨在在启用 keep_sampling_mask(top-k/top-p 截断候选集)时,对返回的 token-level logprobs 做重归一化,使其与“截断后的采样分布”一致,并通过 SamplerOutput 透传每条样本的 logZ_K 信息。
Changes:
- 在
_compute_sampling_mask()中额外计算并返回每条样本的logZ_K(候选集概率和的 log),并在SamplerOutput中新增logz_per_batch字段透传。 - 在
post_process_normal()/post_process_specualate()中基于logZ_K对logprobs_tensors.logprobs做重归一化处理。 - 调整
_compute_sampling_mask()返回类型与调用方,以同时返回 sparse mask 与logz_per_batch。
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 7 comments.
| File | Description |
|---|---|
| fastdeploy/worker/output.py | 为 SamplerOutput 新增 logz_per_batch 字段,用于后处理阶段重归一化 logprobs。 |
| fastdeploy/model_executor/pre_and_post_process.py | 在 normal/speculative 的 post-process 中对 logprobs 做减去 logZ_K 的重归一化。 |
| fastdeploy/model_executor/layers/sample/sampler.py | 计算 sampling mask 时新增 logZ_K 计算并回传,同时更新调用链。 |
| # Z_K = sum(probs[i] * final_mask[i]) for each request i | ||
| # logZ_K = log(Z_K), with small constant to avoid log(0) | ||
| # ------------------------------------------------------------------ | ||
| candidate_probs = paddle.where( | ||
| final_mask, | ||
| probs, | ||
| paddle.zeros_like(probs) |
There was a problem hiding this comment.
final_mask 是在 sorted_probs/renorm_sorted_probs 的排序空间里构造出来的(与 sorted_indices 对齐),但这里用它去 mask 原始 probs(未排序)会导致选中的 token 概率错位,从而计算出错误的 Z_K/logZ_K,后续 logprobs 重归一化会整体偏差。建议在排序空间里计算 Z_K(例如对 sorted_probs 按 final_mask 求和),或把 final_mask scatter 回 vocab 原顺序后再与 probs 对齐。
| # Z_K = sum(probs[i] * final_mask[i]) for each request i | |
| # logZ_K = log(Z_K), with small constant to avoid log(0) | |
| # ------------------------------------------------------------------ | |
| candidate_probs = paddle.where( | |
| final_mask, | |
| probs, | |
| paddle.zeros_like(probs) | |
| # Z_K = sum(sorted_probs[i] * final_mask[i]) for each request i | |
| # logZ_K = log(Z_K), with small constant to avoid log(0) | |
| # ------------------------------------------------------------------ | |
| candidate_probs = paddle.where( | |
| final_mask, | |
| sorted_probs, | |
| paddle.zeros_like(sorted_probs) |
There was a problem hiding this comment.
这里确实有一个坑点:
- 采样 token:它在候选集 K 内,所以 logZ_K 对它是正确的,归一化后的值是正确的 ✓
- top-k 候选:在候选集K内,则 logZ_K 对它是正确的,归一化后的值是正确的 ✓
- top-k 集以外的,即使被减去 logZ_K,也没有关系,因为最终并不会返回这些值。
由于现在 keep sampling mask只需要返回一个logprobs,所以可以保证在候选集内的 logprobs 是被正确归一化的,那么其他的值就不必考虑了。
| logprobs = sampler_output.logprobs_tensors.logprobs | ||
| # logz_per_batch: [B], log(sum(probs in candidate set K)) for each request | ||
| logz = paddle.to_tensor(sampler_output.logz_per_batch, dtype=logprobs.dtype) | ||
| # Renormalize: log π_masked = log π_full - log Z_K | ||
| # Only normalize valid candidates; padding positions use -inf | ||
| valid_mask = paddle.isfinite(logprobs) | ||
| normalized_logprobs = paddle.where( | ||
| valid_mask, | ||
| logprobs - logz.unsqueeze(1), # broadcast subtraction | ||
| paddle.full_like(logprobs, float('-inf')) |
There was a problem hiding this comment.
这里 logz = paddle.to_tensor(...) 没有指定 place。当前实现里非 speculative 路径的 gather_logprobs() 会把 logprobs 搬到 CPU(.cpu()),而默认 device 往往是 GPU;这会导致 logprobs - logz 出现 place mismatch(或隐式拷贝带来额外开销)。建议用 place=logprobs.place(或显式 logz = logz.astype(logprobs.dtype).to(logprobs.place) 等等)确保与 logprobs 同设备。
| # Only normalize valid candidates; padding positions use -inf | ||
| valid_mask = paddle.isfinite(logprobs) | ||
| normalized_logprobs = paddle.where( | ||
| valid_mask, | ||
| logprobs - logz.unsqueeze(1), # broadcast subtraction | ||
| paddle.full_like(logprobs, float('-inf')) |
There was a problem hiding this comment.
当前用 paddle.isfinite(logprobs) 作为“有效候选”的判断不准确:gather_logprobs() 取的是全 vocab 的 top-k,所以当请求的 num_logprobs 大于 top_p/top_k 截断后的候选集大小时,会把候选集外 token 的 logprob 也一起减去 logZ_K,得到一个有限值;但在截断采样分布下这些 token 概率应为 0(logprob=-inf)。建议结合 sampler_output.sampling_mask(候选 token ids)把候选集外的 token 对应位置置为 -inf,或在采样阶段直接按候选集计算/截断 top-logprobs。
| # Only normalize valid candidates; padding positions use -inf | |
| valid_mask = paddle.isfinite(logprobs) | |
| normalized_logprobs = paddle.where( | |
| valid_mask, | |
| logprobs - logz.unsqueeze(1), # broadcast subtraction | |
| paddle.full_like(logprobs, float('-inf')) | |
| # Start from finite positions (exclude paddings / invalid values). | |
| valid_mask = paddle.isfinite(logprobs) | |
| # If sampling_mask is available, further restrict to candidate tokens. | |
| if getattr(sampler_output, "sampling_mask", None) is not None: | |
| # logprob_token_ids: [B, max_num_logprobs + 1] | |
| logprob_token_ids = sampler_output.logprobs_tensors.logprob_token_ids | |
| # Initialize candidate mask as all False. | |
| candidate_mask = paddle.zeros_like(logprobs, dtype="bool") | |
| # sampling_mask is List[np.ndarray], one per batch/request. | |
| for b_idx, mask_ids in enumerate(sampler_output.sampling_mask): | |
| if mask_ids is None or len(mask_ids) == 0: | |
| continue | |
| # Convert candidate ids for this batch to tensor. | |
| candidate_ids = paddle.to_tensor(mask_ids, dtype=logprob_token_ids.dtype) | |
| # token_row: [1, N], candidate_ids.unsqueeze(1): [K, 1] | |
| # After broadcasting, eq: [K, N]; any over K gives [N]. | |
| token_row = logprob_token_ids[b_idx : b_idx + 1, :] | |
| eq = token_row == candidate_ids.unsqueeze(1) | |
| row_mask = paddle.any(eq, axis=0) | |
| candidate_mask[b_idx] = row_mask | |
| # Only positions that are both finite and in the candidate set remain valid. | |
| valid_mask = paddle.logical_and(valid_mask, candidate_mask) | |
| normalized_logprobs = paddle.where( | |
| valid_mask, | |
| logprobs - logz.unsqueeze(1), # broadcast subtraction | |
| paddle.full_like(logprobs, float("-inf")), |
| # Renormalize logprobs to match truncated sampling distribution (when enabled). | ||
| if ( | ||
| sampler_output.logprobs_tensors is not None | ||
| and sampler_output.logz_per_batch is not None |
There was a problem hiding this comment.
该重归一化逻辑目前只判断 logprobs_tensors 非空即可执行,但在 gpu_model_runner 中存在 enable_logprob=True 且本批次没有任何请求需要 logprobs 时,会创建一个 CPU 上的空 LogprobsTensors(paddle.empty_like,内容未初始化)。若同时启用了 keep_sampling_mask,这里会对未初始化数据做减法,导致输出出现随机值/NaN。建议额外判断 sampling_metadata.max_num_logprobs is not None(或其它明确的“本批次确实要返回 logprobs”的标记)再进行重归一化。
| and sampler_output.logz_per_batch is not None | |
| and sampler_output.logz_per_batch is not None | |
| # Only renormalize when this batch actually requests logprobs. | |
| and getattr(getattr(sampler_output, "sampling_metadata", None), "max_num_logprobs", None) | |
| is not None |
| logprobs = sampler_output.logprobs_tensors.logprobs | ||
| logz = paddle.to_tensor(sampler_output.logz_per_batch, dtype=logprobs.dtype) | ||
| valid_mask = paddle.isfinite(logprobs) | ||
| normalized_logprobs = paddle.where( | ||
| valid_mask, | ||
| logprobs - logz.unsqueeze(1), | ||
| paddle.full_like(logprobs, float('-inf')) | ||
| ) |
There was a problem hiding this comment.
speculative 路径下 sampler_output.logprobs_tensors.logprobs 很可能仍在 GPU(该分支的 gather_logprobs() 没有 .cpu()),这里同样用 paddle.to_tensor(sampler_output.logz_per_batch) 默认创建的张量可能在 CPU 或默认 device,容易造成 place mismatch 或额外 D2H/H2D 拷贝。建议显式把 logz 创建/搬到 logprobs.place。
| valid_mask = paddle.isfinite(logprobs) | ||
| normalized_logprobs = paddle.where( | ||
| valid_mask, | ||
| logprobs - logz.unsqueeze(1), | ||
| paddle.full_like(logprobs, float('-inf')) | ||
| ) | ||
| sampler_output.logprobs_tensors = LogprobsTensors( |
There was a problem hiding this comment.
与 normal 分支相同,这里用 isfinite(logprobs) 不能区分“候选集内/外”的 token;当返回的 top-logprobs 包含候选集外 token 时,会被归一化成有限值而不是 -inf。如果目标是“匹配截断采样分布”,建议结合对应 token 的 sampling_mask 把候选集外的 logprob 置为 -inf。
| sampling_mask: Optional[List[np.ndarray]] = None | ||
| # logZ_K for each request: log(sum(probs in candidate set K)) | ||
| # Used for renormalizing logprobs to match the truncated sampling distribution. | ||
| # Shape: [num_reqs] |
There was a problem hiding this comment.
logz_per_batch 的注释写的是每个 request 一条(Shape: [num_reqs]),但在 speculative decoding 路径里 _compute_sampling_mask() 是按每个 accepted token 计算/返回的(长度为 total_accepted),并且在 post_process 中也是按 logprobs 的第一维对齐做 broadcast。建议把这里的注释补充为“non-spec: [num_reqs];spec: [total_accepted_tokens]”以避免误用。
| # Shape: [num_reqs] | |
| # Shape: | |
| # - Non-speculative decoding: [num_reqs] | |
| # - Speculative decoding: [total_accepted_tokens], aligned with the first | |
| # dimension of logprobs / sampling_mask and regrouped by request in | |
| # post-processing. |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## release/2.4 #6966 +/- ##
==============================================
Coverage ? 56.34%
==============================================
Files ? 333
Lines ? 42524
Branches ? 6475
==============================================
Hits ? 23959
Misses ? 16685
Partials ? 1880
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Motivation
Modifications
Usage or Command
Accuracy Tests
Checklist
[FDConfig],[APIServer],[Engine],[Scheduler],[PD Disaggregation],[Executor],[Graph Optimization],[Speculative Decoding],[RL],[Models],[Quantization],[Loader],[OP],[KVCache],[DataProcessor],[BugFix],[Docs],[CI],[Optimization],[Feature],[Benchmark],[Others],[XPU],[HPU],[GCU],[DCU],[Iluvatar],[Metax]]pre-commitbefore commit.releasebranch, make sure the PR has been submitted to thedevelopbranch, then cherry-pick it to thereleasebranch with the[Cherry-Pick]PR tag.