Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 0 additions & 13 deletions .github/workflows/UploadDockerImages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,3 @@ jobs:
- name: build maxdiffusion jax nightly image
run: |
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_nightly MODE=nightly PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_nightly

build-gpu-image:
runs-on: ["self-hosted", "e2", "cpu"]
steps:
- uses: actions/checkout@v3
- name: Cleanup old docker images
run: docker system prune --all --force
- name: build maxdiffusion jax stable stack gpu image
run: |
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_gpu MODE=stable PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_gpu DEVICE=gpu
- name: build maxdiffusion jax nightly image
run: |
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_nightly_gpu MODE=nightly PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_nightly DEVICE=gpu
3 changes: 1 addition & 2 deletions .github/workflows/build_and_upload_images.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,14 @@ for ARGUMENT in "$@"; do
echo "$KEY"="$VALUE"
done

export DEVICE="${DEVICE:-tpu}"

if [[ ! -v CLOUD_IMAGE_NAME ]] || [[ ! -v PROJECT ]] || [[ ! -v MODE ]] ; then
echo "You must set CLOUD_IMAGE_NAME, PROJECT and MODE"
exit 1
fi

gcloud auth configure-docker us-docker.pkg.dev --quiet
bash docker_build_dependency_image.sh LOCAL_IMAGE_NAME=$LOCAL_IMAGE_NAME MODE=$MODE DEVICE=$DEVICE
bash docker_build_dependency_image.sh LOCAL_IMAGE_NAME=$LOCAL_IMAGE_NAME MODE=$MODE
image_date=$(date +%Y-%m-%d)

# Upload only dependencies image
Expand Down
32 changes: 2 additions & 30 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@

# Overview

MaxDiffusion is a collection of reference implementations of various latent diffusion models written in pure Python/Jax that run on XLA devices including Cloud TPUs and GPUs. MaxDiffusion aims to be a launching off point for ambitious Diffusion projects both in research and production. We encourage you to start by experimenting with MaxDiffusion out of the box and then fork and modify MaxDiffusion to meet your needs.
MaxDiffusion is a collection of reference implementations of various latent diffusion models written in pure Python/Jax that run on XLA devices including Cloud TPUs. MaxDiffusion aims to be a launching off point for ambitious Diffusion projects both in research and production. We encourage you to start by experimenting with MaxDiffusion out of the box and then fork and modify MaxDiffusion to meet your needs.

The goal of this project is to provide reference implementations for latent diffusion models that help developers get started with training, tuning, and serving solutions on XLA devices including Cloud TPUs and GPUs. We started with Stable Diffusion inference on TPUs, but welcome code contributions to grow.
The goal of this project is to provide reference implementations for latent diffusion models that help developers get started with training, tuning, and serving solutions on XLA devices including Cloud TPUs. We started with Stable Diffusion inference on TPUs, but welcome code contributions to grow.

MaxDiffusion supports
* Stable Diffusion 2 base (inference)
Expand Down Expand Up @@ -72,7 +72,6 @@ MaxDiffusion supports
- [Wan](#wan-models)
- [LTX-Video](#ltx-video)
- [Flux](#flux)
- [Fused Attention for GPU](#fused-attention-for-gpu)
- [SDXL](#stable-diffusion-xl)
- [SD 2 base](#stable-diffusion-2-base)
- [SD 2.1](#stable-diffusion-21)
Expand Down Expand Up @@ -257,9 +256,6 @@ After installation completes, run the training script.
- In Wan2.1, the ici_fsdp_parallelism axis is used for sequence parallelism, the ici_tensor_parallelism axis is used for head parallelism.
- You can enable both, keeping in mind that Wan2.1 has 40 heads and 40 must be evenly divisible by ici_tensor_parallelism.
- For Sequence parallelism, the code pads the sequence length to evenly divide the sequence. Try out different ici_fsdp_parallelism numbers, but we find 2 and 4 to be the best right now.
- For use on GPU it is recommended to enable the cudnn_te_flash attention kernel for optimal performance.
- Best performance is achieved with the use of batch parallelism, which can be enabled by using the ici_fsdp_batch_parallelism axis. Note that this parallelism strategy does not support fractional batch sizes.
- ici_fsdp_batch_parallelism and ici_fsdp_parallelism can be combined to allow for fractional batch sizes. However, padding is not currently supported for the cudnn_te_flash attention kernel and it is therefore required that the sequence length is divisible by the number of devices in the ici_fsdp_parallelism axis.

You should eventually see a training run as:

Expand Down Expand Up @@ -400,14 +396,6 @@ After installation completes, run the training script.
python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_xl_run" output_dir="gs://your-bucket/" per_device_batch_size=1
```

On GPUS with Fused Attention:

First install Transformer Engine by following the [instructions here](#fused-attention-for-gpu).

```bash
NVTE_FUSED_ATTN=1 python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml hardware=gpu run_name='test-sdxl-train' output_dir=/tmp/ train_new_unet=true train_text_encoder=false cache_latents_text_encoder_outputs=true max_train_steps=200 weights_dtype=bfloat16 resolution=512 per_device_batch_size=1 attention="cudnn_flash_te" jit_initializers=False
```

To generate images with a trained checkpoint, run:

```bash
Expand Down Expand Up @@ -560,23 +548,7 @@ To generate images, run the following command:
```bash
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1 ici_data_parallelism=1 ici_fsdp_parallelism=-1 offload_encoders=False
```
## Fused Attention for GPU:
Fused Attention for GPU is supported via TransformerEngine. Installation instructions:

```bash
cd maxdiffusion
pip install -U "jax[cuda12]"
pip install -r requirements.txt
pip install --upgrade torch torchvision
pip install "transformer_engine[jax]
pip install .
```

Now run the command:

```bash
NVTE_FUSED_ATTN=1 HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 split_head_dim=True per_device_batch_size=1 attention="cudnn_flash_te" hardware=gpu
```
## Wan LoRA

Disclaimer: not all LoRA formats have been tested. Currently supports ComfyUI and AI Toolkit formats. If there is a specific LoRA that doesn't load, please let us know.
Expand Down
50 changes: 18 additions & 32 deletions docker_build_dependency_image.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,6 @@ if [[ -z ${MODE} ]]; then
echo "Default MODE=${MODE}"
fi

if [[ -z ${DEVICE} ]]; then
export DEVICE=tpu
echo "Default DEVICE=${DEVICE}"
fi
echo "DEVICE=${DEVICE}"

if [[ -z ${JAX_VERSION+x} ]] ; then
export JAX_VERSION=NONE
Expand All @@ -63,31 +58,22 @@ COMMIT_HASH=$(git rev-parse --short HEAD)

echo "Building MaxDiffusion with MODE=${MODE} at commit hash ${COMMIT_HASH} . . ."

if [[ ${DEVICE} == "gpu" ]]; then
if [[ ${MODE} == "pinned" ]]; then
export BASEIMAGE=ghcr.io/nvidia/jax:base-2024-10-17
else
export BASEIMAGE=ghcr.io/nvidia/jax:base
fi
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg DEVICE=$DEVICE --build-arg BASEIMAGE=$BASEIMAGE -f ./maxdiffusion_gpu_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
else
if [[ ${MODE} == "stable_stack" || ${MODE} == "jax_ai_image" ]]; then
if [[ ! -v BASEIMAGE ]]; then
echo "Erroring out because BASEIMAGE is unset, please set it!"
exit 1
fi
docker build --no-cache \
--build-arg JAX_AI_IMAGE_BASEIMAGE=${BASEIMAGE} \
--build-arg COMMIT_HASH=${COMMIT_HASH} \
--network=host \
-t ${LOCAL_IMAGE_NAME} \
-f maxdiffusion_jax_ai_image_tpu.Dockerfile .
else
docker build --no-cache \
--network=host \
--build-arg MODE=${MODE} \
--build-arg JAX_VERSION=${JAX_VERSION} \
-t ${LOCAL_IMAGE_NAME} \
-f maxdiffusion_dependencies.Dockerfile .
if [[ ${MODE} == "stable_stack" || ${MODE} == "jax_ai_image" ]]; then
if [[ ! -v BASEIMAGE ]]; then
echo "Erroring out because BASEIMAGE is unset, please set it!"
exit 1
fi
fi
docker build --no-cache \
--build-arg JAX_AI_IMAGE_BASEIMAGE=${BASEIMAGE} \
--build-arg COMMIT_HASH=${COMMIT_HASH} \
--network=host \
-t ${LOCAL_IMAGE_NAME} \
-f maxdiffusion_jax_ai_image_tpu.Dockerfile .
else
docker build --no-cache \
--network=host \
--build-arg MODE=${MODE} \
--build-arg JAX_VERSION=${JAX_VERSION} \
-t ${LOCAL_IMAGE_NAME} \
-f maxdiffusion_dependencies.Dockerfile .
fi
196 changes: 0 additions & 196 deletions docs/dgx_spark.md

This file was deleted.

2 changes: 1 addition & 1 deletion docs/getting_started/first_run.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ cd maxdiffusion
1. Within the root directory of the MaxDiffusion `git` repo, install dependencies by running:
```
# If a Python 3.12+ virtual environment doesn't already exist, you'll need to run the install command three times.
bash setup.sh MODE=stable DEVICE=tpu
bash setup.sh MODE=stable
```

1. Active your virtual environment:
Expand Down
Loading
Loading