|
34 | 34 |
|
35 | 35 | # Overview |
36 | 36 |
|
37 | | -MaxDiffusion is a collection of reference implementations of various latent diffusion models written in pure Python/Jax that run on XLA devices including Cloud TPUs and GPUs. MaxDiffusion aims to be a launching off point for ambitious Diffusion projects both in research and production. We encourage you to start by experimenting with MaxDiffusion out of the box and then fork and modify MaxDiffusion to meet your needs. |
| 37 | +MaxDiffusion is a collection of reference implementations of various latent diffusion models written in pure Python/Jax that run on XLA devices including Cloud TPUs. MaxDiffusion aims to be a launching off point for ambitious Diffusion projects both in research and production. We encourage you to start by experimenting with MaxDiffusion out of the box and then fork and modify MaxDiffusion to meet your needs. |
38 | 38 |
|
39 | | -The goal of this project is to provide reference implementations for latent diffusion models that help developers get started with training, tuning, and serving solutions on XLA devices including Cloud TPUs and GPUs. We started with Stable Diffusion inference on TPUs, but welcome code contributions to grow. |
| 39 | +The goal of this project is to provide reference implementations for latent diffusion models that help developers get started with training, tuning, and serving solutions on XLA devices including Cloud TPUs. We started with Stable Diffusion inference on TPUs, but welcome code contributions to grow. |
40 | 40 |
|
41 | 41 | MaxDiffusion supports |
42 | 42 | * Stable Diffusion 2 base (inference) |
@@ -72,7 +72,6 @@ MaxDiffusion supports |
72 | 72 | - [Wan](#wan-models) |
73 | 73 | - [LTX-Video](#ltx-video) |
74 | 74 | - [Flux](#flux) |
75 | | - - [Fused Attention for GPU](#fused-attention-for-gpu) |
76 | 75 | - [SDXL](#stable-diffusion-xl) |
77 | 76 | - [SD 2 base](#stable-diffusion-2-base) |
78 | 77 | - [SD 2.1](#stable-diffusion-21) |
@@ -257,9 +256,6 @@ After installation completes, run the training script. |
257 | 256 | - In Wan2.1, the ici_fsdp_parallelism axis is used for sequence parallelism, the ici_tensor_parallelism axis is used for head parallelism. |
258 | 257 | - You can enable both, keeping in mind that Wan2.1 has 40 heads and 40 must be evenly divisible by ici_tensor_parallelism. |
259 | 258 | - For Sequence parallelism, the code pads the sequence length to evenly divide the sequence. Try out different ici_fsdp_parallelism numbers, but we find 2 and 4 to be the best right now. |
260 | | - - For use on GPU it is recommended to enable the cudnn_te_flash attention kernel for optimal performance. |
261 | | - - Best performance is achieved with the use of batch parallelism, which can be enabled by using the ici_fsdp_batch_parallelism axis. Note that this parallelism strategy does not support fractional batch sizes. |
262 | | - - ici_fsdp_batch_parallelism and ici_fsdp_parallelism can be combined to allow for fractional batch sizes. However, padding is not currently supported for the cudnn_te_flash attention kernel and it is therefore required that the sequence length is divisible by the number of devices in the ici_fsdp_parallelism axis. |
263 | 259 |
|
264 | 260 | You should eventually see a training run as: |
265 | 261 |
|
@@ -400,14 +396,6 @@ After installation completes, run the training script. |
400 | 396 | python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_xl_run" output_dir="gs://your-bucket/" per_device_batch_size=1 |
401 | 397 | ``` |
402 | 398 |
|
403 | | - On GPUS with Fused Attention: |
404 | | - |
405 | | - First install Transformer Engine by following the [instructions here](#fused-attention-for-gpu). |
406 | | - |
407 | | - ```bash |
408 | | - NVTE_FUSED_ATTN=1 python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml hardware=gpu run_name='test-sdxl-train' output_dir=/tmp/ train_new_unet=true train_text_encoder=false cache_latents_text_encoder_outputs=true max_train_steps=200 weights_dtype=bfloat16 resolution=512 per_device_batch_size=1 attention="cudnn_flash_te" jit_initializers=False |
409 | | - ``` |
410 | | - |
411 | 399 | To generate images with a trained checkpoint, run: |
412 | 400 |
|
413 | 401 | ```bash |
@@ -560,23 +548,7 @@ To generate images, run the following command: |
560 | 548 | ```bash |
561 | 549 | python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1 ici_data_parallelism=1 ici_fsdp_parallelism=-1 offload_encoders=False |
562 | 550 | ``` |
563 | | - ## Fused Attention for GPU: |
564 | | - Fused Attention for GPU is supported via TransformerEngine. Installation instructions: |
565 | 551 |
|
566 | | - ```bash |
567 | | - cd maxdiffusion |
568 | | - pip install -U "jax[cuda12]" |
569 | | - pip install -r requirements.txt |
570 | | - pip install --upgrade torch torchvision |
571 | | - pip install "transformer_engine[jax] |
572 | | - pip install . |
573 | | - ``` |
574 | | -
|
575 | | - Now run the command: |
576 | | -
|
577 | | - ```bash |
578 | | - NVTE_FUSED_ATTN=1 HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 split_head_dim=True per_device_batch_size=1 attention="cudnn_flash_te" hardware=gpu |
579 | | - ``` |
580 | 552 | ## Wan LoRA |
581 | 553 |
|
582 | 554 | Disclaimer: not all LoRA formats have been tested. Currently supports ComfyUI and AI Toolkit formats. If there is a specific LoRA that doesn't load, please let us know. |
|
0 commit comments