From 7dd7d9c686d21973d6179d0261f50a6890bde0b1 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 8 May 2026 10:03:14 +0000 Subject: [PATCH 1/8] [Discrete Diffusion] Add DFlash pipeline Adds DFlashPipeline + DFlashTokenDiffusionScheduler for block-diffusion speculative decoding with a draft DFlash model and a target causal LM. Verified against the six bug patterns surfaced in the LLaDA2 review (#13598). DFlash sidesteps most of them by being batch_size=1 only and relying on the causal default for attention; the applicable patterns (#3 callback bindings, #4 EOS at first generated position, #6 inner progress-bar config preservation) are pinned by regression tests. Public surface mirrors the LLaDA2 / SDAR / IDLM conventions: lazy import, dummy objects, scheduler + output dataclass, pipeline + output dataclass, fast tests for both, scheduler doc page, pipeline doc page. Sample/train scripts under examples/discrete_diffusion/. --- docs/source/en/_toctree.yml | 4 + docs/source/en/api/pipelines/dflash.md | 24 + .../api/schedulers/dflash_token_diffusion.md | 22 + examples/discrete_diffusion/sample_dflash.py | 145 +++++ examples/discrete_diffusion/train_dflash.py | 319 ++++++++++ src/diffusers/__init__.py | 8 + src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/dflash/__init__.py | 47 ++ .../pipelines/dflash/pipeline_dflash.py | 552 ++++++++++++++++++ src/diffusers/schedulers/__init__.py | 8 + .../scheduling_dflash_token_diffusion.py | 277 +++++++++ src/diffusers/utils/dummy_pt_objects.py | 30 + .../dummy_torch_and_transformers_objects.py | 30 + tests/pipelines/dflash/__init__.py | 1 + tests/pipelines/dflash/test_dflash.py | 448 ++++++++++++++ .../test_scheduler_dflash_token_diffusion.py | 310 ++++++++++ 16 files changed, 2227 insertions(+) create mode 100644 docs/source/en/api/pipelines/dflash.md create mode 100644 docs/source/en/api/schedulers/dflash_token_diffusion.md create mode 100644 examples/discrete_diffusion/sample_dflash.py create mode 100644 examples/discrete_diffusion/train_dflash.py create mode 100644 src/diffusers/pipelines/dflash/__init__.py create mode 100644 src/diffusers/pipelines/dflash/pipeline_dflash.py create mode 100644 src/diffusers/schedulers/scheduling_dflash_token_diffusion.py create mode 100644 tests/pipelines/dflash/__init__.py create mode 100644 tests/pipelines/dflash/test_dflash.py create mode 100644 tests/schedulers/test_scheduler_dflash_token_diffusion.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 8e8776d4a8c2..dc934c832b8e 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -648,6 +648,8 @@ title: Z-Image title: Image - sections: + - local: api/pipelines/dflash + title: DFlash - local: api/pipelines/llada2 title: LLaDA2 title: Text @@ -711,6 +713,8 @@ title: DDPMScheduler - local: api/schedulers/deis title: DEISMultistepScheduler + - local: api/schedulers/dflash_token_diffusion + title: DFlashTokenDiffusionScheduler - local: api/schedulers/multistep_dpm_solver_inverse title: DPMSolverMultistepInverse - local: api/schedulers/multistep_dpm_solver diff --git a/docs/source/en/api/pipelines/dflash.md b/docs/source/en/api/pipelines/dflash.md new file mode 100644 index 000000000000..95847e1fdd82 --- /dev/null +++ b/docs/source/en/api/pipelines/dflash.md @@ -0,0 +1,24 @@ + + +# DFlash + +`DFlashPipeline` performs block-diffusion speculative decoding using a diffusion draft model and a target causal LM. +The draft model is conditioned on target hidden features extracted during prefill and verification steps. + +## DFlashPipeline +[[autodoc]] DFlashPipeline + - all + - __call__ + +## DFlashPipelineOutput +[[autodoc]] pipelines.DFlashPipelineOutput diff --git a/docs/source/en/api/schedulers/dflash_token_diffusion.md b/docs/source/en/api/schedulers/dflash_token_diffusion.md new file mode 100644 index 000000000000..c98b11bc9963 --- /dev/null +++ b/docs/source/en/api/schedulers/dflash_token_diffusion.md @@ -0,0 +1,22 @@ + + +# DFlashTokenDiffusionScheduler + +`DFlashTokenDiffusionScheduler` implements the acceptance and posterior sampling logic used in DFlash-style block +diffusion speculative decoding. + +## DFlashTokenDiffusionScheduler +[[autodoc]] DFlashTokenDiffusionScheduler + +## DFlashTokenDiffusionSchedulerOutput +[[autodoc]] schedulers.scheduling_dflash_token_diffusion.DFlashTokenDiffusionSchedulerOutput diff --git a/examples/discrete_diffusion/sample_dflash.py b/examples/discrete_diffusion/sample_dflash.py new file mode 100644 index 000000000000..a10899a0d052 --- /dev/null +++ b/examples/discrete_diffusion/sample_dflash.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Sample script for DFlash speculative decoding. + +Example: + python sample_dflash.py \ + --draft_model_id z-lab/Qwen3-8B-DFlash-b16 \ + --target_model_id Qwen/Qwen3-8B \ + --prompt "How many positive whole-number divisors does 196 have?" \ + --max_new_tokens 256 +""" + +import argparse + +import torch +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer + +from diffusers import DFlashPipeline + + +def main(): + parser = argparse.ArgumentParser(description="Run DFlash speculative decoding.") + parser.add_argument( + "--draft_model_id", + type=str, + default="z-lab/Qwen3-8B-DFlash-b16", + help="Draft model ID or local path.", + ) + parser.add_argument( + "--target_model_id", + type=str, + default="Qwen/Qwen3-8B", + help="Target model ID or local path.", + ) + parser.add_argument( + "--prompt", + type=str, + default="How many positive whole-number divisors does 196 have?", + help="Prompt text to generate from.", + ) + parser.add_argument( + "--max_new_tokens", + type=int, + default=2048, + help="Maximum number of new tokens to generate.", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.0, + help="Sampling temperature.", + ) + parser.add_argument( + "--use_chat_template", + action="store_true", + help="Use the tokenizer chat template for the prompt.", + ) + parser.add_argument( + "--add_generation_prompt", + action="store_true", + help="Add the generation prompt when using the chat template.", + ) + parser.add_argument( + "--enable_thinking", + action="store_true", + help="Enable chat-template thinking mode if supported by the tokenizer.", + ) + parser.add_argument( + "--mask_token", + type=str, + default="<|MASK|>", + help="Mask token to add if the tokenizer does not define one.", + ) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + help="Device to run inference on.", + ) + parser.add_argument( + "--dtype", + type=str, + default="auto", + choices=["auto", "float32", "float16", "bfloat16"], + help="Model dtype.", + ) + + args = parser.parse_args() + + dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16} + torch_dtype = dtype_map.get(args.dtype) + + print(f"Loading draft model: {args.draft_model_id}") + print(f"Loading target model: {args.target_model_id}") + dtype_arg = torch_dtype if torch_dtype is not None else "auto" + # Draft model is a custom DFlashDraftModel; use AutoModel so trust_remote_code routes to the class in `auto_map`. + draft_model = AutoModel.from_pretrained( + args.draft_model_id, + trust_remote_code=True, + dtype=dtype_arg, + device_map=args.device, + ) + target_model = AutoModelForCausalLM.from_pretrained( + args.target_model_id, + dtype=dtype_arg, + device_map=args.device, + ) + tokenizer = AutoTokenizer.from_pretrained(args.target_model_id) + if tokenizer.mask_token is None: + tokenizer.add_special_tokens({"mask_token": args.mask_token}) + pipe = DFlashPipeline(draft_model=draft_model, target_model=target_model, tokenizer=tokenizer) + + chat_kwargs = {"enable_thinking": args.enable_thinking} + + print(f"\nPrompt: {args.prompt}") + output = pipe( + prompt=args.prompt, + max_new_tokens=args.max_new_tokens, + temperature=args.temperature, + use_chat_template=args.use_chat_template, + add_generation_prompt=args.add_generation_prompt, + chat_template_kwargs=chat_kwargs, + ) + + print("\nGenerated text:") + print(output.texts[0]) + print(f"\nGenerated {output.sequences.shape[1]} tokens") + + +if __name__ == "__main__": + main() diff --git a/examples/discrete_diffusion/train_dflash.py b/examples/discrete_diffusion/train_dflash.py new file mode 100644 index 000000000000..673a2173a058 --- /dev/null +++ b/examples/discrete_diffusion/train_dflash.py @@ -0,0 +1,319 @@ +#!/usr/bin/env python +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import math +import os +from dataclasses import asdict, dataclass +from typing import Dict, Optional + +import torch +import torch.nn.functional as F +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from torch.utils.data import DataLoader +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, get_scheduler + + +logger = get_logger(__name__) + + +@dataclass +class TrainConfig: + draft_model_id: str + target_model_id: str + dataset_name: str + dataset_config_name: Optional[str] + text_column: str + + output_dir: str + seed: int + max_train_steps: int + checkpointing_steps: int + logging_steps: int + + per_device_train_batch_size: int + gradient_accumulation_steps: int + learning_rate: float + weight_decay: float + lr_scheduler: str + lr_warmup_steps: int + + max_length: int + block_size: int + mask_token: str + + +def parse_args() -> TrainConfig: + parser = argparse.ArgumentParser(description="Fine-tune a DFlash draft model with target-conditioned blocks.") + + parser.add_argument("--draft_model_id", type=str, default="z-lab/Qwen3-4B-DFlash-b16") + parser.add_argument("--target_model_id", type=str, default="Qwen/Qwen3-4B") + parser.add_argument("--dataset_name", type=str, default="wikitext") + parser.add_argument("--dataset_config_name", type=str, default="wikitext-2-raw-v1") + parser.add_argument("--text_column", type=str, default="text") + + parser.add_argument("--output_dir", type=str, default="dflash-output") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--max_train_steps", type=int, default=1000) + parser.add_argument("--checkpointing_steps", type=int, default=500) + parser.add_argument("--logging_steps", type=int, default=50) + + parser.add_argument("--per_device_train_batch_size", type=int, default=2) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument("--learning_rate", type=float, default=2e-5) + parser.add_argument("--weight_decay", type=float, default=0.0) + parser.add_argument( + "--lr_scheduler", type=str, default="cosine", choices=["linear", "cosine", "cosine_with_restarts"] + ) + parser.add_argument("--lr_warmup_steps", type=int, default=100) + + parser.add_argument("--max_length", type=int, default=512) + parser.add_argument( + "--block_size", type=int, default=0, help="Override draft block size (0 uses the model config)." + ) + parser.add_argument("--mask_token", type=str, default="<|MASK|>") + + args = parser.parse_args() + return TrainConfig(**vars(args)) + + +def tokenize_fn(examples: Dict, tokenizer, text_column: str, max_length: int): + texts = examples[text_column] + texts = [t for t in texts if isinstance(t, str) and len(t.strip()) > 0] + return tokenizer(texts, truncation=True, padding=False, max_length=max_length) + + +def build_target_layer_ids(num_target_layers: int, num_draft_layers: int): + if num_draft_layers == 1: + return [int(num_target_layers // 2)] + start = 1 + end = int(num_target_layers) - 3 + span = end - start + return [int(round(start + (i * span) / (num_draft_layers - 1))) for i in range(int(num_draft_layers))] + + +def extract_context_feature(hidden_states, layer_ids): + offset = 1 + selected_states = [hidden_states[layer_id + offset] for layer_id in layer_ids] + return torch.cat(selected_states, dim=-1) + + +def get_target_input_embeddings(model: torch.nn.Module) -> torch.nn.Module: + embeddings = model.get_input_embeddings() + if embeddings is None: + base = getattr(model, "model", None) + embeddings = getattr(base, "embed_tokens", None) + if embeddings is None: + raise ValueError("Target model must expose input embeddings.") + return embeddings + + +def get_target_output_embeddings(model: torch.nn.Module) -> torch.nn.Module: + embeddings = model.get_output_embeddings() + if embeddings is None: + embeddings = getattr(model, "lm_head", None) + if embeddings is None: + raise ValueError("Target model must expose output embeddings.") + return embeddings + + +def main(): + cfg = parse_args() + + project_config = ProjectConfiguration(project_dir=cfg.output_dir, logging_dir=os.path.join(cfg.output_dir, "logs")) + accelerator = Accelerator( + gradient_accumulation_steps=cfg.gradient_accumulation_steps, + project_config=project_config, + ) + + if accelerator.is_main_process: + os.makedirs(cfg.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + set_seed(cfg.seed) + logger.info("Training configuration: %s", asdict(cfg)) + + tokenizer = AutoTokenizer.from_pretrained(cfg.target_model_id, use_fast=True) + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + if tokenizer.mask_token_id is None: + tokenizer.add_special_tokens({"mask_token": cfg.mask_token}) + + draft_model = AutoModel.from_pretrained(cfg.draft_model_id, trust_remote_code=True) + target_model = AutoModelForCausalLM.from_pretrained(cfg.target_model_id) + target_model.eval() + target_model.requires_grad_(False) + + mask_token_id = tokenizer.mask_token_id + if mask_token_id is None: + raise ValueError("Tokenizer must define a mask token for DFlash training.") + + input_embeddings = get_target_input_embeddings(target_model) + output_embeddings = get_target_output_embeddings(target_model) + + block_size = int(cfg.block_size) + if block_size <= 0: + block_size = getattr(draft_model, "block_size", None) or getattr( + getattr(draft_model, "config", None), "block_size", None + ) + if block_size is None: + raise ValueError("Draft model must define `block_size` or pass --block_size.") + block_size = int(block_size) + if block_size < 2: + raise ValueError("`block_size` must be at least 2 for DFlash training.") + + layer_ids = getattr(draft_model, "target_layer_ids", None) + if layer_ids is None: + cfg_draft = getattr(draft_model, "config", None) + num_target_layers = getattr(cfg_draft, "num_target_layers", None) + num_hidden_layers = getattr(cfg_draft, "num_hidden_layers", None) + if num_target_layers is None or num_hidden_layers is None: + raise ValueError("Draft model must expose `target_layer_ids` or `num_target_layers` in config.") + layer_ids = build_target_layer_ids(int(num_target_layers), int(num_hidden_layers)) + + raw_datasets = load_dataset(cfg.dataset_name, cfg.dataset_config_name) + if "train" not in raw_datasets: + raise ValueError(f"Dataset {cfg.dataset_name} has no 'train' split.") + + with accelerator.main_process_first(): + tokenized = raw_datasets["train"].map( + lambda ex: tokenize_fn(ex, tokenizer, cfg.text_column, cfg.max_length), + batched=True, + remove_columns=raw_datasets["train"].column_names, + desc="Tokenizing", + ) + + collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, return_tensors="pt") + train_dataloader = DataLoader( + tokenized, shuffle=True, collate_fn=collator, batch_size=cfg.per_device_train_batch_size, drop_last=True + ) + + optimizer = torch.optim.AdamW(draft_model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay) + + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.gradient_accumulation_steps) + num_train_epochs = math.ceil(cfg.max_train_steps / num_update_steps_per_epoch) + + lr_scheduler = get_scheduler( + name=cfg.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=cfg.lr_warmup_steps, + num_training_steps=cfg.max_train_steps, + ) + + draft_model, optimizer, train_dataloader, lr_scheduler, target_model = accelerator.prepare( + draft_model, optimizer, train_dataloader, lr_scheduler, target_model + ) + input_embeddings = get_target_input_embeddings(target_model) + output_embeddings = get_target_output_embeddings(target_model) + + global_step = 0 + draft_model.train() + + for epoch in range(num_train_epochs): + for batch in train_dataloader: + with accelerator.accumulate(draft_model): + input_ids = batch["input_ids"] + attention_mask = batch.get("attention_mask", torch.ones_like(input_ids)) + + valid_lengths = attention_mask.sum(dim=1) + min_valid = int(valid_lengths.min().item()) + if min_valid <= block_size: + continue + + max_start = min_valid - block_size + start = torch.randint(1, max_start + 1, (1,), device=input_ids.device).item() + + block_output_ids = torch.full( + (input_ids.shape[0], block_size), + int(mask_token_id), + device=input_ids.device, + dtype=torch.long, + ) + block_output_ids[:, 0] = input_ids[:, start] + block_targets = input_ids[:, start + 1 : start + block_size] + block_mask = attention_mask[:, start + 1 : start + block_size] + + position_ids = torch.arange(start, start + block_size, device=input_ids.device).unsqueeze(0) + position_ids = position_ids.expand(input_ids.shape[0], -1) + + with torch.no_grad(): + target_out = target_model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + target_hidden = extract_context_feature(target_out.hidden_states, layer_ids) + target_hidden = target_hidden[:, :start, :] + + noise_embedding = input_embeddings(block_output_ids) + draft_hidden = draft_model( + target_hidden=target_hidden, + noise_embedding=noise_embedding, + position_ids=position_ids, + use_cache=False, + is_causal=False, + ) + if not torch.is_tensor(draft_hidden): + draft_hidden = getattr(draft_hidden, "last_hidden_state", draft_hidden[0]) + + logits = output_embeddings(draft_hidden[:, -block_size + 1 :, :]) + vocab_size = logits.shape[-1] + loss = F.cross_entropy(logits.view(-1, vocab_size), block_targets.reshape(-1), reduction="none") + loss = loss.view(block_targets.shape[0], -1) + loss = (loss * block_mask.to(loss.dtype)).sum() / block_mask.sum().clamp_min(1) + + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + if accelerator.sync_gradients: + global_step += 1 + + if global_step % cfg.logging_steps == 0 and accelerator.is_main_process: + logger.info("step=%d loss=%.4f lr=%.6g", global_step, loss.item(), lr_scheduler.get_last_lr()[0]) + + if cfg.checkpointing_steps > 0 and global_step % cfg.checkpointing_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + save_dir = os.path.join(cfg.output_dir, f"checkpoint-{global_step}") + os.makedirs(save_dir, exist_ok=True) + unwrapped = accelerator.unwrap_model(draft_model) + unwrapped.save_pretrained(save_dir, save_function=accelerator.save) + tokenizer.save_pretrained(save_dir) + + if global_step >= cfg.max_train_steps: + break + + if global_step >= cfg.max_train_steps: + break + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + final_dir = os.path.join(cfg.output_dir, "final") + os.makedirs(final_dir, exist_ok=True) + unwrapped = accelerator.unwrap_model(draft_model) + unwrapped.save_pretrained(final_dir, save_function=accelerator.save) + tokenizer.save_pretrained(final_dir) + + logger.info("Done.") + + +if __name__ == "__main__": + main() diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 1b1f6b3032b3..f8e683bcd76a 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -371,6 +371,8 @@ "DDPMScheduler", "DDPMWuerstchenScheduler", "DEISMultistepScheduler", + "DFlashTokenDiffusionScheduler", + "DFlashTokenDiffusionSchedulerOutput", "DPMSolverMultistepInverseScheduler", "DPMSolverMultistepScheduler", "DPMSolverSinglestepScheduler", @@ -539,6 +541,8 @@ "CosmosTextToWorldPipeline", "CosmosVideoToWorldPipeline", "CycleDiffusionPipeline", + "DFlashPipeline", + "DFlashPipelineOutput", "EasyAnimateControlPipeline", "EasyAnimateInpaintPipeline", "EasyAnimatePipeline", @@ -1189,6 +1193,8 @@ DDPMScheduler, DDPMWuerstchenScheduler, DEISMultistepScheduler, + DFlashTokenDiffusionScheduler, + DFlashTokenDiffusionSchedulerOutput, DPMSolverMultistepInverseScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, @@ -1336,6 +1342,8 @@ CosmosTextToWorldPipeline, CosmosVideoToWorldPipeline, CycleDiffusionPipeline, + DFlashPipeline, + DFlashPipelineOutput, EasyAnimateControlPipeline, EasyAnimateInpaintPipeline, EasyAnimatePipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index f0fc7585bf31..cd47185bb6a6 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -319,6 +319,7 @@ ] ) _import_structure["latte"] = ["LattePipeline"] + _import_structure["dflash"] = ["DFlashPipeline", "DFlashPipelineOutput"] _import_structure["llada2"] = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"] _import_structure["ltx"] = [ "LTXPipeline", @@ -693,6 +694,7 @@ WuerstchenDecoderPipeline, WuerstchenPriorPipeline, ) + from .dflash import DFlashPipeline, DFlashPipelineOutput from .easyanimate import ( EasyAnimateControlPipeline, EasyAnimateInpaintPipeline, diff --git a/src/diffusers/pipelines/dflash/__init__.py b/src/diffusers/pipelines/dflash/__init__.py new file mode 100644 index 000000000000..c5d0f5fae4cd --- /dev/null +++ b/src/diffusers/pipelines/dflash/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_dflash"] = ["DFlashPipeline", "DFlashPipelineOutput"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_dflash import DFlashPipeline, DFlashPipelineOutput +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/dflash/pipeline_dflash.py b/src/diffusers/pipelines/dflash/pipeline_dflash.py new file mode 100644 index 000000000000..e8b0276db109 --- /dev/null +++ b/src/diffusers/pipelines/dflash/pipeline_dflash.py @@ -0,0 +1,552 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable + +import torch +from tqdm.auto import tqdm +from transformers import DynamicCache + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...schedulers import DFlashTokenDiffusionScheduler +from ...utils import BaseOutput, logging, replace_example_docstring +from ..pipeline_utils import DiffusionPipeline + + +logger = logging.get_logger(__name__) + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import DFlashPipeline + >>> from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer + + >>> draft = AutoModel.from_pretrained( + ... "z-lab/Qwen3-8B-DFlash-b16", trust_remote_code=True, torch_dtype=torch.bfloat16 + ... ) + >>> target = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-8B", torch_dtype=torch.bfloat16) + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B") + >>> pipe = DFlashPipeline(draft_model=draft, target_model=target, tokenizer=tokenizer) + >>> out = pipe(prompt="How many positive whole-number divisors does 196 have?") + >>> print(out.texts[0]) + ``` +""" + + +@dataclass +class DFlashPipelineOutput(BaseOutput): + sequences: torch.LongTensor + texts: list[str] | None = None + + +def _build_target_layer_ids(num_target_layers: int, num_draft_layers: int) -> list[int]: + if num_draft_layers == 1: + return [int(num_target_layers // 2)] + start = 1 + end = int(num_target_layers) - 3 + span = end - start + return [int(round(start + (i * span) / (num_draft_layers - 1))) for i in range(int(num_draft_layers))] + + +def _extract_context_feature(hidden_states: list[torch.Tensor], layer_ids: list[int]) -> torch.Tensor: + offset = 1 + selected_states = [hidden_states[layer_id + offset] for layer_id in layer_ids] + return torch.cat(selected_states, dim=-1) + + +class DFlashPipeline(DiffusionPipeline): + r""" + Block diffusion pipeline for speculative decoding with a DFlash draft model and a target causal LM. + """ + + draft_model: Any + target_model: Any + tokenizer: Any + scheduler: DFlashTokenDiffusionScheduler + _callback_tensor_inputs = ["block_output_ids", "draft_logits", "accepted_length", "next_token", "output_ids"] + + def __init__( + self, + draft_model: torch.nn.Module, + target_model: torch.nn.Module, + tokenizer: Any | None = None, + scheduler: DFlashTokenDiffusionScheduler | None = None, + ): + super().__init__() + if scheduler is None: + scheduler = DFlashTokenDiffusionScheduler() + self.register_modules( + draft_model=draft_model, target_model=target_model, tokenizer=tokenizer, scheduler=scheduler + ) + + # --- Prompt encoding --- + + def _prepare_input_ids( + self, + *, + prompt: str | list[str] | None, + messages: list[dict[str, str]] | None, + input_ids: torch.LongTensor | None, + use_chat_template: bool, + add_generation_prompt: bool, + chat_template_kwargs: dict[str, Any] | None, + ) -> torch.LongTensor: + """Convert prompt/messages/input_ids to a `[batch, seq]` LongTensor.""" + if input_ids is not None: + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + if input_ids.ndim != 2: + raise ValueError(f"`input_ids` must be 2D, got shape {tuple(input_ids.shape)}.") + if input_ids.dtype != torch.long: + raise ValueError(f"`input_ids` must be int64 token IDs, got dtype={input_ids.dtype}.") + return input_ids + + if self.tokenizer is None: + raise ValueError("Tokenizer is required when `input_ids` is not provided.") + + if messages is not None and prompt is not None: + raise ValueError("Provide either `prompt` or `messages`, not both.") + if messages is None and prompt is None: + raise ValueError("Provide one of `prompt`, `messages`, or `input_ids`.") + + chat_template_kwargs = chat_template_kwargs or {} + + if messages is not None: + encoded = self.tokenizer.apply_chat_template( + messages, + add_generation_prompt=add_generation_prompt, + tokenize=True, + return_tensors="pt", + return_dict=True, + **chat_template_kwargs, + ) + return encoded["input_ids"] + + if use_chat_template and getattr(self.tokenizer, "chat_template", None): + if isinstance(prompt, list): + raise ValueError("`prompt` must be a string when `use_chat_template=True`.") + encoded = self.tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + add_generation_prompt=add_generation_prompt, + tokenize=True, + return_tensors="pt", + return_dict=True, + **chat_template_kwargs, + ) + return encoded["input_ids"] + + encoded = self.tokenizer(prompt, return_tensors="pt", padding=isinstance(prompt, list)) + return encoded["input_ids"] + + def check_inputs( + self, + prompt: str | list[str] | None, + messages: list[dict[str, str]] | None, + input_ids: torch.LongTensor | None, + max_new_tokens: int, + output_type: str, + callback_on_step_end: Callable | PipelineCallback | MultiPipelineCallbacks | None, + callback_on_step_end_tensor_inputs: list[str] | None, + ): + # Input source validation + if prompt is None and messages is None and input_ids is None: + raise ValueError("Provide one of `prompt`, `messages`, or `input_ids`.") + if prompt is not None and messages is not None: + raise ValueError("Provide either `prompt` or `messages`, not both.") + if input_ids is not None: + if input_ids.ndim not in (1, 2): + raise ValueError(f"`input_ids` must be 1D or 2D, got shape {tuple(input_ids.shape)}.") + if input_ids.dtype != torch.long: + raise ValueError(f"`input_ids` must be int64 token IDs, got dtype={input_ids.dtype}.") + if prompt is not None and input_ids is None and self.tokenizer is None: + raise ValueError("Tokenizer is required when `input_ids` is not provided.") + if messages is not None and input_ids is None and self.tokenizer is None: + raise ValueError("Tokenizer is required when `input_ids` is not provided.") + + # Generation parameter validation + if max_new_tokens <= 0: + raise ValueError(f"`max_new_tokens` must be > 0, got {max_new_tokens}.") + if output_type not in {"seq", "text"}: + raise ValueError(f"`output_type` must be 'seq' or 'text', got {output_type!r}.") + + # Callback validation + if callback_on_step_end is not None and isinstance( + callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks) + ): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found " + f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + def prepare_latents( + self, + max_length: int, + block_size: int, + mask_token_id: int, + device: torch.device, + ) -> torch.LongTensor: + return torch.full( + (1, max_length + int(block_size)), + int(mask_token_id), + dtype=torch.long, + device=device, + ) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] | None = None, + messages: list[dict[str, str]] | None = None, + input_ids: torch.LongTensor | None = None, + max_new_tokens: int = 2048, + temperature: float = 0.0, + stop_token_ids: list[int] | None = None, + mask_token_id: int | None = None, + use_chat_template: bool = True, + add_generation_prompt: bool = True, + chat_template_kwargs: dict[str, object] | None = None, + output_type: str = "text", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int, dict], None] + | PipelineCallback + | MultiPipelineCallbacks + | None = None, + callback_on_step_end_tensor_inputs: list[str] | None = None, + ) -> DFlashPipelineOutput | tuple[torch.LongTensor, list[str] | None]: + """ + Generate text using block-diffusion speculative decoding. + + Args: + prompt (`str` or `list[str]`, *optional*): + Prompt text. When `use_chat_template` is `True` (default) and a tokenizer with a chat template is + available, the prompt is wrapped in a chat message before tokenization. + messages (`list[dict[str, str]]`, *optional*): + Chat messages to encode. Takes precedence over `prompt` when provided. + input_ids (`torch.LongTensor`, *optional*): + Pre-tokenized input IDs. Takes precedence over `prompt` and `messages`. + max_new_tokens (`int`): + Maximum number of new tokens to generate. + temperature (`float`): + Sampling temperature. + stop_token_ids (`list[int]`, *optional*): + Token IDs that signal generation should stop. + mask_token_id (`int`, *optional*): + Mask token ID for the draft model. + use_chat_template (`bool`, defaults to `True`): + Whether to wrap the prompt in a chat template. + add_generation_prompt (`bool`, defaults to `True`): + Whether to add the generation prompt when using chat templates. + chat_template_kwargs (`dict[str, object]`, *optional*): + Additional keyword arguments for the chat template. + output_type (`str`, defaults to `"text"`): + Output format. `"text"` decodes sequences into strings (requires a tokenizer). `"seq"` returns raw + token ID sequences only. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`DFlashPipelineOutput`] instead of a tuple. + callback_on_step_end (`Callable` or `PipelineCallback`, *optional*): + Callback executed after each speculative decoding step. + callback_on_step_end_tensor_inputs (`list[str]`, *optional*): + Tensor keys to pass to the callback. + + Examples: + """ + # 1. Check inputs early + if callback_on_step_end is not None and isinstance( + callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks) + ): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + if callback_on_step_end_tensor_inputs is None: + callback_on_step_end_tensor_inputs = ["block_output_ids"] + + self.check_inputs( + prompt=prompt, + messages=messages, + input_ids=input_ids, + max_new_tokens=max_new_tokens, + output_type=output_type, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + # 2. Prepare input IDs from prompt/messages/input_ids + input_ids = self._prepare_input_ids( + prompt=prompt, + messages=messages, + input_ids=input_ids, + use_chat_template=use_chat_template, + add_generation_prompt=add_generation_prompt, + chat_template_kwargs=chat_template_kwargs, + ) + + if mask_token_id is None: + mask_token_id = getattr(getattr(self, "tokenizer", None), "mask_token_id", None) + if mask_token_id is None: + # DFlash models store mask_token_id in config.dflash_config + dflash_config = getattr(getattr(self.draft_model, "config", None), "dflash_config", None) + if dflash_config is not None: + mask_token_id = dflash_config.get("mask_token_id", None) + if mask_token_id is None: + raise ValueError("`mask_token_id` must be provided (or available on the tokenizer/model config).") + if input_ids.shape[0] != 1: + raise ValueError("DFlashPipeline currently supports batch_size=1 input_ids.") + + target_params = list(self.target_model.parameters()) if hasattr(self.target_model, "parameters") else [] + device = target_params[0].device if len(target_params) > 0 else torch.device("cpu") + input_ids = input_ids.to(device=device) + draft_params = list(self.draft_model.parameters()) if hasattr(self.draft_model, "parameters") else [] + draft_device = draft_params[0].device if len(draft_params) > 0 else device + if draft_device != device: + logger.warning( + "Draft model is on %s while target model is on %s. For best performance, place both on the same device.", + draft_device, + device, + ) + + if stop_token_ids is None: + eos_token_id = getattr(getattr(self, "tokenizer", None), "eos_token_id", None) + stop_token_ids = [int(eos_token_id)] if eos_token_id is not None else None + if stop_token_ids is not None: + stop_token_ids = [int(token_id) for token_id in stop_token_ids] + + # 3. Setup scheduler and resolve model attributes + self.scheduler.set_timesteps(1, device=device) + + block_size = self._get_block_size() + + # Resolve target layer IDs from draft model config + layer_ids = getattr(self.draft_model, "target_layer_ids", None) + if layer_ids is not None: + target_layer_ids = list(layer_ids) + else: + cfg = getattr(self.draft_model, "config", None) + num_target_layers = getattr(cfg, "num_target_layers", None) + num_hidden_layers = getattr(cfg, "num_hidden_layers", None) + if num_target_layers is None or num_hidden_layers is None: + raise ValueError( + "`draft_model` must define `target_layer_ids` or expose `num_target_layers` in config." + ) + target_layer_ids = _build_target_layer_ids(int(num_target_layers), int(num_hidden_layers)) + + input_embeddings = self.target_model.get_input_embeddings() + output_embeddings = self.target_model.get_output_embeddings() + + num_input_tokens = input_ids.shape[1] + max_length = num_input_tokens + int(max_new_tokens) + + output_ids = self.prepare_latents(max_length, block_size, int(mask_token_id), device) + position_ids = torch.arange(output_ids.shape[1], device=device).unsqueeze(0) + + target_config = getattr(self.target_model, "config", None) + draft_config = getattr(self.draft_model, "config", None) + + # Fast path: some draft models (e.g. z-lab/Qwen3-8B-DFlash-b16) ship a self-contained + # `spec_generate` method. Delegate when available — it's the upstream-canonical loop and + # avoids re-implementing rollback. Newer drafts (Qwen3.5-4B-DFlash) drop this method, so + # fall back to the explicit pipeline loop below. + spec_generate = getattr(self.draft_model, "spec_generate", None) + if callable(spec_generate): + generated = spec_generate( + input_ids=input_ids, + max_new_tokens=int(max_new_tokens), + temperature=float(temperature), + target=self.target_model, + stop_token_ids=stop_token_ids, + ) + sequences = generated[:, input_ids.shape[1] :] + texts = None + if output_type == "text" and getattr(self, "tokenizer", None) is not None: + texts = self.tokenizer.batch_decode(sequences, skip_special_tokens=True) + if not return_dict: + return sequences, texts + return DFlashPipelineOutput(sequences=sequences, texts=texts) + + # Pass `config=` only when it looks like a real PretrainedConfig — hybrid-attention models + # (Qwen3.5) need it so `DynamicCache` instantiates the right per-layer cache types + # (linear vs full), but bare dummy configs in tests don't implement `get_text_config`. + def _new_cache(cfg): + if cfg is not None and hasattr(cfg, "get_text_config"): + try: + return DynamicCache(config=cfg) + except Exception: + pass + return DynamicCache() + + past_key_values_target = _new_cache(target_config) + past_key_values_draft = _new_cache(draft_config) + + # 4. Prefill step + output = self._target_forward( + input_ids=input_ids, + position_ids=position_ids[:, :num_input_tokens], + past_key_values=past_key_values_target, + output_hidden_states=True, + logits_to_keep=1, + ) + output_ids[:, :num_input_tokens] = input_ids + output_ids[:, num_input_tokens : num_input_tokens + 1] = self.scheduler.sample( + output.logits[:, -1:], temperature=temperature + ) + target_hidden = _extract_context_feature(output.hidden_states, target_layer_ids) + + start = num_input_tokens + global_step = 0 + num_blocks = (max_length - num_input_tokens + block_size - 1) // block_size + + # 5. Block-wise speculative decoding loop + block_progress_bar_config = getattr(self, "_progress_bar_config", {}).copy() + block_progress_bar_config["position"] = 0 + block_progress_bar_config["desc"] = "Blocks" + block_iter = tqdm(range(num_blocks), **block_progress_bar_config) + + for _block_idx in block_iter: + if start >= max_length: + break + + block_output_ids = output_ids[:, start : start + int(block_size)].clone() + block_position_ids = position_ids[:, start : start + int(block_size)] + noise_embedding = input_embeddings(block_output_ids) + draft_hidden = self.draft_model( + target_hidden=target_hidden, + noise_embedding=noise_embedding, + position_ids=position_ids[:, past_key_values_draft.get_seq_length() : start + int(block_size)], + past_key_values=past_key_values_draft, + use_cache=True, + is_causal=False, + ) + if not torch.is_tensor(draft_hidden): + draft_hidden = getattr(draft_hidden, "last_hidden_state", draft_hidden[0]) + draft_logits = output_embeddings(draft_hidden[:, -int(block_size) + 1 :, :]) + past_key_values_draft.crop(start) + block_output_ids[:, 1:] = self.scheduler.sample(draft_logits, temperature=temperature) + + # For hybrid-attention targets (Qwen3.5 etc.), linear-attention cache layers silently + # no-op on `.crop()`, so rejected speculative tokens would permanently contaminate the + # recurrent state. Snapshot before the verify forward so we can roll back on partial-accept. + target_needs_rollback = self.scheduler.cache_has_linear_attention(past_key_values_target) + target_snapshot = self.scheduler.snapshot_cache(past_key_values_target) if target_needs_rollback else None + + output = self._target_forward( + input_ids=block_output_ids, + position_ids=block_position_ids, + past_key_values=past_key_values_target, + output_hidden_states=True, + logits_to_keep=None, + ) + step_output = self.scheduler.step( + model_output=output.logits, + timestep=global_step, + sample=block_output_ids, + temperature=temperature, + return_dict=True, + ) + accepted_length = step_output.accepted_length + next_token = step_output.next_token + acceptance_length = int(step_output.accepted_length[0].item()) + output_ids[:, start : start + acceptance_length + 1] = block_output_ids[:, : acceptance_length + 1] + output_ids[:, start + acceptance_length + 1] = step_output.next_token + start += acceptance_length + 1 + partial_accept = acceptance_length + 1 < int(block_size) + if target_needs_rollback and partial_accept: + # Restore linear-attn recurrent state (and full-attn KVs) to pre-verify, then re-run + # target on just the accepted prefix to advance all layer types cleanly to `start`. + self.scheduler.restore_cache(past_key_values_target, target_snapshot) + accepted_ids = block_output_ids[:, : acceptance_length + 1] + accepted_pos = block_position_ids[:, : acceptance_length + 1] + self._target_forward( + input_ids=accepted_ids, + position_ids=accepted_pos, + past_key_values=past_key_values_target, + output_hidden_states=False, + logits_to_keep=1, + ) + elif not target_needs_rollback: + # Full-attn-only cache: cheap crop is fine. + past_key_values_target.crop(start) + target_hidden = _extract_context_feature(output.hidden_states, target_layer_ids)[ + :, : acceptance_length + 1, : + ] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, global_step, 0, callback_kwargs) + output_ids = callback_outputs.pop("output_ids", output_ids) + global_step += 1 + + if self.scheduler.check_should_stop(output_ids, stop_token_ids, num_input_tokens): + break + + # 6. Post-process output + output_ids = output_ids[:, :max_length] + output_ids = output_ids[:, output_ids[0] != int(mask_token_id)] + if stop_token_ids is not None: + stop_tensor = torch.tensor(stop_token_ids, device=device, dtype=torch.long) + stop_positions = torch.isin(output_ids[0, num_input_tokens:], stop_tensor).nonzero(as_tuple=True)[0] + if stop_positions.numel() > 0: + output_ids = output_ids[:, : num_input_tokens + int(stop_positions[0].item()) + 1] + + prompt_len = input_ids.shape[1] + sequences = output_ids[:, prompt_len:] + + texts = None + if output_type == "text" and getattr(self, "tokenizer", None) is not None: + texts = self.tokenizer.batch_decode(sequences, skip_special_tokens=True) + + if not return_dict: + return sequences, texts + return DFlashPipelineOutput(sequences=sequences, texts=texts) + + def _get_block_size(self) -> int: + cfg = getattr(self.draft_model, "config", None) + block_size = getattr(cfg, "block_size", None) + if block_size is None: + raise ValueError("`draft_model.config` must define `block_size`.") + return int(block_size) + + def _target_forward( + self, + *, + input_ids: torch.LongTensor, + position_ids: torch.LongTensor, + past_key_values: DynamicCache, + output_hidden_states: bool, + logits_to_keep: int | None, + ): + kwargs = { + "input_ids": input_ids, + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": True, + "output_hidden_states": output_hidden_states, + } + if logits_to_keep is not None: + try: + return self.target_model(**kwargs, logits_to_keep=logits_to_keep) + except TypeError: + pass + return self.target_model(**kwargs) + + +__all__ = ["DFlashPipeline", "DFlashPipelineOutput"] diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index b1f75bed7dc5..6321d189fc17 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -51,6 +51,10 @@ _import_structure["scheduling_ddpm_parallel"] = ["DDPMParallelScheduler"] _import_structure["scheduling_ddpm_wuerstchen"] = ["DDPMWuerstchenScheduler"] _import_structure["scheduling_deis_multistep"] = ["DEISMultistepScheduler"] + _import_structure["scheduling_dflash_token_diffusion"] = [ + "DFlashTokenDiffusionScheduler", + "DFlashTokenDiffusionSchedulerOutput", + ] _import_structure["scheduling_dpm_cogvideox"] = ["CogVideoXDPMScheduler"] _import_structure["scheduling_dpmsolver_multistep"] = ["DPMSolverMultistepScheduler"] _import_structure["scheduling_dpmsolver_multistep_inverse"] = ["DPMSolverMultistepInverseScheduler"] @@ -157,6 +161,10 @@ from .scheduling_ddpm_parallel import DDPMParallelScheduler from .scheduling_ddpm_wuerstchen import DDPMWuerstchenScheduler from .scheduling_deis_multistep import DEISMultistepScheduler + from .scheduling_dflash_token_diffusion import ( + DFlashTokenDiffusionScheduler, + DFlashTokenDiffusionSchedulerOutput, + ) from .scheduling_dpm_cogvideox import CogVideoXDPMScheduler from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler from .scheduling_dpmsolver_multistep_inverse import DPMSolverMultistepInverseScheduler diff --git a/src/diffusers/schedulers/scheduling_dflash_token_diffusion.py b/src/diffusers/schedulers/scheduling_dflash_token_diffusion.py new file mode 100644 index 000000000000..e90b7cfde09f --- /dev/null +++ b/src/diffusers/schedulers/scheduling_dflash_token_diffusion.py @@ -0,0 +1,277 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass + +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +@dataclass +class DFlashTokenDiffusionSchedulerOutput(BaseOutput): + """ + Output class for DFlash-style speculative token scheduling. + + Args: + prev_sample (`torch.LongTensor` of shape `(batch_size, block_size)`): + The proposed block tokens from the draft model. + accepted_length (`torch.LongTensor` of shape `(batch_size,)`): + Number of consecutive accepted tokens from the block. + next_token (`torch.LongTensor` of shape `(batch_size,)`): + Next token sampled from the target posterior at the first rejection. + posterior (`torch.LongTensor` of shape `(batch_size, block_size)`): + Sampled tokens from the target posterior used for acceptance checks. + """ + + prev_sample: torch.LongTensor + accepted_length: torch.LongTensor + next_token: torch.LongTensor + posterior: torch.LongTensor + + +class DFlashTokenDiffusionScheduler(SchedulerMixin, ConfigMixin): + """ + Scheduler for DFlash-style block diffusion speculative decoding. + + This scheduler samples target posteriors and computes acceptance lengths for draft blocks. + """ + + order = 1 + + @register_to_config + def __init__(self): + self.num_inference_steps = 1 + self.timesteps = torch.tensor([0], dtype=torch.long) + + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: + if num_inference_steps <= 0: + raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.") + self.num_inference_steps = int(num_inference_steps) + self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, device=device, dtype=torch.long) + + def sample(self, logits: torch.Tensor, temperature: float = 0.0) -> torch.LongTensor: + if temperature < 1e-5: + return torch.argmax(logits, dim=-1) + bsz, seq_len, vocab_size = logits.shape + flat = logits.view(-1, vocab_size) / float(temperature) + probs = torch.softmax(flat, dim=-1) + return torch.multinomial(probs, num_samples=1).view(bsz, seq_len) + + def step( + self, + model_output: torch.Tensor, + timestep: int | torch.Tensor, + sample: torch.LongTensor, + *, + temperature: float = 0.0, + return_dict: bool = True, + ) -> ( + DFlashTokenDiffusionSchedulerOutput + | tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor, torch.LongTensor] + ): + """ + Perform a single speculative decoding verification step. + + Args: + model_output (`torch.Tensor` of shape `(batch_size, block_size, vocab_size)`): + Raw logits from the target model for the current block. + timestep (`int` or `torch.Tensor`): + Current step index (unused for single-step DFlash, kept for interface compatibility). + sample (`torch.LongTensor` of shape `(batch_size, block_size)`): + Draft token IDs proposed by the draft model. + temperature (`float`): + Sampling temperature for the target posterior. + return_dict (`bool`): + Whether to return a `DFlashTokenDiffusionSchedulerOutput` or a tuple. + """ + posterior = self.sample(model_output, temperature=temperature) + if sample.shape[1] > 1: + matches = sample[:, 1:] == posterior[:, :-1] + accepted_length = matches.int().cumprod(dim=1).sum(dim=1) + else: + accepted_length = torch.zeros((sample.shape[0],), device=sample.device, dtype=torch.long) + + next_token = posterior.gather(1, accepted_length.unsqueeze(1)).squeeze(1) + + if not return_dict: + return sample, accepted_length, next_token, posterior + return DFlashTokenDiffusionSchedulerOutput( + prev_sample=sample, + accepted_length=accepted_length, + next_token=next_token, + posterior=posterior, + ) + + @staticmethod + def cache_has_linear_attention(cache) -> bool: + """ + Detect whether a `DynamicCache` contains any linear-attention layers (e.g. Qwen3.5's gated-delta-net layers). + The spec-decoding loop needs this to know whether a partial-accept block requires snapshot/restore rather than + a plain `.crop()` — transformers' `DynamicCache.crop()` silently no-ops on linear-attention layers, so rejected + speculative tokens would otherwise permanently contaminate the recurrent state. + + Duck-typed on `recurrent_states`/`conv_states` attributes to avoid importing transformers. + """ + for layer in getattr(cache, "layers", []): + if hasattr(layer, "recurrent_states") and hasattr(layer, "conv_states"): + return True + return False + + @staticmethod + def snapshot_cache(cache) -> list[dict]: + """ + Clone the full per-layer cache state so a speculative target forward can be rolled back. + + Handles both full-attention `DynamicLayer` (keys/values) and linear-attention layers + (conv_states/recurrent_states plus their init flags). Mirrors upstream DFlash's MLX `_GDNStateCapture` + rollback, but via full-layer restore rather than kernel-level replay. Pair with `restore_cache()`; no-op if the + caller only ever fully-accepts. + """ + snapshots: list[dict] = [] + for layer in getattr(cache, "layers", []): + snap: dict = {"cls": type(layer)} + if hasattr(layer, "keys") and layer.keys is not None: + snap["keys"] = layer.keys.clone() + snap["values"] = layer.values.clone() + if hasattr(layer, "recurrent_states"): + snap["has_previous_state"] = bool(getattr(layer, "has_previous_state", False)) + snap["is_recurrent_states_initialized"] = bool( + getattr(layer, "is_recurrent_states_initialized", False) + ) + snap["is_conv_states_initialized"] = bool(getattr(layer, "is_conv_states_initialized", False)) + snap["recurrent_states"] = ( + layer.recurrent_states.clone() if getattr(layer, "recurrent_states", None) is not None else None + ) + snap["conv_states"] = ( + layer.conv_states.clone() if getattr(layer, "conv_states", None) is not None else None + ) + snapshots.append(snap) + return snapshots + + @staticmethod + def restore_cache(cache, snapshots: list[dict]) -> None: + """ + Restore a cache to the state captured by `snapshot_cache()`. After this call, the caller should re-advance the + cache (e.g. by re-running the target model on just the accepted prefix) so both full- and linear-attention + layers end up at the committed token count. + """ + for layer, snap in zip(cache.layers, snapshots): + if "keys" in snap: + # DynamicLayer: reassign (shapes will have grown during the verify forward, so + # in-place copy is not safe here). + layer.keys = snap["keys"] + layer.values = snap["values"] + if "recurrent_states" in snap: + # LinearAttentionLayer: in-place copy preserves any static-address assumption + # (e.g. for cudagraph capture) on the live tensors. + layer.has_previous_state = snap["has_previous_state"] + layer.is_recurrent_states_initialized = snap["is_recurrent_states_initialized"] + layer.is_conv_states_initialized = snap["is_conv_states_initialized"] + if snap["recurrent_states"] is not None and getattr(layer, "recurrent_states", None) is not None: + layer.recurrent_states.copy_(snap["recurrent_states"]) + elif snap["recurrent_states"] is not None: + layer.recurrent_states = snap["recurrent_states"].clone() + if snap["conv_states"] is not None and getattr(layer, "conv_states", None) is not None: + layer.conv_states.copy_(snap["conv_states"]) + elif snap["conv_states"] is not None: + layer.conv_states = snap["conv_states"].clone() + + @staticmethod + def check_should_stop( + output_ids: torch.LongTensor, + stop_token_ids: list[int] | None, + num_input_tokens: int, + ) -> bool: + """ + Check whether any stop token has been generated in the output sequence. + + Args: + output_ids (`torch.LongTensor` of shape `(batch_size, seq_len)`): + Current output token IDs including prompt and generated tokens. + stop_token_ids (`list[int]` or `None`): + Token IDs that signal generation should stop. + num_input_tokens (`int`): + Number of prompt tokens at the start of the sequence. + + Returns: + `bool`: `True` if generation should stop, `False` otherwise. + """ + if stop_token_ids is None: + return False + stop_tensor = torch.tensor(stop_token_ids, device=output_ids.device, dtype=torch.long) + return torch.isin(output_ids[:, num_input_tokens:], stop_tensor).any().item() + + def add_noise( + self, + original_samples: torch.LongTensor, + attention_mask: torch.LongTensor, + *, + prompt_length: int, + block_size: int, + mask_token_id: int, + generator: torch.Generator | None = None, + ) -> tuple[torch.LongTensor, torch.BoolTensor]: + """ + Apply the forward (noising) process for DFlash-style block diffusion training. + + For each block after the prompt, a random fraction of valid (non-padding) tokens are replaced with + `mask_token_id`. + + Args: + original_samples (`torch.LongTensor` of shape `(batch_size, seq_len)`): + Clean token IDs. + attention_mask (`torch.LongTensor` of shape `(batch_size, seq_len)`): + Padding mask (1 for valid, 0 for padding). + prompt_length (`int`): + Number of leading prompt tokens to keep unmasked. + block_size (`int`): + Block size for masking. + mask_token_id (`int`): + Token ID to use for masked positions. + generator (`torch.Generator`, *optional*): + RNG for reproducibility. + + Returns: + `tuple[torch.LongTensor, torch.BoolTensor]`: + `(noisy, masked)` -- the noisy sequence and the boolean mask indicating which positions were masked. + """ + batch_size, seq_len = original_samples.shape + device = original_samples.device + + noisy = original_samples.clone() + masked = torch.zeros_like(original_samples, dtype=torch.bool) + + valid = attention_mask.to(dtype=torch.bool) + for block_start in range(prompt_length, seq_len, block_size): + block_end = min(seq_len, block_start + block_size) + seg_len = block_end - block_start + if seg_len <= 0: + continue + + p_mask = torch.rand((batch_size, 1), device=device, generator=generator) + seg = torch.rand((batch_size, seg_len), device=device, generator=generator) < p_mask + seg = seg & valid[:, block_start:block_end] + + masked[:, block_start:block_end] = seg + + noisy = torch.where(masked, torch.full_like(noisy, mask_token_id), noisy) + return noisy, masked + + +__all__ = ["DFlashTokenDiffusionScheduler", "DFlashTokenDiffusionSchedulerOutput"] diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 9bfb73c1999e..792942dd1928 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -2882,6 +2882,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class DFlashTokenDiffusionScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DFlashTokenDiffusionSchedulerOutput(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class DPMSolverMultistepInverseScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index cfa1318783f3..9ab1581c0045 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1337,6 +1337,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class DFlashPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class DFlashPipelineOutput(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class EasyAnimateControlPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/dflash/__init__.py b/tests/pipelines/dflash/__init__.py new file mode 100644 index 000000000000..8b137891791f --- /dev/null +++ b/tests/pipelines/dflash/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/pipelines/dflash/test_dflash.py b/tests/pipelines/dflash/test_dflash.py new file mode 100644 index 000000000000..7f1f11011c18 --- /dev/null +++ b/tests/pipelines/dflash/test_dflash.py @@ -0,0 +1,448 @@ +import unittest + +import torch + +from diffusers import DFlashPipeline, DFlashTokenDiffusionScheduler + + +class _DummyModelOutput: + def __init__(self, logits, hidden_states=None): + self.logits = logits + self.hidden_states = hidden_states + + +class _DummyConfig: + def __init__(self, block_size, num_target_layers, num_hidden_layers): + self.block_size = block_size + self.num_target_layers = num_target_layers + self.num_hidden_layers = num_hidden_layers + + +class _DummyTargetModel(torch.nn.Module): + """Minimal target (causal LM) model that returns logits and hidden_states.""" + + def __init__(self, vocab_size: int, hidden_dim: int, num_layers: int): + super().__init__() + self.vocab_size = vocab_size + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.embed = torch.nn.Embedding(vocab_size, hidden_dim) + self.lm_head = torch.nn.Linear(hidden_dim, vocab_size, bias=False) + + def get_input_embeddings(self): + return self.embed + + def get_output_embeddings(self): + return self.lm_head + + def forward( + self, + input_ids, + position_ids=None, + past_key_values=None, + use_cache=False, + output_hidden_states=False, + logits_to_keep=None, + **kwargs, + ): + bsz, seq_len = input_ids.shape + h = self.embed(input_ids) + # Create hidden_states list: one entry per layer + 1 for the embedding layer + hidden_states = [h] * (self.num_layers + 1) if output_hidden_states else None + logits = self.lm_head(h) + # Make token 0 the most likely so acceptance is deterministic + logits[:, :, 0] = 10.0 + return _DummyModelOutput(logits=logits, hidden_states=hidden_states) + + def parameters(self): + return super().parameters() + + +class _DummyDraftModel(torch.nn.Module): + """Minimal draft model that returns hidden states of the expected shape.""" + + def __init__(self, hidden_dim: int, num_target_layers: int, block_size: int): + super().__init__() + self.block_size = block_size + self.config = _DummyConfig( + block_size=block_size, + num_target_layers=num_target_layers, + num_hidden_layers=1, + ) + # The draft model receives concatenated hidden states from num_target_layers target layers, + # each of dim hidden_dim, and produces a hidden state of dim hidden_dim. + self.proj = torch.nn.Linear(hidden_dim * num_target_layers, hidden_dim, bias=False) + self.register_buffer("_device_anchor", torch.empty(0)) + + @property + def device(self): + return self._device_anchor.device + + def forward( + self, + target_hidden, + noise_embedding, + position_ids=None, + past_key_values=None, + use_cache=False, + is_causal=False, + **kwargs, + ): + # Return a tensor with shape (batch, seq_len, hidden_dim) + bsz = noise_embedding.shape[0] + seq_len = position_ids.shape[1] if position_ids is not None else noise_embedding.shape[1] + h = torch.zeros(bsz, seq_len, self.proj.out_features, device=noise_embedding.device) + return h + + +def _make_pipeline(tokenizer=None, vocab_size=32, hidden_dim=16, num_target_layers=4, block_size=4): + target = _DummyTargetModel(vocab_size=vocab_size, hidden_dim=hidden_dim, num_layers=num_target_layers) + draft = _DummyDraftModel(hidden_dim=hidden_dim, num_target_layers=1, block_size=block_size) + # Set target_layer_ids directly so we skip the config-based computation. + draft.target_layer_ids = [1] + scheduler = DFlashTokenDiffusionScheduler() + return DFlashPipeline(draft_model=draft, target_model=target, tokenizer=tokenizer, scheduler=scheduler) + + +class DFlashPipelineTest(unittest.TestCase): + # ------------------------------------------------------------------ + # Pipeline runs + # ------------------------------------------------------------------ + def test_pipeline_runs_with_input_ids(self): + pipe = _make_pipeline() + input_ids = torch.tensor([[5, 6, 7, 8]], dtype=torch.long) + + out = pipe( + input_ids=input_ids, + max_new_tokens=8, + temperature=0.0, + mask_token_id=31, + stop_token_ids=None, + output_type="seq", + ) + + self.assertIsNotNone(out.sequences) + self.assertEqual(out.sequences.ndim, 2) + self.assertEqual(out.sequences.shape[0], 1) + # Generated tokens should not be longer than max_new_tokens + self.assertLessEqual(out.sequences.shape[1], 8) + + # ------------------------------------------------------------------ + # output_type="seq" + # ------------------------------------------------------------------ + def test_output_type_seq(self): + pipe = _make_pipeline() + input_ids = torch.tensor([[1, 2, 3]], dtype=torch.long) + + out = pipe( + input_ids=input_ids, + max_new_tokens=8, + temperature=0.0, + mask_token_id=31, + output_type="seq", + ) + + self.assertIsNotNone(out.sequences) + self.assertIsNone(out.texts) + + # ------------------------------------------------------------------ + # output_type="text" with mock tokenizer + # ------------------------------------------------------------------ + def test_output_type_text_with_tokenizer(self): + tok = type( + "Tok", + (), + { + "eos_token_id": None, + "mask_token_id": 31, + "batch_decode": lambda self, seqs, **kw: [f"decoded_{len(s)}" for s in seqs], + }, + )() + pipe = _make_pipeline(tokenizer=tok) + + out = pipe( + input_ids=torch.tensor([[1, 2, 3]], dtype=torch.long), + max_new_tokens=8, + temperature=0.0, + output_type="text", + ) + + self.assertIsNotNone(out.sequences) + self.assertIsNotNone(out.texts) + self.assertEqual(len(out.texts), 1) + self.assertTrue(out.texts[0].startswith("decoded_")) + + def test_output_type_text_without_tokenizer(self): + """output_type='text' without a tokenizer should return texts=None.""" + pipe = _make_pipeline(tokenizer=None) + + out = pipe( + input_ids=torch.tensor([[1, 2, 3]], dtype=torch.long), + max_new_tokens=8, + temperature=0.0, + mask_token_id=31, + output_type="text", + ) + + self.assertIsNotNone(out.sequences) + self.assertIsNone(out.texts) + + # ------------------------------------------------------------------ + # output_type invalid + # ------------------------------------------------------------------ + def test_output_type_invalid_raises(self): + pipe = _make_pipeline() + + with self.assertRaises(ValueError): + pipe( + input_ids=torch.tensor([[1, 2, 3]], dtype=torch.long), + max_new_tokens=8, + mask_token_id=31, + output_type="invalid", + ) + + # ------------------------------------------------------------------ + # return_dict=False + # ------------------------------------------------------------------ + def test_pipeline_return_tuple(self): + pipe = _make_pipeline() + input_ids = torch.tensor([[1, 2, 3]], dtype=torch.long) + + result = pipe( + input_ids=input_ids, + max_new_tokens=8, + temperature=0.0, + mask_token_id=31, + output_type="seq", + return_dict=False, + ) + + self.assertIsInstance(result, tuple) + self.assertEqual(len(result), 2) + sequences, texts = result + self.assertIsNotNone(sequences) + self.assertIsNone(texts) + + # ------------------------------------------------------------------ + # check_inputs validation + # ------------------------------------------------------------------ + def test_check_inputs_no_inputs_raises(self): + pipe = _make_pipeline() + with self.assertRaises(ValueError): + pipe.check_inputs( + prompt=None, + messages=None, + input_ids=None, + max_new_tokens=16, + output_type="seq", + callback_on_step_end=None, + callback_on_step_end_tensor_inputs=None, + ) + + def test_check_inputs_both_prompt_and_messages_raises(self): + pipe = _make_pipeline() + with self.assertRaises(ValueError): + pipe.check_inputs( + prompt="hello", + messages=[{"role": "user", "content": "hi"}], + input_ids=None, + max_new_tokens=16, + output_type="seq", + callback_on_step_end=None, + callback_on_step_end_tensor_inputs=None, + ) + + def test_check_inputs_invalid_input_ids_ndim_raises(self): + pipe = _make_pipeline() + bad_ids = torch.zeros(2, 3, 4, dtype=torch.long) + with self.assertRaises(ValueError): + pipe.check_inputs( + prompt=None, + messages=None, + input_ids=bad_ids, + max_new_tokens=16, + output_type="seq", + callback_on_step_end=None, + callback_on_step_end_tensor_inputs=None, + ) + + def test_check_inputs_invalid_input_ids_dtype_raises(self): + pipe = _make_pipeline() + bad_ids = torch.zeros(1, 4, dtype=torch.float32) + with self.assertRaises(ValueError): + pipe.check_inputs( + prompt=None, + messages=None, + input_ids=bad_ids, + max_new_tokens=16, + output_type="seq", + callback_on_step_end=None, + callback_on_step_end_tensor_inputs=None, + ) + + def test_check_inputs_invalid_max_new_tokens_raises(self): + pipe = _make_pipeline() + with self.assertRaises(ValueError): + pipe.check_inputs( + prompt=None, + messages=None, + input_ids=torch.tensor([[1, 2]], dtype=torch.long), + max_new_tokens=0, + output_type="seq", + callback_on_step_end=None, + callback_on_step_end_tensor_inputs=None, + ) + + def test_check_inputs_invalid_output_type_raises(self): + pipe = _make_pipeline() + with self.assertRaises(ValueError): + pipe.check_inputs( + prompt=None, + messages=None, + input_ids=torch.tensor([[1, 2]], dtype=torch.long), + max_new_tokens=16, + output_type="bad", + callback_on_step_end=None, + callback_on_step_end_tensor_inputs=None, + ) + + def test_check_inputs_prompt_without_tokenizer_raises(self): + pipe = _make_pipeline(tokenizer=None) + with self.assertRaises(ValueError): + pipe.check_inputs( + prompt="hello", + messages=None, + input_ids=None, + max_new_tokens=16, + output_type="seq", + callback_on_step_end=None, + callback_on_step_end_tensor_inputs=None, + ) + + def test_check_inputs_messages_without_tokenizer_raises(self): + pipe = _make_pipeline(tokenizer=None) + with self.assertRaises(ValueError): + pipe.check_inputs( + prompt=None, + messages=[{"role": "user", "content": "hi"}], + input_ids=None, + max_new_tokens=16, + output_type="seq", + callback_on_step_end=None, + callback_on_step_end_tensor_inputs=None, + ) + + def test_check_inputs_valid_input_ids_passes(self): + pipe = _make_pipeline() + # Should not raise. + pipe.check_inputs( + prompt=None, + messages=None, + input_ids=torch.tensor([[1, 2, 3]], dtype=torch.long), + max_new_tokens=16, + output_type="seq", + callback_on_step_end=None, + callback_on_step_end_tensor_inputs=None, + ) + + # ------------------------------------------------------------------ + # _prepare_input_ids + # ------------------------------------------------------------------ + def test_prepare_input_ids_from_tensor(self): + pipe = _make_pipeline() + ids = torch.tensor([[1, 2, 3]], dtype=torch.long) + result = pipe._prepare_input_ids( + prompt=None, + messages=None, + input_ids=ids, + use_chat_template=False, + add_generation_prompt=False, + chat_template_kwargs=None, + ) + self.assertTrue(torch.equal(result, ids)) + + def test_prepare_input_ids_from_1d_tensor(self): + pipe = _make_pipeline() + ids = torch.tensor([1, 2, 3], dtype=torch.long) + result = pipe._prepare_input_ids( + prompt=None, + messages=None, + input_ids=ids, + use_chat_template=False, + add_generation_prompt=False, + chat_template_kwargs=None, + ) + self.assertEqual(result.shape, (1, 3)) + + # ------------------------------------------------------------------ + # prepare_latents + # ------------------------------------------------------------------ + def test_prepare_latents(self): + pipe = _make_pipeline() + mask_token_id = 99 + latents = pipe.prepare_latents( + max_length=10, block_size=4, mask_token_id=mask_token_id, device=torch.device("cpu") + ) + self.assertEqual(latents.shape, (1, 14)) # 10 + 4 + self.assertTrue((latents == mask_token_id).all().item()) + self.assertEqual(latents.dtype, torch.long) + + +class DFlashRegressionTest(unittest.TestCase): + """Pin the bug patterns surfaced in https://github.com/huggingface/diffusers/issues/13598 + (LLaDA2 review) for any that are relevant to DFlash. + + DFlash is batch_size=1 only and does not pass an `attention_mask` to the target model, so + issues #1 (padding mask), #2 (block_length scheduler routing), and #5 (batched EOS row freeze) + don't apply. The applicable patterns are #3 (callback keys must resolve), #4 (EOS at the first + generated position), and #6 (progress-bar config must not be mutated by `__call__`). + """ + + def test_callback_tensor_inputs_advertised_keys_resolve(self): + """Issue #3: every advertised callback key must be a bound local at callback time.""" + observed: list[str] = [] + + def cb(pipe, step, timestep, kwargs): + observed.extend(sorted(kwargs.keys())) + return {} + + pipe = _make_pipeline() + keys = list(pipe._callback_tensor_inputs) + pipe( + input_ids=torch.tensor([[1, 2, 3]], dtype=torch.long), + max_new_tokens=8, + temperature=0.0, + mask_token_id=31, + stop_token_ids=None, + output_type="seq", + callback_on_step_end=cb, + callback_on_step_end_tensor_inputs=keys, + ) + self.assertEqual(set(observed), set(keys)) + + def test_stop_token_at_first_generated_position_triggers_stop(self): + """Issue #4 analogue: a stop token at index `num_input_tokens` (the first generated + position) must terminate generation. Verified at the scheduler level — `check_should_stop` + searches positions starting at `num_input_tokens`, inclusive.""" + # Sequence layout: prompt = [1, 2] (length 2), first generated token (index 2) is the stop. + output_ids = torch.tensor([[1, 2, 99, 0, 0]], dtype=torch.long) + self.assertTrue(DFlashTokenDiffusionScheduler.check_should_stop(output_ids, [99], 2)) + + def test_progress_bar_disable_is_preserved_after_call(self): + """Issue #6: calling the pipeline must not mutate `_progress_bar_config`.""" + pipe = _make_pipeline() + pipe.set_progress_bar_config(disable=True) + before = dict(pipe._progress_bar_config) + pipe( + input_ids=torch.tensor([[1, 2, 3]], dtype=torch.long), + max_new_tokens=8, + temperature=0.0, + mask_token_id=31, + stop_token_ids=None, + output_type="seq", + ) + self.assertEqual(pipe._progress_bar_config, before) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/schedulers/test_scheduler_dflash_token_diffusion.py b/tests/schedulers/test_scheduler_dflash_token_diffusion.py new file mode 100644 index 000000000000..560571907d77 --- /dev/null +++ b/tests/schedulers/test_scheduler_dflash_token_diffusion.py @@ -0,0 +1,310 @@ +import tempfile +import unittest + +import torch + +from diffusers import DFlashTokenDiffusionScheduler + + +class DFlashTokenDiffusionSchedulerTest(unittest.TestCase): + def get_scheduler(self): + return DFlashTokenDiffusionScheduler() + + # ------------------------------------------------------------------ + # set_timesteps + # ------------------------------------------------------------------ + def test_set_timesteps(self): + scheduler = self.get_scheduler() + scheduler.set_timesteps(4) + self.assertEqual(scheduler.num_inference_steps, 4) + self.assertEqual(len(scheduler.timesteps), 4) + self.assertEqual(scheduler.timesteps[0].item(), 3) + self.assertEqual(scheduler.timesteps[-1].item(), 0) + + def test_set_timesteps_single(self): + scheduler = self.get_scheduler() + scheduler.set_timesteps(1) + self.assertEqual(scheduler.num_inference_steps, 1) + self.assertEqual(len(scheduler.timesteps), 1) + self.assertEqual(scheduler.timesteps[0].item(), 0) + + def test_set_timesteps_invalid(self): + scheduler = self.get_scheduler() + with self.assertRaises(ValueError): + scheduler.set_timesteps(0) + with self.assertRaises(ValueError): + scheduler.set_timesteps(-1) + + # ------------------------------------------------------------------ + # Config round-trip + # ------------------------------------------------------------------ + def test_save_load_config_round_trip(self): + scheduler = self.get_scheduler() + with tempfile.TemporaryDirectory() as tmpdir: + scheduler.save_config(tmpdir) + loaded = DFlashTokenDiffusionScheduler.from_pretrained(tmpdir) + # The scheduler has no user-configurable params, but it should survive the round-trip. + self.assertIsInstance(loaded, DFlashTokenDiffusionScheduler) + self.assertEqual(loaded.order, 1) + + def test_from_config(self): + scheduler = self.get_scheduler() + new_scheduler = DFlashTokenDiffusionScheduler.from_config(scheduler.config) + self.assertIsInstance(new_scheduler, DFlashTokenDiffusionScheduler) + self.assertEqual(new_scheduler.order, 1) + + # ------------------------------------------------------------------ + # sample() – greedy + # ------------------------------------------------------------------ + def test_sample_greedy(self): + scheduler = self.get_scheduler() + logits = torch.tensor([[[1.0, 5.0, 2.0], [3.0, 1.0, 4.0]]]) # (1, 2, 3) + tokens = scheduler.sample(logits, temperature=0.0) + self.assertEqual(tokens.shape, (1, 2)) + self.assertEqual(tokens[0, 0].item(), 1) # argmax of [1,5,2] + self.assertEqual(tokens[0, 1].item(), 2) # argmax of [3,1,4] + + def test_sample_greedy_batched(self): + scheduler = self.get_scheduler() + logits = torch.tensor( + [ + [[10.0, 0.0], [0.0, 10.0]], + [[0.0, 10.0], [10.0, 0.0]], + ] + ) # (2, 2, 2) + tokens = scheduler.sample(logits, temperature=0.0) + self.assertEqual(tokens.shape, (2, 2)) + self.assertEqual(tokens[0, 0].item(), 0) + self.assertEqual(tokens[0, 1].item(), 1) + self.assertEqual(tokens[1, 0].item(), 1) + self.assertEqual(tokens[1, 1].item(), 0) + + # ------------------------------------------------------------------ + # sample() – multinomial + # ------------------------------------------------------------------ + def test_sample_multinomial(self): + scheduler = self.get_scheduler() + # One token has overwhelming probability; multinomial should pick it. + logits = torch.tensor([[[0.0, 100.0, -100.0]]]) # (1, 1, 3) + tokens = scheduler.sample(logits, temperature=1.0) + self.assertEqual(tokens.shape, (1, 1)) + self.assertEqual(tokens[0, 0].item(), 1) + + # ------------------------------------------------------------------ + # step() – return dict + # ------------------------------------------------------------------ + def test_step_all_accepted(self): + """All draft tokens match the posterior => accepted_length == block_size - 1.""" + scheduler = self.get_scheduler() + batch_size, block_size, vocab_size = 1, 4, 8 + + # Draft tokens: [0, 3, 3, 3] + draft_tokens = torch.tensor([[0, 3, 3, 3]], dtype=torch.long) + # Target logits: make argmax = [3, 3, 3, X] so posterior[:, :-1] matches draft[:, 1:] + logits = torch.zeros(batch_size, block_size, vocab_size) + logits[:, 0, 3] = 10.0 + logits[:, 1, 3] = 10.0 + logits[:, 2, 3] = 10.0 + logits[:, 3, 5] = 10.0 # last posterior token (next_token candidate) + + out = scheduler.step(logits, 0, draft_tokens, temperature=0.0, return_dict=True) + + self.assertEqual(out.prev_sample.shape, (1, 4)) + self.assertEqual(out.accepted_length.shape, (1,)) + self.assertEqual(out.accepted_length[0].item(), 3) # all 3 comparisons match + self.assertEqual(out.next_token.shape, (1,)) + self.assertEqual(out.next_token[0].item(), 5) + self.assertEqual(out.posterior.shape, (1, 4)) + + def test_step_none_accepted(self): + """First draft token already mismatches => accepted_length == 0.""" + scheduler = self.get_scheduler() + batch_size, block_size, vocab_size = 1, 4, 8 + + draft_tokens = torch.tensor([[0, 1, 2, 3]], dtype=torch.long) + logits = torch.zeros(batch_size, block_size, vocab_size) + logits[:, 0, 5] = 10.0 # posterior[0] = 5, but draft[1] = 1 => mismatch + logits[:, 1, 2] = 10.0 + logits[:, 2, 3] = 10.0 + logits[:, 3, 4] = 10.0 + + out = scheduler.step(logits, 0, draft_tokens, temperature=0.0, return_dict=True) + + self.assertEqual(out.accepted_length[0].item(), 0) + self.assertEqual(out.next_token[0].item(), 5) # posterior at index 0 + + def test_step_partial_accepted(self): + """First two match, third does not => accepted_length == 2.""" + scheduler = self.get_scheduler() + batch_size, block_size, vocab_size = 1, 5, 8 + + # draft: [0, 3, 4, 7, 2] + draft_tokens = torch.tensor([[0, 3, 4, 7, 2]], dtype=torch.long) + logits = torch.zeros(batch_size, block_size, vocab_size) + logits[:, 0, 3] = 10.0 # match draft[1]=3 + logits[:, 1, 4] = 10.0 # match draft[2]=4 + logits[:, 2, 0] = 10.0 # mismatch draft[3]=7 + logits[:, 3, 2] = 10.0 + logits[:, 4, 6] = 10.0 + + out = scheduler.step(logits, 0, draft_tokens, temperature=0.0, return_dict=True) + + self.assertEqual(out.accepted_length[0].item(), 2) + self.assertEqual(out.next_token[0].item(), 0) # posterior at index 2 + + def test_step_single_token_block(self): + """Block with a single token => accepted_length == 0.""" + scheduler = self.get_scheduler() + draft_tokens = torch.tensor([[5]], dtype=torch.long) + logits = torch.zeros(1, 1, 8) + logits[:, 0, 3] = 10.0 + + out = scheduler.step(logits, 0, draft_tokens, temperature=0.0, return_dict=True) + self.assertEqual(out.accepted_length[0].item(), 0) + self.assertEqual(out.next_token[0].item(), 3) + + # ------------------------------------------------------------------ + # step() – return tuple + # ------------------------------------------------------------------ + def test_step_return_tuple(self): + scheduler = self.get_scheduler() + draft_tokens = torch.tensor([[0, 1, 2]], dtype=torch.long) + logits = torch.randn(1, 3, 8) + + result = scheduler.step(logits, 0, draft_tokens, temperature=0.0, return_dict=False) + + self.assertIsInstance(result, tuple) + self.assertEqual(len(result), 4) + prev_sample, accepted_length, next_token, posterior = result + self.assertEqual(prev_sample.shape, (1, 3)) + self.assertEqual(accepted_length.shape, (1,)) + self.assertEqual(next_token.shape, (1,)) + self.assertEqual(posterior.shape, (1, 3)) + + # ------------------------------------------------------------------ + # step() – batched + # ------------------------------------------------------------------ + def test_step_batched(self): + scheduler = self.get_scheduler() + batch_size, block_size, vocab_size = 3, 4, 16 + draft_tokens = torch.randint(0, vocab_size, (batch_size, block_size)) + logits = torch.randn(batch_size, block_size, vocab_size) + + out = scheduler.step(logits, 0, draft_tokens, temperature=0.0, return_dict=True) + + self.assertEqual(out.prev_sample.shape, (batch_size, block_size)) + self.assertEqual(out.accepted_length.shape, (batch_size,)) + self.assertEqual(out.next_token.shape, (batch_size,)) + self.assertEqual(out.posterior.shape, (batch_size, block_size)) + + # ------------------------------------------------------------------ + # check_should_stop() + # ------------------------------------------------------------------ + def test_check_should_stop_no_stop_tokens(self): + output_ids = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.long) + self.assertFalse(DFlashTokenDiffusionScheduler.check_should_stop(output_ids, None, 2)) + + def test_check_should_stop_found(self): + # Stop token 99 is in the generated portion (after num_input_tokens=2). + output_ids = torch.tensor([[1, 2, 3, 99, 5]], dtype=torch.long) + self.assertTrue(DFlashTokenDiffusionScheduler.check_should_stop(output_ids, [99], 2)) + + def test_check_should_stop_only_in_prompt(self): + # Stop token 1 is only in the prompt portion => should NOT stop. + output_ids = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.long) + self.assertFalse(DFlashTokenDiffusionScheduler.check_should_stop(output_ids, [1], 2)) + + def test_check_should_stop_multiple_stop_tokens(self): + output_ids = torch.tensor([[10, 20, 30, 40, 50]], dtype=torch.long) + self.assertTrue(DFlashTokenDiffusionScheduler.check_should_stop(output_ids, [40, 99], 2)) + self.assertFalse(DFlashTokenDiffusionScheduler.check_should_stop(output_ids, [99, 100], 2)) + + # ------------------------------------------------------------------ + # add_noise() + # ------------------------------------------------------------------ + def test_add_noise_prompt_preserved(self): + scheduler = self.get_scheduler() + original = torch.tensor([[10, 11, 12, 13, 14, 15, 16, 17]], dtype=torch.long) + attention_mask = torch.ones_like(original) + mask_token_id = 99 + prompt_length = 3 + + gen = torch.Generator().manual_seed(42) + noisy, masked = scheduler.add_noise( + original, + attention_mask, + prompt_length=prompt_length, + block_size=4, + mask_token_id=mask_token_id, + generator=gen, + ) + + # Prompt positions should never be masked. + self.assertFalse(masked[0, :prompt_length].any().item()) + # Prompt tokens should be unchanged. + self.assertTrue(torch.equal(noisy[0, :prompt_length], original[0, :prompt_length])) + + def test_add_noise_masked_positions(self): + scheduler = self.get_scheduler() + original = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.long) + attention_mask = torch.ones_like(original) + mask_token_id = 99 + + gen = torch.Generator().manual_seed(0) + noisy, masked = scheduler.add_noise( + original, + attention_mask, + prompt_length=2, + block_size=3, + mask_token_id=mask_token_id, + generator=gen, + ) + + # Where masked is True, noisy should equal mask_token_id. + self.assertTrue((noisy[masked] == mask_token_id).all().item()) + # Where masked is False, noisy should equal original. + self.assertTrue(torch.equal(noisy[~masked], original[~masked])) + + def test_add_noise_respects_attention_mask(self): + scheduler = self.get_scheduler() + original = torch.tensor([[1, 2, 3, 4, 0, 0]], dtype=torch.long) + attention_mask = torch.tensor([[1, 1, 1, 1, 0, 0]], dtype=torch.long) + mask_token_id = 99 + + gen = torch.Generator().manual_seed(42) + noisy, masked = scheduler.add_noise( + original, + attention_mask, + prompt_length=1, + block_size=3, + mask_token_id=mask_token_id, + generator=gen, + ) + + # Padding positions (attention_mask=0) should never be masked. + self.assertFalse(masked[0, 4].item()) + self.assertFalse(masked[0, 5].item()) + + def test_add_noise_output_shapes(self): + scheduler = self.get_scheduler() + batch_size, seq_len = 2, 10 + original = torch.randint(0, 50, (batch_size, seq_len)) + attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long) + mask_token_id = 99 + + noisy, masked = scheduler.add_noise( + original, + attention_mask, + prompt_length=2, + block_size=4, + mask_token_id=mask_token_id, + ) + + self.assertEqual(noisy.shape, (batch_size, seq_len)) + self.assertEqual(masked.shape, (batch_size, seq_len)) + self.assertEqual(noisy.dtype, torch.long) + self.assertEqual(masked.dtype, torch.bool) + + +if __name__ == "__main__": + unittest.main() From 715ca5046d82f1128215cf965903fbc1f0d260fc Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 8 May 2026 10:09:55 +0000 Subject: [PATCH 2/8] [Docs] flesh out DFlash pipeline + scheduler pages --- docs/source/en/api/pipelines/dflash.md | 75 ++++++++++++++++++- .../api/schedulers/dflash_token_diffusion.md | 15 +++- 2 files changed, 86 insertions(+), 4 deletions(-) diff --git a/docs/source/en/api/pipelines/dflash.md b/docs/source/en/api/pipelines/dflash.md index 95847e1fdd82..b1caea59f82a 100644 --- a/docs/source/en/api/pipelines/dflash.md +++ b/docs/source/en/api/pipelines/dflash.md @@ -12,8 +12,79 @@ specific language governing permissions and limitations under the License. # DFlash -`DFlashPipeline` performs block-diffusion speculative decoding using a diffusion draft model and a target causal LM. -The draft model is conditioned on target hidden features extracted during prefill and verification steps. +[DFlash](https://huggingface.co/collections/z-lab/dflash) is a block-diffusion speculative decoding scheme. A small +diffusion *draft* model proposes a block of tokens conditioned on hidden features extracted from intermediate layers +of a frozen *target* causal LM; the target then verifies the proposed block in a single forward pass and accepts the +longest matching prefix. The draft model is shared with the target's tokenizer, so no calibration is needed. + +`DFlashPipeline` ties the two models together: prefill on the target, draft a block, verify against the target's +posterior via [`DFlashTokenDiffusionScheduler`], commit the accepted prefix and the next-token resample, and repeat +until `max_new_tokens` or a stop token. Compatible draft/target pairs include `z-lab/Qwen3-8B-DFlash-b16` with +`Qwen/Qwen3-8B`, and `z-lab/Qwen3.5-4B-DFlash` with `Qwen/Qwen3.5-4B` (the latter is a hybrid-attention target — see +the rollback note below). + +## Usage + +```py +import torch +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer + +from diffusers import DFlashPipeline + +draft = AutoModel.from_pretrained( + "z-lab/Qwen3.5-4B-DFlash", trust_remote_code=True, dtype=torch.bfloat16, device_map="auto" +) +target = AutoModelForCausalLM.from_pretrained( + "Qwen/Qwen3.5-4B", trust_remote_code=True, dtype=torch.bfloat16, device_map="auto" +) +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-4B", trust_remote_code=True) + +pipe = DFlashPipeline(draft_model=draft, target_model=target, tokenizer=tokenizer) +output = pipe( + prompt="What is 2 + 2? Answer in one sentence.", + max_new_tokens=128, + temperature=0.0, + chat_template_kwargs={"enable_thinking": False}, +) +print(output.texts[0]) +``` + +`DFlashPipeline` currently runs `batch_size=1` only. Multi-prompt batching requires per-row partial-accept tracking +and is not yet supported. + +## Hybrid-attention targets + +For target models with linear-attention layers (e.g. Qwen3.5's gated-delta-net), `DynamicCache.crop()` silently +no-ops on those layers, so a partial-accept block would otherwise leak rejected speculative tokens into the +recurrent state. The pipeline detects linear-attention caches via +[`DFlashTokenDiffusionScheduler.cache_has_linear_attention`] and uses a snapshot/restore + accepted-prefix +re-forward pattern to advance both layer types cleanly. This adds one extra target forward per partial-accept +block but is required for correctness. + +## Fast path + +When the draft model exposes a `spec_generate(...)` method (e.g. `z-lab/Qwen3-8B-DFlash-b16`), the pipeline +delegates to it — that loop is the upstream-canonical implementation and avoids re-running the rollback bookkeeping. +Newer drafts (`z-lab/Qwen3.5-4B-DFlash`) drop `spec_generate`; the pipeline falls back to its explicit verify loop. + +## Callbacks + +Callbacks run after each block-verify step. Pass `callback_on_step_end_tensor_inputs` to select which tensors are +included in `callback_kwargs`. Allowed keys: `block_output_ids` (the drafted block), `draft_logits`, +`accepted_length`, `next_token`, and `output_ids` (the running output buffer). Return `{"output_ids": ...}` from the +callback to replace the buffer. + +```py +def on_step_end(pipe, step, timestep, callback_kwargs): + output_ids = callback_kwargs["output_ids"] + return {"output_ids": output_ids} + +out = pipe( + prompt="...", + callback_on_step_end=on_step_end, + callback_on_step_end_tensor_inputs=["output_ids"], +) +``` ## DFlashPipeline [[autodoc]] DFlashPipeline diff --git a/docs/source/en/api/schedulers/dflash_token_diffusion.md b/docs/source/en/api/schedulers/dflash_token_diffusion.md index c98b11bc9963..faa5c2405c87 100644 --- a/docs/source/en/api/schedulers/dflash_token_diffusion.md +++ b/docs/source/en/api/schedulers/dflash_token_diffusion.md @@ -12,8 +12,19 @@ specific language governing permissions and limitations under the License. # DFlashTokenDiffusionScheduler -`DFlashTokenDiffusionScheduler` implements the acceptance and posterior sampling logic used in DFlash-style block -diffusion speculative decoding. +[`DFlashTokenDiffusionScheduler`] implements the verification step for DFlash-style block-diffusion speculative +decoding. It samples a posterior block from the target logits, computes the acceptance length as the longest prefix +where the draft proposal matches the posterior, and exposes the resampled `next_token` for the first rejected +position. Used by [`DFlashPipeline`]. + +The scheduler also owns three helpers used by the pipeline's verify loop on hybrid-attention targets: + +- `cache_has_linear_attention(cache)` — detect whether a `DynamicCache` contains any linear-attention layers. +- `snapshot_cache(cache)` / `restore_cache(cache, snapshot)` — clone and restore the full per-layer state so a + partial-accept block can be rolled back and the target re-advanced on just the accepted prefix. + +These exist because `DynamicCache.crop()` silently no-ops on linear-attention layers, which would otherwise let +rejected speculative tokens permanently contaminate the recurrent state. ## DFlashTokenDiffusionScheduler [[autodoc]] DFlashTokenDiffusionScheduler From 4c0e3dd594c4bbe4345d6a6b9cff518a835bf66a Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 8 May 2026 10:41:18 +0000 Subject: [PATCH 3/8] [DFlash] fix train_dflash position_ids + clarify trust_remote_code - Training: `position_ids` must span `[0, start + block_size)` so the draft's attention RoPE cos/sin covers both `k_ctx` (target_hidden, length `start`) and `k_noise` (noise_embedding, length `block_size`). Previously we passed only `arange(start, start + block_size)` which triggered a K-side broadcast mismatch on the very first batch. - Docs/examples: target loads as plain Qwen3 / Qwen3.5 (no remote code), but the draft's custom DFlashDraftModel class lives in the Hub repo's `auto_map`, so `trust_remote_code=True` is required for draft loads only. Updated the example docstring, pipeline doc page, sample script, train script, and the GPU verify script. Smoke-tested via srun on z-lab/Qwen3.5-4B-DFlash + Qwen/Qwen3.5-4B (H100): 3 steps complete, final checkpoint saved. --- docs/source/en/api/pipelines/dflash.md | 7 +++---- examples/discrete_diffusion/sample_dflash.py | 2 +- examples/discrete_diffusion/train_dflash.py | 5 ++++- src/diffusers/pipelines/dflash/pipeline_dflash.py | 4 ++-- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/docs/source/en/api/pipelines/dflash.md b/docs/source/en/api/pipelines/dflash.md index b1caea59f82a..ab9153f5170e 100644 --- a/docs/source/en/api/pipelines/dflash.md +++ b/docs/source/en/api/pipelines/dflash.md @@ -31,13 +31,12 @@ from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer from diffusers import DFlashPipeline +# Draft ships custom modeling code via `auto_map` — `trust_remote_code=True` is required. draft = AutoModel.from_pretrained( "z-lab/Qwen3.5-4B-DFlash", trust_remote_code=True, dtype=torch.bfloat16, device_map="auto" ) -target = AutoModelForCausalLM.from_pretrained( - "Qwen/Qwen3.5-4B", trust_remote_code=True, dtype=torch.bfloat16, device_map="auto" -) -tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-4B", trust_remote_code=True) +target = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3.5-4B", dtype=torch.bfloat16, device_map="auto") +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-4B") pipe = DFlashPipeline(draft_model=draft, target_model=target, tokenizer=tokenizer) output = pipe( diff --git a/examples/discrete_diffusion/sample_dflash.py b/examples/discrete_diffusion/sample_dflash.py index a10899a0d052..b1a3088a971c 100644 --- a/examples/discrete_diffusion/sample_dflash.py +++ b/examples/discrete_diffusion/sample_dflash.py @@ -107,7 +107,7 @@ def main(): print(f"Loading draft model: {args.draft_model_id}") print(f"Loading target model: {args.target_model_id}") dtype_arg = torch_dtype if torch_dtype is not None else "auto" - # Draft model is a custom DFlashDraftModel; use AutoModel so trust_remote_code routes to the class in `auto_map`. + # Draft model is a custom DFlashDraftModel; trust_remote_code routes to the class in `auto_map`. draft_model = AutoModel.from_pretrained( args.draft_model_id, trust_remote_code=True, diff --git a/examples/discrete_diffusion/train_dflash.py b/examples/discrete_diffusion/train_dflash.py index 673a2173a058..6538a8db5f55 100644 --- a/examples/discrete_diffusion/train_dflash.py +++ b/examples/discrete_diffusion/train_dflash.py @@ -248,7 +248,10 @@ def main(): block_targets = input_ids[:, start + 1 : start + block_size] block_mask = attention_mask[:, start + 1 : start + block_size] - position_ids = torch.arange(start, start + block_size, device=input_ids.device).unsqueeze(0) + # The draft's attention concatenates `k_ctx` (target_hidden, length `start`) with + # `k_noise` (noise_embedding, length `block_size`); RoPE needs cos/sin covering the + # full range `[0, start + block_size)` so the K-side broadcast works. + position_ids = torch.arange(start + block_size, device=input_ids.device).unsqueeze(0) position_ids = position_ids.expand(input_ids.shape[0], -1) with torch.no_grad(): diff --git a/src/diffusers/pipelines/dflash/pipeline_dflash.py b/src/diffusers/pipelines/dflash/pipeline_dflash.py index e8b0276db109..94b81dd6b863 100644 --- a/src/diffusers/pipelines/dflash/pipeline_dflash.py +++ b/src/diffusers/pipelines/dflash/pipeline_dflash.py @@ -38,9 +38,9 @@ >>> from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer >>> draft = AutoModel.from_pretrained( - ... "z-lab/Qwen3-8B-DFlash-b16", trust_remote_code=True, torch_dtype=torch.bfloat16 + ... "z-lab/Qwen3-8B-DFlash-b16", trust_remote_code=True, dtype=torch.bfloat16 ... ) - >>> target = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-8B", torch_dtype=torch.bfloat16) + >>> target = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-8B", dtype=torch.bfloat16) >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B") >>> pipe = DFlashPipeline(draft_model=draft, target_model=target, tokenizer=tokenizer) >>> out = pipe(prompt="How many positive whole-number divisors does 196 have?") From cd0ce7b6254a643d2a4b305550327cc3e9f40fe9 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 8 May 2026 11:19:08 +0000 Subject: [PATCH 4/8] [DFlash] remove spec_generate fast path; explicit loop handles all targets MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The pipeline previously short-circuited to `draft.spec_generate(...)` when the draft model exposed it (e.g. z-lab/Qwen3-8B-DFlash-b16). That path is the upstream `dflash_generate` loop, which calls `past_key_values_target.crop()` unconditionally — fine for full-attention targets, but on hybrid targets it silently corrupts the linear-attention recurrent state. Confirmed in transformers 5.8.0.dev0 at cache_utils.py:759-761: def crop(self, max_length: int): # We don't crop the linear attention cache, so simply do nothing here pass `LinearAttentionCacheLayerMixin.crop` is documented as a no-op, so any verify loop that relies on `cache.crop()` for rollback is wrong on hybrid attention targets. Our explicit loop already handles this via `DFlashTokenDiffusionScheduler.snapshot_cache` / `restore_cache` plus an accepted-prefix re-forward, and reduces to a plain `.crop()` on full-attn targets. Verified end-to-end on GPU after the removal: - z-lab/Qwen3.5-4B-DFlash + Qwen/Qwen3.5-4B (hybrid attn): "2 + 2 equals 4." - z-lab/Qwen3-8B-DFlash-b16 + Qwen/Qwen3-8B (full attn): "2 + 2 equals 4." Fast tests: 43 passed. --- docs/source/en/api/pipelines/dflash.md | 18 ++++++---------- .../pipelines/dflash/pipeline_dflash.py | 21 ------------------- 2 files changed, 6 insertions(+), 33 deletions(-) diff --git a/docs/source/en/api/pipelines/dflash.md b/docs/source/en/api/pipelines/dflash.md index ab9153f5170e..215c07530c9e 100644 --- a/docs/source/en/api/pipelines/dflash.md +++ b/docs/source/en/api/pipelines/dflash.md @@ -53,18 +53,12 @@ and is not yet supported. ## Hybrid-attention targets -For target models with linear-attention layers (e.g. Qwen3.5's gated-delta-net), `DynamicCache.crop()` silently -no-ops on those layers, so a partial-accept block would otherwise leak rejected speculative tokens into the -recurrent state. The pipeline detects linear-attention caches via -[`DFlashTokenDiffusionScheduler.cache_has_linear_attention`] and uses a snapshot/restore + accepted-prefix -re-forward pattern to advance both layer types cleanly. This adds one extra target forward per partial-accept -block but is required for correctness. - -## Fast path - -When the draft model exposes a `spec_generate(...)` method (e.g. `z-lab/Qwen3-8B-DFlash-b16`), the pipeline -delegates to it — that loop is the upstream-canonical implementation and avoids re-running the rollback bookkeeping. -Newer drafts (`z-lab/Qwen3.5-4B-DFlash`) drop `spec_generate`; the pipeline falls back to its explicit verify loop. +For target models with linear-attention layers (e.g. Qwen3.5's gated-delta-net), `DynamicCache.crop()` is a +documented no-op on those layers (see `transformers.cache_utils.LinearAttentionCacheLayerMixin.crop`), so a +partial-accept block would otherwise leak rejected speculative tokens into the recurrent state. The pipeline +detects linear-attention caches via [`DFlashTokenDiffusionScheduler.cache_has_linear_attention`] and uses a +snapshot/restore + accepted-prefix re-forward pattern to advance both layer types cleanly. This adds one extra +target forward per partial-accept block on hybrid targets; full-attention targets use a plain `cache.crop()`. ## Callbacks diff --git a/src/diffusers/pipelines/dflash/pipeline_dflash.py b/src/diffusers/pipelines/dflash/pipeline_dflash.py index 94b81dd6b863..214ee0c44052 100644 --- a/src/diffusers/pipelines/dflash/pipeline_dflash.py +++ b/src/diffusers/pipelines/dflash/pipeline_dflash.py @@ -360,27 +360,6 @@ def __call__( target_config = getattr(self.target_model, "config", None) draft_config = getattr(self.draft_model, "config", None) - # Fast path: some draft models (e.g. z-lab/Qwen3-8B-DFlash-b16) ship a self-contained - # `spec_generate` method. Delegate when available — it's the upstream-canonical loop and - # avoids re-implementing rollback. Newer drafts (Qwen3.5-4B-DFlash) drop this method, so - # fall back to the explicit pipeline loop below. - spec_generate = getattr(self.draft_model, "spec_generate", None) - if callable(spec_generate): - generated = spec_generate( - input_ids=input_ids, - max_new_tokens=int(max_new_tokens), - temperature=float(temperature), - target=self.target_model, - stop_token_ids=stop_token_ids, - ) - sequences = generated[:, input_ids.shape[1] :] - texts = None - if output_type == "text" and getattr(self, "tokenizer", None) is not None: - texts = self.tokenizer.batch_decode(sequences, skip_special_tokens=True) - if not return_dict: - return sequences, texts - return DFlashPipelineOutput(sequences=sequences, texts=texts) - # Pass `config=` only when it looks like a real PretrainedConfig — hybrid-attention models # (Qwen3.5) need it so `DynamicCache` instantiates the right per-layer cache types # (linear vs full), but bare dummy configs in tests don't implement `get_text_config`. From a70e329fee17ba25ea83a25596d5aa033f64b685 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 9 May 2026 10:33:38 +0000 Subject: [PATCH 5/8] [DFlash] add num_timesteps property for parity with LLaDA2 --- src/diffusers/pipelines/dflash/pipeline_dflash.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/diffusers/pipelines/dflash/pipeline_dflash.py b/src/diffusers/pipelines/dflash/pipeline_dflash.py index 214ee0c44052..99b5e1883360 100644 --- a/src/diffusers/pipelines/dflash/pipeline_dflash.py +++ b/src/diffusers/pipelines/dflash/pipeline_dflash.py @@ -95,6 +95,10 @@ def __init__( draft_model=draft_model, target_model=target_model, tokenizer=tokenizer, scheduler=scheduler ) + @property + def num_timesteps(self): + return self._num_timesteps + # --- Prompt encoding --- def _prepare_input_ids( @@ -391,6 +395,7 @@ def _new_cache(cfg): start = num_input_tokens global_step = 0 num_blocks = (max_length - num_input_tokens + block_size - 1) // block_size + self._num_timesteps = int(num_blocks) # 5. Block-wise speculative decoding loop block_progress_bar_config = getattr(self, "_progress_bar_config", {}).copy() From 471afa98fc7d588d7cd1e914e6e369aa012ea514 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 9 May 2026 10:36:35 +0000 Subject: [PATCH 6/8] [DFlash] document examples in discrete_diffusion README --- examples/discrete_diffusion/README.md | 41 +++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/examples/discrete_diffusion/README.md b/examples/discrete_diffusion/README.md index a3a8253b1927..0b8fe38316c5 100644 --- a/examples/discrete_diffusion/README.md +++ b/examples/discrete_diffusion/README.md @@ -48,3 +48,44 @@ python examples/discrete_diffusion/sample_llada2.py \ --use_chat_template \ --add_generation_prompt ``` + +## DFlash + +[DFlash](https://huggingface.co/collections/z-lab/dflash) is a block-diffusion **speculative decoding** scheme: a small diffusion *draft* model, conditioned on hidden features from a frozen *target* causal LM, proposes a block of tokens that the target verifies in a single forward pass. The pipeline accepts the longest matching prefix and resamples the next token at the rejection point. + +### Sample + +The published draft pairs with a stock target (no `trust_remote_code` for the target): + +```bash +python examples/discrete_diffusion/sample_dflash.py \ + --draft_model_id z-lab/Qwen3.5-4B-DFlash \ + --target_model_id Qwen/Qwen3.5-4B \ + --prompt "How many positive whole-number divisors does 196 have?" \ + --max_new_tokens 4096 +``` + +The draft ships a custom `DFlashDraftModel` class via `auto_map`, so the sample script loads it with `trust_remote_code=True`; the target loads as a stock Qwen3 / Qwen3.5 model. Per-draft thinking-mode defaults from the upstream model cards: + +| Draft | `enable_thinking` | +|---|---| +| `z-lab/Qwen3.5-*-DFlash` | `True` | +| `z-lab/Qwen3-*-DFlash-b16` | `False` (drafts are non-thinking-trained) | + +### Train + +The training loop conditions the draft on intermediate target hidden states and predicts the next `block_size − 1` tokens of each block: + +```bash +accelerate launch examples/discrete_diffusion/train_dflash.py \ + --draft_model_id z-lab/Qwen3-4B-DFlash-b16 \ + --target_model_id Qwen/Qwen3-4B \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --text_column text \ + --output_dir dflash-output \ + --max_train_steps 1000 \ + --learning_rate 2e-5 +``` + +`--block_size 0` (default) reads the block size from the draft model's config (16 for the b16 drafts, 16 for `Qwen3.5-*-DFlash`). From cab4413f54207688840e04938ed29297ba94b622 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 10 May 2026 11:31:15 +0000 Subject: [PATCH 7/8] [DFlash] align train recipe with paper --- examples/discrete_diffusion/train_dflash.py | 43 ++++++++++++++++++--- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/examples/discrete_diffusion/train_dflash.py b/examples/discrete_diffusion/train_dflash.py index 6538a8db5f55..83da7d7c8432 100644 --- a/examples/discrete_diffusion/train_dflash.py +++ b/examples/discrete_diffusion/train_dflash.py @@ -52,10 +52,13 @@ class TrainConfig: weight_decay: float lr_scheduler: str lr_warmup_steps: int + lr_warmup_ratio: float + max_grad_norm: float max_length: int block_size: int mask_token: str + loss_decay_gamma: float def parse_args() -> TrainConfig: @@ -75,18 +78,32 @@ def parse_args() -> TrainConfig: parser.add_argument("--per_device_train_batch_size", type=int, default=2) parser.add_argument("--gradient_accumulation_steps", type=int, default=1) - parser.add_argument("--learning_rate", type=float, default=2e-5) + parser.add_argument("--learning_rate", type=float, default=6e-4) parser.add_argument("--weight_decay", type=float, default=0.0) parser.add_argument( "--lr_scheduler", type=str, default="cosine", choices=["linear", "cosine", "cosine_with_restarts"] ) - parser.add_argument("--lr_warmup_steps", type=int, default=100) + parser.add_argument( + "--lr_warmup_steps", + type=int, + default=0, + help="Absolute warmup steps. Ignored when --lr_warmup_ratio > 0 (default).", + ) + parser.add_argument("--lr_warmup_ratio", type=float, default=0.04) + parser.add_argument("--max_grad_norm", type=float, default=1.0) - parser.add_argument("--max_length", type=int, default=512) + parser.add_argument("--max_length", type=int, default=3072) parser.add_argument( "--block_size", type=int, default=0, help="Override draft block size (0 uses the model config)." ) parser.add_argument("--mask_token", type=str, default="<|MASK|>") + parser.add_argument( + "--loss_decay_gamma", + type=float, + default=0.0, + help="Per-position loss decay γ for w_k = exp(-(k-1)/γ). 0 selects the paper default for the " + "draft block size (γ=7 for block 16, γ=5 for block 10, γ=4 for block 8, else block_size/2).", + ) args = parser.parse_args() return TrainConfig(**vars(args)) @@ -177,6 +194,14 @@ def main(): if block_size < 2: raise ValueError("`block_size` must be at least 2 for DFlash training.") + # Eq. 4 in the DFlash paper: w_k = exp(-(k-1)/γ) over predicted positions k=1..block_size-1. + # Defaults from Appendix A.1. + if cfg.loss_decay_gamma > 0.0: + loss_gamma = float(cfg.loss_decay_gamma) + else: + loss_gamma = {16: 7.0, 10: 5.0, 8: 4.0}.get(block_size, max(2.0, block_size / 2.0)) + pos_weights = torch.exp(-torch.arange(block_size - 1, dtype=torch.float32) / loss_gamma) + layer_ids = getattr(draft_model, "target_layer_ids", None) if layer_ids is None: cfg_draft = getattr(draft_model, "config", None) @@ -208,10 +233,14 @@ def main(): num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.gradient_accumulation_steps) num_train_epochs = math.ceil(cfg.max_train_steps / num_update_steps_per_epoch) + if cfg.lr_warmup_ratio > 0.0: + num_warmup_steps = int(cfg.lr_warmup_ratio * cfg.max_train_steps) + else: + num_warmup_steps = cfg.lr_warmup_steps lr_scheduler = get_scheduler( name=cfg.lr_scheduler, optimizer=optimizer, - num_warmup_steps=cfg.lr_warmup_steps, + num_warmup_steps=num_warmup_steps, num_training_steps=cfg.max_train_steps, ) @@ -220,6 +249,7 @@ def main(): ) input_embeddings = get_target_input_embeddings(target_model) output_embeddings = get_target_output_embeddings(target_model) + pos_weights = pos_weights.to(accelerator.device) global_step = 0 draft_model.train() @@ -279,9 +309,12 @@ def main(): vocab_size = logits.shape[-1] loss = F.cross_entropy(logits.view(-1, vocab_size), block_targets.reshape(-1), reduction="none") loss = loss.view(block_targets.shape[0], -1) - loss = (loss * block_mask.to(loss.dtype)).sum() / block_mask.sum().clamp_min(1) + weights = pos_weights.to(loss.dtype)[None, :].expand_as(loss) * block_mask.to(loss.dtype) + loss = (loss * weights).sum() / weights.sum().clamp_min(1) accelerator.backward(loss) + if accelerator.sync_gradients and cfg.max_grad_norm > 0: + accelerator.clip_grad_norm_(draft_model.parameters(), cfg.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) From 9a5fb11929ce3a12d30f4ca15b8e4f5511aa065f Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 10 May 2026 11:35:46 +0000 Subject: [PATCH 8/8] [DFlash] cite paper in docs --- docs/source/en/api/pipelines/dflash.md | 16 +++++++++------- examples/discrete_diffusion/README.md | 2 +- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/docs/source/en/api/pipelines/dflash.md b/docs/source/en/api/pipelines/dflash.md index 215c07530c9e..ca5fbe29e1ea 100644 --- a/docs/source/en/api/pipelines/dflash.md +++ b/docs/source/en/api/pipelines/dflash.md @@ -12,16 +12,18 @@ specific language governing permissions and limitations under the License. # DFlash -[DFlash](https://huggingface.co/collections/z-lab/dflash) is a block-diffusion speculative decoding scheme. A small -diffusion *draft* model proposes a block of tokens conditioned on hidden features extracted from intermediate layers -of a frozen *target* causal LM; the target then verifies the proposed block in a single forward pass and accepts the -longest matching prefix. The draft model is shared with the target's tokenizer, so no calibration is needed. +[DFlash: Block Diffusion for Flash Speculative Decoding](https://huggingface.co/papers/2602.06036) is by Jian Chen, Yesheng Liang, and Zhijian Liu. + +The abstract from the paper is: + +*Autoregressive large language models (LLMs) deliver strong performance but require inherently sequential decoding, leading to high inference latency and poor GPU utilization. Speculative decoding mitigates this bottleneck by using a fast draft model whose outputs are verified in parallel by the target LLM. However, existing methods still rely on autoregressive drafting, which remains sequential and constrains practical speedups. Diffusion LLMs offer a promising alternative by enabling parallel generation, but current diffusion models typically underperform compared with autoregressive models. In this paper, we introduce DFlash, a speculative decoding framework that employs a lightweight block diffusion model for parallel drafting. We show that speculative decoding provides a natural and effective setting for diffusion models. By generating draft tokens in a single forward pass, DFlash enables efficient drafting, and by conditioning the draft model on context features extracted from the target model, it achieves high-quality drafts with higher acceptance rates. Experiments show that DFlash achieves over 6× lossless acceleration across a range of models and tasks, delivering up to 2.5× higher speedup than the state-of-the-art speculative decoding method EAGLE-3.* `DFlashPipeline` ties the two models together: prefill on the target, draft a block, verify against the target's posterior via [`DFlashTokenDiffusionScheduler`], commit the accepted prefix and the next-token resample, and repeat -until `max_new_tokens` or a stop token. Compatible draft/target pairs include `z-lab/Qwen3-8B-DFlash-b16` with -`Qwen/Qwen3-8B`, and `z-lab/Qwen3.5-4B-DFlash` with `Qwen/Qwen3.5-4B` (the latter is a hybrid-attention target — see -the rollback note below). +until `max_new_tokens` or a stop token. Pretrained draft/target pairs are available in the +[z-lab/dflash collection](https://huggingface.co/collections/z-lab/dflash); compatible pairs include +`z-lab/Qwen3-8B-DFlash-b16` with `Qwen/Qwen3-8B`, and `z-lab/Qwen3.5-4B-DFlash` with `Qwen/Qwen3.5-4B` (the latter is +a hybrid-attention target — see the rollback note below). ## Usage diff --git a/examples/discrete_diffusion/README.md b/examples/discrete_diffusion/README.md index 0b8fe38316c5..0050632de255 100644 --- a/examples/discrete_diffusion/README.md +++ b/examples/discrete_diffusion/README.md @@ -51,7 +51,7 @@ python examples/discrete_diffusion/sample_llada2.py \ ## DFlash -[DFlash](https://huggingface.co/collections/z-lab/dflash) is a block-diffusion **speculative decoding** scheme: a small diffusion *draft* model, conditioned on hidden features from a frozen *target* causal LM, proposes a block of tokens that the target verifies in a single forward pass. The pipeline accepts the longest matching prefix and resamples the next token at the rejection point. +[DFlash](https://huggingface.co/papers/2602.06036) ([model collection](https://huggingface.co/collections/z-lab/dflash)) is a block-diffusion **speculative decoding** scheme: a small diffusion *draft* model, conditioned on hidden features from a frozen *target* causal LM, proposes a block of tokens that the target verifies in a single forward pass. The pipeline accepts the longest matching prefix and resamples the next token at the rejection point. ### Sample