Skip to content

Add a (rolling) KV cache for Wan models to enable autoregressive rollouts ("SelfForcing-style").#13681

Open
gueraf wants to merge 1 commit intohuggingface:mainfrom
gueraf:wan-rolling-kv-cache
Open

Add a (rolling) KV cache for Wan models to enable autoregressive rollouts ("SelfForcing-style").#13681
gueraf wants to merge 1 commit intohuggingface:mainfrom
gueraf:wan-rolling-kv-cache

Conversation

@gueraf
Copy link
Copy Markdown

@gueraf gueraf commented May 5, 2026

What does this PR do?

  • Implements a simple (rolling) KV cache for Wan models to enable autoregressive generation.
  • Tries to mirror the KV cache pattern in transformer_flux2.py as well as transformer's DynamicCache as much as possible.
  • Videos and byte-level equivalence against upstream Self Forcing tested in https://github.com/gueraf/self-forcing-diffusers/ (see videos attached to release, and inference script here).
  • This initial PR does not yet implement sink-frame pinning yet, lacks some model-level adjustments (Self Forcing has cross-attention QK norms and per-frame timestep modulation), and does not implement cross attention caching (easy to add, but in reality this is negligible GPU time and often a small regression).
  • Add tests for cache append/overwrite, and window eviction behavior.

Motivation

This is a tightly scoped follow-up to #12773 and a first step toward #12600. The previous draft explored similar functionality but also included Krea-specific experiments and broader integration work.

As for practical use, we (https://odyssey.ml/) would like to rely on the Hugging Face Diffusers ecosystem to ship Self-Forcing-like models without having to ship many custom modules, ideally none.

Progresses #12600

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul thanks for offering help with this :)

@github-actions github-actions Bot added models tests size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 5, 2026
@gueraf gueraf changed the title Add a (rolling) KV cache for Wan models to enable autoregressive rollouts ("SelfForcing"). Add a (rolling) KV cache for Wan models to enable autoregressive rollouts ("SelfForcing-style"). May 6, 2026
@gueraf gueraf force-pushed the wan-rolling-kv-cache branch 2 times, most recently from fb97f37 to b2f85fa Compare May 6, 2026 19:58
@github-actions github-actions Bot added the documentation Improvements or additions to documentation label May 6, 2026
@gueraf gueraf force-pushed the wan-rolling-kv-cache branch 5 times, most recently from 123314b to 7c6255e Compare May 6, 2026 20:25
WanKVCache is a per-block self-attention KV cache that lets a Wan
transformer generate video chunk by chunk while reusing the K/V tensors
computed for prior chunks instead of re-running the full attention over the
whole prefix on every step.

API:
- ``WanKVCache(num_blocks, window_size=-1)`` — one cache per transformer
  instance. ``window_size=-1`` keeps the full prefix; a finite window
  evicts the oldest tokens once the cap is reached.
- ``cache.enable_append_mode()`` / ``cache.enable_overwrite_mode()`` — pick
  the write semantics for the next forward pass. Append grows the cache
  (or rolls when full); overwrite replaces the newest chunk in place — used
  for additional denoising steps that re-do the most recent chunk.
- ``cache.update(block_idx, key, value)`` — called from ``WanAttnProcessor``
  during self-attention to merge the current chunk into the per-block
  cache and return the K/V to attend over.
- ``cache.reset()`` — clear all blocks between videos.

Wan plumbing:
- ``WanTransformer3DModel.forward`` accepts ``frame_offset: int = 0`` and
  forwards ``kv_cache`` (extracted from ``attention_kwargs``) plus
  ``block_idx`` to each transformer block.
- ``WanRotaryPosEmbed.forward`` takes ``frame_offset`` so RoPE can address
  positions in the original (uncached) sequence even when the latent input
  is just one chunk.
- ``WanAttnProcessor.__call__`` receives ``kv_cache`` / ``block_idx``;
  on self-attention it calls ``cache.update(...)`` and uses the returned
  K/V for SDPA. Cross-attention is unaffected.

Caller usage::

    cache = WanKVCache(num_blocks=len(transformer.blocks))
    for chunk_idx, latent_chunk in enumerate(chunks):
        cache.enable_append_mode()
        for step_idx, t in enumerate(denoising_steps):
            if step_idx > 0:
                cache.enable_overwrite_mode()
            transformer(
                hidden_states=latent_chunk,
                timestep=t,
                encoder_hidden_states=prompt_embeds,
                frame_offset=chunk_idx * patch_frames_per_chunk,
                attention_kwargs={"kv_cache": cache},
            )

Tests cover unbounded append, windowed append (with eviction across one and
multiple chunks), in-place overwrite of the newest chunk, the
read-from-prior-context contract, reset, and frame_offset's effect on RoPE.
@gueraf gueraf force-pushed the wan-rolling-kv-cache branch from 7c6255e to 5d28b7a Compare May 6, 2026 20:35
@gueraf gueraf marked this pull request as ready for review May 6, 2026 20:35
@rootonchair
Copy link
Copy Markdown
Contributor

Hi, I'm very interested in this line of work. It would be great if there is an example inference script for running Wan in Self-Forcing style?

@gueraf
Copy link
Copy Markdown
Author

gueraf commented May 7, 2026

uv run python scripts/autoregressive_video_generation.py \
  --prompt "A cat walks on the grass, realistic style, high quality" \
  --output ./autoregressive.mp4```

is a starting point if you have a GPU with enough HBM (i tested it on an rtx 6000). 

Caveats: 
- It runs with window_size=-1, i.e. the KV cache grows as we generate. 
- There are some minor patches needed for some Wan model layers to create full parity with the self forcing codebase (norm precisions, norm order, etc.). I'll submit them separately.

@sayakpaul sayakpaul requested review from DN6 and yiyixuxu May 8, 2026 06:14
@sayakpaul
Copy link
Copy Markdown
Member

@zucchini-nlp FYI.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation models size/L PR with diff > 200 LOC tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants