Skip to content

[BugFix][Engine] Fix functional gaps in ZMQ token processing path vs legacy batch output path#6954

Draft
Copilot wants to merge 2 commits intodevelopfrom
copilot/fix-token-processor-gaps
Draft

[BugFix][Engine] Fix functional gaps in ZMQ token processing path vs legacy batch output path#6954
Copilot wants to merge 2 commits intodevelopfrom
copilot/fix-token-processor-gaps

Conversation

Copy link
Contributor

Copilot AI commented Mar 20, 2026

The ZMQ-based token post-processing path (_process_batch_output_use_zmq + _process_per_token) had multiple functional gaps compared to the legacy _process_batch_output path, causing behavioral divergence in production scenarios.

Modifications

_process_per_token

  • FD_ENABLE_INTERNAL_ADAPTER eos filtering: eos tokens were unconditionally appended to result.outputs.token_ids; now filtered out when internal adapter is enabled (still appended to task.output_token_ids)
  • _compute_speculative_status missing arg: called without result, causing a runtime error in speculative decoding paths
  • cache_output_tokens missing: output token KV cache not persisted when enable_prefix_caching + enable_output_caching are both on
  • Completion log missing TTFT_S: added ttft_s = ttft + task.metrics.time_in_queue to match legacy log format

_process_batch_output_use_zmq

  • RequestOutput missing fields: output_type=3 and prompt_token_ids_len were absent; downstream usage stats and serialization depended on these
  • num_cached_tokens set only on first token: moved outside the tokens_counter == 0 block so it reflects the current value on every step
  • prefill_chunk_info not handled: chunked prefill would produce premature intermediate results before all chunks completed
  • scheduler_metrics_logger not notified: decode token metrics were never reported via ZMQ path
  • draft_token_ids not populated for multi-token prefill: splitwise prefill scenarios missing draft token passthrough

Checklist

  • Add at least a tag in the PR title.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests. — No unit test infrastructure exists for this internal token post-processing path; the changes are logic-parity fixes validated by code review against the legacy path.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.
Original prompt

Background

In PR #6879 (feat/zmq_mtp_new branch from sunlei1024/FastDeploy), token_processor.py introduces a new ZMQ-based token post-processing path via process_model_runner_output() + _process_per_token(), intended to be functionally equivalent to the legacy _process_batch_output() path. However, a detailed comparison reveals multiple functional gaps and missing logic in the new path. These need to be fixed to ensure behavioral parity.

File to modify

fastdeploy/output/token_processor.py

The file can be viewed at: https://github.com/sunlei1024/FastDeploy/blob/c8eaaec4ff504248cf86c04fac422d44822c4f3c/fastdeploy/output/token_processor.py

Specific Issues to Fix

1. Missing prompt_token_ids_len and output_type in RequestOutput construction (process_model_runner_output, ~L1006-L1017)

The old path (_process_batch_output, L656-L669) constructs RequestOutput with output_type=mtype and prompt_token_ids_len=task.prompt_token_ids_len. The new path is missing both fields.

Fix: Add output_type=model_output.decode_mode and prompt_token_ids_len=task.prompt_token_ids_len to the RequestOutput constructor in process_model_runner_output.

2. Missing FD_ENABLE_INTERNAL_ADAPTER filtering in _process_per_token (~L1070-L1072)

The old path (L685-L689) filters eos tokens from result.outputs.token_ids when FD_ENABLE_INTERNAL_ADAPTER is enabled:

if token_id != RECOVERY_STOP_SIGNAL:
    if not (envs.FD_ENABLE_INTERNAL_ADAPTER and token_id in task.eos_token_ids):
        result.outputs.token_ids.append(token_id)
    task.output_token_ids.append(token_id)

The new path unconditionally appends.

Fix: Add the same FD_ENABLE_INTERNAL_ADAPTER check in _process_per_token:

if token_id != RECOVERY_STOP_SIGNAL:
    if not (envs.FD_ENABLE_INTERNAL_ADAPTER and token_id in task.eos_token_ids):
        result.outputs.token_ids.append(token_id)
    task.output_token_ids.append(token_id)

3. Missing cache_output_tokens call in _process_per_token (~L1074-L1092)

The old path (L758-L765) caches output tokens when prefix caching + output caching is enabled. The new path has no such call before _recycle_resources.

Fix: Add the same output caching logic in _process_per_token before the _recycle_resources call:

if (
    envs.ENABLE_V1_KVCACHE_SCHEDULER
    and self.cfg.cache_config.enable_prefix_caching
    and self.cfg.cache_config.enable_output_caching
):
    self.resource_manager.cache_output_tokens(task)

4. Missing num_cached_tokens assignment outside first-token block (process_model_runner_output, ~L1037-L1044)

The old path sets num_cached_tokens unconditionally (outside the tokens_counter == 0 block). The new path only sets it inside the first-token block.

Fix: Move result.num_cached_tokens = task.num_cached_tokens outside the if self.tokens_counter[task_id] == 0 block. Keep multimodal_inputs handling inside the first-token block (matching old behavior).

5. Missing prefill_chunk_info handling (process_model_runner_output)

The old path (L623-L628) handles chunked prefill. The new path has no equivalent logic.

Fix: Add prefill_chunk_info handling in process_model_runner_output, after the abort handling and before the metrics section:

if task.get("prefill_chunk_info", None) is not None:
    prefill_chunk_num = task.get("prefill_chunk_num", 0)
    task.prefill_chunk_num = prefill_chunk_num + 1
    if task.prefill_chunk_num < len(task.prefill_chunk_info):
        continue

6. Missing scheduler_metrics_logger callback (process_model_runner_output)

The old path (L620-L621) notifies the scheduler metrics logger. The new path has no equivalent.

Fix: Add in process_model_runner_output, after abort handling and chunk handling, before metrics:

if self.scheduler_metrics_logger and self._is_decode_stage(task):
    self.scheduler_metrics_logger.on_decode_tokens(len(token_ids))

7. Missing draft_token_ids for prefill scenario

The old path (L679-L680) populates draft_token_ids during prefill. The new path has no equivalent.

Fix: Add in process_model_runner_output, after constructing the result and before calling _process_per_token:

if is_prefill and len(token_ids) > 1:
    result.outputs.draft_token_ids = copy.deepcopy(token_ids)

8. Missing _record_speculative_decoding_metrics in ZMQ loop

The old path calls global speculative decoding metrics recording after processing all batches. The new ZMQ path never calls it.

Fix: In process_sampling_results_use_zmq, after process_model_runner_output returns, compute accept_num from cu_num_generated_tokens and call _record_speculative_decoding_metrics:

if self.cfg.speculative_config.method and batch_result:
    accept_num = [
        model_runner_output.cu_num_generated_tokens[i + 1] - model_runner_output.cu_num_generated_tokens[i]
        for i in r...

</details>



<!-- START COPILOT CODING AGENT SUFFIX -->

*This pull request was created from Copilot chat.*
>

<!-- START COPILOT CODING AGENT TIPS -->
---

🔒 GitHub Advanced Security automatically protects Copilot coding agent pull requests. You can protect all pull requests by enabling Advanced Security for your repositories. [Learn more about Advanced Security.](https://gh.io/cca-advanced-security)

@CLAassistant
Copy link

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.

@paddle-bot
Copy link

paddle-bot bot commented Mar 20, 2026

Thanks for your contribution!

@paddle-bot paddle-bot bot added the contributor External developers label Mar 20, 2026
…atch output path

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
Copilot AI changed the title [WIP] Fix functional gaps in new ZMQ-based token processing [BugFix][Engine] Fix functional gaps in ZMQ token processing path vs legacy batch output path Mar 20, 2026
Copilot AI requested a review from Jiang-Jia-Jun March 20, 2026 09:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

contributor External developers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants