Skip to content

[Feature] Support NVFP4 Flashinfer-cutedsl MoE on SM100#6963

Open
mpgemm wants to merge 7 commits intoPaddlePaddle:developfrom
mpgemm:fp4-moe
Open

[Feature] Support NVFP4 Flashinfer-cutedsl MoE on SM100#6963
mpgemm wants to merge 7 commits intoPaddlePaddle:developfrom
mpgemm:fp4-moe

Conversation

@mpgemm
Copy link

@mpgemm mpgemm commented Mar 21, 2026

Motivation

FastDeploy 集成 flashinfer-cutedsl nvfp4 grouped masked gemm。支持apply_ep_prefill 和 apply_ep_decode。

Modifications

修改围绕两方面,一方面是导入flashinfer-cutedsl-blockscaled-gemm与paddle格式不兼容的问题,另一方面则是 FD框架集成。

解决paddle不兼容的三个问题,需要在 miniconda/envs/yourname/lib/python3.10/site-packages/ 里面修改nvidia-dsl和 flashinfer):
1:nvidia_cutlass_dsl/python_packages/cutlass/torch.py 将 torch.device 改成 "torch.device"。 (ctrl+f搜索替换)
2:flashinfer/utils.py. get_compute_capability函数下面改成:
@functools.cache
def get_compute_capability(device: torch.device) -> Tuple[int, int]:
return torch.cuda.get_device_capability(device)
if device.type != "cuda":
raise ValueError("device must be a cuda device")
return torch.cuda.get_device_capability(device.index)
注:如果遇到device的问题,将 A.place 换成 A.device 可以解决大部分问题。
3:flashinfer/cute_dsl/blockscaled_gemm.py
首先 import cuda.bindings.driver as cuda
然后将 cutlass_torch.current_stream() 替换成 cuda.CUstream(torch.cuda.current_stream().stream_base.raw_stream) (ctrl+f搜索替换)

FD 框架修改 nvfp4.py, 当前已经支持 apply_ep_prefill 和 apply_ep_decode 。 目前存在的问题是 call_depermute_prefill_combine 这个算子只支持 top-k=4 or 8,但是 eb-45-fp4 的 top-k=6,所以 在当top-k != 4 / 8 时 prefill 会走 python 实现的通道,性能很低。

还需要修改两个 utils.py 文件用于正确加载权重。

增加了一个单测。

Usage or Command

下面是端到端测试脚本
export PYTHONPATH="/root/paddlejob/workspace/output/dcc/FastDeploy":$PYTHONPATH
export MODEL_PATH="/raid0/ERNIE-4.5-21B-A3B-FP4"

export FD_USE_PFCC_DEEP_EP=1
export FD_MOE_BACKEND="flashinfer-cutedsl"
export CUDA_VISIBLE_DEVICES=4,5,6,7

python -m fastdeploy.entrypoints.openai.multi_api_server --ports "9811,9812,9813,9814" --num-servers 4 --args --model "$MODEL_PATH" --ep-prefill-use-worst-num-tokens --disable-custom-all-reduce --tensor-parallel-size 1 --data-parallel-size 4 --no-enable-prefix-caching --max-model-len 65536 --enable-expert-parallel --num-gpu-blocks-override 2048 --max-num-seqs 4 --gpu-memory-utilization 0.9 --max-num-batched-tokens 8192 --graph-optimization-config '{"use_cudagraph":false}'

Accuracy Tests

下面是单测脚本。
export PYTHONPATH="/root/paddlejob/workspace/output/dcc/FastDeploy":$PYTHONPATH

export FD_MOE_BACKEND="flashinfer-cutedsl"
export FD_USE_PFCC_DEEP_EP=1
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7

NVFP4_TEST_MODE=decode NVFP4_TEST_ITERS=2 python -m paddle.distributed.launch --gpus 0,1,2,3,4,5,6,7 FastDeploy/tests/layers/test_nvfp4_fusedmoe.py

NVFP4_TEST_MODE=prefill NVFP4_TEST_ITERS=2 python -m paddle.distributed.launch --gpus 0,1,2,3,4,5,6,7 FastDeploy/tests/layers/test_nvfp4_fusedmoe.py

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@CLAassistant
Copy link

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.

@paddle-bot
Copy link

paddle-bot bot commented Mar 21, 2026

Thanks for your contribution!

@paddle-bot paddle-bot bot added the contributor External developers label Mar 21, 2026
zhoutianzi666
zhoutianzi666 previously approved these changes Mar 22, 2026
layer.up_gate_proj_weight.set_value(paddle.concat([b, a], axis=1))
[a, b] = layer.up_gate_proj_weight_scale.split(2, axis=1)
layer.up_gate_proj_weight_scale.set_value(paddle.concat([b, a], axis=1))
# [a, b] = layer.up_gate_proj_weight.split(2, axis=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一块为什么删除?

Copy link
Author

@mpgemm mpgemm Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FlashInfer CUTLASS 期望是[up, gate] 的输入,但是ckp存的是[gate, up] 所以要交换。但是 cutedsl grouped_gemm_nt_masked 直接以 [gate, up](ckp原始顺序)读取权重,配合内部的 silu_and_mul 顺序是一致的,不需要 swap。

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我会增加一个后端判断来决定是否交换

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对于cutlass的情况呢,这里还需要swap吧,qwen3-30b-a3b的checkpoint当时验证的应该是需要swap的,加一个if判断?


# flashinfer-trtllm
return output
return paddle.empty_like(x)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上面的两个条件都return了,这里return为什么?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果既不是 flashinfer-cutlass 也不是 flashinfer-cutedsl 则返回

@mpgemm mpgemm changed the title fp4-moe [Feature] Support NVFP4 Flashinfer-cutedsl MoE on SM100 Mar 22, 2026
layer.up_gate_proj_weight.set_value(paddle.concat([b, a], axis=1))
[a, b] = layer.up_gate_proj_weight_scale.split(2, axis=1)
layer.up_gate_proj_weight_scale.set_value(paddle.concat([b, a], axis=1))
if self.backend != "flashinfer-cutedsl":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里明确为flashinfer_cutlass吧,以后可能会加别的后端

@codecov-commenter
Copy link

Codecov Report

❌ Patch coverage is 16.58537% with 171 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@0b4c1cb). Learn more about missing BASE report.

Files with missing lines Patch % Lines
...deploy/model_executor/layers/quantization/nvfp4.py 15.00% 149 Missing and 4 partials ⚠️
fastdeploy/model_executor/layers/utils.py 29.16% 15 Missing and 2 partials ⚠️
fastdeploy/model_executor/utils.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #6963   +/-   ##
==========================================
  Coverage           ?   73.44%           
==========================================
  Files              ?      399           
  Lines              ?    56002           
  Branches           ?     8846           
==========================================
  Hits               ?    41131           
  Misses             ?    11951           
  Partials           ?     2920           
Flag Coverage Δ
GPU 73.44% <16.58%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

contributor External developers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants