2020import torch
2121
2222from ...callbacks import MultiPipelineCallbacks , PipelineCallback
23+ from ...schedulers import BlockRefinementScheduler
2324from ...utils import BaseOutput
2425from ..pipeline_utils import DiffusionPipeline , DiscreteDiffusionPipelineMixin
2526
@@ -30,16 +31,6 @@ class BlockRefinementPipelineOutput(BaseOutput):
3031 texts : Optional [List [str ]] = None
3132
3233
33- def _get_num_transfer_tokens (block_length : int , steps : int ) -> torch .LongTensor :
34- if steps <= 0 :
35- return torch .zeros ((0 ,), dtype = torch .long )
36- base = int (block_length ) // int (steps )
37- remainder = int (block_length ) % int (steps )
38- out = torch .full ((int (steps ),), base , dtype = torch .long )
39- out [:remainder ] += 1
40- return out
41-
42-
4334class BlockRefinementPipeline (DiffusionPipeline , DiscreteDiffusionPipelineMixin ):
4435 """
4536 Block-wise iterative refinement pipeline for token generation.
@@ -52,17 +43,19 @@ class BlockRefinementPipeline(DiffusionPipeline, DiscreteDiffusionPipelineMixin)
5243 """
5344
5445 model : Any
46+ scheduler : BlockRefinementScheduler
5547 tokenizer : Any
5648
5749 _callback_tensor_inputs = ["cur_x" , "x0" , "x0_p" , "transfer_index" , "confidence" , "active_block" ]
5850
5951 def __init__ (
6052 self ,
6153 model : Any ,
54+ scheduler : BlockRefinementScheduler ,
6255 tokenizer : Optional [Any ] = None ,
6356 ):
6457 super ().__init__ ()
65- self .register_modules (model = model , tokenizer = tokenizer )
58+ self .register_modules (model = model , scheduler = scheduler , tokenizer = tokenizer )
6659
6760 @property
6861 def num_timesteps (self ):
@@ -310,6 +303,8 @@ def __call__(
310303
311304 steps = min (int (steps ), int (gen_length ) // int (minimal_topk ))
312305
306+ self .scheduler .set_timesteps (steps , device = model_device )
307+
313308 num_blocks = (prompt_length + int (gen_length ) + int (block_length ) - 1 ) // int (block_length )
314309 total_length = int (num_blocks ) * int (block_length )
315310
@@ -333,7 +328,6 @@ def __call__(
333328
334329 prefill_blocks = prompt_length // int (block_length )
335330 self ._num_timesteps = int (steps ) * max (int (num_blocks ) - int (prefill_blocks ), 0 )
336- transfer_schedule = _get_num_transfer_tokens (int (block_length ), int (steps )).to (device = model_device )
337331
338332 finished = torch .zeros ((batch_size ,), device = model_device , dtype = torch .bool )
339333 resolved_attention_mode : str = str (attention_mask_mode )
@@ -362,8 +356,9 @@ def __call__(
362356 if finished .all ():
363357 break
364358
365- active_block = cur_x [:, - int (block_length ) :] == int (mask_token_id )
366- masks_remaining = active_block .sum () > 0
359+ block_tokens = cur_x [:, - int (block_length ) :]
360+ active_block = block_tokens == int (mask_token_id )
361+ masks_remaining = active_block .any ()
367362
368363 if not masks_remaining and not editing_enabled :
369364 break
@@ -390,47 +385,26 @@ def __call__(
390385 use_multinomial = use_multinomial ,
391386 )
392387
393- # --- Mask-filling transfer ---
394- transfer_index = torch .zeros_like (x0 , dtype = torch .bool )
395- if masks_remaining and step_idx < int (steps ):
396- clamped_step = min (step_idx , len (transfer_schedule ) - 1 )
397- num_to_transfer = int (transfer_schedule [clamped_step ].item ())
398-
399- confidence = torch .where (
400- active_block ,
401- x0_p .to (dtype = torch .float32 ),
402- torch .full_like (x0_p , - torch .inf , dtype = torch .float32 ),
403- )
404-
405- for b in range (batch_size ):
406- if finished [b ]:
407- continue
408- high_conf = confidence [b ] > float (threshold )
409- if high_conf .sum ().item () >= num_to_transfer :
410- transfer_index [b ] = high_conf
411- else :
412- k = min (num_to_transfer , int (active_block [b ].sum ().item ()))
413- if k > 0 :
414- _ , idx = torch .topk (confidence [b ], k = k )
415- transfer_index [b , idx ] = True
416-
417- # --- Editing transfer (non-mask, non-prompt positions) ---
418- editing_transfer_index = torch .zeros_like (x0 , dtype = torch .bool )
419- if editing_enabled :
420- old_block_tokens = cur_x [:, - int (block_length ) :]
421- editable = (~ active_block ) & (~ prompt_mask_in_block .unsqueeze (0 ))
422- editing_conf = torch .where (
423- editable , x0_p .to (dtype = torch .float32 ), torch .full_like (x0_p , - torch .inf , dtype = torch .float32 )
424- )
425- high_conf_edit = editing_conf > float (editing_threshold )
426- token_changed = x0 != old_block_tokens
427- editing_transfer_index = high_conf_edit & token_changed & editable
388+ scheduler_output = self .scheduler .step (
389+ sampled_tokens = x0 ,
390+ sampled_probs = x0_p ,
391+ timestep = step_idx ,
392+ sample = block_tokens ,
393+ mask_token_id = int (mask_token_id ),
394+ threshold = float (threshold ),
395+ editing_threshold = editing_threshold ,
396+ minimal_topk = int (minimal_topk ),
397+ prompt_mask = prompt_mask_in_block ,
398+ generator = generator ,
399+ return_dict = True ,
400+ )
428401
402+ transfer_index = scheduler_output .transfer_index
403+ editing_transfer_index = scheduler_output .editing_transfer_index
429404 final_transfer = transfer_index | editing_transfer_index
405+
430406 if final_transfer .any ():
431- updated = cur_x [:, - int (block_length ) :].clone ()
432- updated [final_transfer ] = x0 [final_transfer ]
433- cur_x [:, - int (block_length ) :] = updated
407+ cur_x [:, - int (block_length ) :] = scheduler_output .prev_sample
434408
435409 # Break if no masks remain and no edits were made.
436410 if not masks_remaining and not editing_transfer_index .any ():
0 commit comments