Skip to content

[Qwen3.5] Fix caching bug in GDN layer for autoregressive mode#3907

Open
Rohan-Bierneni wants to merge 1 commit into
mainfrom
rbierneni-test-q3-caching
Open

[Qwen3.5] Fix caching bug in GDN layer for autoregressive mode#3907
Rohan-Bierneni wants to merge 1 commit into
mainfrom
rbierneni-test-q3-caching

Conversation

@Rohan-Bierneni
Copy link
Copy Markdown
Collaborator

@Rohan-Bierneni Rohan-Bierneni commented May 14, 2026

Description

Previously when we enabled caching in the GDN layer for Qwen3-Next model bringup, there was a bug in autoregressive mode when running inference with api_server, our benchmarking tool. There was an issue with how api_server was managing the GDN caches resulting in repeated token output in decoding.

Now, I have fixed this bug by reusing the existing kvcache class in kvcache.py for storing the GDN recurrent_state and conv_state. It seems like the kvcache class is well integrated with api_server in terms of batching, padding, updating, etc.

For new models with special cache structures, for the fastest way to functional decoding, it seems like reusing the kvcache class as much as possible is the best solution since it is already integrated within our inference framework

Tests

I have run standard decode.py and have also run an inference benchmark on api_server for Qwen3.5, which uses the GDN

decode.py output: https://paste.googleplex.com/6084567757881344

Test run with api_server for Qwen3.5-35b-a3b:

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented May 14, 2026

Codecov Report

❌ Patch coverage is 26.47059% with 25 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/models/qwen3.py 25.80% 21 Missing and 2 partials ⚠️
src/maxtext/layers/decoders.py 0.00% 2 Missing ⚠️

📢 Thoughts on this report? Let us know!

@Rohan-Bierneni Rohan-Bierneni force-pushed the rbierneni-test-q3-caching branch from be2561a to 77c4785 Compare May 14, 2026 16:40
Add mini config model support for q3.5

Wrong config name updated

Remove special casing for caching since using existing kvcache class

return kvcache instead of active_cache

Use kvcache class and remove extra logic in decoders.py

Add logic for proper batching of gdn caches

Update for nnx issue when batch size > 1

Remove GDN specific cache

Fixed linter issues

Run linter on qwen3.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants