Skip to content

Commit 1069de1

Browse files
committed
Remove all GPU related references
1 parent 3074581 commit 1069de1

22 files changed

+49
-526
lines changed

.github/workflows/UploadDockerImages.yml

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,3 @@ jobs:
4040
- name: build maxdiffusion jax nightly image
4141
run: |
4242
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_nightly MODE=nightly PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_nightly
43-
44-
build-gpu-image:
45-
runs-on: ["self-hosted", "e2", "cpu"]
46-
steps:
47-
- uses: actions/checkout@v3
48-
- name: Cleanup old docker images
49-
run: docker system prune --all --force
50-
- name: build maxdiffusion jax stable stack gpu image
51-
run: |
52-
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_gpu MODE=stable PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_gpu DEVICE=gpu
53-
- name: build maxdiffusion jax nightly image
54-
run: |
55-
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_nightly_gpu MODE=nightly PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_nightly DEVICE=gpu

README.md

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@
3434

3535
# Overview
3636

37-
MaxDiffusion is a collection of reference implementations of various latent diffusion models written in pure Python/Jax that run on XLA devices including Cloud TPUs and GPUs. MaxDiffusion aims to be a launching off point for ambitious Diffusion projects both in research and production. We encourage you to start by experimenting with MaxDiffusion out of the box and then fork and modify MaxDiffusion to meet your needs.
37+
MaxDiffusion is a collection of reference implementations of various latent diffusion models written in pure Python/Jax that run on XLA devices including Cloud TPUs. MaxDiffusion aims to be a launching off point for ambitious Diffusion projects both in research and production. We encourage you to start by experimenting with MaxDiffusion out of the box and then fork and modify MaxDiffusion to meet your needs.
3838

39-
The goal of this project is to provide reference implementations for latent diffusion models that help developers get started with training, tuning, and serving solutions on XLA devices including Cloud TPUs and GPUs. We started with Stable Diffusion inference on TPUs, but welcome code contributions to grow.
39+
The goal of this project is to provide reference implementations for latent diffusion models that help developers get started with training, tuning, and serving solutions on XLA devices including Cloud TPUs. We started with Stable Diffusion inference on TPUs, but welcome code contributions to grow.
4040

4141
MaxDiffusion supports
4242
* Stable Diffusion 2 base (inference)
@@ -72,7 +72,6 @@ MaxDiffusion supports
7272
- [Wan](#wan-models)
7373
- [LTX-Video](#ltx-video)
7474
- [Flux](#flux)
75-
- [Fused Attention for GPU](#fused-attention-for-gpu)
7675
- [SDXL](#stable-diffusion-xl)
7776
- [SD 2 base](#stable-diffusion-2-base)
7877
- [SD 2.1](#stable-diffusion-21)
@@ -257,9 +256,6 @@ After installation completes, run the training script.
257256
- In Wan2.1, the ici_fsdp_parallelism axis is used for sequence parallelism, the ici_tensor_parallelism axis is used for head parallelism.
258257
- You can enable both, keeping in mind that Wan2.1 has 40 heads and 40 must be evenly divisible by ici_tensor_parallelism.
259258
- For Sequence parallelism, the code pads the sequence length to evenly divide the sequence. Try out different ici_fsdp_parallelism numbers, but we find 2 and 4 to be the best right now.
260-
- For use on GPU it is recommended to enable the cudnn_te_flash attention kernel for optimal performance.
261-
- Best performance is achieved with the use of batch parallelism, which can be enabled by using the ici_fsdp_batch_parallelism axis. Note that this parallelism strategy does not support fractional batch sizes.
262-
- ici_fsdp_batch_parallelism and ici_fsdp_parallelism can be combined to allow for fractional batch sizes. However, padding is not currently supported for the cudnn_te_flash attention kernel and it is therefore required that the sequence length is divisible by the number of devices in the ici_fsdp_parallelism axis.
263259

264260
You should eventually see a training run as:
265261

@@ -400,14 +396,6 @@ After installation completes, run the training script.
400396
python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_xl_run" output_dir="gs://your-bucket/" per_device_batch_size=1
401397
```
402398

403-
On GPUS with Fused Attention:
404-
405-
First install Transformer Engine by following the [instructions here](#fused-attention-for-gpu).
406-
407-
```bash
408-
NVTE_FUSED_ATTN=1 python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml hardware=gpu run_name='test-sdxl-train' output_dir=/tmp/ train_new_unet=true train_text_encoder=false cache_latents_text_encoder_outputs=true max_train_steps=200 weights_dtype=bfloat16 resolution=512 per_device_batch_size=1 attention="cudnn_flash_te" jit_initializers=False
409-
```
410-
411399
To generate images with a trained checkpoint, run:
412400

413401
```bash
@@ -560,23 +548,7 @@ To generate images, run the following command:
560548
```bash
561549
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1 ici_data_parallelism=1 ici_fsdp_parallelism=-1 offload_encoders=False
562550
```
563-
## Fused Attention for GPU:
564-
Fused Attention for GPU is supported via TransformerEngine. Installation instructions:
565551

566-
```bash
567-
cd maxdiffusion
568-
pip install -U "jax[cuda12]"
569-
pip install -r requirements.txt
570-
pip install --upgrade torch torchvision
571-
pip install "transformer_engine[jax]
572-
pip install .
573-
```
574-
575-
Now run the command:
576-
577-
```bash
578-
NVTE_FUSED_ATTN=1 HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 split_head_dim=True per_device_batch_size=1 attention="cudnn_flash_te" hardware=gpu
579-
```
580552
## Wan LoRA
581553

582554
Disclaimer: not all LoRA formats have been tested. Currently supports ComfyUI and AI Toolkit formats. If there is a specific LoRA that doesn't load, please let us know.

docker_build_dependency_image.sh

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -63,31 +63,22 @@ COMMIT_HASH=$(git rev-parse --short HEAD)
6363

6464
echo "Building MaxDiffusion with MODE=${MODE} at commit hash ${COMMIT_HASH} . . ."
6565

66-
if [[ ${DEVICE} == "gpu" ]]; then
67-
if [[ ${MODE} == "pinned" ]]; then
68-
export BASEIMAGE=ghcr.io/nvidia/jax:base-2024-10-17
69-
else
70-
export BASEIMAGE=ghcr.io/nvidia/jax:base
66+
if [[ ${MODE} == "stable_stack" || ${MODE} == "jax_ai_image" ]]; then
67+
if [[ ! -v BASEIMAGE ]]; then
68+
echo "Erroring out because BASEIMAGE is unset, please set it!"
69+
exit 1
7170
fi
72-
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg DEVICE=$DEVICE --build-arg BASEIMAGE=$BASEIMAGE -f ./maxdiffusion_gpu_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
73-
else
74-
if [[ ${MODE} == "stable_stack" || ${MODE} == "jax_ai_image" ]]; then
75-
if [[ ! -v BASEIMAGE ]]; then
76-
echo "Erroring out because BASEIMAGE is unset, please set it!"
77-
exit 1
78-
fi
79-
docker build --no-cache \
80-
--build-arg JAX_AI_IMAGE_BASEIMAGE=${BASEIMAGE} \
81-
--build-arg COMMIT_HASH=${COMMIT_HASH} \
82-
--network=host \
83-
-t ${LOCAL_IMAGE_NAME} \
84-
-f maxdiffusion_jax_ai_image_tpu.Dockerfile .
85-
else
86-
docker build --no-cache \
87-
--network=host \
88-
--build-arg MODE=${MODE} \
89-
--build-arg JAX_VERSION=${JAX_VERSION} \
90-
-t ${LOCAL_IMAGE_NAME} \
91-
-f maxdiffusion_dependencies.Dockerfile .
92-
fi
93-
fi
71+
docker build --no-cache \
72+
--build-arg JAX_AI_IMAGE_BASEIMAGE=${BASEIMAGE} \
73+
--build-arg COMMIT_HASH=${COMMIT_HASH} \
74+
--network=host \
75+
-t ${LOCAL_IMAGE_NAME} \
76+
-f maxdiffusion_jax_ai_image_tpu.Dockerfile .
77+
else
78+
docker build --no-cache \
79+
--network=host \
80+
--build-arg MODE=${MODE} \
81+
--build-arg JAX_VERSION=${JAX_VERSION} \
82+
-t ${LOCAL_IMAGE_NAME} \
83+
-f maxdiffusion_dependencies.Dockerfile .
84+
fi

docs/dgx_spark.md

Lines changed: 0 additions & 196 deletions
This file was deleted.

0 commit comments

Comments
 (0)