diff --git a/src/diffusers/pipelines/llada2/pipeline_llada2.py b/src/diffusers/pipelines/llada2/pipeline_llada2.py index a6ba6e8ff689..0093a9d80414 100644 --- a/src/diffusers/pipelines/llada2/pipeline_llada2.py +++ b/src/diffusers/pipelines/llada2/pipeline_llada2.py @@ -71,7 +71,14 @@ class LLaDA2Pipeline(DiffusionPipeline): scheduler: BlockRefinementScheduler tokenizer: Any - _callback_tensor_inputs = ["block_x", "x0", "x0_p", "transfer_index", "confidence", "active_block"] + _callback_tensor_inputs = [ + "block_x", + "transfer_index", + "editing_transfer_index", + "sampled_tokens", + "sampled_probs", + "active_block", + ] def __init__( self, @@ -99,8 +106,9 @@ def _prepare_input_ids( use_chat_template: bool, add_generation_prompt: bool, chat_template_kwargs: dict[str, Any] | None, - ) -> torch.LongTensor: - """Convert prompt/messages/input_ids to a [batch, seq] LongTensor.""" + attention_mask: torch.LongTensor | None = None, + ) -> tuple[torch.LongTensor, torch.LongTensor]: + """Convert prompt/messages/input_ids to `(input_ids, attention_mask)` tensors of shape `[batch, seq]`.""" if input_ids is not None: if input_ids.ndim == 1: input_ids = input_ids.unsqueeze(0) @@ -108,7 +116,18 @@ def _prepare_input_ids( raise ValueError(f"`input_ids` must be 2D, got shape {tuple(input_ids.shape)}.") if input_ids.dtype != torch.long: raise ValueError(f"`input_ids` must be int64 token IDs, got dtype={input_ids.dtype}.") - return input_ids + if attention_mask is None: + attention_mask = torch.ones_like(input_ids, dtype=torch.long) + else: + if attention_mask.ndim == 1: + attention_mask = attention_mask.unsqueeze(0) + if attention_mask.shape != input_ids.shape: + raise ValueError( + f"`attention_mask` shape {tuple(attention_mask.shape)} must match `input_ids` shape " + f"{tuple(input_ids.shape)}." + ) + attention_mask = attention_mask.to(dtype=torch.long) + return input_ids, attention_mask if self.tokenizer is None: raise ValueError("Tokenizer is required when `input_ids` is not provided.") @@ -129,7 +148,11 @@ def _prepare_input_ids( return_dict=True, **chat_template_kwargs, ) - return encoded["input_ids"] + ids = encoded["input_ids"] + mask = encoded.get("attention_mask") + if mask is None: + mask = torch.ones_like(ids, dtype=torch.long) + return ids, mask.to(dtype=torch.long) if use_chat_template and getattr(self.tokenizer, "chat_template", None): if isinstance(prompt, list): @@ -142,10 +165,18 @@ def _prepare_input_ids( return_dict=True, **chat_template_kwargs, ) - return encoded["input_ids"] + ids = encoded["input_ids"] + mask = encoded.get("attention_mask") + if mask is None: + mask = torch.ones_like(ids, dtype=torch.long) + return ids, mask.to(dtype=torch.long) encoded = self.tokenizer(prompt, return_tensors="pt", padding=isinstance(prompt, list)) - return encoded["input_ids"] + ids = encoded["input_ids"] + mask = encoded.get("attention_mask") + if mask is None: + mask = torch.ones_like(ids, dtype=torch.long) + return ids, mask.to(dtype=torch.long) def check_inputs( self, @@ -215,6 +246,7 @@ def __call__( prompt: str | list[str] | None = None, messages: list[dict[str, str]] | None = None, input_ids: torch.LongTensor | None = None, + attention_mask: torch.LongTensor | None = None, use_chat_template: bool = True, add_generation_prompt: bool = True, gen_length: int = 2048, @@ -252,6 +284,11 @@ def __call__( when provided. Requires a tokenizer with `apply_chat_template`. input_ids (`torch.LongTensor`, *optional*): Pre-tokenized input IDs. Takes precedence over `prompt` and `messages`. + attention_mask (`torch.LongTensor`, *optional*): + Per-token mask (1 for valid prompt tokens, 0 for padding) matching the shape of `input_ids`. Only used + when `input_ids` is provided. When omitted (and `input_ids` is given), all positions are treated as + valid. When constructing inputs from `prompt` / `messages`, the tokenizer's mask is carried through + automatically. use_chat_template (`bool`, defaults to `True`): Whether to wrap the prompt in a chat template. add_generation_prompt (`bool`, defaults to `True`): @@ -299,8 +336,8 @@ def __call__( Callback executed after each refinement step with signature `callback_on_step_end(self, step: int, timestep: int, callback_kwargs: Dict)`. callback_on_step_end_tensor_inputs (`List[str]`, *optional*): - Tensor keys to pass to the callback. Allowed keys: `block_x`, `x0`, `x0_p`, `transfer_index`, - `confidence`, `active_block`. + Tensor keys to pass to the callback. Allowed keys: `block_x`, `transfer_index`, + `editing_transfer_index`, `sampled_tokens`, `sampled_probs`, `active_block`. Examples: """ @@ -328,10 +365,11 @@ def __call__( ) # 2. Prepare input IDs from prompt/messages/input_ids - prompt_ids = self._prepare_input_ids( + prompt_ids, prompt_attention_mask = self._prepare_input_ids( prompt=prompt, messages=messages, input_ids=input_ids, + attention_mask=attention_mask, use_chat_template=use_chat_template, add_generation_prompt=add_generation_prompt, chat_template_kwargs=None, @@ -342,6 +380,7 @@ def __call__( if prompt_ids.ndim == 1: prompt_ids = prompt_ids.unsqueeze(0) prompt_ids = prompt_ids.to(device=device) + prompt_attention_mask = prompt_attention_mask.to(device=device) batch_size, prompt_length = prompt_ids.shape if eos_token_id is None: @@ -353,14 +392,18 @@ def __call__( num_inference_steps = min(num_inference_steps, gen_length // minimal_topk) - self.scheduler.set_timesteps(num_inference_steps, device=device) + self.scheduler.set_timesteps(num_inference_steps, device=device, block_length=block_length) # 3. Build attention mask and position IDs num_blocks = (prompt_length + gen_length + block_length - 1) // block_length total_length = num_blocks * block_length - # 2D attention mask (no padding) — the model handles backend-specific conversion internally. - attn_mask = torch.ones((batch_size, total_length), device=device, dtype=torch.long) + # 2D attention mask: prompt tokenizer mask + ones over generated positions + zeros over the + # block-aligned tail past `prompt_length + gen_length`. The model handles backend-specific + # conversion internally; this just tells it which positions are real context. + attn_mask = torch.zeros((batch_size, total_length), device=device, dtype=torch.long) + attn_mask[:, :prompt_length] = prompt_attention_mask + attn_mask[:, prompt_length : prompt_length + gen_length] = 1 position_ids = torch.arange(total_length, device=device, dtype=torch.long).unsqueeze(0).expand(batch_size, -1) @@ -377,9 +420,8 @@ def __call__( global_step = 0 # 5. Block-wise refinement loop - block_progress_bar_config = getattr(self, "_progress_bar_config", {}).copy() - block_progress_bar_config["position"] = 0 - block_progress_bar_config["desc"] = "Blocks" + outer_progress_bar_config = getattr(self, "_progress_bar_config", {}).copy() + block_progress_bar_config = {**outer_progress_bar_config, "position": 0, "desc": "Blocks"} for num_block in tqdm(range(prefill_blocks, num_blocks), **block_progress_bar_config): current_window_end = (num_block + 1) * block_length block_x = x[:, :current_window_end] @@ -396,8 +438,13 @@ def __call__( post_steps = 0 step_idx = 0 should_continue = True - self.set_progress_bar_config(position=1, leave=False, desc=f"Block {num_block} Inference Steps") - progress_bar = self.progress_bar(total=num_inference_steps) + inner_progress_bar_config = { + **outer_progress_bar_config, + "position": 1, + "leave": False, + "desc": f"Block {num_block} Inference Steps", + } + progress_bar = tqdm(total=num_inference_steps, **inner_progress_bar_config) while should_continue: block_tokens = block_x[:, -block_length:] @@ -428,10 +475,19 @@ def __call__( transfer_index = scheduler_output.transfer_index editing_transfer_index = scheduler_output.editing_transfer_index + sampled_tokens = scheduler_output.sampled_tokens + sampled_probs = scheduler_output.sampled_probs + active_block = block_tokens == mask_token_id final_transfer = transfer_index | editing_transfer_index + # Freeze rows that already emitted EOS so further blocks don't extend them. + if eos_early_stop and finished.any(): + final_transfer = final_transfer & ~finished[:, None] + if final_transfer.any(): - block_x[:, -block_length:] = scheduler_output.prev_sample + block_x[:, -block_length:] = torch.where( + final_transfer, scheduler_output.prev_sample, block_tokens + ) if eos_early_stop and eos_token_id is not None: finished = self.scheduler.check_eos_finished( @@ -474,14 +530,21 @@ def __call__( # 6. Post-process output generated = x[:, : prompt_length + gen_length] sequences = generated[:, prompt_length:] - if eos_token_id is not None and batch_size == 1: - eos_positions = (sequences[0] == eos_token_id).nonzero(as_tuple=True)[0] - if len(eos_positions) > 0: - sequences = sequences[:, : int(eos_positions[0].item()) + 1] + + # For decode, trim each row at the first EOS so post-EOS positions (which may still hold + # mask tokens or refined content for unfinished blocks) don't leak into the decoded text. + decode_sequences: list[torch.LongTensor] | torch.LongTensor = sequences + if eos_token_id is not None: + decode_sequences = [ + seq[: int((seq == eos_token_id).nonzero(as_tuple=True)[0][0]) + 1] + if (seq == eos_token_id).any() + else seq + for seq in sequences + ] texts = None if output_type == "text" and self.tokenizer is not None: - texts = self.tokenizer.batch_decode(sequences, skip_special_tokens=True) + texts = self.tokenizer.batch_decode(decode_sequences, skip_special_tokens=True) if not return_dict: return sequences.to(device=device), texts diff --git a/src/diffusers/schedulers/scheduling_block_refinement.py b/src/diffusers/schedulers/scheduling_block_refinement.py index 296ad1b6a5fe..3b4d737767ce 100644 --- a/src/diffusers/schedulers/scheduling_block_refinement.py +++ b/src/diffusers/schedulers/scheduling_block_refinement.py @@ -75,12 +75,21 @@ def __init__( self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, dtype=torch.long) self._transfer_schedule: torch.LongTensor | None = None - def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: + def set_timesteps( + self, + num_inference_steps: int, + device: str | torch.device | None = None, + block_length: int | None = None, + ) -> None: if num_inference_steps <= 0: raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.") + if block_length is None: + block_length = self.config.block_length + elif block_length <= 0: + raise ValueError(f"`block_length` must be > 0, got {block_length}.") self.num_inference_steps = num_inference_steps self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, device=device, dtype=torch.long) - self._transfer_schedule = self.get_num_transfer_tokens(self.config.block_length, self.num_inference_steps).to( + self._transfer_schedule = self.get_num_transfer_tokens(block_length, self.num_inference_steps).to( device=device if device is not None else "cpu" ) @@ -343,7 +352,8 @@ def check_eos_finished( if len(eos_pos[0]) == 0: continue eos_pos = int(eos_pos[0][0].item()) - if prompt_length >= eos_pos: + # The first generated token sits at index `prompt_length`; allow EOS there. + if eos_pos < prompt_length: continue if (cur_x[b, prompt_length:eos_pos] != mask_token_id).all().item(): finished[b] = True diff --git a/tests/pipelines/llada2/test_llada2.py b/tests/pipelines/llada2/test_llada2.py index c3511918fe67..6b00e133c7b1 100644 --- a/tests/pipelines/llada2/test_llada2.py +++ b/tests/pipelines/llada2/test_llada2.py @@ -178,7 +178,7 @@ def test_output_type_invalid_raises(self): def test_prepare_input_ids_from_tensor(self): pipe = _make_pipeline() ids = torch.tensor([[1, 2, 3]], dtype=torch.long) - result = pipe._prepare_input_ids( + result_ids, result_mask = pipe._prepare_input_ids( prompt=None, messages=None, input_ids=ids, @@ -186,12 +186,14 @@ def test_prepare_input_ids_from_tensor(self): add_generation_prompt=False, chat_template_kwargs=None, ) - self.assertTrue(torch.equal(result, ids)) + self.assertTrue(torch.equal(result_ids, ids)) + self.assertEqual(result_mask.shape, ids.shape) + self.assertTrue((result_mask == 1).all().item()) def test_prepare_input_ids_from_1d_tensor(self): pipe = _make_pipeline() ids = torch.tensor([1, 2, 3], dtype=torch.long) - result = pipe._prepare_input_ids( + result_ids, result_mask = pipe._prepare_input_ids( prompt=None, messages=None, input_ids=ids, @@ -199,7 +201,8 @@ def test_prepare_input_ids_from_1d_tensor(self): add_generation_prompt=False, chat_template_kwargs=None, ) - self.assertEqual(result.shape, (1, 3)) + self.assertEqual(result_ids.shape, (1, 3)) + self.assertEqual(result_mask.shape, (1, 3)) def test_prepare_input_ids_no_tokenizer_raises(self): pipe = _make_pipeline(tokenizer=None) @@ -241,5 +244,176 @@ def test_prepare_input_ids_neither_raises(self): ) +class LLaDA2RegressionTest(unittest.TestCase): + """Pin the regressions identified in https://github.com/huggingface/diffusers/issues/13598.""" + + def test_attention_mask_carried_through_for_pre_tokenized_input(self): + """Issue #1: explicit `attention_mask` must reach the model and zero out padded prompt + positions and the block-aligned tail past `prompt_length + gen_length`.""" + captured: list[torch.Tensor] = [] + + class _MaskCapturingModel(_DummyCausalLM): + def forward(self, input_ids, attention_mask=None, position_ids=None, **kwargs): + captured.append(attention_mask.detach().cpu().clone() if attention_mask is not None else None) + return super().forward(input_ids, attention_mask=attention_mask, position_ids=position_ids) + + model = _MaskCapturingModel(vocab_size=32) + scheduler = BlockRefinementScheduler() + pipe = LLaDA2Pipeline(model=model, scheduler=scheduler).to("cpu") + + input_ids = torch.tensor([[10, 11, 12, 0], [20, 0, 0, 0]], dtype=torch.long) + attention_mask = torch.tensor([[1, 1, 1, 0], [1, 0, 0, 0]], dtype=torch.long) + + pipe( + input_ids=input_ids, + attention_mask=attention_mask, + use_chat_template=False, + gen_length=4, + block_length=4, + num_inference_steps=2, + threshold=2.0, + mask_token_id=31, + eos_token_id=None, + eos_early_stop=False, + output_type="seq", + ) + + self.assertGreater(len(captured), 0) + first_mask = captured[0] + # Padded prompt positions stay zero in the runtime mask (Issue #1). + self.assertEqual(first_mask[0, 3].item(), 0) + self.assertEqual(first_mask[1, 1].item(), 0) + self.assertEqual(first_mask[1, 2].item(), 0) + self.assertEqual(first_mask[1, 3].item(), 0) + # Real prompt positions stay one. + self.assertEqual(first_mask[0, 0].item(), 1) + self.assertEqual(first_mask[1, 0].item(), 1) + + def test_block_length_routes_into_scheduler_transfer_schedule(self): + """Issue #2: the per-call `block_length` must drive the scheduler's `_transfer_schedule`.""" + commits: list[int] = [] + + def cb(pipe, step, timestep, kwargs): + commits.append(int(kwargs["transfer_index"].sum())) + return {} + + pipe = _make_pipeline().to("cpu") + pipe( + input_ids=torch.empty((1, 0), dtype=torch.long), + use_chat_template=False, + gen_length=8, + block_length=8, + num_inference_steps=8, + threshold=2.0, + mask_token_id=31, + eos_token_id=None, + eos_early_stop=False, + output_type="seq", + callback_on_step_end=cb, + callback_on_step_end_tensor_inputs=["transfer_index"], + ) + # With block_length=num_inference_steps=8 the schedule commits exactly one token per step. + self.assertEqual(commits[0], 1) + self.assertEqual(commits[1], 1) + self.assertEqual(commits[2], 1) + + def test_callback_tensor_inputs_advertised_keys_resolve(self): + """Issue #3: every advertised callback key must be a bound local at callback time.""" + observed: list[str] = [] + + def cb(pipe, step, timestep, kwargs): + observed.extend(sorted(kwargs.keys())) + return {} + + pipe = _make_pipeline().to("cpu") + keys = list(pipe._callback_tensor_inputs) + pipe( + input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long), + use_chat_template=False, + gen_length=8, + block_length=8, + num_inference_steps=4, + threshold=2.0, + mask_token_id=31, + eos_token_id=None, + eos_early_stop=False, + output_type="seq", + callback_on_step_end=cb, + callback_on_step_end_tensor_inputs=keys, + ) + self.assertEqual(set(observed), set(keys)) + + def test_eos_at_first_generated_position_triggers_finished(self): + """Issue #4: EOS exactly at index `prompt_length` must mark the row finished.""" + cur_x = torch.tensor([[10, 2, 99]]) + sampled_tokens = torch.tensor([[0, 2]]) + final_transfer = torch.tensor([[False, True]]) + finished = BlockRefinementScheduler.check_eos_finished( + cur_x=cur_x, + sampled_tokens=sampled_tokens, + final_transfer=final_transfer, + finished=torch.tensor([False]), + eos_token_id=2, + mask_token_id=99, + prompt_length=1, + ) + self.assertTrue(bool(finished[0].item())) + + def test_finished_rows_are_frozen_for_subsequent_blocks(self): + """Issue #5: once a row emits EOS, later blocks must not overwrite its committed tokens.""" + + class _EosThenJunkModel(_DummyCausalLM): + """Row 0 commits EOS in the first block, then later blocks would emit token 7. Row 1 keeps emitting token 6.""" + + def forward(self, input_ids, attention_mask=None, position_ids=None, **kwargs): + batch_size, seq_len = input_ids.shape + logits = torch.zeros((batch_size, seq_len, self.vocab_size), device=input_ids.device) + # First block (seq_len <= 3): row 0 emits 5 then EOS=2; row 1 emits 6. + if seq_len <= 3: + logits[0, :, 5] = 10 + logits[0, 2, 2] = 20 # strong EOS at last block position + logits[1, :, 6] = 10 + else: + logits[0, :, 7] = 10 # would overwrite row 0's prior tokens if not frozen + logits[1, :, 6] = 10 + return _DummyModelOutput(logits=logits) + + model = _EosThenJunkModel(vocab_size=32) + pipe = LLaDA2Pipeline(model=model, scheduler=BlockRefinementScheduler()).to("cpu") + out = pipe( + input_ids=torch.tensor([[10], [20]], dtype=torch.long), + use_chat_template=False, + gen_length=5, + block_length=3, + num_inference_steps=3, + threshold=2.0, + mask_token_id=31, + eos_token_id=2, + eos_early_stop=True, + output_type="seq", + ) + # Row 0's first generated tokens must not be overwritten by later-block sampling (token 7). + self.assertNotIn(7, out.sequences[0].tolist()[:2]) + + def test_progress_bar_disable_is_preserved_after_call(self): + """Issue #6: calling the pipeline must not mutate `_progress_bar_config`.""" + pipe = _make_pipeline().to("cpu") + pipe.set_progress_bar_config(disable=True) + before = dict(pipe._progress_bar_config) + pipe( + input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long), + use_chat_template=False, + gen_length=8, + block_length=8, + num_inference_steps=2, + threshold=2.0, + mask_token_id=31, + eos_token_id=None, + eos_early_stop=False, + output_type="seq", + ) + self.assertEqual(pipe._progress_bar_config, before) + + if __name__ == "__main__": unittest.main()