From 620067b32df2d604096929cd51f76def4ac45f4e Mon Sep 17 00:00:00 2001 From: michelle-yooh Date: Fri, 20 Feb 2026 23:00:28 +0000 Subject: [PATCH] Remove all GPU related references --- .github/workflows/UploadDockerImages.yml | 13 -- .github/workflows/build_and_upload_images.sh | 3 +- README.md | 32 +-- docker_build_dependency_image.sh | 50 ++--- docs/dgx_spark.md | 196 ------------------ docs/getting_started/first_run.md | 2 +- gpu_multi_process_run.sh | 156 -------------- maxdiffusion_gpu_dependencies.Dockerfile | 6 +- setup.sh | 59 ++---- src/maxdiffusion/configs/base14.yml | 2 +- src/maxdiffusion/configs/base21.yml | 2 +- src/maxdiffusion/configs/base_2_base.yml | 2 +- src/maxdiffusion/configs/base_flux_dev.yml | 2 +- .../configs/base_flux_dev_multi_res.yml | 2 +- .../configs/base_flux_schnell.yml | 2 +- src/maxdiffusion/configs/base_wan_14b.yml | 2 +- src/maxdiffusion/configs/base_wan_1_3b.yml | 2 +- src/maxdiffusion/configs/base_wan_27b.yml | 2 +- src/maxdiffusion/configs/base_wan_i2v_14b.yml | 2 +- src/maxdiffusion/configs/base_wan_i2v_27b.yml | 2 +- src/maxdiffusion/configs/base_xl.yml | 2 +- .../configs/base_xl_lightning.yml | 2 +- src/maxdiffusion/max_utils.py | 22 -- src/maxdiffusion/models/attention_flax.py | 22 -- src/maxdiffusion/utils/import_utils.py | 2 - src/maxdiffusion/utils/testing_utils.py | 5 - 26 files changed, 52 insertions(+), 542 deletions(-) delete mode 100644 docs/dgx_spark.md delete mode 100644 gpu_multi_process_run.sh diff --git a/.github/workflows/UploadDockerImages.yml b/.github/workflows/UploadDockerImages.yml index 64c363b3..53e25bf6 100644 --- a/.github/workflows/UploadDockerImages.yml +++ b/.github/workflows/UploadDockerImages.yml @@ -40,16 +40,3 @@ jobs: - name: build maxdiffusion jax nightly image run: | bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_nightly MODE=nightly PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_nightly - - build-gpu-image: - runs-on: ["self-hosted", "e2", "cpu"] - steps: - - uses: actions/checkout@v3 - - name: Cleanup old docker images - run: docker system prune --all --force - - name: build maxdiffusion jax stable stack gpu image - run: | - bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_gpu MODE=stable PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_gpu DEVICE=gpu - - name: build maxdiffusion jax nightly image - run: | - bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_nightly_gpu MODE=nightly PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_nightly DEVICE=gpu diff --git a/.github/workflows/build_and_upload_images.sh b/.github/workflows/build_and_upload_images.sh index 9718bf84..724866a7 100644 --- a/.github/workflows/build_and_upload_images.sh +++ b/.github/workflows/build_and_upload_images.sh @@ -34,7 +34,6 @@ for ARGUMENT in "$@"; do echo "$KEY"="$VALUE" done -export DEVICE="${DEVICE:-tpu}" if [[ ! -v CLOUD_IMAGE_NAME ]] || [[ ! -v PROJECT ]] || [[ ! -v MODE ]] ; then echo "You must set CLOUD_IMAGE_NAME, PROJECT and MODE" @@ -42,7 +41,7 @@ if [[ ! -v CLOUD_IMAGE_NAME ]] || [[ ! -v PROJECT ]] || [[ ! -v MODE ]] ; then fi gcloud auth configure-docker us-docker.pkg.dev --quiet -bash docker_build_dependency_image.sh LOCAL_IMAGE_NAME=$LOCAL_IMAGE_NAME MODE=$MODE DEVICE=$DEVICE +bash docker_build_dependency_image.sh LOCAL_IMAGE_NAME=$LOCAL_IMAGE_NAME MODE=$MODE image_date=$(date +%Y-%m-%d) # Upload only dependencies image diff --git a/README.md b/README.md index 15667de4..346074d4 100644 --- a/README.md +++ b/README.md @@ -34,9 +34,9 @@ # Overview -MaxDiffusion is a collection of reference implementations of various latent diffusion models written in pure Python/Jax that run on XLA devices including Cloud TPUs and GPUs. MaxDiffusion aims to be a launching off point for ambitious Diffusion projects both in research and production. We encourage you to start by experimenting with MaxDiffusion out of the box and then fork and modify MaxDiffusion to meet your needs. +MaxDiffusion is a collection of reference implementations of various latent diffusion models written in pure Python/Jax that run on XLA devices including Cloud TPUs. MaxDiffusion aims to be a launching off point for ambitious Diffusion projects both in research and production. We encourage you to start by experimenting with MaxDiffusion out of the box and then fork and modify MaxDiffusion to meet your needs. -The goal of this project is to provide reference implementations for latent diffusion models that help developers get started with training, tuning, and serving solutions on XLA devices including Cloud TPUs and GPUs. We started with Stable Diffusion inference on TPUs, but welcome code contributions to grow. +The goal of this project is to provide reference implementations for latent diffusion models that help developers get started with training, tuning, and serving solutions on XLA devices including Cloud TPUs. We started with Stable Diffusion inference on TPUs, but welcome code contributions to grow. MaxDiffusion supports * Stable Diffusion 2 base (inference) @@ -72,7 +72,6 @@ MaxDiffusion supports - [Wan](#wan-models) - [LTX-Video](#ltx-video) - [Flux](#flux) - - [Fused Attention for GPU](#fused-attention-for-gpu) - [SDXL](#stable-diffusion-xl) - [SD 2 base](#stable-diffusion-2-base) - [SD 2.1](#stable-diffusion-21) @@ -257,9 +256,6 @@ After installation completes, run the training script. - In Wan2.1, the ici_fsdp_parallelism axis is used for sequence parallelism, the ici_tensor_parallelism axis is used for head parallelism. - You can enable both, keeping in mind that Wan2.1 has 40 heads and 40 must be evenly divisible by ici_tensor_parallelism. - For Sequence parallelism, the code pads the sequence length to evenly divide the sequence. Try out different ici_fsdp_parallelism numbers, but we find 2 and 4 to be the best right now. - - For use on GPU it is recommended to enable the cudnn_te_flash attention kernel for optimal performance. - - Best performance is achieved with the use of batch parallelism, which can be enabled by using the ici_fsdp_batch_parallelism axis. Note that this parallelism strategy does not support fractional batch sizes. - - ici_fsdp_batch_parallelism and ici_fsdp_parallelism can be combined to allow for fractional batch sizes. However, padding is not currently supported for the cudnn_te_flash attention kernel and it is therefore required that the sequence length is divisible by the number of devices in the ici_fsdp_parallelism axis. You should eventually see a training run as: @@ -400,14 +396,6 @@ After installation completes, run the training script. python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_xl_run" output_dir="gs://your-bucket/" per_device_batch_size=1 ``` - On GPUS with Fused Attention: - - First install Transformer Engine by following the [instructions here](#fused-attention-for-gpu). - - ```bash - NVTE_FUSED_ATTN=1 python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml hardware=gpu run_name='test-sdxl-train' output_dir=/tmp/ train_new_unet=true train_text_encoder=false cache_latents_text_encoder_outputs=true max_train_steps=200 weights_dtype=bfloat16 resolution=512 per_device_batch_size=1 attention="cudnn_flash_te" jit_initializers=False - ``` - To generate images with a trained checkpoint, run: ```bash @@ -560,23 +548,7 @@ To generate images, run the following command: ```bash python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1 ici_data_parallelism=1 ici_fsdp_parallelism=-1 offload_encoders=False ``` - ## Fused Attention for GPU: - Fused Attention for GPU is supported via TransformerEngine. Installation instructions: - ```bash - cd maxdiffusion - pip install -U "jax[cuda12]" - pip install -r requirements.txt - pip install --upgrade torch torchvision - pip install "transformer_engine[jax] - pip install . - ``` - - Now run the command: - - ```bash - NVTE_FUSED_ATTN=1 HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 split_head_dim=True per_device_batch_size=1 attention="cudnn_flash_te" hardware=gpu - ``` ## Wan LoRA Disclaimer: not all LoRA formats have been tested. Currently supports ComfyUI and AI Toolkit formats. If there is a specific LoRA that doesn't load, please let us know. diff --git a/docker_build_dependency_image.sh b/docker_build_dependency_image.sh index d8838ec2..b0ed50a5 100644 --- a/docker_build_dependency_image.sh +++ b/docker_build_dependency_image.sh @@ -48,11 +48,6 @@ if [[ -z ${MODE} ]]; then echo "Default MODE=${MODE}" fi -if [[ -z ${DEVICE} ]]; then - export DEVICE=tpu - echo "Default DEVICE=${DEVICE}" -fi -echo "DEVICE=${DEVICE}" if [[ -z ${JAX_VERSION+x} ]] ; then export JAX_VERSION=NONE @@ -63,31 +58,22 @@ COMMIT_HASH=$(git rev-parse --short HEAD) echo "Building MaxDiffusion with MODE=${MODE} at commit hash ${COMMIT_HASH} . . ." -if [[ ${DEVICE} == "gpu" ]]; then - if [[ ${MODE} == "pinned" ]]; then - export BASEIMAGE=ghcr.io/nvidia/jax:base-2024-10-17 - else - export BASEIMAGE=ghcr.io/nvidia/jax:base - fi - docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg DEVICE=$DEVICE --build-arg BASEIMAGE=$BASEIMAGE -f ./maxdiffusion_gpu_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} . -else - if [[ ${MODE} == "stable_stack" || ${MODE} == "jax_ai_image" ]]; then - if [[ ! -v BASEIMAGE ]]; then - echo "Erroring out because BASEIMAGE is unset, please set it!" - exit 1 - fi - docker build --no-cache \ - --build-arg JAX_AI_IMAGE_BASEIMAGE=${BASEIMAGE} \ - --build-arg COMMIT_HASH=${COMMIT_HASH} \ - --network=host \ - -t ${LOCAL_IMAGE_NAME} \ - -f maxdiffusion_jax_ai_image_tpu.Dockerfile . - else - docker build --no-cache \ - --network=host \ - --build-arg MODE=${MODE} \ - --build-arg JAX_VERSION=${JAX_VERSION} \ - -t ${LOCAL_IMAGE_NAME} \ - -f maxdiffusion_dependencies.Dockerfile . +if [[ ${MODE} == "stable_stack" || ${MODE} == "jax_ai_image" ]]; then + if [[ ! -v BASEIMAGE ]]; then + echo "Erroring out because BASEIMAGE is unset, please set it!" + exit 1 fi -fi \ No newline at end of file + docker build --no-cache \ + --build-arg JAX_AI_IMAGE_BASEIMAGE=${BASEIMAGE} \ + --build-arg COMMIT_HASH=${COMMIT_HASH} \ + --network=host \ + -t ${LOCAL_IMAGE_NAME} \ + -f maxdiffusion_jax_ai_image_tpu.Dockerfile . +else + docker build --no-cache \ + --network=host \ + --build-arg MODE=${MODE} \ + --build-arg JAX_VERSION=${JAX_VERSION} \ + -t ${LOCAL_IMAGE_NAME} \ + -f maxdiffusion_dependencies.Dockerfile . +fi diff --git a/docs/dgx_spark.md b/docs/dgx_spark.md deleted file mode 100644 index c0b3efa0..00000000 --- a/docs/dgx_spark.md +++ /dev/null @@ -1,196 +0,0 @@ -# MaxDiffusion on Nvidia DGX Spark GPU: A complete User Guide - -This guide provides a detailed step-by-step walkthrough for setting up and running the maxdiffusion library within a custom Docker environment on an ARM-based machine with NVIDIA GPU support. We will cover everything from building the optimized Docker image to generating your first image and retrieving it successfully. - -## Prerequisites - -Before you begin, ensure you have the following: - -- Access to [Nvidia DGX Spark Box](https://www.nvidia.com/en-us/products/workstations/dgx-spark/). -- The maxdiffusion source code cloned onto the machine. - - Branch: dgx_spark -- An internet connection for the initial Docker build and for downloading models (if not cached). - -## Part 1: Building the Optimized Docker Image - -The foundation of a smooth workflow is a well-built Docker image. The following Dockerfile is optimized for build speed by caching dependencies, ensuring that code changes don't require a full reinstall of all libraries. - -### Step1: Create the Dockerfile - -In the root directory of your maxdiffusion project, create a file named box.Dockerfile and paste the following content into it. - -```docker -# Nvidia Base image for ARM64 with CUDA support -# As JAX AI Image as it currently doesn't support ARM builds. -FROM nvcr.io/nvidia/cuda-dl-base@sha256:3631d968c12ef22b1dfe604de63dbc71a55f3ffcc23a085677a6d539d98884a4 - -# Set environment variables (these rarely change) -ENV PIP_BREAK_SYSTEM_PACKAGES=1 -ENV DEBIAN_FRONTEND=noninteractive - -# Install system-level dependencies (these change very infrequently) -RUN apt-get update && apt-get install -y python3 python3-pip -RUN update-alternatives --install /usr/bin/python python /usr/bin/python3 1 && \ - update-alternatives --install /usr/bin/pip pip /usr/bin/pip3 1 - -WORKDIR /app - -# --- Dependency Installation Layer --- -# First, copy only the requirements file to leverage caching -COPY requirements.txt . - -# Install dependencies from requirements.txt -RUN pip install -r requirements.txt - -# Install other major Python libraries in separate layers for better caching -RUN pip install "jax[cuda13-local]==0.7.2" - -# --- Application Code Layer --- -# Now, copy your application source code. This layer is rebuilt only when your code changes. -COPY . . - -# Install the maxdiffusion package from the copied source -RUN pip install . - -# Set a default command to keep the container running for interactive use -CMD ["/bin/bash"] -``` - -### Step2: Build the Image - -Open your terminal on DGX Spark, navigate to the root directory of the maxdiffusion project, and run the build command: - -```bash -docker build -f box.Dockerfile -t maxdiffusion-arm-gpu . -``` - -This command will execute the steps in your Dockerfile, download the necessary layers, install all dependencies, and create a local Docker image named `maxdiffusion-arm-gpu`. The first build may take some time. Subsequent builds will be much faster if you only change the source code. - -## Part 2: Running the Container for Image Generation - -To run the image generator effectively, we need to connect our local machine's folders to the container. This prevents re-downloading models and makes it easy to retrieve the output images. - -### Step 1: Create a Local Output Directory - -On your DGX Spark, create a directory to store the generated images. - -```bash -mkdir -p ~/maxdiffusion_output -``` - -### Step 2a: Launch the Container with Volume Mounts - -Run the following command to start an interactive session inside your container. This command links your Hugging Face cache (to avoid re-downloading models) and the output directory you just created. - -```bash -docker run -it --gpus all \ --v ~/.cache/huggingface:/root/.cache/huggingface \ --v ~/maxdiffusion_output:/tmp \ -maxdiffusion-arm-gpu -``` -Your terminal prompt will change, indicating you are now inside the running container. - -#### Step 2b: Log in to Hugging Face (First-Time Setup) - -You must do this once to download the required model weights. - -```bash -# [Inside the Docker Container] -huggingface-cli login -``` - -You will be prompted to paste a Hugging Face User Access Token. - -1. Go to[ huggingface.co/settings/tokens](https://huggingface.co/settings/tokens) in your web browser. - -2. Copy your token (or create a new one with write permissions). - -3. Paste the token into the terminal and press Enter. - - -## Part 3: Generating Your First Image - -Now that you are inside the container's interactive shell, you can execute the image generation script. Run the following command: - -```bash -NVTE_FRAMEWORK=JAX NVTE_FUSED_ATTN=1 HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 split_head_dim=True per_device_batch_size=1 attention="cudnn_flash_te" hardware=gpu -``` -The script will initialize, use the models from your mounted cache, and begin the generation process. - -## Part 4: Accessing Your Generated Image - -The generation script saves the final image to its working directory (/app) inside the container. Here is the complete workflow to get that image onto your Laptop. - -### Step 1: Copy the Image from Container to DGX Spark - -Open a new terminal window. Do not close the terminal where the container is running. -First, find your container's ID: - -```bash -docker ps -``` - -Look for the container with the image maxdiffusion-arm-gpu and note its ID (e.g., 9049895399fc). -Now, copy the image from the container to a temporary location on DGX Spark and fix its permissions. - -```bash -# Copy the file to the /tmp/ directory on DGX Spark -docker cp 9049895399fc:/app/flux_0.png /tmp/flux_0.png - -# Change the file's owner to your user to avoid permission errors -sudo chown username:username /tmp/flux_0.png -``` - -### Step 2: Copy the Image from DGX Spark to Your Laptop - -Now, open the Terminal app on your Laptop and use the scp (secure copy) command to download the file from DGX Spark. - -```bash -scp username@spark:/tmp/flux_0.png . -``` - -This command will download flux_0.png to the current directory on your Laptop. You can now view your generated image! - -## Troubleshooting and Common Pitfalls - -Here are solutions to common issues you might encounter: -- Error: `pip: command not found` during Docker build. - - **Cause**: The base Docker image doesn't have pip in the system's default PATH. - - **Solution**: The provided Dockerfile fixes this by explicitly installing python3-pip and using update-alternatives to create the necessary symbolic links. -- Error: `externally-managed-environment` during `pip install`. - - **Cause**: Newer versions of Debian/Ubuntu protect system Python packages from being modified by pip. - - **Solution**: The `ENV PIP_BREAK_SYSTEM_PACKAGES=1` line in the `Dockerfile` safely bypasses this protection within the container's isolated environment. -- Error: `OSError: ...is not a local folder and is not a valid model identifier` - - **Cause**: The script is trying to download models from the Hugging Face Hub because it cannot find them locally. - - **Solution**: This is solved by launching the container with the `-v ~/.cache/huggingface:/root/.cache/huggingface` flag, which gives the container access to your local model cache. -- Error: `open ... permission denied` when trying to access a copied file. - - **Cause**: Files copied from a Docker container with docker cp are owned by the root user by default. - - **Solution**: After copying the file to the DGX Spark, immediately run `sudo chown your_user:your_user /path/to/file` to take ownership before trying to access or transfer it. -- Can't find the generated image. - - **Cause**: The script may not be saving the image to the directory specified by the output_dir argument. - - **Solution**: Always check the script's source code to confirm the final save location. As we discovered, generate_flux.py saves to the current working directory (/app), not /tmp. Knowing this allows you to copy the file from the correct location. -- If a process requires more memory than the available RAM, your system will crash with an "Out-of-Memory" (OOM) error. - - `Swap memory is your safety net.` It's a designated space on your hard drive that the operating system uses as a "virtual" extension of your RAM. When RAM is full, the system moves less active data to the slower swap space, freeing up RAM for the immediate task. While it's slower than RAM, it's infinitely better than a system crash, ensuring your long-running training or generation jobs can complete successfully. For a machine with 119GB of RAM, adding 64GB of swap provides a robust buffer for memory-intensive operations. - - Step 1: Create a 64GB Swap File - - Run these commands on your DGX Spark to create, format, and enable a permanent 64GB swap file. - - ```bash - # Instantly allocate a 64GB file - sudo fallocate -l 64G /swapfile - # Set secure permissions (only root can access) - sudo chmod 600 /swapfile - # Format the file as swap space - sudo mkswap /swapfile - # Enable the swap file for the current session - sudo swapon /swapfile - # Add the swap file to the system's startup configuration to make it permanent - echo '/swapfile none swap sw 0 0' | sudo tee -a /etc/fstab - ``` - - - Step 2: Verify Swap is Active - - Check that the swap space is correctly configured. - - ```bash - free -h - # The output should now show a 64GB total for Swap. - ``` diff --git a/docs/getting_started/first_run.md b/docs/getting_started/first_run.md index 7ba0d106..7f389934 100644 --- a/docs/getting_started/first_run.md +++ b/docs/getting_started/first_run.md @@ -21,7 +21,7 @@ cd maxdiffusion 1. Within the root directory of the MaxDiffusion `git` repo, install dependencies by running: ``` # If a Python 3.12+ virtual environment doesn't already exist, you'll need to run the install command three times. -bash setup.sh MODE=stable DEVICE=tpu +bash setup.sh MODE=stable ``` 1. Active your virtual environment: diff --git a/gpu_multi_process_run.sh b/gpu_multi_process_run.sh deleted file mode 100644 index b689347f..00000000 --- a/gpu_multi_process_run.sh +++ /dev/null @@ -1,156 +0,0 @@ -#! /bin/bash -set -e -set -u -set -o pipefail - -: "${NNODES:?Must set NNODES}" -: "${NODE_RANK:?Must set NODE_RANK}" -: "${JAX_COORDINATOR_PORT:?Must set JAX_COORDINATOR_PORT}" -: "${JAX_COORDINATOR_ADDRESS:?Must set JAX_COORDINATOR_ADDRESS}" -: "${GPUS_PER_NODE:?Must set GPUS_PER_NODE}" -: "${COMMAND:?Must set COMMAND}" - - -export GPUS_PER_NODE=$GPUS_PER_NODE -export JAX_COORDINATOR_PORT=$JAX_COORDINATOR_PORT -export JAX_COORDINATOR_ADDRESS=$JAX_COORDINATOR_ADDRESS - -set_nccl_gpudirect_tcpx_specific_configuration() { - if [[ "$USE_GPUDIRECT" == "tcpx" ]] || [[ "$USE_GPUDIRECT" == "fastrak" ]]; then - export CUDA_DEVICE_MAX_CONNECTIONS=1 - export NCCL_CROSS_NIC=0 - export NCCL_DEBUG=INFO - export NCCL_DYNAMIC_CHUNK_SIZE=524288 - export NCCL_NET_GDR_LEVEL=PIX - export NCCL_NVLS_ENABLE=0 - export NCCL_P2P_NET_CHUNKSIZE=524288 - export NCCL_P2P_NVL_CHUNKSIZE=1048576 - export NCCL_P2P_PCI_CHUNKSIZE=524288 - export NCCL_PROTO=Simple - export NCCL_SOCKET_IFNAME=eth0 - export NVTE_FUSED_ATTN=1 - export TF_CPP_MAX_LOG_LEVEL=100 - export TF_CPP_VMODULE=profile_guided_latency_estimator=10 - export XLA_PYTHON_CLIENT_MEM_FRACTION=0.85 - shopt -s globstar nullglob - IFS=:$IFS - set -- /usr/local/cuda-*/compat - export LD_LIBRARY_PATH="${1+:"$*"}:${LD_LIBRARY_PATH}:/usr/local/tcpx/lib64" - IFS=${IFS#?} - shopt -u globstar nullglob - - if [[ "$USE_GPUDIRECT" == "tcpx" ]]; then - echo "Using GPUDirect-TCPX" - export NCCL_ALGO=Ring - export NCCL_DEBUG_SUBSYS=INIT,GRAPH,ENV,TUNING,NET,VERSION - export NCCL_GPUDIRECTTCPX_CTRL_DEV=eth0 - export NCCL_GPUDIRECTTCPX_FORCE_ACK=0 - export NCCL_GPUDIRECTTCPX_PROGRAM_FLOW_STEERING_WAIT_MICROS=1000000 - export NCCL_GPUDIRECTTCPX_RX_BINDINGS="eth1:22-35,124-139;eth2:22-35,124-139;eth3:74-87,178-191;eth4:74-87,178-191" - export NCCL_GPUDIRECTTCPX_SOCKET_IFNAME=eth1,eth2,eth3,eth4 - export NCCL_GPUDIRECTTCPX_TX_BINDINGS="eth1:8-21,112-125;eth2:8-21,112-125;eth3:60-73,164-177;eth4:60-73,164-177" - export NCCL_GPUDIRECTTCPX_TX_COMPLETION_NANOSLEEP=1000 - export NCCL_MAX_NCHANNELS=12 - export NCCL_MIN_NCHANNELS=12 - export NCCL_NSOCKS_PERTHREAD=4 - export NCCL_P2P_PXN_LEVEL=0 - export NCCL_SOCKET_NTHREADS=1 - elif [[ "$USE_GPUDIRECT" == "fastrak" ]]; then - echo "Using GPUDirect-TCPFasTrak" - export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 - export NCCL_ALGO=Ring,Tree - export NCCL_BUFFSIZE=8388608 - export NCCL_FASTRAK_CTRL_DEV=eth0 - export NCCL_FASTRAK_ENABLE_CONTROL_CHANNEL=0 - export NCCL_FASTRAK_ENABLE_HOTPATH_LOGGING=0 - export NCCL_FASTRAK_IFNAME=eth1,eth2,eth3,eth4,eth5,eth6,eth7,eth8 - export NCCL_FASTRAK_NUM_FLOWS=2 - export NCCL_FASTRAK_USE_LLCM=1 - export NCCL_FASTRAK_USE_SNAP=1 - export NCCL_MIN_NCHANNELS=4 - export NCCL_SHIMNET_GUEST_CONFIG_CHECKER_CONFIG_FILE=/usr/local/nvidia/lib64/a3plus_guest_config.textproto - export NCCL_TUNER_CONFIG_PATH=/usr/local/nvidia/lib64/a3plus_tuner_config.textproto - export NCCL_TUNER_PLUGIN=libnccl-tuner.so - fi - else - echo "NOT using GPUDirect" - fi -} - -echo "LD_LIBRARY_PATH ${LD_LIBRARY_PATH}" - -set_nccl_gpudirect_tcpx_specific_configuration - -wait_all_success_or_exit() { - # https://www.baeldung.com/linux/background-process-get-exit-code - local pids=("$@") - while [[ ${#pids[@]} -ne 0 ]]; do - all_success="true" - for pid in "${pids[@]}"; do - code=$(non_blocking_wait "$pid") - if [[ $code -ne 127 ]]; then - if [[ $code -ne 0 ]]; then - echo "PID $pid failed with exit code $code" - exit "$code" - fi - else - all_success="false" - fi - done - if [[ $all_success == "true" ]]; then - echo "All pids succeeded" - break - fi - sleep 5 - done -} -non_blocking_wait() { - # https://www.baeldung.com/linux/background-process-get-exit-code - local pid=$1 - local code=127 # special code to indicate not-finished - if [[ ! -d "/proc/$pid" ]]; then - wait "$pid" - code=$? - fi - echo $code -} - -resolve_coordinator_ip() { - local lookup_attempt=1 - local max_coordinator_lookups=500 - local coordinator_found=false - local coordinator_ip_address="" - - echo "Coordinator Address $JAX_COORDINATOR_ADDRESS" - - while [[ "$coordinator_found" = false && $lookup_attempt -le $max_coordinator_lookups ]]; do - coordinator_ip_address=$(nslookup "$JAX_COORDINATOR_ADDRESS" 2>/dev/null | awk '/^Address: / { print $2 }' | head -n 1) - if [[ -n "$coordinator_ip_address" ]]; then - coordinator_found=true - echo "Coordinator IP address: $coordinator_ip_address" - export JAX_COORDINATOR_IP=$coordinator_ip_address - return 0 - else - echo "Failed to recognize coordinator address $JAX_COORDINATOR_ADDRESS on attempt $lookup_attempt, retrying..." - ((lookup_attempt++)) - sleep 1 - fi - done - - if [[ "$coordinator_found" = false ]]; then - echo "Failed to resolve coordinator address after $max_coordinator_lookups attempts." - return 1 - fi -} - -# Resolving coordinator IP -set +e -resolve_coordinator_ip -set -e - -PIDS=() -eval ${COMMAND} & -PID=$! -PIDS+=($PID) - -wait_all_success_or_exit "${PIDS[@]}" \ No newline at end of file diff --git a/maxdiffusion_gpu_dependencies.Dockerfile b/maxdiffusion_gpu_dependencies.Dockerfile index 45f03354..39af67e8 100644 --- a/maxdiffusion_gpu_dependencies.Dockerfile +++ b/maxdiffusion_gpu_dependencies.Dockerfile @@ -30,8 +30,6 @@ ENV ENV_MODE=$MODE ARG JAX_VERSION ENV ENV_JAX_VERSION=$JAX_VERSION -ARG DEVICE -ENV ENV_DEVICE=$DEVICE RUN mkdir -p /deps @@ -42,7 +40,7 @@ WORKDIR /deps COPY . . RUN ls . -RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION DEVICE=${ENV_DEVICE}" -RUN --mount=type=cache,target=/root/.cache/pip bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} DEVICE=${ENV_DEVICE} +RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION +RUN --mount=type=cache,target=/root/.cache/pip bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} WORKDIR /deps \ No newline at end of file diff --git a/setup.sh b/setup.sh index ab81c4a7..5f04b28b 100644 --- a/setup.sh +++ b/setup.sh @@ -15,7 +15,7 @@ # limitations under the License. # Description: -# bash setup.sh MODE={stable,nightly} DEVICE={tpu,gpu} +# bash setup.sh MODE={stable,nightly} # You need to specify a MODE, default value stable. # For MODE=stable you may additionally specify JAX_VERSION, e.g. JAX_VERSION=0.4.33 @@ -87,11 +87,6 @@ for ARGUMENT in "$@"; do export "$KEY"="$VALUE" done -# Default device is TPU -if [[ -z "$DEVICE" ]]; then - export DEVICE="tpu" -fi - # Unset JAX_VERSION if set to "NONE" if [[ $JAX_VERSION == NONE ]]; then unset JAX_VERSION @@ -109,47 +104,23 @@ pip3 install -U -r requirements.txt || echo "Failed to install dependencies in t # Install JAX and JAXlib based on the specified mode if [[ "$MODE" == "stable" || ! -v MODE ]]; then # Stable mode - if [[ $DEVICE == "tpu" ]]; then - echo "Installing stable jax, jaxlib for tpu" - if [[ -n "$JAX_VERSION" ]]; then - echo "Installing stable jax, jaxlib, libtpu version ${JAX_VERSION}" - pip3 install "jax[tpu]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - else - echo "Installing stable jax, jaxlib, libtpu - for tpu" - pip3 install 'jax[tpu]>0.4' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - fi - elif [[ $DEVICE == "gpu" ]]; then - echo "Installing stable jax, jaxlib for NVIDIA gpu" - if [[ -n "$JAX_VERSION" ]]; then - echo "Installing stable jax, jaxlib ${JAX_VERSION}" - pip3 install -U "jax[cuda12]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - else - echo "Installing stable jax, jaxlib, libtpu for NVIDIA gpu" - pip3 install "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - fi - export NVTE_FRAMEWORK=jax - pip3 install transformer_engine[jax]==2.1.0 + echo "Installing stable jax, jaxlib for tpu" + if [[ -n "$JAX_VERSION" ]]; then + echo "Installing stable jax, jaxlib, libtpu version ${JAX_VERSION}" + pip3 install "jax[tpu]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + else + echo "Installing stable jax, jaxlib, libtpu for tpu" + pip3 install 'jax[tpu]>0.4' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html fi - elif [[ $MODE == "nightly" ]]; then # Nightly mode - if [[ $DEVICE == "gpu" ]]; then - echo "Installing jax-nightly, jaxlib-nightly" - # Install jax-nightly - pip install -U --pre jax jaxlib jax-cuda12-plugin[with_cuda] jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html - # Install Transformer Engine - export NVTE_FRAMEWORK=jax - pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@stable - elif [[ $DEVICE == "tpu" ]]; then - echo "Installing jax-nightly,jaxlib-nightly" - # Install jax-nightly - pip3 install --pre -U jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html - # Install jaxlib-nightly - pip3 install --pre -U jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html - # Install libtpu-nightly - pip3 install --pre -U libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - fi + echo "Installing jax-nightly,jaxlib-nightly" + # Install jax-nightly + pip3 install --pre -U jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html + # Install jaxlib-nightly + pip3 install --pre -U jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html + # Install libtpu-nightly + pip3 install --pre -U libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html echo "Installing nightly tensorboard plugin profile" pip3 install tbp-nightly --upgrade else diff --git a/src/maxdiffusion/configs/base14.yml b/src/maxdiffusion/configs/base14.yml index ca2579d9..ce730359 100644 --- a/src/maxdiffusion/configs/base14.yml +++ b/src/maxdiffusion/configs/base14.yml @@ -100,7 +100,7 @@ diffusion_scheduler_config: { } # Hardware -hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' +hardware: 'tpu' # Only supported hardware type is 'tpu' at the moment skip_jax_distributed_system: False base_output_directory: "" diff --git a/src/maxdiffusion/configs/base21.yml b/src/maxdiffusion/configs/base21.yml index 65e7d19e..09acbd95 100644 --- a/src/maxdiffusion/configs/base21.yml +++ b/src/maxdiffusion/configs/base21.yml @@ -100,7 +100,7 @@ diffusion_scheduler_config: { } # Hardware -hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' +hardware: 'tpu' # Only supported hardware type is 'tpu' at the moment skip_jax_distributed_system: False # Output directory diff --git a/src/maxdiffusion/configs/base_2_base.yml b/src/maxdiffusion/configs/base_2_base.yml index 16948296..95cd71b7 100644 --- a/src/maxdiffusion/configs/base_2_base.yml +++ b/src/maxdiffusion/configs/base_2_base.yml @@ -113,7 +113,7 @@ diffusion_scheduler_config: { } # Hardware -hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' +hardware: 'tpu' # Only supported hardware type is 'tpu' at the moment skip_jax_distributed_system: False # Output directory diff --git a/src/maxdiffusion/configs/base_flux_dev.yml b/src/maxdiffusion/configs/base_flux_dev.yml index 7a80ea58..2b381862 100644 --- a/src/maxdiffusion/configs/base_flux_dev.yml +++ b/src/maxdiffusion/configs/base_flux_dev.yml @@ -128,7 +128,7 @@ diffusion_scheduler_config: { base_output_directory: "" # Hardware -hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' +hardware: 'tpu' # Only supported hardware type is 'tpu' at the moment skip_jax_distributed_system: False # Parallelism diff --git a/src/maxdiffusion/configs/base_flux_dev_multi_res.yml b/src/maxdiffusion/configs/base_flux_dev_multi_res.yml index 1aba7431..cad64b1c 100644 --- a/src/maxdiffusion/configs/base_flux_dev_multi_res.yml +++ b/src/maxdiffusion/configs/base_flux_dev_multi_res.yml @@ -128,7 +128,7 @@ diffusion_scheduler_config: { base_output_directory: "" # Hardware -hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' +hardware: 'tpu' # Only supported hardware type is 'tpu' at the moment skip_jax_distributed_system: False # Parallelism diff --git a/src/maxdiffusion/configs/base_flux_schnell.yml b/src/maxdiffusion/configs/base_flux_schnell.yml index 9ae39971..4154beb6 100644 --- a/src/maxdiffusion/configs/base_flux_schnell.yml +++ b/src/maxdiffusion/configs/base_flux_schnell.yml @@ -136,7 +136,7 @@ diffusion_scheduler_config: { base_output_directory: "" # Hardware -hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' +hardware: 'tpu' # Only supported hardware type is 'tpu' at the moment skip_jax_distributed_system: False # Parallelism diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index bb9a8a3b..492de7db 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -147,7 +147,7 @@ diffusion_scheduler_config: { base_output_directory: "" # Hardware -hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' +hardware: 'tpu' # Only supported hardware type is 'tpu' at the moment skip_jax_distributed_system: False # Parallelism diff --git a/src/maxdiffusion/configs/base_wan_1_3b.yml b/src/maxdiffusion/configs/base_wan_1_3b.yml index ffd2864a..7c6f6c76 100644 --- a/src/maxdiffusion/configs/base_wan_1_3b.yml +++ b/src/maxdiffusion/configs/base_wan_1_3b.yml @@ -124,7 +124,7 @@ diffusion_scheduler_config: { base_output_directory: "" # Hardware -hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' +hardware: 'tpu' # Only supported hardware type is 'tpu' at the moment skip_jax_distributed_system: False # Parallelism diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index 022b18c9..3eca85ed 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -135,7 +135,7 @@ diffusion_scheduler_config: { base_output_directory: "" # Hardware -hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' +hardware: 'tpu' # Only supported hardware type is 'tpu' at the moment skip_jax_distributed_system: False # Parallelism diff --git a/src/maxdiffusion/configs/base_wan_i2v_14b.yml b/src/maxdiffusion/configs/base_wan_i2v_14b.yml index 2a5b0338..f4834f99 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_14b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_14b.yml @@ -130,7 +130,7 @@ diffusion_scheduler_config: { base_output_directory: "" # Hardware -hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' +hardware: 'tpu' # Only supported hardware type is 'tpu' at the moment skip_jax_distributed_system: False # Parallelism diff --git a/src/maxdiffusion/configs/base_wan_i2v_27b.yml b/src/maxdiffusion/configs/base_wan_i2v_27b.yml index 0bd6a27f..f1aebda9 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_27b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_27b.yml @@ -131,7 +131,7 @@ diffusion_scheduler_config: { base_output_directory: "" # Hardware -hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' +hardware: 'tpu' # Only supported hardware type is 'tpu' at the moment skip_jax_distributed_system: False # Parallelism diff --git a/src/maxdiffusion/configs/base_xl.yml b/src/maxdiffusion/configs/base_xl.yml index 3dbb1578..d9d89cf9 100644 --- a/src/maxdiffusion/configs/base_xl.yml +++ b/src/maxdiffusion/configs/base_xl.yml @@ -103,7 +103,7 @@ diffusion_scheduler_config: { base_output_directory: "" # Hardware -hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' +hardware: 'tpu' # Only supported hardware type is 'tpu' at the moment skip_jax_distributed_system: False # Parallelism mesh_axes: ['data', 'fsdp', 'context', 'tensor'] diff --git a/src/maxdiffusion/configs/base_xl_lightning.yml b/src/maxdiffusion/configs/base_xl_lightning.yml index e487559a..46d71ad0 100644 --- a/src/maxdiffusion/configs/base_xl_lightning.yml +++ b/src/maxdiffusion/configs/base_xl_lightning.yml @@ -79,7 +79,7 @@ diffusion_scheduler_config: { } # Hardware -hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' +hardware: 'tpu' # Only supported hardware type is 'tpu' at the moment skip_jax_distributed_system: False # Output directory # Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/" diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 04b3869f..e51fa3b6 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -638,31 +638,9 @@ def get_global_batch_size(per_device_batch_size): return per_device_batch_size * jax.device_count() -def is_gpu_backend(raw_keys): - """Determine whether Maxdiffusion is intended to run on a GPU backend.""" - return raw_keys["hardware"] == "gpu" - - -def initialize_jax_for_gpu(): - """Jax distribute initialize for GPUs.""" - if os.environ.get("JAX_COORDINATOR_IP") is not None: - coordinator_ip = str(os.getenv("JAX_COORDINATOR_IP")) - coordinator_port = str(os.getenv("JAX_COORDINATOR_PORT")) - jax.distributed.initialize( - coordinator_address=f"{coordinator_ip}:{coordinator_port}", - num_processes=int(os.getenv("NNODES")), - process_id=int(os.getenv("NODE_RANK")), - ) - max_logging.log(f"JAX global devices: {jax.devices()}") - - def maybe_initialize_jax_distributed_system(raw_keys): if raw_keys["skip_jax_distributed_system"]: max_logging.log("Skipping jax distributed system due to skip_jax_distributed_system=True flag.") return - if is_gpu_backend(raw_keys): - max_logging.log("Attempting to initialize the jax distributed system for GPU backend...") - initialize_jax_for_gpu() - max_logging.log("Jax distributed system initialized on GPU!") else: jax.distributed.initialize() diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 0e3e24e5..15f4c377 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -480,26 +480,6 @@ def _apply_attention_dot( return hidden_states -def _cudnn_flash_attention(query: Array, key: Array, value: Array, heads: int, mesh: Mesh, dpa_layer: Callable) -> Array: - """CUDNN Flash Attention with Transformer Engine. - 1. Stable API, supports GQA - 2. Supports head_dim till 128; head_dim=256 support will be added soon - """ - # These imports are only meant to work in a GPU build. - # copied from tpu_flash_attention - query = _reshape_data_for_cudnn_flash(query, heads) - key = _reshape_data_for_cudnn_flash(key, heads) - value = _reshape_data_for_cudnn_flash(value, heads) - - axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD, D_KV)) - query = jax.lax.with_sharding_constraint(query, axis_names) - key = jax.lax.with_sharding_constraint(key, axis_names) - value = jax.lax.with_sharding_constraint(value, axis_names) - - out = dpa_layer(query, key, value, mask=None) - return _reshape_data_from_cudnn_flash(out) - - def _apply_attention( query: Array, key: Array, @@ -569,8 +549,6 @@ def _apply_attention( attention_kernel, mask_padding_tokens=mask_padding_tokens, ) - elif attention_kernel == "cudnn_flash_te": - return _cudnn_flash_attention(query, key, value, heads, mesh, dpa_layer) else: raise ValueError(f"Unexpected attention kernel {attention_kernel=}.") diff --git a/src/maxdiffusion/utils/import_utils.py b/src/maxdiffusion/utils/import_utils.py index 05ef72ec..1efb128d 100644 --- a/src/maxdiffusion/utils/import_utils.py +++ b/src/maxdiffusion/utils/import_utils.py @@ -143,8 +143,6 @@ def _is_package_available(pkg_name: str): if _onnx_available: candidates = ( "onnxruntime", - "onnxruntime-gpu", - "ort_nightly_gpu", "onnxruntime-directml", "onnxruntime-openvino", "ort_nightly_directml", diff --git a/src/maxdiffusion/utils/testing_utils.py b/src/maxdiffusion/utils/testing_utils.py index 55be62ac..ce860ff0 100644 --- a/src/maxdiffusion/utils/testing_utils.py +++ b/src/maxdiffusion/utils/testing_utils.py @@ -212,11 +212,6 @@ def require_torch_2(test_case): return unittest.skipUnless(is_torch_available() and is_torch_version(">=", "2.0.0"), "test requires PyTorch 2")(test_case) -def require_torch_gpu(test_case): - """Decorator marking a test that requires CUDA and PyTorch.""" - return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")(test_case) - - def skip_mps(test_case): """Decorator marking a test to skip if torch_device is 'mps'""" return unittest.skipUnless(torch_device != "mps", "test requires non 'mps' device")(test_case)