Skip to content

Add Qwen3.5-4B math RL recipes (full + delta) + Qwen3.5 enablement#11

Open
iamziyuzhao wants to merge 1 commit into
Infini-AI-Lab:mainfrom
iamziyuzhao:add-qwen3.5-recipe-normal
Open

Add Qwen3.5-4B math RL recipes (full + delta) + Qwen3.5 enablement#11
iamziyuzhao wants to merge 1 commit into
Infini-AI-Lab:mainfrom
iamziyuzhao:add-qwen3.5-recipe-normal

Conversation

@iamziyuzhao

@iamziyuzhao iamziyuzhao commented Jun 2, 2026

Copy link
Copy Markdown

PR: Add Qwen3.5-4B math RL recipes (full + delta) + Qwen3.5 enablement

Branch: add-qwen3.5-recipe-normal · Base: main

Summary

Adds two new math RL recipes — examples/math/qwen3.5-4b-m2po-full and
examples/math/qwen3.5-4b-m2po-delta — plus the minimal framework changes needed
to train Qwen3.5 (a dense, hybrid Gated-DeltaNet model, model_type: qwen3_5)
with M2PO on AstraFlow, mirroring the existing qwen3-8b-m2po-{full,delta}

Verified end-to-end on NVIDIA L40 (Ada, sm_89). A full run trained 86+ steps
with zero crashes
and steadily rising eval:

metric step 0 step 80 Δ
overall avg@k 47.8% 57.4% +9.6
overall pass@k 56.5% 67.4% +10.9

(eval every 10 steps over AIME24 / AIME25 / AMC / Minerva / MATH500; monotonic climb,
grad_norm stable ~0.8–0.9. Per-benchmark trend in the attached CSV.)

Note: Qwen3.5 ships only as a vision-language checkpoint
(Qwen3_5ForConditionalGeneration, has vision_config); these recipes train it
text-only for math. The text backbone is dense (non-MoE) with a hybrid
3:1 Gated-DeltaNet / full-attention stack. Requires transformers>=5.8
(+ flash-linear-attention for the GDN layers) and SGLang main (the qwen3_5 model).

The key recipe change: SGLang attention backend (fa3 → flashinfer)

The most important recipe-level adjustment, plus matching memory caps.

Qwen3.5 is a hybrid Gated-DeltaNet model. With the default fa3 (FlashAttention-3),
the GDN path dispatches a Hopper-only kernel (hopper/flash_fwd_launch_template.h)
that fails on Ada GPUs such as the L40 (sm_89) with CUDA error: invalid argument
under real load. (Plain dense Qwen3 with full-attention runs fine with fa3 on L40 — this
is specific to the fa3 × Qwen3.5-GDN combination, not a general L40 limitation.)

On non-Hopper archs SGLang auto-selects flashinfer (full-attention) + triton
(the GDN/linear-attn + mamba layers). Empirically on L40 with Qwen3.5-4B:
flashinfer ✓, triton ✓, fa3 ✗. The recipes set attention_backend: flashinfer
explicitly (SGLang's literal auto-default on Ada/L40); triton is an equally-valid
alternative
(both verified). slime/Miles don't force fa3 for Qwen3.5 either — they
defer to SGLang's per-arch default.

Gated-DeltaNet keeps both a KV cache and a Mamba state cache, so without a concurrency
limit the engine retracts and then CUDA-OOMs on a 44 GB L40. The recipes therefore set
max_running_requests: 32 and mem_fraction_static: 0.7 on the inference side, and
max_tokens_per_mb: 8192 with FSDP dp=4 on the trainer side.

Files

New recipes (mirror the existing qwen3-8b-m2po-{full,delta} structure exactly)

Each recipe = yaml/{experiment,raas}.yaml + scripts/{1_astraflow,2_raas,3_trainer_model0,run_*}.sh.

  • examples/math/qwen3.5-4b-m2po-full/ — TCP full weight transfer.
  • examples/math/qwen3.5-4b-m2po-delta/ — TCP delta weight transfer
    (weight_transfer_strategies: delta, delta_full_sync_interval: 10).

Both: Qwen3.5-4B, M2PO (m2_threshold 0.01), ctx 8k, lr 5e-6, 800 steps,
train_batch_size 256, n_samples 8, DeepScaleR train set, rlvr/math_verify,
eval on AIME24/AIME25/AMC/Minerva/MATH500, RaaS sglang dp4 + trainer FSDP dp4,
attention_backend: flashinfer, attn_impl: sdpa.

Core changes (4 files) — Qwen3.5 / transformers-5 compatibility

  1. train_worker/utils/model.py — register qwen3_5 as a valid vision model
  2. train_worker/engine/fsdp_engine.py — pass attention_mask=None for qwen3_5
  3. train_worker/utils/fsdp/__init__.pyapply_fsdp2 normalizes
    _no_split_modules to a list. Qwen3.5 exposes it as a set (multiple decoder
    layer classes); the code indexed [0]TypeError: 'set' object is not subscriptable.
  4. core/workflow/impl/rlvr.py Works on both 4.x (list) and 5.x (BatchEncoding) at transformer

All four are guarded/minimal and do not change behavior for existing models. The
trainer uses the standard packed training forward (no model-specific forward path).

Not included / caveats

  • Requires transformers>=5.8 + flash-linear-attention + SGLang main; dependency
    pin bumps to the repo are out of scope for this recipe PR (called out so maintainers
    can reproduce).
  • Recipes default to flashinfer; switch to attention_backend: triton if preferred
    (both verified).

New recipes examples/math/qwen3.5-4b-m2po-{full,delta} for training Qwen3.5
(dense, hybrid Gated-DeltaNet text backbone; model_type qwen3_5) with M2PO on
AstraFlow, mirroring the existing qwen3-8b-m2po recipe structure. Trained
text-only for math (the checkpoint ships as Qwen3_5ForConditionalGeneration).

Verified end-to-end on NVIDIA L40 (Ada, sm_89): a full run trained 86+ steps
with no crash and steadily rising eval — overall avg@k 47.8% -> 57.4% (+9.6) and
pass@k 56.5% -> 67.4% over the first 80 steps (AIME24/AIME25/AMC/Minerva/MATH500,
eval every 10 steps).

Minimal framework changes for Qwen3.5 / transformers>=5 compatibility:
- model.py: register qwen3_5 + is_qwen3_5_model()
- fsdp_engine.py: pass attention_mask=None for qwen3_5 (transformers>=5
  create_causal_mask calls .ndim; the old dict form raised AttributeError)
- fsdp/__init__.py: normalize _no_split_modules set->list (qwen3_5 exposes a set)
- rlvr.py: unwrap BatchEncoding from apply_chat_template (transformers>=5)

Recipes use the standard packed training forward. attention_backend=flashinfer
(fa3 dispatches a Hopper-only kernel that fails on Ada/L40 for the GDN path;
flashinfer + triton both verified); max_running_requests=32, mem_fraction_static=0.7
on inference and FSDP dp=4 + max_tokens_per_mb=8192 on the trainer to fit 44GB L40.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@iamziyuzhao iamziyuzhao force-pushed the add-qwen3.5-recipe-normal branch from 208dd60 to 2d8c5f0 Compare June 8, 2026 23:21
@iamziyuzhao

iamziyuzhao commented Jun 8, 2026

Copy link
Copy Markdown
Author

Rebased onto latest main (316756b, v0.1.1) — conflicts resolved

Two files conflicted because upstream had independently fixed the same transformers>=5 issues, so this PR no longer needs to touch them:

  • core/workflow/impl/rlvr.py — upstream added return_dict=False to apply_chat_template (same effect as our BatchEncoding unwrap) → dropped our redundant block, now identical to upstream.
  • train_worker/utils/fsdp/__init__.py — upstream now normalizes _no_split_modules via list(...) → adopted upstream's form.

Net effect: the PR's core footprint shrank from 4 files to 2. Remaining changes:

file change
train_worker/utils/model.py register qwen3_5 + is_qwen3_5_model()
train_worker/engine/fsdp_engine.py pass attention_mask=None for qwen3_5 (like qwen3_moe / qwen3_vl)
examples/math/qwen3.5-4b-m2po-{full,delta}/ new recipes (unchanged)

Re-validated end-to-end on the rebased code (L40, RaaS sglang dp4 + trainer FSDP dp4)

metric step 0 step 40 Δ
overall avg@k 47.8% 53.7% +5.9
overall pass@k 57.1% 65.6% +8.5

@iamziyuzhao

Copy link
Copy Markdown
Author

Exact validated environment (for reproducibility)

Pinned versions this PR was validated against (conda env on NVIDIA L40, Ada sm_89):

component version
AstraFlow 0.1.1 (base v0.1.1 = 316756b; this PR HEAD 2d8c5f0)
Python 3.12.13
PyTorch 2.11.0+cu130 (CUDA 13.0, cuDNN 9.19.00)
transformers 5.8.1
SGLang 0.5.6.post3.dev5643+g373cadc92
flash-linear-attention (fla) 0.5.0
flashinfer-python 0.6.11.post1
triton 3.6.0
datasets 4.8.5
accelerate 1.13.0
numpy 2.3.5
GPU / driver NVIDIA L40, 580.105.08

Notes:

  • transformers>=5.8 is required (the qwen3_5 / Qwen3_5ForConditionalGeneration model type + the create_causal_mask API this PR adapts to).
  • SGLang must be a build that includes the qwen3_5 model; on Ada (sm_89) it auto-selects flashinfer (full-attn) + triton (GDN/linear-attn) — fa3 is Hopper-only and crashes here.
  • flashinfer-python (not the legacy flashinfer package) is what's installed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants