Skip to content

[Bug] train_dreambooth_lora_flux2_klein.py: batch size mismatch with --with_prior_preservation #13292

@vishk23

Description

@vishk23

When using --with_prior_preservation with train_dreambooth_lora_flux2_klein.py,
the prompt embedding repeat logic doubles the batch incorrectly.

The line:
num_repeat_elements = len(prompts)

should be:
num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts)

Because prompts already contains both instance and class samples from collate_fn,
but the repeat creates 4 embeddings for a batch of 2 latents.

Reproducible with: train_batch_size=1, with_prior_preservation=True

Diffusers version: 0.38.0.dev0 (main branch)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions