Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/diffusers/pipelines/lumina2/pipeline_lumina2.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,8 @@ def __call__(

# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
image_seq_len = latents.shape[1]
patch_size = self.transformer.config.patch_size
image_seq_len = (latents.shape[-2] // patch_size) * (latents.shape[-1] // patch_size)
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
Expand Down
62 changes: 62 additions & 0 deletions tests/pipelines/lumina2/test_pipeline_lumina2.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,68 @@ def get_dummy_components(self):
}
return components

def test_image_seq_len_uses_spatial_dimensions(self):
"""Test that image_seq_len is computed from spatial dims, not channel dim.

Lumina2 latents have shape (batch, channels, height, width) and are NOT
packed before image_seq_len is computed. The transformer patchifies
internally with patch_size=2, so the correct sequence length is
(H // patch_size) * (W // patch_size).

Previously, the code used latents.shape[1] which gives the channel
count (e.g. 4) instead of the spatial sequence length (e.g. 64 for
16x16 latents with patch_size=2). This caused calculate_shift() to
compute a completely wrong mu value for the scheduler.
"""
components = self.get_dummy_components()
pipe = Lumina2Pipeline(**components)
pipe.to(torch.device("cpu"))

patch_size = pipe.transformer.config.patch_size # 2

# Use height=32, width=32 -> latent size 4x4 (vae downscale 8x)
# With patch_size=2: seq_len = (4//2)*(4//2) = 4
# Channel dim = 4, which would be wrong if used as seq_len
# Use a larger size to make the distinction clearer
height, width = 64, 64
latent_h, latent_w = height // 8, width // 8 # 8, 8
expected_seq_len = (latent_h // patch_size) * (latent_w // patch_size) # 16

# The channel dimension is 4 (from vae latent_channels)
# If the bug were present, image_seq_len would be 4 instead of 16
channels = components["vae"].config.latent_channels # 4
self.assertNotEqual(channels, expected_seq_len, "Test needs channels != expected_seq_len to be meaningful")

# Capture the mu value passed to the scheduler
captured = {}
original_set_timesteps = pipe.scheduler.set_timesteps

def capture_mu_set_timesteps(*args, **kwargs):
captured["mu"] = kwargs.get("mu")
return original_set_timesteps(*args, **kwargs)

pipe.scheduler.set_timesteps = capture_mu_set_timesteps

# Run pipeline with specific dimensions
generator = torch.Generator(device="cpu").manual_seed(0)
pipe(
prompt="test",
height=height,
width=width,
num_inference_steps=1,
generator=generator,
output_type="latent",
)

# Verify mu was computed using spatial seq_len, not channel dim
from diffusers.pipelines.lumina2.pipeline_lumina2 import calculate_shift

correct_mu = calculate_shift(expected_seq_len)
wrong_mu = calculate_shift(channels)

self.assertAlmostEqual(captured["mu"], correct_mu, places=5, msg="mu should use spatial sequence length")
self.assertNotAlmostEqual(captured["mu"], wrong_mu, places=5, msg="mu should NOT use channel dimension")

def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
Expand Down