diff --git a/README.md b/README.md index e309c39b6..6392a720f 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,8 @@ For training, we use [PyTorch Lightning](https://lightning.ai/docs/pytorch/stabl ## Installation: +**NOTE:** This is tested under `python3.8` and `python3.10`. For other Python versions, you might encounter version conflicts. + #### 1. Clone the repo ```shell @@ -52,29 +54,18 @@ cd generative-models This is assuming you have navigated to the `generative-models` root after cloning it. -**NOTE:** This is tested under `python3.8` and `python3.10`. For other python versions, you might encounter version conflicts. - - -**PyTorch 1.13** - ```shell -# install required packages from pypi -python3 -m venv .pt1 -source .pt1/bin/activate -pip3 install wheel -pip3 install -r requirements_pt13.txt +python3 -m venv venv +source venv/bin/activate +pip install -U setuptools wheel ``` -**PyTorch 2.0** +Then, depending on your use case, choose a set of requirements to install. - -```shell -# install required packages from pypi -python3 -m venv .pt2 -source .pt2/bin/activate -pip3 install wheel -pip3 install -r requirements_pt2.txt -``` +* `pip install -r requirements-demo-streamlit.txt`: Demo inference dependencies, enough to run the Streamlit demo +* `pip install -r requirements-demo-minimal.txt`: Demo inference dependencies, enough to run the minimal txt2img script +* `pip install -r requirements_pt2.txt`: PyTorch 2, including training dependencies +* `pip install -r requirements_pt13.txt`: PyTorch 1.13, including training dependencies ## Packaging @@ -93,7 +84,17 @@ You will find the built package in `dist/`. You can install the wheel with `pip Note that the package does **not** currently specify dependencies; you will need to install the required packages, depending on your use case and PyTorch version, manually. -## Inference: +## Inference + +### Minimal txt2img demo + +There is a minimal SDXL 0.9 text-to-image demo available as `txt2img.py`: + +``` +python txt2img.py --prompt "Big fluffy cat in a cereal bowl" --steps 25 --seed 1050 +``` + +### Streamlit demo We provide a [streamlit](https://streamlit.io/) demo for text-to-image and image-to-image sampling in `scripts/demo/sampling.py`. The following models are currently supported: - [SD-XL 0.9-base](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9) diff --git a/requirements-demo-minimal.txt b/requirements-demo-minimal.txt new file mode 100644 index 000000000..a00ea4b00 --- /dev/null +++ b/requirements-demo-minimal.txt @@ -0,0 +1,8 @@ +einops +kornia~=0.6.12 +omegaconf +open-clip-torch +pytorch-lightning~=2.0.5 +safetensors~=0.3.1 +torchvision~=0.15.2 +transformers~=4.31.0 diff --git a/requirements-demo-streamlit.txt b/requirements-demo-streamlit.txt new file mode 100644 index 000000000..1f48e638d --- /dev/null +++ b/requirements-demo-streamlit.txt @@ -0,0 +1,4 @@ +-r requirements-demo-minimal.txt +-e git+https://github.com/openai/CLIP.git@main#egg=clip +invisible-watermark +streamlit diff --git a/sgm/models/diffusion.py b/sgm/models/diffusion.py index 8b7cc1740..c465fb987 100644 --- a/sgm/models/diffusion.py +++ b/sgm/models/diffusion.py @@ -1,5 +1,5 @@ import logging -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from typing import Any, Dict, List, Tuple, Union import pytorch_lightning as pl @@ -14,6 +14,7 @@ from ..util import ( default, disabled_train, + get_default_device_name, get_obj_from_str, instantiate_from_config, log_txt_as_img, @@ -117,16 +118,22 @@ def get_input(self, batch): # image tensors should be scaled to -1 ... 1 and in bchw format return batch[self.input_key] + def _first_stage_autocast_context(self): + device = get_default_device_name() + if device not in ("cpu", "cuda"): + return nullcontext() + return torch.autocast(device, enabled=not self.disable_first_stage_autocast) + @torch.no_grad() def decode_first_stage(self, z): z = 1.0 / self.scale_factor * z - with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): + with self._first_stage_autocast_context(): out = self.first_stage_model.decode(z) return out @torch.no_grad() def encode_first_stage(self, x): - with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): + with self._first_stage_autocast_context(): z = self.first_stage_model.encode(x) z = self.scale_factor * z return z diff --git a/sgm/modules/autoencoding/losses/__init__.py b/sgm/modules/autoencoding/losses/__init__.py index c4964f739..1374325d9 100644 --- a/sgm/modules/autoencoding/losses/__init__.py +++ b/sgm/modules/autoencoding/losses/__init__.py @@ -4,9 +4,6 @@ import torch import torch.nn as nn from einops import rearrange -from taming.modules.discriminator.model import NLayerDiscriminator, weights_init -from taming.modules.losses.lpips import LPIPS -from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss from ....util import default, instantiate_from_config @@ -30,6 +27,7 @@ def __init__( scale_tgt_to_input_size=False, perceptual_weight_on_inputs=0.0, ): + from taming.modules.losses.lpips import LPIPS # late import to avoid extra dependency super().__init__() self.scale_input_to_tgt_size = scale_input_to_tgt_size self.scale_tgt_to_input_size = scale_tgt_to_input_size @@ -105,6 +103,9 @@ def __init__( learn_logvar: bool = False, regularization_weights: Union[None, dict] = None, ): + from taming.modules.losses.lpips import LPIPS # late import to avoid extra dependency + from taming.modules.discriminator.model import NLayerDiscriminator, weights_init # late import to avoid extra dependency + from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss # late import to avoid extra dependency super().__init__() self.dims = dims if self.dims > 2: diff --git a/sgm/modules/diffusionmodules/loss.py b/sgm/modules/diffusionmodules/loss.py index 555abc1c3..68a3463fc 100644 --- a/sgm/modules/diffusionmodules/loss.py +++ b/sgm/modules/diffusionmodules/loss.py @@ -3,7 +3,6 @@ import torch import torch.nn as nn from omegaconf import ListConfig -from taming.modules.losses.lpips import LPIPS from ...util import append_dims, instantiate_from_config @@ -26,6 +25,8 @@ def __init__( self.offset_noise_level = offset_noise_level if type == "lpips": + from taming.modules.losses.lpips import LPIPS # late import to avoid extra dependency + self.lpips = LPIPS().eval() if not batch2model_keys: diff --git a/sgm/modules/diffusionmodules/openaimodel.py b/sgm/modules/diffusionmodules/openaimodel.py index 1c87442a9..f9b38bc31 100644 --- a/sgm/modules/diffusionmodules/openaimodel.py +++ b/sgm/modules/diffusionmodules/openaimodel.py @@ -20,7 +20,7 @@ timestep_embedding, zero_module, ) -from ...util import default, exists +from ...util import default, exists, get_default_device_name logger = logging.getLogger(__name__) @@ -1244,6 +1244,7 @@ def __init__(self, in_channels=3, model_channels=64): ] ) + device = get_default_device_name() model = UNetModel( use_checkpoint=True, image_size=64, @@ -1258,8 +1259,8 @@ def __init__(self, in_channels=3, model_channels=64): use_linear_in_transformer=True, transformer_depth=1, legacy=False, - ).cuda() - x = th.randn(11, 4, 64, 64).cuda() - t = th.randint(low=0, high=10, size=(11,), device="cuda") + ).to(device) + x = th.randn(11, 4, 64, 64).to(device) + t = th.randint(low=0, high=10, size=(11,), device=device) o = model(x, t) print("done.") diff --git a/sgm/modules/diffusionmodules/sampling.py b/sgm/modules/diffusionmodules/sampling.py index 6346829c8..93a251541 100644 --- a/sgm/modules/diffusionmodules/sampling.py +++ b/sgm/modules/diffusionmodules/sampling.py @@ -16,7 +16,7 @@ to_neg_log_sigma, to_sigma, ) -from ...util import append_dims, default, instantiate_from_config +from ...util import append_dims, default, instantiate_from_config, get_default_device_name DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"} @@ -28,8 +28,10 @@ def __init__( num_steps: Union[int, None] = None, guider_config: Union[Dict, ListConfig, OmegaConf, None] = None, verbose: bool = False, - device: str = "cuda", + device: Union[str, None] = None, ): + if device is None: + device = get_default_device_name() self.num_steps = num_steps self.discretization = instantiate_from_config(discretization_config) self.guider = instantiate_from_config( diff --git a/sgm/modules/diffusionmodules/sampling_utils.py b/sgm/modules/diffusionmodules/sampling_utils.py index 7cca6361c..3911c5c13 100644 --- a/sgm/modules/diffusionmodules/sampling_utils.py +++ b/sgm/modules/diffusionmodules/sampling_utils.py @@ -1,5 +1,4 @@ import torch -from scipy import integrate from ...util import append_dims @@ -10,6 +9,7 @@ def __call__(self, uncond, cond, scale): def linear_multistep_coeff(order, t, i, j, epsrel=1e-4): + from scipy import integrate # late import to avoid extra dependency if order - 1 > i: raise ValueError(f"Order {order} too high for step {i}") diff --git a/sgm/modules/encoders/modules.py b/sgm/modules/encoders/modules.py index afb4bb6e1..1e2721ac0 100644 --- a/sgm/modules/encoders/modules.py +++ b/sgm/modules/encoders/modules.py @@ -30,6 +30,7 @@ default, disabled_train, expand_dims_like, + get_default_device_name, instantiate_from_config, ) @@ -239,7 +240,9 @@ def forward(self, c): c = c[:, None, :] return c - def get_unconditional_conditioning(self, bs, device="cuda"): + def get_unconditional_conditioning(self, bs, device=None): + if device is None: + device = get_default_device_name() uc_class = ( self.n_classes - 1 ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) @@ -264,9 +267,10 @@ class FrozenT5Embedder(AbstractEmbModel): """Uses the T5 transformer encoder for text""" def __init__( - self, version="google/t5-v1_1-xxl", device="cuda", max_length=77, freeze=True + self, version="google/t5-v1_1-xxl", device=None, max_length=77, freeze=True ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl super().__init__() + device = device or get_default_device_name() self.tokenizer = T5Tokenizer.from_pretrained(version) self.transformer = T5EncoderModel.from_pretrained(version) self.device = device @@ -307,9 +311,10 @@ class FrozenByT5Embedder(AbstractEmbModel): """ def __init__( - self, version="google/byt5-base", device="cuda", max_length=77, freeze=True + self, version="google/byt5-base", device=None, max_length=77, freeze=True ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl super().__init__() + device = device or get_default_device_name() self.tokenizer = ByT5Tokenizer.from_pretrained(version) self.transformer = T5EncoderModel.from_pretrained(version) self.device = device @@ -351,7 +356,7 @@ class FrozenCLIPEmbedder(AbstractEmbModel): def __init__( self, version="openai/clip-vit-large-patch14", - device="cuda", + device=None, max_length=77, freeze=True, layer="last", @@ -359,6 +364,7 @@ def __init__( always_return_pooled=False, ): # clip-vit-base-patch32 super().__init__() + device = device or get_default_device_name() assert layer in self.LAYERS self.tokenizer = CLIPTokenizer.from_pretrained(version) self.transformer = CLIPTextModel.from_pretrained(version) @@ -419,7 +425,7 @@ def __init__( self, arch="ViT-H-14", version="laion2b_s32b_b79k", - device="cuda", + device=None, max_length=77, freeze=True, layer="last", @@ -427,6 +433,7 @@ def __init__( legacy=True, ): super().__init__() + device = device or get_default_device_name() assert layer in self.LAYERS model, _, _ = open_clip.create_model_and_transforms( arch, @@ -521,12 +528,13 @@ def __init__( self, arch="ViT-H-14", version="laion2b_s32b_b79k", - device="cuda", + device=None, max_length=77, freeze=True, layer="last", ): super().__init__() + device = device or get_default_device_name() assert layer in self.LAYERS model, _, _ = open_clip.create_model_and_transforms( arch, device=torch.device("cpu"), pretrained=version @@ -591,7 +599,7 @@ def __init__( self, arch="ViT-H-14", version="laion2b_s32b_b79k", - device="cuda", + device=None, max_length=77, freeze=True, antialias=True, @@ -602,6 +610,7 @@ def __init__( output_tokens=False, ): super().__init__() + device = device or get_default_device_name() model, _, _ = open_clip.create_model_and_transforms( arch, device=torch.device("cpu"), @@ -747,11 +756,12 @@ def __init__( self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", - device="cuda", + device=None, clip_max_length=77, t5_max_length=77, ): super().__init__() + device = device or get_default_device_name() self.clip_encoder = FrozenCLIPEmbedder( clip_version, device, max_length=clip_max_length ) diff --git a/sgm/util.py b/sgm/util.py index e23616dca..fb866996a 100644 --- a/sgm/util.py +++ b/sgm/util.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import functools import importlib import logging @@ -14,6 +16,10 @@ logger = logging.getLogger(__name__) +def get_default_device_name() -> str: + return os.environ.get("SGM_DEFAULT_DEVICE", "cuda" if torch.cuda.is_available() else "cpu") + + def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode does not change anymore.""" @@ -202,19 +208,25 @@ def append_dims(x, target_dims): return x[(...,) + (None,) * dims_to_append] -def load_model_from_config(config, ckpt, verbose=True, freeze=True): - logger.info(f"Loading model from {ckpt}") - if ckpt.endswith("ckpt"): - pl_sd = torch.load(ckpt, map_location="cpu") - if "global_step" in pl_sd: - logger.debug(f"Global Step: {pl_sd['global_step']}") - sd = pl_sd["state_dict"] - elif ckpt.endswith("safetensors"): - sd = load_safetensors(ckpt) - else: - raise NotImplementedError - +def load_model_from_config( + config, + ckpt: str | None, + verbose=True, + freeze=True, + device="cpu", +): model = instantiate_from_config(config.model) + if ckpt: + print(f"Loading model from {ckpt}") + if ckpt.endswith("ckpt"): + pl_sd = torch.load(ckpt, map_location=device) + if verbose and "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + elif ckpt.endswith("safetensors"): + sd = load_safetensors(ckpt, device=device) + else: + raise NotImplementedError missing, unexpected = model.load_state_dict(sd, strict=False) @@ -228,7 +240,6 @@ def load_model_from_config(config, ckpt, verbose=True, freeze=True): for param in model.parameters(): param.requires_grad = False - model.eval() return model diff --git a/txt2img.py b/txt2img.py new file mode 100644 index 000000000..b8bfb123e --- /dev/null +++ b/txt2img.py @@ -0,0 +1,160 @@ +""" +This is a very minimal txt2img example for SD-XL only. +""" +import argparse +import logging +import time +from pathlib import Path +from unittest.mock import patch + +import numpy as np +import torch +from PIL import Image +import einops +import omegaconf +import pytorch_lightning +from sgm.modules.diffusionmodules.sampling import EulerEDMSampler +from sgm.util import load_model_from_config, get_default_device_name + + +def run_txt2img( + *, + model, + prompt: str, + steps: int = 10, + width: int = 1024, + height: int = 1024, + cfg_scale=5.0, + num_samples=1, + seed: int, + device: str, +): + discretization_config = { + "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", + } + guider_config = { + "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG", + "params": { + "scale": cfg_scale, + "dyn_thresh_config": { + "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding", + }, + }, + } + sampler = EulerEDMSampler( + num_steps=steps, + discretization_config=discretization_config, + guider_config=guider_config, + s_churn=0, + s_tmin=0, + s_tmax=999, + s_noise=1.0, + verbose=False, + ) + C = 4 # SD-XL value + F = 8 # SD-XL value + + with torch.no_grad(), model.ema_scope(): + pytorch_lightning.seed_everything(seed) + batch = { + "txt": [prompt] * num_samples, + "crop_coords_top_left": torch.tensor([0, 0]) + .to(device) + .repeat(num_samples, 1), + "original_size_as_tuple": torch.tensor([1024, 1024]) + .to(device) + .repeat(num_samples, 1), # SD-XL values + "target_size_as_tuple": torch.tensor([width, width]) + .to(device) + .repeat(num_samples, 1), + } + c, uc = model.conditioner.get_unconditional_conditioning( + batch, force_uc_zero_embeddings=["txt"] + ) + for k in c: + if k != "crossattn": + c[k] = c[k][:num_samples].to(device) + uc[k] = uc[k][:num_samples].to(device) + + shape = (num_samples, C, height // F, width // F) + initial_latent = torch.randn(shape).to(device) + + def denoiser(input, sigma, c): + return model.denoiser(model.model, input, sigma, c) + + latent_samples = sampler(denoiser, initial_latent, cond=c, uc=uc) + decoded_samples = model.decode_first_stage(latent_samples) + samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) + return samples + + +@torch.no_grad() +def fast_load(*, config, ckpt, device): + config = omegaconf.OmegaConf.load(config) + # This patch is borrowed from AUTOMATIC1111's stable-diffusion-webui; + # we don't need to initialize the weights just for them to be overwritten + # by the checkpoint. + with ( + patch.object(torch.nn.init, "kaiming_uniform_"), + patch.object(torch.nn.init, "_no_grad_normal_"), + patch.object(torch.nn.init, "_no_grad_uniform_"), + ): + model = load_model_from_config( + config, ckpt=ckpt, device="cpu", freeze=True, verbose=False + ) + model.to(device) + model.eval() + return model + + +def main(): + logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(name)s: %(message)s") + # Quiesce some uninformative CLIP and attention logging. + logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) + logging.getLogger("sgm.modules.attention").setLevel(logging.ERROR) + + ap = argparse.ArgumentParser() + ap.add_argument("--device", default=get_default_device_name()) + ap.add_argument( + "--prompt", + default="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + ) + ap.add_argument("--seed", type=int, default=42) + ap.add_argument("--steps", type=int, default=20) + ap.add_argument("--width", type=int, default=1024) + ap.add_argument("--height", type=int, default=1024) + ap.add_argument("--cfg-scale", type=float, default=5.0) + ap.add_argument("--num-samples", type=int, default=1) + args = ap.parse_args() + model = fast_load( + config="configs/inference/sd_xl_base.yaml", + ckpt="checkpoints/sd_xl_base_0.9.safetensors", + device=args.device, + ) + + samples = run_txt2img( + model=model, + prompt=args.prompt, + steps=args.steps, + width=args.width, + height=args.height, + cfg_scale=args.cfg_scale, + num_samples=args.num_samples, + device=args.device, + seed=args.seed, + ) + + out_path = Path("output") + out_path.mkdir(exist_ok=True) + + prefix = int(time.time()) + + for i, sample in enumerate(samples, 1): + filename = out_path / f"{prefix}-{i:04}.png" + print(f"Saving {i}/{len(samples)}: {filename}") + sample = 255.0 * einops.rearrange(sample, "c h w -> h w c") + Image.fromarray(sample.cpu().numpy().astype(np.uint8)).save(filename) + + +if __name__ == "__main__": + main()