From ea5ae2734b2799c856437a5f7986a28983a0e8de Mon Sep 17 00:00:00 2001 From: cmcamdy Date: Fri, 20 Mar 2026 03:10:37 +0000 Subject: [PATCH 01/12] [XPU] add verify draft tokens --- .../xpu_ops/src/ops/mtp/verify_draft_token.cc | 205 +++++ custom_ops/xpu_ops/src/ops/pybind/pybind.cc | 56 ++ .../xpu_ops/src/plugin/include/xpu/plugin.h | 39 + .../mtp_kernel/verify_draft_tokens.xpu | 349 ++++++++ .../wrapper/mtp_wrapper/speculate_verify.cpp | 3 - .../mtp_wrapper/verify_draft_tokens.cpp | 613 ++++++++++++++ .../xpu_ops/test/test_verify_draft_tokens.py | 771 ++++++++++++++++++ 7 files changed, 2033 insertions(+), 3 deletions(-) create mode 100644 custom_ops/xpu_ops/src/ops/mtp/verify_draft_token.cc create mode 100644 custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/verify_draft_tokens.xpu create mode 100644 custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/verify_draft_tokens.cpp create mode 100644 custom_ops/xpu_ops/test/test_verify_draft_tokens.py diff --git a/custom_ops/xpu_ops/src/ops/mtp/verify_draft_token.cc b/custom_ops/xpu_ops/src/ops/mtp/verify_draft_token.cc new file mode 100644 index 00000000000..fa685e30a68 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp/verify_draft_token.cc @@ -0,0 +1,205 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Verification kernel — outputs step_output_ids + step_output_len, +// and performs EOS / max_dec_len detection (read-only on step_idx). +// step_idx is NOT modified here; all state updates (including step_idx) +// are handled by unified_update_model_status. +// +// Verification strategies: +// 0 = TOPP : draft token in top-p candidate set (+ verify_window +// fallback) 1 = GREEDY : draft token == top-1 token (strict argmax +// match) 2 = TARGET_MATCH : draft token == target model's sampled token + +#include +#include "paddle/extension.h" +#include "xpu/plugin.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +namespace api = baidu::xpu::api; + +// ============================================================ +// Host function +// ============================================================ +void VerifyDraftTokens( + // Core I/O + const paddle::Tensor &step_output_ids, + const paddle::Tensor &step_output_len, + const paddle::Tensor &step_input_ids, + // Target model outputs (optional, required for TARGET_MATCH) + const paddle::optional &target_tokens, + // Candidate set (optional, required for TOPP/GREEDY) + const paddle::optional &candidate_ids, + const paddle::optional &candidate_scores, + const paddle::optional &candidate_lens, + // Sampling params + const paddle::Tensor &topp, + // Metadata + const paddle::Tensor &stop_flags, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &end_tokens, + const paddle::Tensor &is_block_step, + const paddle::Tensor &cu_seqlens_q_output, + const paddle::Tensor &reasoning_status, + // max_dec_len / step_idx for EOS/max-len detection + const paddle::Tensor &max_dec_len, + const paddle::Tensor &step_idx, + int max_seq_len, + int verify_window, + int verify_strategy, + bool reject_all, + bool accept_all) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + api::Context *ctx = + static_cast(dev_ctx)->x_context(); + bool xpu_ctx_flag = true; + if (step_output_ids.is_cpu()) { + ctx = new api::Context(api::kCPU); + xpu_ctx_flag = false; + } + + auto bsz = step_output_ids.shape()[0]; + auto real_bsz = seq_lens_this_time.shape()[0]; + auto max_step_tokens = step_input_ids.shape()[1]; + auto end_length = end_tokens.shape()[0]; + // max_candidate_len: 1 if candidate_ids not provided, else from shape + int max_candidate_len = candidate_ids ? candidate_ids->shape()[1] : 1; + + // curand state: only needed for TOPP(0) strategy (stochastic sampling) + int random_seed = 0; + std::vector infer_seed(bsz, random_seed); + std::uniform_real_distribution dist(0.0, 1.0); + std::vector dev_curand_states_cpu; + for (int i = 0; i < bsz; i++) { + std::mt19937_64 engine(infer_seed[i]); + dev_curand_states_cpu.push_back(dist(engine)); + } + float *dev_curand_states_xpu; + if (xpu_ctx_flag) { + xpu::ctx_guard RAII_GUARD(ctx); + dev_curand_states_xpu = + RAII_GUARD.alloc(dev_curand_states_cpu.size()); + xpu_memcpy(dev_curand_states_xpu, + dev_curand_states_cpu.data(), + dev_curand_states_cpu.size() * sizeof(float), + XPUMemcpyKind::XPU_HOST_TO_DEVICE); + } + + // Get data pointers (nullptr if optional not provided) + const int64_t *target_tokens_ptr = + target_tokens ? target_tokens->data() : nullptr; + const int64_t *candidate_ids_ptr = + candidate_ids ? candidate_ids->data() : nullptr; + const float *candidate_scores_ptr = + candidate_scores ? candidate_scores->data() : nullptr; + const int *candidate_lens_ptr = + candidate_lens ? candidate_lens->data() : nullptr; + + // Validate parameters based on verify_strategy. + // Note: empty_input_forward may lead to empty optional tensors — only + // validate when bsz > 0 (i.e. there are active sequences). + if (bsz > 0) { + if (verify_strategy == 0 /* TOPP */) { + if (!candidate_ids_ptr || !candidate_scores_ptr || !candidate_lens_ptr) { + PD_THROW( + "verify_strategy=TOPP (0) requires candidate_ids, " + "candidate_scores, candidate_lens"); + } + } else if (verify_strategy == 1 /* GREEDY */) { + if (!target_tokens_ptr) { + PD_THROW("verify_strategy=GREEDY (1) requires target_tokens (argmax)"); + } + } else if (verify_strategy == 2 /* TARGET_MATCH */) { + if (!target_tokens_ptr) { + PD_THROW( + "verify_strategy=TARGET_MATCH (2) requires target_tokens " + "(sampled)"); + } + } + } + int ret = fastdeploy::plugin::verify_draft_tokens( + ctx, + // Core I/O + const_cast(step_output_ids.data()), + const_cast(step_output_len.data()), + step_input_ids.data(), + // Target model outputs + target_tokens_ptr, + // Candidate set + candidate_ids_ptr, + candidate_scores_ptr, + candidate_lens_ptr, + // Sampling params + dev_curand_states_xpu, + topp.data(), + // Metadata + stop_flags.data(), + seq_lens_encoder.data(), + seq_lens_this_time.data(), + end_tokens.data(), + is_block_step.data(), + cu_seqlens_q_output.data(), + reasoning_status.data(), + // max_dec_len / step_idx + max_dec_len.data(), + step_idx.data(), + // Dimensions and config + bsz, // max_bsz + real_bsz, // real_bsz + max_step_tokens, + end_length, + max_seq_len, + max_candidate_len, + verify_window, + verify_strategy, + reject_all, + accept_all); + PD_CHECK(ret == 0, "verify_draft_tokens failed."); + if (step_output_ids.is_cpu()) { + delete ctx; + } +} + +PD_BUILD_STATIC_OP(verify_draft_tokens) + .Inputs({"step_output_ids", + "step_output_len", + "step_input_ids", + paddle::Optional("target_tokens"), + paddle::Optional("candidate_ids"), + paddle::Optional("candidate_scores"), + paddle::Optional("candidate_lens"), + "topp", + "stop_flags", + "seq_lens_encoder", + "seq_lens_this_time", + "end_tokens", + "is_block_step", + "cu_seqlens_q_output", + "reasoning_status", + "max_dec_len", + "step_idx"}) + .Outputs({"step_output_ids_out", "step_output_len_out"}) + .Attrs({"max_seq_len: int", + "verify_window: int", + "verify_strategy: int", + "reject_all: bool", + "accept_all: bool"}) + .SetInplaceMap({{"step_output_ids", "step_output_ids_out"}, + {"step_output_len", "step_output_len_out"}}) + .SetKernelFn(PD_KERNEL(VerifyDraftTokens)); diff --git a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc index 14468ddda48..058bf3c4d88 100644 --- a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc +++ b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc @@ -678,6 +678,36 @@ std::vector WeightQuantize(const paddle::Tensor& x, const int32_t arch, const int32_t group_size); +void VerifyDraftTokens( + // Core I/O + const paddle::Tensor& step_output_ids, + const paddle::Tensor& step_output_len, + const paddle::Tensor& step_input_ids, + // Target model outputs (optional, required for TARGET_MATCH) + const paddle::optional& target_tokens, + // Candidate set (optional, required for TOPP/GREEDY) + const paddle::optional& candidate_ids, + const paddle::optional& candidate_scores, + const paddle::optional& candidate_lens, + // Sampling params + const paddle::Tensor& topp, + // Metadata + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& end_tokens, + const paddle::Tensor& is_block_step, + const paddle::Tensor& cu_seqlens_q_output, + const paddle::Tensor& reasoning_status, + // max_dec_len / step_idx for EOS/max-len detection + const paddle::Tensor& max_dec_len, + const paddle::Tensor& step_idx, + int max_seq_len, + int verify_window, + int verify_strategy, + bool reject_all, + bool accept_all); + PYBIND11_MODULE(fastdeploy_ops, m) { m.def("adjust_batch", &AdjustBatch, @@ -1193,6 +1223,32 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("accept_all_drafts"), "Perform speculative verification for decoding"); + m.def("verify_draft_tokens", + &VerifyDraftTokens, + py::arg("step_output_ids"), + py::arg("step_output_len"), + py::arg("step_input_ids"), + py::arg("target_tokens"), + py::arg("candidate_ids"), + py::arg("candidate_scores"), + py::arg("candidate_lens"), + py::arg("topp"), + py::arg("stop_flags"), + py::arg("seq_lens_encoder"), + py::arg("seq_lens_this_time"), + py::arg("end_tokens"), + py::arg("is_block_step"), + py::arg("cu_seqlens_q_output"), + py::arg("reasoning_status"), + py::arg("max_dec_len"), + py::arg("step_idx"), + py::arg("max_seq_len"), + py::arg("verify_window"), + py::arg("verify_strategy"), + py::arg("reject_all"), + py::arg("accept_all"), + "Perform speculative verification for decoding v2"); + m.def("speculate_save_output", &SpeculateSaveWithOutputMsgStatic, py::arg("accept_tokens"), diff --git a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h index 1cbd7a8029b..cb32d555af1 100644 --- a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h +++ b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h @@ -724,6 +724,45 @@ DLL_EXPORT int speculate_limit_thinking_content_length_kernel( const int eos_token_id_len, const int inject_len, const bool splitwise_role_is_decode); + +DLL_EXPORT int verify_draft_tokens( + api::Context* ctx, + // Core I/O + int64_t* step_output_ids, + int* step_output_len, + const int64_t* step_input_ids, // draft tokens + // Target model outputs (strategy-dependent interpretation) + const int64_t* + target_tokens, // GREEDY:argmax, TARGET_MATCH:sampled, TOPP:unused + // Candidate set for TOPP/GREEDY (TARGET_MATCH: unused) + const int64_t* candidate_ids, + const float* candidate_scores, + const int* candidate_lens, + // Sampling params + const float* curand_states, // nullptr for GREEDY/TARGET_MATCH + const float* topp, + // Metadata + const bool* stop_flags, + const int* seq_lens_encoder, + const int* seq_lens_this_time, + const int64_t* end_tokens, + const bool* is_block_step, + const int* cu_seqlens_q_output, + const int* reasoning_status, + // max_dec_len / step_idx for EOS/max-len detection (read-only) + const int64_t* max_dec_len, + const int64_t* step_idx, + // Dimensions and config + const int max_bsz, + const int real_bsz, + const int max_step_tokens, + const int end_length, + const int max_seq_len, + const int max_candidate_len, + const int verify_window, + const int verify_strategy, // 0=TOPP, 1=GREEDY, 2=TARGET_MATCH + const bool reject_all, + const bool accept_all); /*--------------------------------------- MTP end * --------------------------------------------*/ diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/verify_draft_tokens.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/verify_draft_tokens.xpu new file mode 100644 index 00000000000..3519a963442 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/verify_draft_tokens.xpu @@ -0,0 +1,349 @@ +#include "xpu/kernel/cluster_debug.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/xtdk.h" +#include "xpu/kernel/xtdk_math.h" +#include "xpu/kernel/xtdk_simd.h" + +namespace fd_xpu3 { + +static inline __device__ int v_reduce(int32x16_t &v0, int32x16_t &v1) { + int res; + v1 = vvadd_int32x16(v0, v1); + auto v = vsrlp_int32x16(256, v1); + v1 = vvadd_int32x16(v, v1); + v = vsrlp_int32x16(128, v1); + v1 = vvadd_int32x16(v, v1); + v = vsrlp_int32x16(64, v1); + v1 = vvadd_int32x16(v, v1); + v = vsrlp_int32x16(32, v1); + v1 = vvadd_int32x16(v, v1); + res = vextract_int32x16(v1, 1); + return res; +} +static inline __device__ int ClusterReduce( + const _shared_ptr_ int *stop_flag_now_int_sm, int len) { + int sum = 0; + if (core_id() == 0) { + int32x16_t vec_x_0; + int32x16_t vec_x_1; + int32x16_t vec_y_0 = vzero(); + int32x16_t vec_y_1 = vzero(); + for (int i = 0; i < len; i += 32) { + vload2_sm(stop_flag_now_int_sm + i, vec_x_0, vec_x_1); + vec_y_0 = vvadd_int32x16(vec_y_0, vec_x_0); + vec_y_1 = vvadd_int32x16(vec_y_1, vec_x_1); + } + sum = v_reduce(vec_y_0, vec_y_1); + } + return sum; +} +__device__ bool is_in_end(const int64_t id, + __global_ptr__ const int64_t *end_ids, + int length) { + bool flag = false; + for (int i = 0; i < length; i++) { + if (id == end_ids[i]) { + return true; + } + } + return flag; +} +__device__ inline bool is_in(__global_ptr__ const int64_t *candidates, + const int64_t draft, + const int candidate_len) { + for (int i = 0; i < candidate_len; i++) { + if (draft == candidates[i]) { + return true; + } + } + return false; +} +// static __device__ inline unsigned int xorwow(unsigned int& state) { +// state ^= state >> 7; +// state ^= state << 9; +// state ^= state >> 13; +// return state; +// } +static __device__ inline unsigned int xorwow(unsigned int &state) { + state ^= state >> 7; + state ^= state << 9; + state ^= state >> 13; + return state; +} + +__device__ int64_t +topp_sampling_kernel(__global_ptr__ const int64_t *candidate_ids, + __global_ptr__ const float *candidate_scores, + __global_ptr__ const float *dev_curand_states, + const int candidate_len, + const float topp) { + const int tid = core_id(); + float sum_scores = 0.0f; + float rand_top_p = *dev_curand_states * topp; + // printf("debug rand_top_p:%f\n",rand_top_p); + for (int i = 0; i < candidate_len; i++) { + sum_scores += candidate_scores[i]; + if (rand_top_p <= sum_scores) { + return candidate_ids[i]; + } + } + return candidate_ids[0]; +} + +// GREEDY / TARGET_MATCH: exact single-token match +__device__ inline bool verify_one_match(int64_t target_token, + int64_t draft_token) { + return target_token == draft_token; +} + +__device__ inline bool verify_one_topp( + __global_ptr__ const int64_t *verify_tokens_row, + int64_t draft_token, + int actual_cand_len) { + return is_in(verify_tokens_row, draft_token, actual_cand_len); +} + +// ============================================================ +// VerifyContext — per-batch mutable state + accept helpers. +// Eliminates repeated EOS/max_dec_len check and output write +// patterns across Phase 1 and Phase 2. +// ============================================================ +struct VerifyContext { + // Immutable per-batch (set once at kernel entry) + int bid; + int max_step_tokens; + int end_length; + __global_ptr__ const int64_t *end_tokens; + __global_ptr__ const int64_t *max_dec_len; + __global_ptr__ const int64_t *step_input_ids_now; + __global_ptr__ int64_t *step_output_ids; + + // Mutable per-batch state + int64_t cur_step_idx; + int output_len_now; + bool stopped; + + // Emit a token at position `pos` to output in Phase 1. + // Performs: step_idx check, EOS detection, token replacement, output write. + // Returns true if this sequence should stop (EOS or max_dec_len hit). + __device__ bool emit_token(int pos, int64_t token) { + cur_step_idx++; + bool is_eos = is_in_end(token, end_tokens, end_length); + bool max_len_hit = (cur_step_idx >= max_dec_len[bid]); + if ((is_eos || max_len_hit) && !is_eos) { + token = end_tokens[0]; + } + step_output_ids[bid * max_step_tokens + pos] = token; + output_len_now++; + if (is_eos || max_len_hit) { + stopped = true; + return true; + } + return false; + } + + // Emit the final token at position `pos` in Phase 2. + // Same EOS/max_dec_len logic. Increments output_len_now since + // Phase 2 produces one additional token. + __device__ void emit_final_token(int pos, int64_t token) { + cur_step_idx++; + bool is_eos = is_in_end(token, end_tokens, end_length); + bool max_len_hit = (cur_step_idx >= max_dec_len[bid]); + if ((is_eos || max_len_hit) && !is_eos) { + token = end_tokens[0]; + } + step_output_ids[bid * max_step_tokens + pos] = token; + output_len_now++; + } + + // TOPP-only: verify_window bulk-accept fallback. + // + // When draft token is NOT in top-p set but IS the top-2 token, + // check verify_window consecutive positions for top-1 match. + // If all match, bulk-accept from position i through ii. + // + // Returns the new loop position (i) after handling. + // Sets *rejected=true if fallback was not triggered (caller should break). + __device__ int try_verify_window_fallback( + int i, + bool *rejected, + __global_ptr__ const int64_t *verify_tokens_now, + int seq_len_this_time, + int max_candidate_len, + int verify_window) { + int ii = i; + if (max_candidate_len >= 2 && + verify_tokens_now[ii * max_candidate_len + 1] == + step_input_ids_now[ii + 1]) { + // top-2 matches — scan verify_window consecutive top-1 matches + int j = 0; + ii += 1; + for (; j < verify_window && ii < seq_len_this_time - 1; j++, ii++) { + if (verify_tokens_now[ii * max_candidate_len] != + step_input_ids_now[ii + 1]) { + break; + } + } + if (j >= verify_window) { + // Bulk accept all tokens from i to ii + for (; i < ii; i++) { + if (emit_token(i, step_input_ids_now[i + 1])) return i; + } + return i; // continue outer loop from position ii + } + } + // Fallback not triggered or insufficient window — reject + *rejected = true; + return i; + } +}; + +__global__ void verify_draft_tokens( + // Core I/O + int64_t *step_output_ids, + int *step_output_len, + const int64_t *step_input_ids, // draft tokens + // Target model outputs (strategy-dependent interpretation) + const int64_t + *target_tokens, // GREEDY:argmax, TARGET_MATCH:sampled, TOPP:unused + // Candidate set for TOPP/GREEDY (TARGET_MATCH: unused) + const int64_t *candidate_ids, + const float *candidate_scores, + const int *candidate_lens, + // Sampling params + const float *curand_states, // nullptr for GREEDY/TARGET_MATCH + const float *topp, + // Metadata + const bool *stop_flags, + const int *seq_lens_encoder, + const int *seq_lens_this_time, + const int64_t *end_tokens, + const bool *is_block_step, + const int *cu_seqlens_q_output, + const int *reasoning_status, + // max_dec_len / step_idx for EOS/max-len detection (read-only) + const int64_t *max_dec_len, + const int64_t *step_idx, + // Dimensions and config + const int max_bsz, + const int real_bsz, + const int max_step_tokens, + const int end_length, + const int max_seq_len, + const int max_candidate_len, + const int verify_window, + const int verify_strategy, // 0=TOPP, 1=GREEDY, 2=TARGET_MATCH + const bool reject_all, + const bool accept_all) { + const int64_t tid = core_id() * cluster_num() + cluster_id(); + const int64_t nthreads = cluster_num() * core_num(); + for (int64_t bid = tid; bid < real_bsz; bid += nthreads) { + step_output_len[bid] = 0; + if (bid >= real_bsz || is_block_step[bid] || stop_flags[bid]) continue; + const int start_token_id = cu_seqlens_q_output[bid]; + // Pointers are strategy-dependent (may be nullptr for unused params) + auto *candidate_ids_now = + candidate_ids ? candidate_ids + start_token_id * max_candidate_len + : nullptr; + auto *candidate_scores_now = + candidate_scores ? candidate_scores + start_token_id * max_candidate_len + : nullptr; + auto *candidate_lens_now = + candidate_lens ? candidate_lens + start_token_id : nullptr; + auto *target_tokens_now = + target_tokens ? target_tokens + start_token_id : nullptr; + + // Initialize per-batch verification context + VerifyContext ctx; + ctx.bid = bid; + ctx.max_step_tokens = max_step_tokens; + ctx.end_length = end_length; + ctx.end_tokens = end_tokens; + ctx.max_dec_len = max_dec_len; + ctx.step_input_ids_now = step_input_ids + bid * max_step_tokens; + ctx.step_output_ids = step_output_ids; + ctx.cur_step_idx = step_idx[bid]; + ctx.output_len_now = 0; + ctx.stopped = false; + + // ======== Phase 1: Verify draft tokens ======== + int i = 0; + for (; i < seq_lens_this_time[bid] - 1; i++) { + // Early exit conditions: reject-all, prefill, reasoning + if (reject_all || seq_lens_encoder[bid] != 0 || + reasoning_status[bid] == 1) { + break; + } + + // Accept-all override (debug/warmup) + if (accept_all) { + if (ctx.emit_token(i, ctx.step_input_ids_now[i + 1])) break; + continue; + } + + // Strategy dispatch + bool accepted = false; + switch (verify_strategy) { + case 0: { // TOPP + auto actual_cand_len = candidate_lens_now[i] > max_candidate_len + ? max_candidate_len + : candidate_lens_now[i]; + accepted = verify_one_topp(candidate_ids_now + i * max_candidate_len, + ctx.step_input_ids_now[i + 1], + actual_cand_len); + if (!accepted) { + bool rejected = false; + i = ctx.try_verify_window_fallback(i, + &rejected, + candidate_ids_now, + seq_lens_this_time[bid], + max_candidate_len, + verify_window); + if (ctx.stopped || rejected) goto phase1_done; + continue; // bulk accept succeeded, continue from new i + } + break; + } + case 1: // GREEDY + case 2: // TARGET_MATCH + accepted = verify_one_match(target_tokens_now[i], + ctx.step_input_ids_now[i + 1]); + break; + } + + if (accepted) { + if (ctx.emit_token(i, ctx.step_input_ids_now[i + 1])) break; + } else { + break; // reject + } + } + phase1_done: + + // ======== Phase 2: Output token for rejected/last position ======== + if (!ctx.stopped) { + int64_t output_token; + switch (verify_strategy) { + case 0: { // TOPP — stochastic sampling from candidate set + auto actual_cand_len = candidate_lens_now[i] > max_candidate_len + ? max_candidate_len + : candidate_lens_now[i]; + output_token = + topp_sampling_kernel(candidate_ids_now + i * max_candidate_len, + candidate_scores_now + i * max_candidate_len, + curand_states, + actual_cand_len, + topp[bid]); + break; + } + case 1: // GREEDY — deterministic argmax from target_tokens + case 2: // TARGET_MATCH — target model's sampled token + output_token = target_tokens_now[i]; + break; + } + ctx.emit_final_token(i, output_token); + } + step_output_len[bid] = ctx.output_len_now; + } +} + +} // namespace fd_xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_verify.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_verify.cpp index 0d44020581f..6fa25220c3c 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_verify.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_verify.cpp @@ -19,7 +19,6 @@ #include "xpu/refactor/impl_public/wrapper_check.h" namespace fd_xpu3 { -typedef uint32_t curandStatePhilox4_32_10_t; template __attribute__((global)) void speculate_verify( @@ -87,8 +86,6 @@ static inline unsigned int xorwow(unsigned int &state) { // NOLINT return state; } -typedef uint32_t curandStatePhilox4_32_10_t; - static int64_t topp_sampling_kernel(const int64_t *candidate_ids, const float *candidate_scores, const float *dev_curand_states, diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/verify_draft_tokens.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/verify_draft_tokens.cpp new file mode 100644 index 00000000000..76376356948 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/verify_draft_tokens.cpp @@ -0,0 +1,613 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace fd_xpu3 { +__attribute__((global)) void verify_draft_tokens( + // Core I/O + int64_t *step_output_ids, + int *step_output_len, + const int64_t *step_input_ids, // draft tokens + // Target model outputs (strategy-dependent interpretation) + const int64_t + *target_tokens, // GREEDY:argmax, TARGET_MATCH:sampled, TOPP:unused + // Candidate set for TOPP/GREEDY (TARGET_MATCH: unused) + const int64_t *candidate_ids, + const float *candidate_scores, + const int *candidate_lens, + // Sampling params + const float *curand_states, // nullptr for GREEDY/TARGET_MATCH + const float *topp, + // Metadata + const bool *stop_flags, + const int *seq_lens_encoder, + const int *seq_lens_this_time, + const int64_t *end_tokens, + const bool *is_block_step, + const int *cu_seqlens_q_output, + const int *reasoning_status, + // max_dec_len / step_idx for EOS/max-len detection (read-only) + const int64_t *max_dec_len, + const int64_t *step_idx, + // Dimensions and config + const int max_bsz, + const int real_bsz, + const int max_step_tokens, + const int end_length, + const int max_seq_len, + const int max_candidate_len, + const int verify_window, + const int verify_strategy, // 0=TOPP, 1=GREEDY, 2=TARGET_MATCH + const bool reject_all, + const bool accept_all); +} // namespace fd_xpu3 + +namespace fastdeploy { +namespace plugin { + +// ============================================================ +// Phase 1 helpers — single-step draft token verification +// ============================================================ + +// Check if draft_token appears in the candidate set +static inline bool is_in(const int64_t *candidates, + const int64_t draft, + const int candidate_len) { + for (int i = 0; i < candidate_len; i++) { + if (draft == candidates[i]) { + return true; + } + } + return false; +} +// TOPP: draft in top-p filtered candidate set +static inline bool verify_one_topp(const int64_t *verify_tokens_row, + int64_t draft_token, + int actual_cand_len) { + return is_in(verify_tokens_row, draft_token, actual_cand_len); +} + +// GREEDY / TARGET_MATCH: exact single-token match +static inline bool verify_one_match(int64_t target_token, int64_t draft_token) { + return target_token == draft_token; +} + +static inline bool is_in_end(const int64_t id, + const int64_t *end_ids, + int length) { + bool flag = false; + for (int i = 0; i < length; i++) { + if (id == end_ids[i]) { + return true; + } + } + return flag; +} + +// ============================================================ +// VerifyContext — per-batch mutable state + accept helpers. +// Eliminates repeated EOS/max_dec_len check and output write +// patterns across Phase 1 and Phase 2. +// ============================================================ +struct VerifyContext { + // Immutable per-batch (set once at kernel entry) + int bid; + int max_step_tokens; + int end_length; + const int64_t *end_tokens; + const int64_t *max_dec_len; + const int64_t *step_input_ids_now; + int64_t *step_output_ids; + + // Mutable per-batch state + int64_t cur_step_idx; + int output_len_now; + bool stopped; + + // Emit a token at position `pos` to output in Phase 1. + // Performs: step_idx check, EOS detection, token replacement, output write. + // Returns true if this sequence should stop (EOS or max_dec_len hit). + bool emit_token(int pos, int64_t token) { + cur_step_idx++; + bool is_eos = is_in_end(token, end_tokens, end_length); + bool max_len_hit = (cur_step_idx >= max_dec_len[bid]); + if ((is_eos || max_len_hit) && !is_eos) { + token = end_tokens[0]; + } + step_output_ids[bid * max_step_tokens + pos] = token; + output_len_now++; + if (is_eos || max_len_hit) { + stopped = true; + return true; + } + return false; + } + + // Emit the final token at position `pos` in Phase 2. + // Same EOS/max_dec_len logic. Increments output_len_now since + // Phase 2 produces one additional token. + void emit_final_token(int pos, int64_t token) { + cur_step_idx++; + bool is_eos = is_in_end(token, end_tokens, end_length); + bool max_len_hit = (cur_step_idx >= max_dec_len[bid]); + if ((is_eos || max_len_hit) && !is_eos) { + token = end_tokens[0]; + } + step_output_ids[bid * max_step_tokens + pos] = token; + output_len_now++; + } + + // TOPP-only: verify_window bulk-accept fallback. + // + // When draft token is NOT in top-p set but IS the top-2 token, + // check verify_window consecutive positions for top-1 match. + // If all match, bulk-accept from position i through ii. + // + // Returns the new loop position (i) after handling. + // Sets *rejected=true if fallback was not triggered (caller should break). + int try_verify_window_fallback(int i, + bool *rejected, + const int64_t *verify_tokens_now, + int seq_len_this_time, + int max_candidate_len, + int verify_window) { + int ii = i; + if (max_candidate_len >= 2 && + verify_tokens_now[ii * max_candidate_len + 1] == + step_input_ids_now[ii + 1]) { + // top-2 matches — scan verify_window consecutive top-1 matches + int j = 0; + ii += 1; + for (; j < verify_window && ii < seq_len_this_time - 1; j++, ii++) { + if (verify_tokens_now[ii * max_candidate_len] != + step_input_ids_now[ii + 1]) { + break; + } + } + if (j >= verify_window) { + // Bulk accept all tokens from i to ii + for (; i < ii; i++) { + if (emit_token(i, step_input_ids_now[i + 1])) return i; + } + return i; // continue outer loop from position ii + } + } + // Fallback not triggered or insufficient window — reject + *rejected = true; + return i; + } +}; + +static int64_t topp_sampling_kernel(const int64_t *candidate_ids, + const float *candidate_scores, + const float *dev_curand_states, + const int candidate_len, + const float topp, + int tid) { + // const int tid = core_id(); + float sum_scores = 0.0f; + float rand_top_p = *dev_curand_states * topp; + for (int i = 0; i < candidate_len; i++) { + // printf("debug cpu sample i:%d scores:%f,ids:%ld + // rand_top_p:%f,candidate_len:%d\n", + // i,candidate_scores[i],candidate_ids[i],rand_top_p,candidate_len); + sum_scores += candidate_scores[i]; + if (rand_top_p <= sum_scores) { + return candidate_ids[i]; + } + } + return candidate_ids[0]; +} + +static int cpu_wrapper( + api::Context *ctx, + // Core I/O + int64_t *step_output_ids, + int *step_output_len, + const int64_t *step_input_ids, // draft tokens + // Target model outputs (strategy-dependent interpretation) + const int64_t + *target_tokens, // GREEDY:argmax, TARGET_MATCH:sampled, TOPP:unused + // Candidate set for TOPP/GREEDY (TARGET_MATCH: unused) + const int64_t *candidate_ids, + const float *candidate_scores, + const int *candidate_lens, + // Sampling params + const float *curand_states, // nullptr for GREEDY/TARGET_MATCH + const float *topp, + // Metadata + const bool *stop_flags, + const int *seq_lens_encoder, + const int *seq_lens_this_time, + const int64_t *end_tokens, + const bool *is_block_step, + const int *cu_seqlens_q_output, + const int *reasoning_status, + // max_dec_len / step_idx for EOS/max-len detection (read-only) + const int64_t *max_dec_len, + const int64_t *step_idx, + // Dimensions and config + const int max_bsz, + const int real_bsz, + const int max_step_tokens, + const int end_length, + const int max_seq_len, + const int max_candidate_len, + const int verify_window, + const int verify_strategy, // 0=TOPP, 1=GREEDY, 2=TARGET_MATCH + const bool reject_all, + const bool accept_all) { + for (int bid = 0; bid < max_bsz; bid++) { + step_output_len[bid] = 0; + + if (bid >= real_bsz || is_block_step[bid] || stop_flags[bid]) continue; + + const int start_token_id = cu_seqlens_q_output[bid]; + // Pointers are strategy-dependent (may be nullptr for unused params) + auto *candidate_ids_now = + candidate_ids ? candidate_ids + start_token_id * max_candidate_len + : nullptr; + auto *candidate_scores_now = + candidate_scores ? candidate_scores + start_token_id * max_candidate_len + : nullptr; + auto *candidate_lens_now = + candidate_lens ? candidate_lens + start_token_id : nullptr; + auto *target_tokens_now = + target_tokens ? target_tokens + start_token_id : nullptr; + + // Initialize per-batch verification context + VerifyContext v_ctx; + v_ctx.bid = bid; + v_ctx.max_step_tokens = max_step_tokens; + v_ctx.end_length = end_length; + v_ctx.end_tokens = end_tokens; + v_ctx.max_dec_len = max_dec_len; + v_ctx.step_input_ids_now = step_input_ids + bid * max_step_tokens; + v_ctx.step_output_ids = step_output_ids; + v_ctx.cur_step_idx = step_idx[bid]; + v_ctx.output_len_now = 0; + v_ctx.stopped = false; + + // ======== Phase 1: Verify draft tokens ======== + int i = 0; + for (; i < seq_lens_this_time[bid] - 1; i++) { + // Early exit conditions: reject-all, prefill, reasoning + if (reject_all || seq_lens_encoder[bid] != 0 || + reasoning_status[bid] == 1) { + break; + } + + // Accept-all override (debug/warmup) + if (accept_all) { + if (v_ctx.emit_token(i, v_ctx.step_input_ids_now[i + 1])) break; + continue; + } + + // Strategy dispatch + bool accepted = false; + switch (verify_strategy) { + case 0: { // TOPP + auto actual_cand_len = candidate_lens_now[i] > max_candidate_len + ? max_candidate_len + : candidate_lens_now[i]; + accepted = verify_one_topp(candidate_ids_now + i * max_candidate_len, + v_ctx.step_input_ids_now[i + 1], + actual_cand_len); + if (!accepted) { + bool rejected = false; + i = v_ctx.try_verify_window_fallback(i, + &rejected, + candidate_ids_now, + seq_lens_this_time[bid], + max_candidate_len, + verify_window); + if (v_ctx.stopped || rejected) goto phase1_done; + continue; // bulk accept succeeded, continue from new i + } + break; + } + case 1: // GREEDY + case 2: // TARGET_MATCH + accepted = verify_one_match(target_tokens_now[i], + v_ctx.step_input_ids_now[i + 1]); + break; + } + + if (accepted) { + if (v_ctx.emit_token(i, v_ctx.step_input_ids_now[i + 1])) break; + } else { + break; // reject + } + } + phase1_done: + + // ======== Phase 2: Output token for rejected/last position ======== + if (!v_ctx.stopped) { + int64_t output_token = 0; + switch (verify_strategy) { + case 0: { // TOPP — stochastic sampling from candidate set + auto actual_cand_len = candidate_lens_now[i] > max_candidate_len + ? max_candidate_len + : candidate_lens_now[i]; + output_token = + topp_sampling_kernel(candidate_ids_now + i * max_candidate_len, + candidate_scores_now + i * max_candidate_len, + curand_states + i, + actual_cand_len, + topp[bid], + bid); + break; + } + case 1: // GREEDY — deterministic argmax from target_tokens + case 2: // TARGET_MATCH — target model's sampled token + output_token = target_tokens_now[i]; + break; + } + v_ctx.emit_final_token(i, output_token); + } + step_output_len[bid] = v_ctx.output_len_now; + } + + return api::SUCCESS; +} + +static int xpu3_wrapper( + api::Context *ctx, + // Core I/O + int64_t *step_output_ids, + int *step_output_len, + const int64_t *step_input_ids, // draft tokens + // Target model outputs (strategy-dependent interpretation) + const int64_t + *target_tokens, // GREEDY:argmax, TARGET_MATCH:sampled, TOPP:unused + // Candidate set for TOPP/GREEDY (TARGET_MATCH: unused) + const int64_t *candidate_ids, + const float *candidate_scores, + const int *candidate_lens, + // Sampling params + const float *curand_states, // nullptr for GREEDY/TARGET_MATCH + const float *topp, + // Metadata + const bool *stop_flags, + const int *seq_lens_encoder, + const int *seq_lens_this_time, + const int64_t *end_tokens, + const bool *is_block_step, + const int *cu_seqlens_q_output, + const int *reasoning_status, + // max_dec_len / step_idx for EOS/max-len detection (read-only) + const int64_t *max_dec_len, + const int64_t *step_idx, + // Dimensions and config + const int max_bsz, + const int real_bsz, + const int max_step_tokens, + const int end_length, + const int max_seq_len, + const int max_candidate_len, + const int verify_window, + const int verify_strategy, // 0=TOPP, 1=GREEDY, 2=TARGET_MATCH + const bool reject_all, + const bool accept_all) { + using XPU_INT64 = typename api::XPUIndexType::type; + int32_t ret_xre = + fd_xpu3::verify_draft_tokens<<ncluster(), 64, ctx->xpu_stream>>>( + reinterpret_cast(step_output_ids), + step_output_len, + reinterpret_cast(step_input_ids), + reinterpret_cast(target_tokens), + reinterpret_cast(candidate_ids), + candidate_scores, + candidate_lens, + curand_states, + topp, + stop_flags, + seq_lens_encoder, + seq_lens_this_time, + reinterpret_cast(end_tokens), + is_block_step, + cu_seqlens_q_output, + reasoning_status, + reinterpret_cast(max_dec_len), + reinterpret_cast(step_idx), + max_bsz, + real_bsz, + max_step_tokens, + end_length, + max_seq_len, + max_candidate_len, + verify_window, + verify_strategy, + reject_all, + accept_all); + KERNEL_ASSERT_SUCCESS(ctx, ret_xre); + return api::SUCCESS; +} + +int verify_draft_tokens( + api::Context *ctx, + // Core I/O + int64_t *step_output_ids, + int *step_output_len, + const int64_t *step_input_ids, // draft tokens + // Target model outputs (strategy-dependent interpretation) + const int64_t + *target_tokens, // GREEDY:argmax, TARGET_MATCH:sampled, TOPP:unused + // Candidate set for TOPP/GREEDY (TARGET_MATCH: unused) + const int64_t *candidate_ids, + const float *candidate_scores, + const int *candidate_lens, + // Sampling params + const float *curand_states, // nullptr for GREEDY/TARGET_MATCH + const float *topp, + // Metadata + const bool *stop_flags, + const int *seq_lens_encoder, + const int *seq_lens_this_time, + const int64_t *end_tokens, + const bool *is_block_step, + const int *cu_seqlens_q_output, + const int *reasoning_status, + // max_dec_len / step_idx for EOS/max-len detection (read-only) + const int64_t *max_dec_len, + const int64_t *step_idx, + // Dimensions and config + const int max_bsz, + const int real_bsz, + const int max_step_tokens, + const int end_length, + const int max_seq_len, + const int max_candidate_len, + const int verify_window, + const int verify_strategy, // 0=TOPP, 1=GREEDY, 2=TARGET_MATCH + const bool reject_all, + const bool accept_all) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "verify_draft_tokens", int64_t); + WRAPPER_DUMP_PARAM6(ctx, + step_output_ids, + step_output_len, + step_input_ids, + target_tokens, + candidate_ids, + candidate_scores); + WRAPPER_DUMP_PARAM6(ctx, + candidate_lens, + curand_states, + topp, + stop_flags, + seq_lens_encoder, + seq_lens_this_time); + + WRAPPER_DUMP_PARAM6(ctx, + end_tokens, + is_block_step, + cu_seqlens_q_output, + reasoning_status, + max_dec_len, + step_idx); + + WRAPPER_DUMP_PARAM6(ctx, + max_bsz, + real_bsz, + max_step_tokens, + end_length, + max_seq_len, + max_candidate_len); + + WRAPPER_DUMP_PARAM4( + ctx, verify_window, verify_strategy, reject_all, accept_all); + WRAPPER_DUMP(ctx); + + WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * max_step_tokens, step_output_ids); + WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * max_step_tokens, step_input_ids); + // len(target_tokens) = cu_seqlens_q_output[-1] + WRAPPER_CHECK_PTR_OR_NULL(ctx, int64_t, real_bsz, target_tokens); + WRAPPER_CHECK_PTR_OR_NULL(ctx, int64_t, real_bsz, candidate_lens); + WRAPPER_CHECK_PTR_OR_NULL( + ctx, int64_t, real_bsz * max_candidate_len, candidate_ids); + WRAPPER_CHECK_PTR_OR_NULL( + ctx, float, real_bsz *max_candidate_len, candidate_scores); + + WRAPPER_CHECK_PTR(ctx, float, real_bsz, curand_states); + WRAPPER_CHECK_PTR(ctx, float, real_bsz, topp); + WRAPPER_CHECK_PTR(ctx, bool, real_bsz, stop_flags); + WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_encoder); + WRAPPER_CHECK_PTR(ctx, float, real_bsz, seq_lens_this_time); + WRAPPER_CHECK_PTR(ctx, int64_t, end_length, end_tokens); + + WRAPPER_CHECK_PTR(ctx, bool, real_bsz, is_block_step); + WRAPPER_CHECK_PTR(ctx, bool, real_bsz, cu_seqlens_q_output); + WRAPPER_CHECK_PTR(ctx, bool, real_bsz, reasoning_status); + WRAPPER_CHECK_PTR(ctx, bool, real_bsz, max_dec_len); + WRAPPER_CHECK_PTR(ctx, bool, real_bsz, step_idx); + // param check sm size limit + WRAPPER_ASSERT_GT(ctx, real_bsz, 0); + WRAPPER_ASSERT_LE(ctx, real_bsz, 1024); + WRAPPER_ASSERT_LE(ctx, real_bsz * max_candidate_len, 2048); + WRAPPER_ASSERT_LE(ctx, verify_window * max_candidate_len, 128); + + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + step_output_ids, + step_output_len, + step_input_ids, + target_tokens, + candidate_ids, + candidate_scores, + candidate_lens, + + curand_states, + topp, + stop_flags, + seq_lens_encoder, + seq_lens_this_time, + end_tokens, + is_block_step, + cu_seqlens_q_output, + reasoning_status, + max_dec_len, + step_idx, + max_bsz, + real_bsz, + max_step_tokens, + end_length, + max_seq_len, + max_candidate_len, + verify_window, + verify_strategy, + reject_all, + accept_all); + } + if (ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, + step_output_ids, + step_output_len, + step_input_ids, + target_tokens, + candidate_ids, + candidate_scores, + candidate_lens, + curand_states, + topp, + stop_flags, + seq_lens_encoder, + seq_lens_this_time, + end_tokens, + is_block_step, + cu_seqlens_q_output, + reasoning_status, + max_dec_len, + step_idx, + max_bsz, + real_bsz, + max_step_tokens, + end_length, + max_seq_len, + max_candidate_len, + verify_window, + verify_strategy, + reject_all, + accept_all); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +} // namespace plugin +} // namespace fastdeploy diff --git a/custom_ops/xpu_ops/test/test_verify_draft_tokens.py b/custom_ops/xpu_ops/test/test_verify_draft_tokens.py new file mode 100644 index 00000000000..ce3d4f71cb6 --- /dev/null +++ b/custom_ops/xpu_ops/test/test_verify_draft_tokens.py @@ -0,0 +1,771 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for verify_draft_tokens kernel. + +Verification strategies: +- TOPP (0): Verify draft token is in top-p candidate set +- GREEDY (1): Verify draft token matches target model's argmax +- TARGET_MATCH (2): Verify draft token matches target model's sampled token +""" + +import random +import unittest +from typing import Any, Dict + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.xpu import static_op_verify_draft_tokens + +verify_draft_tokens = static_op_verify_draft_tokens +from fastdeploy.spec_decode import VerifyStrategy + +CPU_PLACE = paddle.CPUPlace() +CUDA_PLACE = paddle.XPUPlace(0) if paddle.is_compiled_with_xpu() else paddle.CPUPlace() + + +# ============================================================ +# Helpers: tensor creation / kernel invocation / comparison +# ============================================================ + + +def to_paddle_inputs(inputs: Dict[str, Any]) -> Dict[str, Any]: + """Convert numpy input dict to paddle tensors on GPU.""" + paddle_inputs = {} + for k, v in inputs.items(): + if isinstance(v, (int, bool, float, str)): + paddle_inputs[k] = v + elif v is not None: + paddle_inputs[k] = paddle.to_tensor(v, place=CUDA_PLACE) + else: + paddle_inputs[k] = None + return paddle_inputs + + +def run_kernel(paddle_inputs: Dict[str, Any], inputs: Dict[str, Any]): + """Call verify_draft_tokens kernel.""" + verify_draft_tokens( + paddle_inputs["step_output_ids"], + paddle_inputs["step_output_len"], + paddle_inputs["step_input_ids"], + paddle_inputs["target_tokens"], + paddle_inputs["candidate_ids"], + paddle_inputs["candidate_scores"], + paddle_inputs["candidate_lens"], + paddle_inputs["topp"], + paddle_inputs["stop_flags"], + paddle_inputs["seq_lens_encoder"], + paddle_inputs["seq_lens_this_time"], + paddle_inputs["end_tokens"], + paddle_inputs["is_block_step"], + paddle_inputs["cu_seqlens_q_output"], + paddle_inputs["reasoning_status"], + paddle_inputs["max_dec_len"], + paddle_inputs["step_idx"], + inputs["max_seq_len"], + inputs["verify_window"], + inputs["verify_strategy"], + inputs["reject_all"], + inputs["accept_all"], + ) + + +def run_ref(inputs: Dict[str, Any]): + """Run reference implementation on deep-copied inputs, return (output_ids, output_len).""" + ref = {k: v.copy() if isinstance(v, np.ndarray) else v for k, v in inputs.items()} + return verify_draft_tokens_ref( + ref["step_output_ids"], + ref["step_output_len"], + ref["step_input_ids"], + ref["target_tokens"], + ref["candidate_ids"], + ref["candidate_scores"], + ref["candidate_lens"], + ref["topp"], + ref["stop_flags"], + ref["seq_lens_encoder"], + ref["seq_lens_this_time"], + ref["end_tokens"], + ref["is_block_step"], + ref["cu_seqlens_q_output"], + ref["reasoning_status"], + ref["max_dec_len"], + ref["step_idx"], + ref["max_seq_len"], + ref["verify_window"], + ref["verify_strategy"], + ref["reject_all"], + ref["accept_all"], + ) + + +def compare_results( + paddle_inputs: Dict[str, Any], + step_output_ids_ref: np.ndarray, + step_output_len_ref: np.ndarray, + inputs: Dict[str, Any], + label: str = "unknown", +): + """Compare GPU kernel output vs reference.""" + gpu_ids = paddle_inputs["step_output_ids"].numpy() + gpu_len = paddle_inputs["step_output_len"].numpy() + np.testing.assert_array_equal( + gpu_len, + step_output_len_ref, + err_msg=f"step_output_len mismatch ({label})", + ) + + if inputs["verify_strategy"] == 0: # TOPP — Phase 2 is stochastic + real_bsz = inputs["seq_lens_this_time"].shape[0] + for bid in range(real_bsz): + ref_len = int(step_output_len_ref[bid]) + if ref_len > 1: + print(gpu_ids[bid, : ref_len - 1], step_output_ids_ref[bid, : ref_len - 1]) + np.testing.assert_array_equal( + gpu_ids[bid, : ref_len - 1], + step_output_ids_ref[bid, : ref_len - 1], + err_msg=f"step_output_ids (accepted) mismatch at bid={bid} ({label})", + ) + else: + np.testing.assert_array_equal( + gpu_ids, + step_output_ids_ref, + err_msg=f"step_output_ids mismatch ({label})", + ) + + +# ============================================================ +# Reference helpers +# ============================================================ + + +def topp_sampling_kernel(candidate_ids, candidate_scores, curand_value, candidate_len, topp, tid=0): + rand_top_p = curand_value * topp + sum_scores = 0.0 + for i in range(candidate_len): + sum_scores += candidate_scores[i] + if rand_top_p <= sum_scores: + return int(candidate_ids[i]) + return int(candidate_ids[0]) + + +def is_in_end(token, end_tokens, end_length): + return token in end_tokens[:end_length] + + +def is_in(candidate_list, token, length): + return token in candidate_list[:length] + + +class _VerifyContext: + """Python mirror of the CUDA VerifyContext struct for reference testing.""" + + def __init__( + self, + bid, + max_step_tokens, + end_length, + end_tokens, + max_dec_len, + step_input_ids_now, + step_output_ids_flat, + cur_step_idx, + ): + self.bid = bid + self.max_step_tokens = max_step_tokens + self.end_length = end_length + self.end_tokens = end_tokens + self.max_dec_len = max_dec_len + self.step_input_ids_now = step_input_ids_now + self.step_output_ids_flat = step_output_ids_flat + self.cur_step_idx = cur_step_idx + self.output_len_now = 0 + self.stopped = False + + def emit_token(self, pos, token): + """Emit a token to output. Returns True if sequence should stop.""" + self.cur_step_idx += 1 + eos = is_in_end(token, self.end_tokens, self.end_length) + max_hit = self.cur_step_idx >= int(self.max_dec_len[self.bid]) + if (eos or max_hit) and not eos: + token = int(self.end_tokens[0]) + self.step_output_ids_flat[self.bid * self.max_step_tokens + pos] = token + self.output_len_now += 1 + if eos or max_hit: + self.stopped = True + return True + return False + + def emit_final_token(self, pos, token): + """Emit the Phase 2 final token. Increments output_len_now.""" + self.cur_step_idx += 1 + eos = is_in_end(token, self.end_tokens, self.end_length) + max_hit = self.cur_step_idx >= int(self.max_dec_len[self.bid]) + if (eos or max_hit) and not eos: + token = int(self.end_tokens[0]) + self.step_output_ids_flat[self.bid * self.max_step_tokens + pos] = token + self.output_len_now += 1 + + +def verify_draft_tokens_ref( + step_output_ids, + step_output_len, + step_input_ids, + target_tokens, + candidate_ids, + candidate_scores, + candidate_lens, + topp, + stop_flags, + seq_lens_encoder, + seq_lens_this_time, + end_tokens, + is_block_step, + cu_seqlens_q_output, + reasoning_status, + max_dec_len, + step_idx, + max_seq_len, + verify_window, + verify_strategy, + reject_all, + accept_all, +): + """Reference implementation of verify_draft_tokens in Python.""" + real_bsz = seq_lens_this_time.shape[0] + max_step_tokens = step_input_ids.shape[1] + end_length = end_tokens.shape[0] + max_candidate_len = candidate_ids.shape[1] if candidate_ids is not None else 1 + + dev_curand_states = [random.Random(0).random() for _ in range(max_step_tokens)] + + step_output_ids_flat = step_output_ids.reshape(-1) + step_input_ids_flat = step_input_ids.reshape(-1) + candidate_ids_flat = candidate_ids.reshape(-1) if candidate_ids is not None else None + candidate_scores_flat = candidate_scores.reshape(-1) if candidate_scores is not None else None + + for bid in range(real_bsz): + start_token_id = cu_seqlens_q_output[bid] + + if is_block_step[bid] or stop_flags[bid]: + step_output_len[bid] = 0 + continue + + step_input_ids_now = step_input_ids_flat[bid * max_step_tokens :] + target_tokens_now = target_tokens[start_token_id:] if target_tokens is not None else None + candidate_ids_now = ( + candidate_ids_flat[start_token_id * max_candidate_len :] if candidate_ids_flat is not None else None + ) + candidate_lens_now = candidate_lens[start_token_id:] if candidate_lens is not None else None + candidate_scores_now = ( + candidate_scores_flat[start_token_id * max_candidate_len :] if candidate_scores_flat is not None else None + ) + + ctx = _VerifyContext( + bid, + max_step_tokens, + end_length, + end_tokens, + max_dec_len, + step_input_ids_now, + step_output_ids_flat, + int(step_idx[bid]), + ) + + # Phase 1: Verify + i = 0 + while i < seq_lens_this_time[bid] - 1: + if reject_all or seq_lens_encoder[bid] != 0 or reasoning_status[bid] == 1: + break + if accept_all: + if ctx.emit_token(i, step_input_ids_now[i + 1]): + break + i += 1 + continue + + accepted = False + if verify_strategy == 0: # TOPP + actual_cand_len = min(candidate_lens_now[i], max_candidate_len) + accepted = is_in( + candidate_ids_now[i * max_candidate_len : (i + 1) * max_candidate_len], + step_input_ids_now[i + 1], + actual_cand_len, + ) + if not accepted: + # verify_window fallback + ii = i + if ( + max_candidate_len >= 2 + and candidate_ids_now[ii * max_candidate_len + 1] == step_input_ids_now[ii + 1] + ): + j, ii = 0, ii + 1 + while j < verify_window and ii < seq_lens_this_time[bid] - 1: + if candidate_ids_now[ii * max_candidate_len] != step_input_ids_now[ii + 1]: + break + j += 1 + ii += 1 + if j >= verify_window: + for k in range(i, ii): + if ctx.emit_token(k, step_input_ids_now[k + 1]): + i = k + break + if ctx.stopped: + break + i = ii + continue + break + elif verify_strategy in (1, 2): # GREEDY / TARGET_MATCH + accepted = target_tokens_now[i] == step_input_ids_now[i + 1] + + if accepted: + if ctx.emit_token(i, step_input_ids_now[i + 1]): + break + else: + break + i += 1 + + # Phase 2: Sample for rejected/last position + if not ctx.stopped: + if verify_strategy == 0: + if candidate_lens_now is not None and len(candidate_lens_now) > i: + actual_cand_len = min(candidate_lens_now[i], max_candidate_len) + accept_token = topp_sampling_kernel( + candidate_ids_now[i * max_candidate_len : (i + 1) * max_candidate_len], + candidate_scores_now[i * max_candidate_len : (i + 1) * max_candidate_len], + dev_curand_states[i], + actual_cand_len, + topp[bid], + ) + else: + accept_token = int(step_input_ids_now[0]) + elif verify_strategy in (1, 2): + accept_token = ( + int(target_tokens_now[i]) + if target_tokens_now is not None and len(target_tokens_now) > i + else int(step_input_ids_now[0]) + ) + else: + accept_token = ( + int(candidate_ids_now[i * max_candidate_len]) + if candidate_ids_now is not None + else int(step_input_ids_now[0]) + ) + ctx.emit_final_token(i, accept_token) + + step_output_len[bid] = ctx.output_len_now + + return step_output_ids, step_output_len + + +# ============================================================ +# Input generation +# ============================================================ + + +def gen_verify_draft_tokens_inputs( + real_bsz: int = 32, + max_draft_tokens: int = 16, + max_seq_len: int = 256, + max_candidate_len: int = 8, + verify_window: int = 2, + end_length: int = 4, + verify_strategy: int = 1, + reject_all: bool = False, + accept_all: bool = False, + match_ratio: float = 0.0, + seed: int = 2025, +) -> Dict[str, Any]: + """Generate test inputs for verify_draft_tokens kernel. + + Args: + match_ratio: Fraction of draft token positions where target/candidates + are forced to match step_input_ids, so the acceptance path is exercised. + 0.0 = fully random (mostly rejects), 1.0 = all positions match. + """ + rng = np.random.default_rng(seed) + + seq_lens_encoder = np.zeros(real_bsz, dtype=np.int32) + seq_lens_this_time = rng.integers(1, max_draft_tokens + 1, size=real_bsz, dtype=np.int32) + step_input_ids = rng.integers(0, 1000, size=(real_bsz, max_draft_tokens), dtype=np.int64) + + sum_seq = int(np.sum(seq_lens_this_time)) + + if verify_strategy in (1, 2): # GREEDY / TARGET_MATCH + target_tokens = rng.integers(0, 1000, size=(sum_seq,), dtype=np.int64) + candidate_ids = None + candidate_scores = None + candidate_lens = None + else: # TOPP + target_tokens = None + candidate_ids = rng.integers(0, 1000, size=(sum_seq, max_candidate_len), dtype=np.int64) + candidate_scores = rng.random(size=(sum_seq, max_candidate_len)).astype(np.float32) + candidate_scores = candidate_scores / candidate_scores.sum(axis=1, keepdims=True) + candidate_lens = rng.integers(1, max_candidate_len + 1, size=sum_seq, dtype=np.int32) + + end_tokens = rng.integers(1, 1000, size=end_length, dtype=np.int64) + is_block_step = rng.integers(0, 2, size=real_bsz, dtype=bool) + + cu_seqlens_q_output = np.zeros(real_bsz + 1, dtype=np.int32) + for i in range(real_bsz): + cu_seqlens_q_output[i + 1] = cu_seqlens_q_output[i] + seq_lens_this_time[i] + cu_seqlens_q_output = cu_seqlens_q_output[:real_bsz].astype(np.int32) + + topp = rng.uniform(0.8, 1.0, size=real_bsz).astype(np.float32) + reasoning_status = np.zeros(real_bsz, dtype=np.int32) + step_output_ids = np.zeros((real_bsz, max_draft_tokens), dtype=np.int64) + step_output_len = np.zeros(real_bsz, dtype=np.int32) + stop_flags = np.zeros(real_bsz, dtype=bool) + + # Force match_ratio fraction of positions so acceptance path is tested + if match_ratio > 0.0: + offset = 0 + for bid in range(real_bsz): + slt = int(seq_lens_this_time[bid]) + n_match = max(1, int((slt - 1) * match_ratio)) # slt-1 verify positions + for pos in range(min(n_match, slt - 1)): + draft_token = int(step_input_ids[bid, pos + 1]) + # Ensure draft_token is not an end_token (would cause early stop) + while draft_token in end_tokens[:end_length]: + draft_token = (draft_token + 1) % 1000 + step_input_ids[bid, pos + 1] = draft_token + if verify_strategy in (1, 2) and target_tokens is not None: + target_tokens[offset + pos] = draft_token + elif verify_strategy == 0 and candidate_ids is not None: + candidate_ids[offset + pos, 0] = draft_token + candidate_lens[offset + pos] = max(candidate_lens[offset + pos], 1) + offset += slt + + return { + "step_output_ids": step_output_ids, + "step_output_len": step_output_len, + "step_input_ids": step_input_ids, + "target_tokens": target_tokens, + "candidate_ids": candidate_ids, + "candidate_scores": candidate_scores, + "candidate_lens": candidate_lens, + "topp": topp, + "stop_flags": stop_flags, + "seq_lens_encoder": seq_lens_encoder, + "seq_lens_this_time": seq_lens_this_time, + "end_tokens": end_tokens, + "is_block_step": is_block_step, + "cu_seqlens_q_output": cu_seqlens_q_output, + "reasoning_status": reasoning_status, + "max_dec_len": rng.integers(50, 200, size=real_bsz, dtype=np.int64), + "step_idx": rng.integers(0, 30, size=real_bsz, dtype=np.int64), + "max_seq_len": max_seq_len, + "verify_window": verify_window, + "verify_strategy": verify_strategy, + "reject_all": reject_all, + "accept_all": accept_all, + } + + +# ============================================================ +# Test configs +# ============================================================ + +TEST_CONFIGS = [ + # --- strategy coverage (random, mostly rejects) --- + { + "name": "greedy_small_batch", + "real_bsz": 1, + "max_draft_tokens": 9, + "max_seq_len": 11, + "max_candidate_len": 4, + "verify_window": 2, + "end_length": 5, + "verify_strategy": VerifyStrategy.GREEDY.value, + "seed": 42, + }, + { + "name": "greedy_medium_batch", + "real_bsz": 33, + "max_draft_tokens": 5, + "max_seq_len": 10111, + "max_candidate_len": 5, + "verify_window": 2, + "end_length": 6, + "verify_strategy": VerifyStrategy.GREEDY.value, + "seed": 42, + }, + { + "name": "topp_small_batch", + "real_bsz": 6, + "max_draft_tokens": 4, + "max_seq_len": 10001, + "max_candidate_len": 6, + "verify_window": 2, + "end_length": 7, + "verify_strategy": VerifyStrategy.TOPP.value, + "seed": 42, + }, + { + "name": "target_match_medium", + "real_bsz": 7, + "max_draft_tokens": 3, + "max_seq_len": 777, + "max_candidate_len": 7, + "verify_window": 2, + "end_length": 5, + "verify_strategy": VerifyStrategy.TARGET_MATCH.value, + "seed": 42, + }, + { + "name": "greedy_large_batch", + "real_bsz": 55, + "max_draft_tokens": 5, + "max_seq_len": 31, + "max_candidate_len": 9, + "verify_window": 2, + "end_length": 3, + "verify_strategy": VerifyStrategy.GREEDY.value, + "seed": 42, + }, + # --- partial acceptance (match_ratio forces draft tokens to match target/candidates) --- + { + "name": "greedy_half_accept", + "real_bsz": 8, + "max_draft_tokens": 8, + "max_seq_len": 256, + "max_candidate_len": 4, + "verify_window": 2, + "end_length": 3, + "verify_strategy": VerifyStrategy.GREEDY.value, + "seed": 42, + "match_ratio": 0.5, + }, + { + "name": "greedy_full_accept", + "real_bsz": 8, + "max_draft_tokens": 8, + "max_seq_len": 256, + "max_candidate_len": 4, + "verify_window": 2, + "end_length": 3, + "verify_strategy": VerifyStrategy.GREEDY.value, + "seed": 42, + "match_ratio": 1.0, + }, + { + "name": "topp_half_accept", + "real_bsz": 8, + "max_draft_tokens": 8, + "max_seq_len": 256, + "max_candidate_len": 6, + "verify_window": 2, + "end_length": 3, + "verify_strategy": VerifyStrategy.TOPP.value, + "seed": 42, + "match_ratio": 0.5, + }, + { + "name": "topp_full_accept", + "real_bsz": 8, + "max_draft_tokens": 8, + "max_seq_len": 256, + "max_candidate_len": 6, + "verify_window": 2, + "end_length": 3, + "verify_strategy": VerifyStrategy.TOPP.value, + "seed": 42, + "match_ratio": 1.0, + }, + { + "name": "target_match_accept", + "real_bsz": 8, + "max_draft_tokens": 6, + "max_seq_len": 256, + "max_candidate_len": 4, + "verify_window": 2, + "end_length": 3, + "verify_strategy": VerifyStrategy.TARGET_MATCH.value, + "seed": 42, + "match_ratio": 0.7, + }, + # --- reject_all / accept_all (kernel-level flags) --- + { + "name": "reject_all", + "real_bsz": 8, + "max_draft_tokens": 5, + "max_seq_len": 100, + "max_candidate_len": 5, + "verify_window": 2, + "end_length": 3, + "verify_strategy": VerifyStrategy.GREEDY.value, + "seed": 42, + "reject_all": True, + }, + { + "name": "accept_all", + "real_bsz": 8, + "max_draft_tokens": 5, + "max_seq_len": 100, + "max_candidate_len": 5, + "verify_window": 2, + "end_length": 3, + "verify_strategy": VerifyStrategy.TOPP.value, + "seed": 42, + "accept_all": True, + }, + # --- edge cases --- + { + "name": "empty_batch", + "real_bsz": 1, + "max_draft_tokens": 1, + "max_seq_len": 10, + "max_candidate_len": 2, + "verify_window": 1, + "end_length": 4, + "verify_strategy": VerifyStrategy.GREEDY.value, + "seed": 42, + }, +] + + +# ============================================================ +# Test suite +# ============================================================ + + +class TestVerifyDraftTokens(unittest.TestCase): + + def setUp(self): + pass + # if not paddle.is_compiled_with_cuda(): + # self.skipTest("Requires CUDA") + + # ------ shared run + check helper ------ + + def _run_and_compare(self, inputs: Dict[str, Any], label: str = ""): + """Convert→run kernel→run ref→compare.""" + paddle_inputs = to_paddle_inputs(inputs) + # print("paddle_inputs: ", paddle_inputs) + run_kernel(paddle_inputs, inputs) + ids_ref, len_ref = run_ref(inputs) + compare_results(paddle_inputs, ids_ref, len_ref, inputs, label) + return paddle_inputs + + # ------ test cases ------ + + def test_verify_configs(self): + """Test all configs in TEST_CONFIGS (strategies, reject/accept, edge cases).""" + for cfg in TEST_CONFIGS: + with self.subTest(name=cfg["name"]): + test_cfg = {k: v for k, v in cfg.items() if k != "name"} + inputs = gen_verify_draft_tokens_inputs(**test_cfg) + self._run_and_compare(inputs, label=cfg["name"]) + + def test_eos_handling(self): + """Test EOS token in draft triggers early stop.""" + inputs = gen_verify_draft_tokens_inputs( + real_bsz=4, max_draft_tokens=5, verify_strategy=VerifyStrategy.GREEDY.value, seed=42 + ) + inputs["step_input_ids"][0, 2] = inputs["end_tokens"][0] + self._run_and_compare(inputs, label="eos_handling") + + def test_max_dec_len_truncation(self): + """Test max_dec_len causes token replacement with end_tokens[0].""" + inputs = gen_verify_draft_tokens_inputs( + real_bsz=4, max_draft_tokens=5, verify_strategy=VerifyStrategy.GREEDY.value, seed=42, match_ratio=1.0 + ) + # Set step_idx close to max_dec_len so it triggers during verification + inputs["step_idx"][:] = [48, 10, 10, 10] + inputs["max_dec_len"][:] = [50, 200, 200, 200] + inputs["is_block_step"][:] = False + inputs["stop_flags"][:] = False + # Ensure no accidental EOS in draft tokens + for bid in range(4): + for j in range(5): + while inputs["step_input_ids"][bid, j] in inputs["end_tokens"]: + inputs["step_input_ids"][bid, j] = (inputs["step_input_ids"][bid, j] + 1) % 1000 + self._run_and_compare(inputs, label="max_dec_len_truncation") + + def test_verify_strategy_enum(self): + self.assertEqual(VerifyStrategy.TOPP.value, 0) + self.assertEqual(VerifyStrategy.GREEDY.value, 1) + self.assertEqual(VerifyStrategy.TARGET_MATCH.value, 2) + + def test_verify_strategy_from_string(self): + self.assertEqual(VerifyStrategy.from_string("topp"), VerifyStrategy.TOPP) + self.assertEqual(VerifyStrategy.from_string("TOPP"), VerifyStrategy.TOPP) + self.assertEqual(VerifyStrategy.from_string("greedy"), VerifyStrategy.GREEDY) + self.assertEqual(VerifyStrategy.from_string("target_match"), VerifyStrategy.TARGET_MATCH) + with self.assertRaises(ValueError): + VerifyStrategy.from_string("invalid") + + def test_topp_verify_window_fallback(self): + """Test TOPP verify_window fallback: top-2 match + consecutive top-1 matches.""" + real_bsz, max_draft_tokens, max_candidate_len, verify_window = 1, 8, 4, 2 + + inputs = gen_verify_draft_tokens_inputs( + real_bsz=real_bsz, + max_draft_tokens=max_draft_tokens, + verify_strategy=VerifyStrategy.TOPP.value, + max_candidate_len=max_candidate_len, + verify_window=verify_window, + seed=42, + ) + + # Rebuild arrays for full seq_lens_this_time + new_slt = max_draft_tokens + 1 + inputs["seq_lens_this_time"] = np.array([new_slt], dtype=np.int32) + inputs["cu_seqlens_q_output"] = np.array([0], dtype=np.int32) + + rng = np.random.default_rng(42) + sum_seq = new_slt + inputs["candidate_ids"] = rng.integers(0, 1000, size=(sum_seq, max_candidate_len), dtype=np.int64) + inputs["candidate_scores"] = rng.random(size=(sum_seq, max_candidate_len)).astype(np.float32) + inputs["candidate_scores"] /= inputs["candidate_scores"].sum(axis=1, keepdims=True) + inputs["candidate_lens"] = rng.integers(1, max_candidate_len + 1, size=sum_seq, dtype=np.int32) + + # Draft tokens + draft_tokens = [100, 200, 300, 400, 500, 600, 700] + for i, token in enumerate(draft_tokens): + inputs["step_input_ids"][0, i + 1] = token + + # Position 0: draft NOT in candidates, but top-2 matches draft + inputs["candidate_ids"][0] = [999, 100, 998, 997] + # Positions 1,2: top-1 matches next draft tokens + inputs["candidate_ids"][1] = [200, 888, 777, 666] + inputs["candidate_ids"][2] = [300, 555, 444, 333] + inputs["candidate_lens"][:3] = 4 + inputs["is_block_step"] = np.zeros(real_bsz, dtype=bool) + + self._run_and_compare(inputs, label="verify_window_fallback") + + def test_topp_verify_window_no_fallback(self): + """Test TOPP when verify_window fallback does NOT trigger.""" + inputs = gen_verify_draft_tokens_inputs( + real_bsz=1, + max_draft_tokens=5, + verify_strategy=VerifyStrategy.TOPP.value, + max_candidate_len=4, + verify_window=2, + seed=42, + ) + + inputs["step_input_ids"][0, 1:] = [999, 998, 997, 996] + inputs["candidate_ids"][:] = 0 + inputs["candidate_ids"][0] = [1, 2, 3, 4] + inputs["candidate_lens"][0] = 4 + inputs["seq_lens_this_time"][0] = 5 + + self._run_and_compare(inputs, label="verify_window_no_fallback") + + +if __name__ == "__main__": + unittest.main() From ffcec19b1cfc0d925833a40290661122db95a42a Mon Sep 17 00:00:00 2001 From: cmcamdy Date: Fri, 20 Mar 2026 06:10:29 +0000 Subject: [PATCH 02/12] fix test --- .../xpu_ops/test/test_verify_draft_tokens.py | 274 +++++++++++++++++- 1 file changed, 271 insertions(+), 3 deletions(-) diff --git a/custom_ops/xpu_ops/test/test_verify_draft_tokens.py b/custom_ops/xpu_ops/test/test_verify_draft_tokens.py index ce3d4f71cb6..cfe6a6214f6 100644 --- a/custom_ops/xpu_ops/test/test_verify_draft_tokens.py +++ b/custom_ops/xpu_ops/test/test_verify_draft_tokens.py @@ -28,9 +28,7 @@ import numpy as np import paddle -from fastdeploy.model_executor.ops.xpu import static_op_verify_draft_tokens - -verify_draft_tokens = static_op_verify_draft_tokens +from fastdeploy.model_executor.ops.xpu import verify_draft_tokens from fastdeploy.spec_decode import VerifyStrategy CPU_PLACE = paddle.CPUPlace() @@ -621,6 +619,54 @@ def gen_verify_draft_tokens_inputs( "seed": 42, "accept_all": True, }, + { + "name": "reject_all_topp", + "real_bsz": 8, + "max_draft_tokens": 5, + "max_seq_len": 100, + "max_candidate_len": 5, + "verify_window": 2, + "end_length": 3, + "verify_strategy": VerifyStrategy.TOPP.value, + "seed": 42, + "reject_all": True, + }, + { + "name": "reject_all_target_match", + "real_bsz": 8, + "max_draft_tokens": 5, + "max_seq_len": 100, + "max_candidate_len": 5, + "verify_window": 2, + "end_length": 3, + "verify_strategy": VerifyStrategy.TARGET_MATCH.value, + "seed": 42, + "reject_all": True, + }, + { + "name": "accept_all_greedy", + "real_bsz": 8, + "max_draft_tokens": 5, + "max_seq_len": 100, + "max_candidate_len": 5, + "verify_window": 2, + "end_length": 3, + "verify_strategy": VerifyStrategy.GREEDY.value, + "seed": 42, + "accept_all": True, + }, + { + "name": "accept_all_target_match", + "real_bsz": 8, + "max_draft_tokens": 5, + "max_seq_len": 100, + "max_candidate_len": 5, + "verify_window": 2, + "end_length": 3, + "verify_strategy": VerifyStrategy.TARGET_MATCH.value, + "seed": 42, + "accept_all": True, + }, # --- edge cases --- { "name": "empty_batch", @@ -766,6 +812,228 @@ def test_topp_verify_window_no_fallback(self): self._run_and_compare(inputs, label="verify_window_no_fallback") + def test_stop_flags_skip(self): + """Test that sequences with stop_flags=True are skipped (output_len=0).""" + inputs = gen_verify_draft_tokens_inputs( + real_bsz=4, max_draft_tokens=5, verify_strategy=VerifyStrategy.GREEDY.value, seed=42, match_ratio=1.0 + ) + inputs["is_block_step"][:] = False + inputs["stop_flags"][:] = [True, False, True, False] + self._run_and_compare(inputs, label="stop_flags_skip") + # Double-check stopped sequences produce output_len=0 + paddle_inputs = to_paddle_inputs(inputs) + run_kernel(paddle_inputs, inputs) + gpu_len = paddle_inputs["step_output_len"].numpy() + self.assertEqual(gpu_len[0], 0, "stopped seq bid=0 should have output_len=0") + self.assertEqual(gpu_len[2], 0, "stopped seq bid=2 should have output_len=0") + + def test_prefill_skip(self): + """Test that prefill requests (seq_lens_encoder != 0) skip Phase 1, only output 1 token.""" + inputs = gen_verify_draft_tokens_inputs( + real_bsz=4, max_draft_tokens=6, verify_strategy=VerifyStrategy.GREEDY.value, seed=42, match_ratio=1.0 + ) + inputs["is_block_step"][:] = False + inputs["stop_flags"][:] = False + # Set bid 0 and 2 as prefill requests + inputs["seq_lens_encoder"][0] = 10 + inputs["seq_lens_encoder"][2] = 5 + self._run_and_compare(inputs, label="prefill_skip") + + def test_reasoning_status_skip(self): + """Test that reasoning_status=1 skips Phase 1, only outputs 1 token.""" + inputs = gen_verify_draft_tokens_inputs( + real_bsz=4, max_draft_tokens=6, verify_strategy=VerifyStrategy.GREEDY.value, seed=42, match_ratio=1.0 + ) + inputs["is_block_step"][:] = False + inputs["stop_flags"][:] = False + # Set bid 1 and 3 as reasoning mode + inputs["reasoning_status"][1] = 1 + inputs["reasoning_status"][3] = 1 + self._run_and_compare(inputs, label="reasoning_status_skip") + + def test_reject_all_and_accept_all_priority(self): + """Test that reject_all takes priority over accept_all when both are True.""" + inputs = gen_verify_draft_tokens_inputs( + real_bsz=4, + max_draft_tokens=5, + verify_strategy=VerifyStrategy.GREEDY.value, + seed=42, + match_ratio=1.0, + reject_all=True, + accept_all=True, + ) + inputs["is_block_step"][:] = False + inputs["stop_flags"][:] = False + self._run_and_compare(inputs, label="reject_all_and_accept_all") + # All sequences should produce exactly 1 token (Phase 2 only) + paddle_inputs = to_paddle_inputs(inputs) + run_kernel(paddle_inputs, inputs) + gpu_len = paddle_inputs["step_output_len"].numpy() + for bid in range(4): + self.assertEqual(gpu_len[bid], 1, f"reject_all should produce exactly 1 token at bid={bid}") + + def test_mixed_batch_heterogeneous(self): + """Test a batch with mixed states: normal, stopped, prefill, reasoning, block_step.""" + inputs = gen_verify_draft_tokens_inputs( + real_bsz=6, max_draft_tokens=6, verify_strategy=VerifyStrategy.GREEDY.value, seed=42, match_ratio=0.8 + ) + # bid 0: normal decode + inputs["is_block_step"][0] = False + inputs["stop_flags"][0] = False + inputs["seq_lens_encoder"][0] = 0 + inputs["reasoning_status"][0] = 0 + # bid 1: stopped + inputs["is_block_step"][1] = False + inputs["stop_flags"][1] = True + inputs["seq_lens_encoder"][1] = 0 + inputs["reasoning_status"][1] = 0 + # bid 2: prefill + inputs["is_block_step"][2] = False + inputs["stop_flags"][2] = False + inputs["seq_lens_encoder"][2] = 8 + inputs["reasoning_status"][2] = 0 + # bid 3: reasoning mode + inputs["is_block_step"][3] = False + inputs["stop_flags"][3] = False + inputs["seq_lens_encoder"][3] = 0 + inputs["reasoning_status"][3] = 1 + # bid 4: block step + inputs["is_block_step"][4] = True + inputs["stop_flags"][4] = False + inputs["seq_lens_encoder"][4] = 0 + inputs["reasoning_status"][4] = 0 + # bid 5: normal decode + inputs["is_block_step"][5] = False + inputs["stop_flags"][5] = False + inputs["seq_lens_encoder"][5] = 0 + inputs["reasoning_status"][5] = 0 + self._run_and_compare(inputs, label="mixed_batch_heterogeneous") + + def test_single_token_sequence(self): + """Test seq_lens_this_time=1: Phase 1 is skipped entirely, only Phase 2 outputs 1 token.""" + for strategy in [VerifyStrategy.GREEDY.value, VerifyStrategy.TOPP.value, VerifyStrategy.TARGET_MATCH.value]: + with self.subTest(strategy=strategy): + inputs = gen_verify_draft_tokens_inputs( + real_bsz=4, max_draft_tokens=8, verify_strategy=strategy, seed=42 + ) + inputs["seq_lens_this_time"][:] = 1 + # Recompute cu_seqlens_q_output for all-1 seq_lens + inputs["cu_seqlens_q_output"] = np.array([0, 1, 2, 3], dtype=np.int32) + # Regenerate target/candidate arrays for new sum_seq=4 + sum_seq = 4 + rng = np.random.default_rng(42) + if strategy in (1, 2): + inputs["target_tokens"] = rng.integers(0, 1000, size=(sum_seq,), dtype=np.int64) + else: + max_candidate_len = 8 + inputs["candidate_ids"] = rng.integers(0, 1000, size=(sum_seq, max_candidate_len), dtype=np.int64) + inputs["candidate_scores"] = rng.random(size=(sum_seq, max_candidate_len)).astype(np.float32) + inputs["candidate_scores"] /= inputs["candidate_scores"].sum(axis=1, keepdims=True) + inputs["candidate_lens"] = rng.integers(1, max_candidate_len + 1, size=sum_seq, dtype=np.int32) + inputs["is_block_step"][:] = False + inputs["stop_flags"][:] = False + self._run_and_compare(inputs, label=f"single_token_strategy_{strategy}") + + def test_max_dec_len_exact_boundary(self): + """Test step_idx == max_dec_len - 1: first emit triggers max_len_hit immediately.""" + inputs = gen_verify_draft_tokens_inputs( + real_bsz=4, max_draft_tokens=6, verify_strategy=VerifyStrategy.GREEDY.value, seed=42, match_ratio=1.0 + ) + inputs["is_block_step"][:] = False + inputs["stop_flags"][:] = False + # Set step_idx = max_dec_len - 1, so first emit_token increments past max_dec_len + inputs["max_dec_len"][:] = 50 + inputs["step_idx"][:] = 49 + # Ensure no accidental EOS in draft tokens + for bid in range(4): + for j in range(6): + while inputs["step_input_ids"][bid, j] in inputs["end_tokens"]: + inputs["step_input_ids"][bid, j] = (inputs["step_input_ids"][bid, j] + 1) % 1000 + self._run_and_compare(inputs, label="max_dec_len_exact_boundary") + # All sequences should produce exactly 1 token (first emit triggers stop) + paddle_inputs = to_paddle_inputs(inputs) + run_kernel(paddle_inputs, inputs) + gpu_len = paddle_inputs["step_output_len"].numpy() + for bid in range(4): + self.assertEqual(gpu_len[bid], 1, f"max_dec_len boundary should produce 1 token at bid={bid}") + + def test_eos_during_verify_window_bulk_accept(self): + """Test EOS token in the middle of verify_window bulk-accept range stops correctly.""" + real_bsz, max_draft_tokens, max_candidate_len, verify_window = 1, 10, 4, 2 + inputs = gen_verify_draft_tokens_inputs( + real_bsz=real_bsz, + max_draft_tokens=max_draft_tokens, + verify_strategy=VerifyStrategy.TOPP.value, + max_candidate_len=max_candidate_len, + verify_window=verify_window, + seed=42, + ) + + new_slt = max_draft_tokens + inputs["seq_lens_this_time"] = np.array([new_slt], dtype=np.int32) + inputs["cu_seqlens_q_output"] = np.array([0], dtype=np.int32) + + rng = np.random.default_rng(42) + sum_seq = new_slt + inputs["candidate_ids"] = rng.integers(0, 1000, size=(sum_seq, max_candidate_len), dtype=np.int64) + inputs["candidate_scores"] = rng.random(size=(sum_seq, max_candidate_len)).astype(np.float32) + inputs["candidate_scores"] /= inputs["candidate_scores"].sum(axis=1, keepdims=True) + inputs["candidate_lens"] = np.full(sum_seq, max_candidate_len, dtype=np.int32) + inputs["is_block_step"] = np.zeros(real_bsz, dtype=bool) + inputs["stop_flags"] = np.zeros(real_bsz, dtype=bool) + inputs["max_dec_len"][:] = 200 + + eos_token = int(inputs["end_tokens"][0]) + # Draft tokens: 100, 200, EOS, 400, 500, ... + draft_tokens = [100, 200, eos_token, 400, 500, 600, 700, 800, 900] + for i, token in enumerate(draft_tokens): + inputs["step_input_ids"][0, i + 1] = token + + # Position 0: draft NOT in top-1, but top-2 matches draft -> verify_window triggers + inputs["candidate_ids"][0] = [999, 100, 998, 997] + # Position 1: top-1 matches next draft + inputs["candidate_ids"][1] = [200, 888, 777, 666] + # Position 2: top-1 matches next draft (which is EOS) + inputs["candidate_ids"][2] = [eos_token, 555, 444, 333] + # Position 3 onwards: top-1 matches (shouldn't be reached due to EOS) + inputs["candidate_ids"][3] = [400, 222, 111, 100] + + self._run_and_compare(inputs, label="eos_during_verify_window") + + def test_topp_max_candidate_len_1(self): + """Test TOPP with max_candidate_len=1: verify_window fallback cannot trigger.""" + inputs = gen_verify_draft_tokens_inputs( + real_bsz=4, + max_draft_tokens=6, + verify_strategy=VerifyStrategy.TOPP.value, + max_candidate_len=1, + verify_window=2, + seed=42, + match_ratio=0.5, + ) + inputs["is_block_step"][:] = False + inputs["stop_flags"][:] = False + self._run_and_compare(inputs, label="topp_max_candidate_len_1") + + def test_phase2_eos_token(self): + """Test Phase 2 target token is an EOS token.""" + inputs = gen_verify_draft_tokens_inputs( + real_bsz=4, max_draft_tokens=5, verify_strategy=VerifyStrategy.GREEDY.value, seed=42 + ) + inputs["is_block_step"][:] = False + inputs["stop_flags"][:] = False + # Make all draft tokens NOT match target (all reject at position 0) + inputs["step_input_ids"][:, 1:] = 999 + if inputs["target_tokens"] is not None: + inputs["target_tokens"][:] = 888 + # Now set the Phase 2 token (target_tokens at position 0 for each bid) to EOS + eos_token = int(inputs["end_tokens"][0]) + offset = 0 + for bid in range(4): + inputs["target_tokens"][offset] = eos_token + offset += int(inputs["seq_lens_this_time"][bid]) + self._run_and_compare(inputs, label="phase2_eos_token") + if __name__ == "__main__": unittest.main() From e94441a0fd5069c6fe5dc972fab4d1d6cd8cf77a Mon Sep 17 00:00:00 2001 From: cmcamdy Date: Fri, 20 Mar 2026 06:11:59 +0000 Subject: [PATCH 03/12] fix code style --- .../kernel/kunlun3cpp/mtp_kernel/verify_draft_tokens.xpu | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/verify_draft_tokens.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/verify_draft_tokens.xpu index 3519a963442..066298e7028 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/verify_draft_tokens.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/verify_draft_tokens.xpu @@ -58,12 +58,7 @@ __device__ inline bool is_in(__global_ptr__ const int64_t *candidates, } return false; } -// static __device__ inline unsigned int xorwow(unsigned int& state) { -// state ^= state >> 7; -// state ^= state << 9; -// state ^= state >> 13; -// return state; -// } + static __device__ inline unsigned int xorwow(unsigned int &state) { state ^= state >> 7; state ^= state << 9; From 025fa3fe85fab2ec63d5e1359fd2f8d294d8c7ec Mon Sep 17 00:00:00 2001 From: cmcamdy Date: Wed, 8 Apr 2026 11:56:20 +0000 Subject: [PATCH 04/12] use sync cpy --- .../xpu_ops/src/ops/mtp/verify_draft_token.cc | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/custom_ops/xpu_ops/src/ops/mtp/verify_draft_token.cc b/custom_ops/xpu_ops/src/ops/mtp/verify_draft_token.cc index fa685e30a68..6c32fd6183b 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/verify_draft_token.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/verify_draft_token.cc @@ -90,15 +90,19 @@ void VerifyDraftTokens( std::mt19937_64 engine(infer_seed[i]); dev_curand_states_cpu.push_back(dist(engine)); } - float *dev_curand_states_xpu; + float *dev_curand_states = dev_curand_states_cpu.data(); + auto dev_curand_states_tensor = + paddle::empty({static_cast(dev_curand_states_cpu.size())}, + paddle::DataType::FLOAT32, + draft_tokens.place()); + int ret; if (xpu_ctx_flag) { - xpu::ctx_guard RAII_GUARD(ctx); - dev_curand_states_xpu = - RAII_GUARD.alloc(dev_curand_states_cpu.size()); - xpu_memcpy(dev_curand_states_xpu, - dev_curand_states_cpu.data(), - dev_curand_states_cpu.size() * sizeof(float), - XPUMemcpyKind::XPU_HOST_TO_DEVICE); + ret = xpu::do_host2device(ctx, + dev_curand_states_cpu.data(), + dev_curand_states_tensor.data(), + dev_curand_states_cpu.size() * sizeof(float)); + PD_CHECK(ret == 0, "do_host2device failed."); + dev_curand_states = dev_curand_states_tensor.data(); } // Get data pointers (nullptr if optional not provided) @@ -146,7 +150,7 @@ void VerifyDraftTokens( candidate_scores_ptr, candidate_lens_ptr, // Sampling params - dev_curand_states_xpu, + dev_curand_states, topp.data(), // Metadata stop_flags.data(), From e929f07500788f9914fd0f46612f0b46f9e6af3d Mon Sep 17 00:00:00 2001 From: cmcamdy Date: Wed, 8 Apr 2026 12:38:51 +0000 Subject: [PATCH 05/12] fix code style --- custom_ops/xpu_ops/src/ops/mtp/verify_draft_token.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/custom_ops/xpu_ops/src/ops/mtp/verify_draft_token.cc b/custom_ops/xpu_ops/src/ops/mtp/verify_draft_token.cc index 6c32fd6183b..7351ccb7f7d 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/verify_draft_token.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/verify_draft_token.cc @@ -94,7 +94,7 @@ void VerifyDraftTokens( auto dev_curand_states_tensor = paddle::empty({static_cast(dev_curand_states_cpu.size())}, paddle::DataType::FLOAT32, - draft_tokens.place()); + seq_lens_this_time.place()); int ret; if (xpu_ctx_flag) { ret = xpu::do_host2device(ctx, @@ -137,7 +137,7 @@ void VerifyDraftTokens( } } } - int ret = fastdeploy::plugin::verify_draft_tokens( + ret = fastdeploy::plugin::verify_draft_tokens( ctx, // Core I/O const_cast(step_output_ids.data()), From f6d5b5250985a831472f5d35df9bb96c96f0f62c Mon Sep 17 00:00:00 2001 From: cmcamdy Date: Mon, 13 Apr 2026 03:17:41 +0000 Subject: [PATCH 06/12] fix kernel check --- .../src/wrapper/mtp_wrapper/verify_draft_tokens.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/verify_draft_tokens.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/verify_draft_tokens.cpp index 76376356948..e8c5c075459 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/verify_draft_tokens.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/verify_draft_tokens.cpp @@ -529,14 +529,14 @@ int verify_draft_tokens( WRAPPER_CHECK_PTR(ctx, float, real_bsz, topp); WRAPPER_CHECK_PTR(ctx, bool, real_bsz, stop_flags); WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_encoder); - WRAPPER_CHECK_PTR(ctx, float, real_bsz, seq_lens_this_time); + WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_this_time); WRAPPER_CHECK_PTR(ctx, int64_t, end_length, end_tokens); WRAPPER_CHECK_PTR(ctx, bool, real_bsz, is_block_step); - WRAPPER_CHECK_PTR(ctx, bool, real_bsz, cu_seqlens_q_output); - WRAPPER_CHECK_PTR(ctx, bool, real_bsz, reasoning_status); - WRAPPER_CHECK_PTR(ctx, bool, real_bsz, max_dec_len); - WRAPPER_CHECK_PTR(ctx, bool, real_bsz, step_idx); + WRAPPER_CHECK_PTR(ctx, int, real_bsz, cu_seqlens_q_output); + WRAPPER_CHECK_PTR(ctx, int, real_bsz, reasoning_status); + WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz, max_dec_len); + WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz, step_idx); // param check sm size limit WRAPPER_ASSERT_GT(ctx, real_bsz, 0); WRAPPER_ASSERT_LE(ctx, real_bsz, 1024); From c595cf018c328e20fa3a93c460c55ebc96cea386 Mon Sep 17 00:00:00 2001 From: cmcamdy Date: Mon, 13 Apr 2026 05:10:04 +0000 Subject: [PATCH 07/12] fix ramdom seed --- custom_ops/xpu_ops/src/ops/mtp/speculate_verify.cc | 14 ++++++++++---- .../xpu_ops/src/ops/mtp/verify_draft_token.cc | 13 ++++++++++--- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_verify.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_verify.cc index 6749518d89b..d504917fe44 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/speculate_verify.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_verify.cc @@ -26,6 +26,10 @@ namespace api = baidu::xpu::api; +// Persistent seed/offset — mirrors GPU curand state lifecycle. +static uint64_t g_seed = 0; +static uint64_t g_offset = 0; + void SpeculateVerify(const paddle::Tensor &sampled_token_ids, const paddle::Tensor &accept_tokens, const paddle::Tensor &accept_num, @@ -82,13 +86,15 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids, prefill_one_step_stop = true; } } - // random - int random_seed = 0; - std::vector infer_seed(bsz, random_seed); + // random — use persistent seed/offset so each call and batch element + // produce distinct random numbers (mirrors GPU curand lifecycle). + uint64_t cur_seed = g_seed++; + uint64_t cur_offset = g_offset++; std::uniform_real_distribution dist(0.0, 1.0); std::vector dev_curand_states_cpu; for (int i = 0; i < bsz; i++) { - std::mt19937_64 engine(infer_seed[i]); + std::mt19937_64 engine(cur_seed + i); + engine.discard(cur_offset); dev_curand_states_cpu.push_back(dist(engine)); } float *dev_curand_states_xpu; diff --git a/custom_ops/xpu_ops/src/ops/mtp/verify_draft_token.cc b/custom_ops/xpu_ops/src/ops/mtp/verify_draft_token.cc index 7351ccb7f7d..5a07be2db0b 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/verify_draft_token.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/verify_draft_token.cc @@ -32,6 +32,10 @@ namespace api = baidu::xpu::api; +// Persistent seed/offset — mirrors GPU curand state lifecycle. +static uint64_t g_seed = 0; +static uint64_t g_offset = 0; + // ============================================================ // Host function // ============================================================ @@ -82,12 +86,15 @@ void VerifyDraftTokens( int max_candidate_len = candidate_ids ? candidate_ids->shape()[1] : 1; // curand state: only needed for TOPP(0) strategy (stochastic sampling) - int random_seed = 0; - std::vector infer_seed(bsz, random_seed); + // Use persistent seed/offset (mirrors GPU curand lifecycle) so that + // each call and each batch element produce distinct random numbers. + uint64_t cur_seed = g_seed++; + uint64_t cur_offset = g_offset++; std::uniform_real_distribution dist(0.0, 1.0); std::vector dev_curand_states_cpu; for (int i = 0; i < bsz; i++) { - std::mt19937_64 engine(infer_seed[i]); + std::mt19937_64 engine(cur_seed + i); + engine.discard(cur_offset); dev_curand_states_cpu.push_back(dist(engine)); } float *dev_curand_states = dev_curand_states_cpu.data(); From f2b119260f45259c18ca88fd00d478d432707a2a Mon Sep 17 00:00:00 2001 From: cmcamdy Date: Mon, 13 Apr 2026 05:11:09 +0000 Subject: [PATCH 08/12] fix test --- custom_ops/xpu_ops/test/test_verify_draft_tokens.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/custom_ops/xpu_ops/test/test_verify_draft_tokens.py b/custom_ops/xpu_ops/test/test_verify_draft_tokens.py index cfe6a6214f6..384d6788ca7 100644 --- a/custom_ops/xpu_ops/test/test_verify_draft_tokens.py +++ b/custom_ops/xpu_ops/test/test_verify_draft_tokens.py @@ -32,7 +32,7 @@ from fastdeploy.spec_decode import VerifyStrategy CPU_PLACE = paddle.CPUPlace() -CUDA_PLACE = paddle.XPUPlace(0) if paddle.is_compiled_with_xpu() else paddle.CPUPlace() +DEVICE_PLACE = paddle.XPUPlace(0) if paddle.is_compiled_with_xpu() else paddle.CPUPlace() # ============================================================ @@ -47,7 +47,7 @@ def to_paddle_inputs(inputs: Dict[str, Any]) -> Dict[str, Any]: if isinstance(v, (int, bool, float, str)): paddle_inputs[k] = v elif v is not None: - paddle_inputs[k] = paddle.to_tensor(v, place=CUDA_PLACE) + paddle_inputs[k] = paddle.to_tensor(v, place=DEVICE_PLACE) else: paddle_inputs[k] = None return paddle_inputs From 97e960bca3631db63bca4f7be8f97dbd872c2f15 Mon Sep 17 00:00:00 2001 From: cmcamdy Date: Mon, 13 Apr 2026 05:17:47 +0000 Subject: [PATCH 09/12] fix check --- custom_ops/xpu_ops/src/ops/mtp/verify_draft_token.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/custom_ops/xpu_ops/src/ops/mtp/verify_draft_token.cc b/custom_ops/xpu_ops/src/ops/mtp/verify_draft_token.cc index 5a07be2db0b..f0a3472e063 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/verify_draft_token.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/verify_draft_token.cc @@ -181,10 +181,10 @@ void VerifyDraftTokens( verify_strategy, reject_all, accept_all); - PD_CHECK(ret == 0, "verify_draft_tokens failed."); if (step_output_ids.is_cpu()) { delete ctx; } + PD_CHECK(ret == 0, "verify_draft_tokens failed."); } PD_BUILD_STATIC_OP(verify_draft_tokens) From 662101db3c7f55fa2167668e8f64e8a761f880aa Mon Sep 17 00:00:00 2001 From: cmcamdy Date: Mon, 13 Apr 2026 07:55:07 +0000 Subject: [PATCH 10/12] fix eos set --- .../src/kernel/kunlun3cpp/mtp_kernel/verify_draft_tokens.xpu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/verify_draft_tokens.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/verify_draft_tokens.xpu index 066298e7028..a0c4d736ad9 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/verify_draft_tokens.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/verify_draft_tokens.xpu @@ -125,7 +125,7 @@ struct VerifyContext { cur_step_idx++; bool is_eos = is_in_end(token, end_tokens, end_length); bool max_len_hit = (cur_step_idx >= max_dec_len[bid]); - if ((is_eos || max_len_hit) && !is_eos) { + if (max_len_hit && !is_eos) { token = end_tokens[0]; } step_output_ids[bid * max_step_tokens + pos] = token; @@ -144,7 +144,7 @@ struct VerifyContext { cur_step_idx++; bool is_eos = is_in_end(token, end_tokens, end_length); bool max_len_hit = (cur_step_idx >= max_dec_len[bid]); - if ((is_eos || max_len_hit) && !is_eos) { + if (max_len_hit && !is_eos) { token = end_tokens[0]; } step_output_ids[bid * max_step_tokens + pos] = token; From 59d705974b0b299f5d44b7be930c2a66ed458743 Mon Sep 17 00:00:00 2001 From: cmcamdy Date: Tue, 14 Apr 2026 03:26:33 +0000 Subject: [PATCH 11/12] fix verify --- custom_ops/xpu_ops/src/ops/mtp/speculate_verify.cc | 5 +++-- custom_ops/xpu_ops/src/ops/mtp/verify_draft_token.cc | 7 ++++--- .../kernel/kunlun3cpp/mtp_kernel/verify_draft_tokens.xpu | 2 +- .../plugin/src/wrapper/mtp_wrapper/verify_draft_tokens.cpp | 3 ++- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_verify.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_verify.cc index d504917fe44..be5f78addd1 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/speculate_verify.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_verify.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include "paddle/common/flags.h" @@ -27,8 +28,8 @@ namespace api = baidu::xpu::api; // Persistent seed/offset — mirrors GPU curand state lifecycle. -static uint64_t g_seed = 0; -static uint64_t g_offset = 0; + static std::atomic g_seed{0}; + static std::atomic g_offset{0}; void SpeculateVerify(const paddle::Tensor &sampled_token_ids, const paddle::Tensor &accept_tokens, diff --git a/custom_ops/xpu_ops/src/ops/mtp/verify_draft_token.cc b/custom_ops/xpu_ops/src/ops/mtp/verify_draft_token.cc index f0a3472e063..ae0058bf8c6 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/verify_draft_token.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/verify_draft_token.cc @@ -22,6 +22,7 @@ // fallback) 1 = GREEDY : draft token == top-1 token (strict argmax // match) 2 = TARGET_MATCH : draft token == target model's sampled token +#include #include #include "paddle/extension.h" #include "xpu/plugin.h" @@ -33,8 +34,8 @@ namespace api = baidu::xpu::api; // Persistent seed/offset — mirrors GPU curand state lifecycle. -static uint64_t g_seed = 0; -static uint64_t g_offset = 0; +static std::atomic g_seed{0}; +static std::atomic g_offset{0}; // ============================================================ // Host function @@ -104,7 +105,7 @@ void VerifyDraftTokens( seq_lens_this_time.place()); int ret; if (xpu_ctx_flag) { - ret = xpu::do_host2device(ctx, + ret = api::do_host2device(ctx, dev_curand_states_cpu.data(), dev_curand_states_tensor.data(), dev_curand_states_cpu.size() * sizeof(float)); diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/verify_draft_tokens.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/verify_draft_tokens.xpu index a0c4d736ad9..645c69469d9 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/verify_draft_tokens.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/verify_draft_tokens.xpu @@ -325,7 +325,7 @@ __global__ void verify_draft_tokens( output_token = topp_sampling_kernel(candidate_ids_now + i * max_candidate_len, candidate_scores_now + i * max_candidate_len, - curand_states, + curand_states + bid, actual_cand_len, topp[bid]); break; diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/verify_draft_tokens.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/verify_draft_tokens.cpp index e8c5c075459..0233497ecd9 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/verify_draft_tokens.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/verify_draft_tokens.cpp @@ -348,7 +348,7 @@ static int cpu_wrapper( output_token = topp_sampling_kernel(candidate_ids_now + i * max_candidate_len, candidate_scores_now + i * max_candidate_len, - curand_states + i, + curand_states + bid, actual_cand_len, topp[bid], bid); @@ -542,6 +542,7 @@ int verify_draft_tokens( WRAPPER_ASSERT_LE(ctx, real_bsz, 1024); WRAPPER_ASSERT_LE(ctx, real_bsz * max_candidate_len, 2048); WRAPPER_ASSERT_LE(ctx, verify_window * max_candidate_len, 128); + WRAPPER_CHECK_PTR(ctx, int, real_bsz, step_output_len); if (ctx->dev().type() == api::kCPU) { return cpu_wrapper(ctx, From d7fccfe41a571ce6c9630eef9f26127e5a5770f6 Mon Sep 17 00:00:00 2001 From: cmcamdy Date: Tue, 14 Apr 2026 03:29:48 +0000 Subject: [PATCH 12/12] fix verify --- custom_ops/xpu_ops/src/ops/mtp/speculate_verify.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_verify.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_verify.cc index be5f78addd1..c2bb21c8929 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/speculate_verify.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_verify.cc @@ -28,8 +28,8 @@ namespace api = baidu::xpu::api; // Persistent seed/offset — mirrors GPU curand state lifecycle. - static std::atomic g_seed{0}; - static std::atomic g_offset{0}; +static std::atomic g_seed{0}; +static std::atomic g_offset{0}; void SpeculateVerify(const paddle::Tensor &sampled_token_ids, const paddle::Tensor &accept_tokens,