Skip to content

Commit b3f6cb5

Browse files
committed
feat: add BlockRefinementScheduler for commit-by-confidence scheduling
Extract the confidence-based token commit logic from BlockRefinementPipeline into a dedicated BlockRefinementScheduler, following diffusers conventions. The scheduler owns: - Transfer schedule computation (get_num_transfer_tokens) - Timestep management (set_timesteps) - Step logic: confidence-based mask-filling and optional token editing The pipeline now delegates scheduling to self.scheduler.step() and accepts a scheduler parameter in __init__.
1 parent 929522f commit b3f6cb5

11 files changed

Lines changed: 308 additions & 73 deletions

File tree

docs/source/en/api/pipelines/block_refinement.md

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,16 @@ You can set default sampling parameters when creating the pipeline. Passing `Non
2121
falls back to `pipe.config`.
2222

2323
```py
24-
from diffusers import BlockRefinementPipeline
24+
from diffusers import BlockRefinementPipeline, BlockRefinementScheduler
2525

26+
scheduler = BlockRefinementScheduler()
2627
pipe = BlockRefinementPipeline(
2728
model=model,
29+
scheduler=scheduler,
2830
tokenizer=tokenizer,
29-
gen_length=256,
30-
block_length=32,
31-
steps=16,
32-
temperature=0.8,
33-
sampling_method="multinomial",
3431
)
3532

