Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions fastdeploy/model_executor/layers/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _compute_sampling_mask(
top_p: paddle.Tensor,
top_k: Optional[paddle.Tensor] = None,
top_k_list: Optional[list] = None,
) -> List[np.ndarray]:
) -> tuple[List[np.ndarray], np.ndarray]:
"""
Compute a combined top-k + top-p (nucleus) sampling mask as sparse
retained-token indices.
Expand All @@ -124,8 +124,11 @@ def _compute_sampling_mask(
top-k filtering is needed at all.

Returns:
List of length num_reqs; element i is a 1-D int64 numpy array of the
retained vocab indices for request i.
Tuple of (sparse_indices, logz_per_batch):
- sparse_indices: List of length num_reqs; element i is a 1-D int64
numpy array of the retained vocab indices for request i.
- logz_per_batch: 1-D numpy array of shape [num_reqs] containing
log(Z_K) where Z_K is the sum of probabilities in the candidate set.
"""
real_bsz = probs.shape[0]
vocab_size = probs.shape[1]
Expand Down Expand Up @@ -179,11 +182,21 @@ def _compute_sampling_mask(
k_per_row = final_mask.astype("int32").sum(axis=-1) # [B]
max_k = int(k_per_row.max().item())

# ------------------------------------------------------------------
# Stage 5: compute logZ_K for renormalization
# Z_K = sum(probs[i] * final_mask[i]) for each request i
# logZ_K = log(Z_K), with small constant to avoid log(0)
# ------------------------------------------------------------------
candidate_probs = paddle.where(final_mask, sorted_probs, paddle.zeros_like(sorted_probs))
z_k = candidate_probs.sum(axis=-1) # [B]
logz_per_batch = paddle.log(z_k + 1e-10).cpu().numpy() # [B]

# Transfer only the leading max_k columns — typically max_k << vocab_size.
indices_window_cpu = sorted_indices[:, :max_k].cpu().numpy() # [B, max_k]
mask_window_cpu = final_mask[:, :max_k].cpu().numpy() # [B, max_k]

return [indices_window_cpu[i, mask_window_cpu[i]] for i in range(real_bsz)]
sparse_indices = [indices_window_cpu[i, mask_window_cpu[i]] for i in range(real_bsz)]
return sparse_indices, logz_per_batch


class GuidedDecoding:
Expand Down Expand Up @@ -622,8 +635,9 @@ def forward_cuda(
# Compute sampling mask BEFORE top_k_top_p_sampling modifies probs.
# Binary mask [num_reqs, vocab_size]: 1 = retained by top_k/top_p, 0 = truncated.
sampling_mask = None
logz_per_batch = None
if sampling_metadata.keep_sampling_mask:
sampling_mask = _compute_sampling_mask(
sampling_mask, logz_per_batch = _compute_sampling_mask(
probs,
sampling_metadata.top_p,
top_k=sampling_metadata.top_k,
Expand Down Expand Up @@ -654,6 +668,7 @@ def forward_cuda(
logprobs_tensors=logprobs_tensors,
logits=logits,
sampling_mask=sampling_mask,
logz_per_batch=logz_per_batch,
)

return sampler_output
Expand Down Expand Up @@ -970,6 +985,7 @@ def forward_cuda(
# Compute sampling mask at accepted token positions.
# Shape: [total_accepted_tokens, vocab_size], bool (CPU).
sampling_mask = None
logz_per_batch = None
if keep_sampling_mask:
# Expand top_p from [batch, 1] to [total_accepted, 1].
accept_top_p = sampling_metadata.top_p[:real_bsz].squeeze(1).repeat_interleave(accept_nums).unsqueeze(1)
Expand All @@ -982,7 +998,7 @@ def forward_cuda(
accept_top_k = (
sampling_metadata.top_k[:real_bsz].squeeze(1).repeat_interleave(accept_nums).unsqueeze(1)
)
sampling_mask = _compute_sampling_mask(
sampling_mask, logz_per_batch = _compute_sampling_mask(
target_probs,
accept_top_p,
top_k=accept_top_k,
Expand All @@ -996,6 +1012,7 @@ def forward_cuda(
cu_batch_token_offset=share_inputs["cu_batch_token_offset"],
logits=logits,
sampling_mask=sampling_mask,
logz_per_batch=logz_per_batch,
)

return sampler_output
Expand Down
34 changes: 34 additions & 0 deletions fastdeploy/model_executor/pre_and_post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,27 @@ def post_process_normal(
)
# 3. Transmit the model's output and stop generation signal via message queue.
# In the future, we will abandon this approach.
# Renormalize logprobs to match truncated sampling distribution (when enabled).
if sampler_output.logprobs_tensors is not None and sampler_output.logz_per_batch is not None:
# logprobs_tensors.logprobs: [B, max_num_logprobs + 1]
logprobs = sampler_output.logprobs_tensors.logprobs
# logz_per_batch: [B], log(sum(probs in candidate set K)) for each request
logz = paddle.to_tensor(sampler_output.logz_per_batch, dtype=logprobs.dtype)
# Renormalize: log π_masked = log π_full - log Z_K
# Only normalize valid candidates; padding positions use -inf
valid_mask = paddle.isfinite(logprobs)
normalized_logprobs = paddle.where(
valid_mask,
logprobs - logz.unsqueeze(1), # broadcast subtraction
paddle.full_like(logprobs, float("-inf")),
Comment on lines +373 to +385
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的重归一化逻辑可能与现有的 top_p_normalized_logprobs 选项叠加:当 compute_logprobs() 已经对 top_p 做过归一化(prob sum=1)时,再减去 logZ_K 会产生二次归一化,导致返回 logprobs 不再对应任何有效分布。建议为每个 request 传递/判断是否已做 top_p 归一化,并对这类 request 跳过该步骤,或统一在同一位置完成“top_p + top_k”截断分布的 logprobs 计算。

Copilot uses AI. Check for mistakes.
)
# Update logprobs_tensors with normalized values
sampler_output.logprobs_tensors = LogprobsTensors(
logprob_token_ids=sampler_output.logprobs_tensors.logprob_token_ids,
logprobs=normalized_logprobs,
selected_token_ranks=sampler_output.logprobs_tensors.selected_token_ranks,
)
Comment on lines +387 to +392
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议补充用例覆盖 enable_keep_sampling_mask + logprobs/top_logprobs 场景下的 logprobs 归一化正确性(例如 top_logprobs=-1 返回全量时,验证 sampling_mask 内 exp(logprob) 求和≈1 且 sampling_mask 外为 0/-inf;同时覆盖 speculative 路径)。目前仓库已有 keep_sampling_mask 的 e2e 测试,但没有覆盖本次新增的 logprobs 重归一化行为,容易回归。

Suggested change
# Update logprobs_tensors with normalized values
sampler_output.logprobs_tensors = LogprobsTensors(
logprob_token_ids=sampler_output.logprobs_tensors.logprob_token_ids,
logprobs=normalized_logprobs,
selected_token_ranks=sampler_output.logprobs_tensors.selected_token_ranks,
)
# Update existing logprobs_tensors with normalized values in place
sampler_output.logprobs_tensors.logprobs = normalized_logprobs

Copilot uses AI. Check for mistakes.

if not skip_save_output:
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
if save_each_rank or model_output.mp_rank == 0:
Expand Down Expand Up @@ -483,6 +504,19 @@ def post_process_specualate(
model_output.stop_nums,
model_output.mask_rollback,
)
# Renormalize logprobs to match truncated sampling distribution (when enabled).
if sampler_output.logprobs_tensors is not None and sampler_output.logz_per_batch is not None:
logprobs = sampler_output.logprobs_tensors.logprobs
logz = paddle.to_tensor(sampler_output.logz_per_batch, dtype=logprobs.dtype)
valid_mask = paddle.isfinite(logprobs)
normalized_logprobs = paddle.where(
valid_mask, logprobs - logz.unsqueeze(1), paddle.full_like(logprobs, float("-inf"))
)
Comment on lines +509 to +514
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

speculative 路径下 sampler_output.logprobs_tensors.logprobs 很可能仍在 GPU(该分支的 gather_logprobs() 没有 .cpu()),这里同样用 paddle.to_tensor(sampler_output.logz_per_batch) 默认创建的张量可能在 CPU 或默认 device,容易造成 place mismatch 或额外 D2H/H2D 拷贝。建议显式把 logz 创建/搬到 logprobs.place

Copilot uses AI. Check for mistakes.
sampler_output.logprobs_tensors = LogprobsTensors(
Comment on lines +511 to +515
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

与 normal 分支相同,这里用 isfinite(logprobs) 不能区分“候选集内/外”的 token;当返回的 top-logprobs 包含候选集外 token 时,会被归一化成有限值而不是 -inf。如果目标是“匹配截断采样分布”,建议结合对应 token 的 sampling_mask 把候选集外的 logprob 置为 -inf

Copilot uses AI. Check for mistakes.
logprob_token_ids=sampler_output.logprobs_tensors.logprob_token_ids,
logprobs=normalized_logprobs,
selected_token_ranks=sampler_output.logprobs_tensors.selected_token_ranks,
)
Comment on lines +507 to +519
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Speculative 路径同样存在:gather_logprobs() 返回的 topk 可能包含候选集 K 之外的 token,当前仅做 - logZ_K 会让这些 token 获得“非零”的截断分布概率。建议利用每个 accepted token 对应的 sampling_mask(flatten 列表)把不在 K 的 token logprob 置为 -inf,或在 sampler 侧直接对截断分布计算 logprobs。

Copilot uses AI. Check for mistakes.

if not skip_save_output:
if sampler_output.logprobs_tensors is None:
Expand Down
4 changes: 4 additions & 0 deletions fastdeploy/worker/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,10 @@ class SamplerOutput:
# check whether the current path is speculative or non-speculative when
# interpreting the dimension.
sampling_mask: Optional[List[np.ndarray]] = None
# logZ_K for each request: log(sum(probs in candidate set K))
# Used for renormalizing logprobs to match the truncated sampling distribution.
# Shape: [num_reqs]
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logz_per_batch 的注释写的是每个 request 一条(Shape: [num_reqs]),但在 speculative decoding 路径里 _compute_sampling_mask() 是按每个 accepted token 计算/返回的(长度为 total_accepted),并且在 post_process 中也是按 logprobs 的第一维对齐做 broadcast。建议把这里的注释补充为“non-spec: [num_reqs];spec: [total_accepted_tokens]”以避免误用。

Suggested change
# Shape: [num_reqs]
# Shape:
# - Non-speculative decoding: [num_reqs]
# - Speculative decoding: [total_accepted_tokens], aligned with the first
# dimension of logprobs / sampling_mask and regrouped by request in
# post-processing.

Copilot uses AI. Check for mistakes.
logz_per_batch: Optional[np.ndarray] = None


@dataclass
Expand Down
Loading