36-
out = pipe(prompt="Explain gradient descent.")
33+
out = pipe(prompt="Explain gradient descent.", gen_length=256, block_length=32, steps=16, temperature=0.8)
3734
print(out.texts[0])
3835
```
3936

@@ -61,3 +58,9 @@ out = pipe(
6158

6259
## BlockRefinementPipelineOutput
6360
[[autodoc]] pipelines.BlockRefinementPipelineOutput
61+
62+
## BlockRefinementScheduler
63+
[[autodoc]] BlockRefinementScheduler
64+
65+
## BlockRefinementSchedulerOutput
66+
[[autodoc]] schedulers.scheduling_block_refinement.BlockRefinementSchedulerOutput

examples/discrete_diffusion/sample_block_refinement.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from transformers import AutoModelForCausalLM, AutoTokenizer
77

8-
from diffusers import BlockRefinementPipeline
8+
from diffusers import BlockRefinementPipeline, BlockRefinementScheduler
99

1010

1111
def main():
@@ -38,7 +38,8 @@ def main():
3838
if tokenizer.mask_token_id is None:
3939
raise ValueError("Tokenizer must have `mask_token_id` for block refinement sampling.")
4040

41-
pipe = BlockRefinementPipeline(model=model, tokenizer=tokenizer).to(args.device)
41+
scheduler = BlockRefinementScheduler()
42+
pipe = BlockRefinementPipeline(model=model, scheduler=scheduler, tokenizer=tokenizer).to(args.device)
4243
gen = torch.Generator(device=args.device).manual_seed(args.seed)
4344

4445
prompt_ids = tokenizer(args.prompt, return_tensors="pt")["input_ids"].to(args.device)

examples/discrete_diffusion/sample_llada2.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import torch
3030
from transformers import AutoModelForCausalLM, AutoTokenizer
3131

32-
from diffusers import LLaDA2Pipeline
32+
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
3333
from diffusers.hooks import apply_group_offloading
3434

3535

@@ -207,7 +207,8 @@ def main():
207207
model.eval()
208208

209209
# Create pipeline
210-
pipe = LLaDA2Pipeline(model=model, tokenizer=tokenizer)
210+
scheduler = BlockRefinementScheduler()
211+
pipe = LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer)
211212

212213
# Apply sequential CPU offload if requested
213214
if args.offload == "sequential":

examples/discrete_diffusion/train_block_refinement_cap.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from accelerate.utils import ProjectConfiguration, set_seed
2727
from torch.utils.data import DataLoader, Dataset
2828

29-
from diffusers import BlockRefinementPipeline
29+
from diffusers import BlockRefinementPipeline, BlockRefinementScheduler
3030
from diffusers.training_utils import compute_confidence_aware_loss
3131

3232

@@ -249,7 +249,8 @@ def main():
249249
dataloader = DataLoader(dataset, batch_size=cfg.per_device_train_batch_size, shuffle=True, drop_last=True)
250250

251251
model = TinyBlockRefinementLM(vocab_size=cfg.vocab_size)
252-
pipe = BlockRefinementPipeline(model=model, tokenizer=None)
252+
scheduler = BlockRefinementScheduler()
253+
pipe = BlockRefinementPipeline(model=model, scheduler=scheduler, tokenizer=None)
253254

254255
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
255256

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,8 @@
347347
_import_structure["schedulers"].extend(
348348
[
349349
"AmusedScheduler",
350+
"BlockRefinementScheduler",
351+
"BlockRefinementSchedulerOutput",
350352
"CMStochasticIterativeScheduler",
351353
"CogVideoXDDIMScheduler",
352354
"CogVideoXDPMScheduler",
@@ -1121,6 +1123,8 @@
11211123
from .quantizers import DiffusersQuantizer
11221124
from .schedulers import (
11231125
AmusedScheduler,
1126+
BlockRefinementScheduler,
1127+
BlockRefinementSchedulerOutput,
11241128
CMStochasticIterativeScheduler,
11251129
CogVideoXDDIMScheduler,
11261130
CogVideoXDPMScheduler,

src/diffusers/pipelines/block_refinement/pipeline_block_refinement.py

Lines changed: 26 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch
2121

2222
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
23+
from ...schedulers import BlockRefinementScheduler
2324
from ...utils import BaseOutput
2425
from ..pipeline_utils import DiffusionPipeline, DiscreteDiffusionPipelineMixin
2526

@@ -30,16 +31,6 @@ class BlockRefinementPipelineOutput(BaseOutput):
3031
texts: Optional[List[str]] = None
3132

3233

33-
def _get_num_transfer_tokens(block_length: int, steps: int) -> torch.LongTensor:
34-
if steps <= 0:
35-
return torch.zeros((0,), dtype=torch.long)
36-
base = int(block_length) // int(steps)
37-
remainder = int(block_length) % int(steps)
38-
out = torch.full((int(steps),), base, dtype=torch.long)
39-
out[:remainder] += 1
40-
return out
41-
42-
4334
class BlockRefinementPipeline(DiffusionPipeline, DiscreteDiffusionPipelineMixin):
4435
"""
4536
Block-wise iterative refinement pipeline for token generation.
@@ -52,17 +43,19 @@ class BlockRefinementPipeline(DiffusionPipeline, DiscreteDiffusionPipelineMixin)
5243
"""
5344

5445
model: Any
46+
scheduler: BlockRefinementScheduler
5547
tokenizer: Any
5648

5749
_callback_tensor_inputs = ["cur_x", "x0", "x0_p", "transfer_index", "confidence", "active_block"]
5850

5951
def __init__(
6052
self,
6153
model: Any,
54+
scheduler: BlockRefinementScheduler,
6255
tokenizer: Optional[Any] = None,
6356
):
6457
super().__init__()
65-
self.register_modules(model=model, tokenizer=tokenizer)
58+
self.register_modules(model=model, scheduler=scheduler, tokenizer=tokenizer)
6659

6760
@property
6861
def num_timesteps(self):
@@ -310,6 +303,8 @@ def __call__(
310303

311304
steps = min(int(steps), int(gen_length) // int(minimal_topk))
312305

306+
self.scheduler.set_timesteps(steps, device=model_device)
307+
313308
num_blocks = (prompt_length + int(gen_length) + int(block_length) - 1) // int(block_length)
314309
total_length = int(num_blocks) * int(block_length)
315310

@@ -333,7 +328,6 @@ def __call__(
333328

334329
prefill_blocks = prompt_length // int(block_length)
335330
self._num_timesteps = int(steps) * max(int(num_blocks) - int(prefill_blocks), 0)
336-
transfer_schedule = _get_num_transfer_tokens(int(block_length), int(steps)).to(device=model_device)
337331

338332
finished = torch.zeros((batch_size,), device=model_device, dtype=torch.bool)
339333
resolved_attention_mode: str = str(attention_mask_mode)
@@ -362,8 +356,9 @@ def __call__(
362356
if finished.all():
363357
break
364358

365-
active_block = cur_x[:, -int(block_length) :] == int(mask_token_id)
366-
masks_remaining = active_block.sum() > 0
359+
block_tokens = cur_x[:, -int(block_length) :]
360+
active_block = block_tokens == int(mask_token_id)
361+
masks_remaining = active_block.any()
367362

368363
if not masks_remaining and not editing_enabled:
369364
break
@@ -390,47 +385,26 @@ def __call__(
390385
use_multinomial=use_multinomial,
391386
)
392387

393-
# --- Mask-filling transfer ---
394-
transfer_index = torch.zeros_like(x0, dtype=torch.bool)
395-
if masks_remaining and step_idx < int(steps):
396-
clamped_step = min(step_idx, len(transfer_schedule) - 1)
397-
num_to_transfer = int(transfer_schedule[clamped_step].item())
398-
399-
confidence = torch.where(
400-
active_block,
401-
x0_p.to(dtype=torch.float32),
402-
torch.full_like(x0_p, -torch.inf, dtype=torch.float32),
403-
)
404-
405-
for b in range(batch_size):
406-
if finished[b]:
407-
continue
408-
high_conf = confidence[b] > float(threshold)
409-
if high_conf.sum().item() >= num_to_transfer:
410-
transfer_index[b] = high_conf
411-
else:
412-
k = min(num_to_transfer, int(active_block[b].sum().item()))
413-
if k > 0:
414-
_, idx = torch.topk(confidence[b], k=k)
415-
transfer_index[b, idx] = True
416-
417-
# --- Editing transfer (non-mask, non-prompt positions) ---
418-
editing_transfer_index = torch.zeros_like(x0, dtype=torch.bool)
419-
if editing_enabled:
420-
old_block_tokens = cur_x[:, -int(block_length) :]
421-
editable = (~active_block) & (~prompt_mask_in_block.unsqueeze(0))
422-
editing_conf = torch.where(
423-
editable, x0_p.to(dtype=torch.float32), torch.full_like(x0_p, -torch.inf, dtype=torch.float32)
424-
)
425-
high_conf_edit = editing_conf > float(editing_threshold)
426-
token_changed = x0 != old_block_tokens
427-
editing_transfer_index = high_conf_edit & token_changed & editable
388+
scheduler_output = self.scheduler.step(
389+
sampled_tokens=x0,
390+
sampled_probs=x0_p,
391+
timestep=step_idx,
392+
sample=block_tokens,
393+
mask_token_id=int(mask_token_id),
394+
threshold=float(threshold),
395+
editing_threshold=editing_threshold,
396+
minimal_topk=int(minimal_topk),
397+
prompt_mask=prompt_mask_in_block,
398+
generator=generator,
399+
return_dict=True,
400+
)
428401

402+
transfer_index = scheduler_output.transfer_index
403+
editing_transfer_index = scheduler_output.editing_transfer_index
429404
final_transfer = transfer_index | editing_transfer_index
405+
430406
if final_transfer.any():
431-
updated = cur_x[:, -int(block_length) :].clone()
432-
updated[final_transfer] = x0[final_transfer]
433-
cur_x[:, -int(block_length) :] = updated
407+
cur_x[:, -int(block_length) :] = scheduler_output.prev_sample
434408

435409
# Break if no masks remain and no edits were made.
436410
if not masks_remaining and not editing_transfer_index.any():

src/diffusers/pipelines/llada2/pipeline_llada2.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,15 @@
3333
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
3434
>>> from diffusers import LLaDA2Pipeline
3535
36+
>>> from diffusers import BlockRefinementScheduler
37+
3638
>>> model_id = "inclusionAI/LLaDA2.0-mini"
3739
>>> model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.bfloat16)
3840
>>> tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
3941
>>> model = model.to("cuda")
42+
>>> scheduler = BlockRefinementScheduler()
4043
41-
>>> pipe = LLaDA2Pipeline(model=model, tokenizer=tokenizer)
44+
>>> pipe = LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer)
4245
>>> output = pipe(prompt="What is the meaning of life?", gen_length=256)
4346
>>> print(output.texts[0])
4447
```

src/diffusers/schedulers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
else:
4141
_import_structure["deprecated"] = ["KarrasVeScheduler", "ScoreSdeVpScheduler"]
4242
_import_structure["scheduling_amused"] = ["AmusedScheduler"]
43+
_import_structure["scheduling_block_refinement"] = ["BlockRefinementScheduler", "BlockRefinementSchedulerOutput"]
4344
_import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"]
4445
_import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"]
4546
_import_structure["scheduling_ddim"] = ["DDIMScheduler"]
@@ -145,6 +146,7 @@
145146
else:
146147
from .deprecated import KarrasVeScheduler, ScoreSdeVpScheduler
147148
from .scheduling_amused import AmusedScheduler
149+
from .scheduling_block_refinement import BlockRefinementScheduler, BlockRefinementSchedulerOutput
148150
from .scheduling_consistency_decoder import ConsistencyDecoderScheduler
149151
from .scheduling_consistency_models import CMStochasticIterativeScheduler
150152
from .scheduling_ddim import DDIMScheduler

0 commit comments

Comments
 (0)