Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ For training, we use [PyTorch Lightning](https://lightning.ai/docs/pytorch/stabl
## Installation:
<a name="installation"></a>

**NOTE:** This is tested under `python3.8` and `python3.10`. For other Python versions, you might encounter version conflicts.

#### 1. Clone the repo

```shell
Expand All @@ -52,29 +54,18 @@ cd generative-models

This is assuming you have navigated to the `generative-models` root after cloning it.

**NOTE:** This is tested under `python3.8` and `python3.10`. For other python versions, you might encounter version conflicts.


**PyTorch 1.13**

```shell
# install required packages from pypi
python3 -m venv .pt1
source .pt1/bin/activate
pip3 install wheel
pip3 install -r requirements_pt13.txt
python3 -m venv venv
source venv/bin/activate
pip install -U setuptools wheel
```

**PyTorch 2.0**
Then, depending on your use case, choose a set of requirements to install.


```shell
# install required packages from pypi
python3 -m venv .pt2
source .pt2/bin/activate
pip3 install wheel
pip3 install -r requirements_pt2.txt
```
* `pip install -r requirements-demo-streamlit.txt`: Demo inference dependencies, enough to run the Streamlit demo
* `pip install -r requirements-demo-minimal.txt`: Demo inference dependencies, enough to run the minimal txt2img script
* `pip install -r requirements_pt2.txt`: PyTorch 2, including training dependencies
* `pip install -r requirements_pt13.txt`: PyTorch 1.13, including training dependencies

## Packaging

Expand All @@ -93,7 +84,17 @@ You will find the built package in `dist/`. You can install the wheel with `pip
Note that the package does **not** currently specify dependencies; you will need to install the required packages,
depending on your use case and PyTorch version, manually.

## Inference:
## Inference

### Minimal txt2img demo

There is a minimal SDXL 0.9 text-to-image demo available as `txt2img.py`:

```
python txt2img.py --prompt "Big fluffy cat in a cereal bowl" --steps 25 --seed 1050
```

### Streamlit demo

We provide a [streamlit](https://streamlit.io/) demo for text-to-image and image-to-image sampling in `scripts/demo/sampling.py`. The following models are currently supported:
- [SD-XL 0.9-base](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9)
Expand Down
8 changes: 8 additions & 0 deletions requirements-demo-minimal.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
einops
kornia~=0.6.12
omegaconf
open-clip-torch
pytorch-lightning~=2.0.5
safetensors~=0.3.1
torchvision~=0.15.2
transformers~=4.31.0
4 changes: 4 additions & 0 deletions requirements-demo-streamlit.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-r requirements-demo-minimal.txt
-e git+https://github.com/openai/CLIP.git@main#egg=clip
invisible-watermark
streamlit
13 changes: 10 additions & 3 deletions sgm/models/diffusion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from typing import Any, Dict, List, Tuple, Union

import pytorch_lightning as pl
Expand All @@ -14,6 +14,7 @@
from ..util import (
default,
disabled_train,
get_default_device_name,
get_obj_from_str,
instantiate_from_config,
log_txt_as_img,
Expand Down Expand Up @@ -117,16 +118,22 @@ def get_input(self, batch):
# image tensors should be scaled to -1 ... 1 and in bchw format
return batch[self.input_key]

def _first_stage_autocast_context(self):
device = get_default_device_name()
if device not in ("cpu", "cuda"):
return nullcontext()
return torch.autocast(device, enabled=not self.disable_first_stage_autocast)

@torch.no_grad()
def decode_first_stage(self, z):
z = 1.0 / self.scale_factor * z
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
with self._first_stage_autocast_context():
out = self.first_stage_model.decode(z)
return out

@torch.no_grad()
def encode_first_stage(self, x):
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
with self._first_stage_autocast_context():
z = self.first_stage_model.encode(x)
z = self.scale_factor * z
return z
Expand Down
7 changes: 4 additions & 3 deletions sgm/modules/autoencoding/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
import torch
import torch.nn as nn
from einops import rearrange
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
from taming.modules.losses.lpips import LPIPS
from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss

from ....util import default, instantiate_from_config

Expand All @@ -30,6 +27,7 @@ def __init__(
scale_tgt_to_input_size=False,
perceptual_weight_on_inputs=0.0,
):
from taming.modules.losses.lpips import LPIPS # late import to avoid extra dependency
super().__init__()
self.scale_input_to_tgt_size = scale_input_to_tgt_size
self.scale_tgt_to_input_size = scale_tgt_to_input_size
Expand Down Expand Up @@ -105,6 +103,9 @@ def __init__(
learn_logvar: bool = False,
regularization_weights: Union[None, dict] = None,
):
from taming.modules.losses.lpips import LPIPS # late import to avoid extra dependency
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init # late import to avoid extra dependency
from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss # late import to avoid extra dependency
super().__init__()
self.dims = dims
if self.dims > 2:
Expand Down
3 changes: 2 additions & 1 deletion sgm/modules/diffusionmodules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch
import torch.nn as nn
from omegaconf import ListConfig
from taming.modules.losses.lpips import LPIPS

from ...util import append_dims, instantiate_from_config

Expand All @@ -26,6 +25,8 @@ def __init__(
self.offset_noise_level = offset_noise_level

if type == "lpips":
from taming.modules.losses.lpips import LPIPS # late import to avoid extra dependency

self.lpips = LPIPS().eval()

if not batch2model_keys:
Expand Down
9 changes: 5 additions & 4 deletions sgm/modules/diffusionmodules/openaimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
timestep_embedding,
zero_module,
)
from ...util import default, exists
from ...util import default, exists, get_default_device_name

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1244,6 +1244,7 @@ def __init__(self, in_channels=3, model_channels=64):
]
)

device = get_default_device_name()
model = UNetModel(
use_checkpoint=True,
image_size=64,
Expand All @@ -1258,8 +1259,8 @@ def __init__(self, in_channels=3, model_channels=64):
use_linear_in_transformer=True,
transformer_depth=1,
legacy=False,
).cuda()
x = th.randn(11, 4, 64, 64).cuda()
t = th.randint(low=0, high=10, size=(11,), device="cuda")
).to(device)
x = th.randn(11, 4, 64, 64).to(device)
t = th.randint(low=0, high=10, size=(11,), device=device)
o = model(x, t)
print("done.")
6 changes: 4 additions & 2 deletions sgm/modules/diffusionmodules/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
to_neg_log_sigma,
to_sigma,
)
from ...util import append_dims, default, instantiate_from_config
from ...util import append_dims, default, instantiate_from_config, get_default_device_name

DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}

Expand All @@ -28,8 +28,10 @@ def __init__(
num_steps: Union[int, None] = None,
guider_config: Union[Dict, ListConfig, OmegaConf, None] = None,
verbose: bool = False,
device: str = "cuda",
device: Union[str, None] = None,
):
if device is None:
device = get_default_device_name()
self.num_steps = num_steps
self.discretization = instantiate_from_config(discretization_config)
self.guider = instantiate_from_config(
Expand Down
2 changes: 1 addition & 1 deletion sgm/modules/diffusionmodules/sampling_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
from scipy import integrate

from ...util import append_dims

Expand All @@ -10,6 +9,7 @@ def __call__(self, uncond, cond, scale):


def linear_multistep_coeff(order, t, i, j, epsrel=1e-4):
from scipy import integrate # late import to avoid extra dependency
if order - 1 > i:
raise ValueError(f"Order {order} too high for step {i}")

Expand Down
26 changes: 18 additions & 8 deletions sgm/modules/encoders/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
default,
disabled_train,
expand_dims_like,
get_default_device_name,
instantiate_from_config,
)

Expand Down Expand Up @@ -239,7 +240,9 @@ def forward(self, c):
c = c[:, None, :]
return c

def get_unconditional_conditioning(self, bs, device="cuda"):
def get_unconditional_conditioning(self, bs, device=None):
if device is None:
device = get_default_device_name()
uc_class = (
self.n_classes - 1
) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
Expand All @@ -264,9 +267,10 @@ class FrozenT5Embedder(AbstractEmbModel):
"""Uses the T5 transformer encoder for text"""

def __init__(
self, version="google/t5-v1_1-xxl", device="cuda", max_length=77, freeze=True
self, version="google/t5-v1_1-xxl", device=None, max_length=77, freeze=True
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__()
device = device or get_default_device_name()
self.tokenizer = T5Tokenizer.from_pretrained(version)
self.transformer = T5EncoderModel.from_pretrained(version)
self.device = device
Expand Down Expand Up @@ -307,9 +311,10 @@ class FrozenByT5Embedder(AbstractEmbModel):
"""

def __init__(
self, version="google/byt5-base", device="cuda", max_length=77, freeze=True
self, version="google/byt5-base", device=None, max_length=77, freeze=True
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__()
device = device or get_default_device_name()
self.tokenizer = ByT5Tokenizer.from_pretrained(version)
self.transformer = T5EncoderModel.from_pretrained(version)
self.device = device
Expand Down Expand Up @@ -351,14 +356,15 @@ class FrozenCLIPEmbedder(AbstractEmbModel):
def __init__(
self,
version="openai/clip-vit-large-patch14",
device="cuda",
device=None,
max_length=77,
freeze=True,
layer="last",
layer_idx=None,
always_return_pooled=False,
): # clip-vit-base-patch32
super().__init__()
device = device or get_default_device_name()
assert layer in self.LAYERS
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
Expand Down Expand Up @@ -419,14 +425,15 @@ def __init__(
self,
arch="ViT-H-14",
version="laion2b_s32b_b79k",
device="cuda",
device=None,
max_length=77,
freeze=True,
layer="last",
always_return_pooled=False,
legacy=True,
):
super().__init__()
device = device or get_default_device_name()
assert layer in self.LAYERS
model, _, _ = open_clip.create_model_and_transforms(
arch,
Expand Down Expand Up @@ -521,12 +528,13 @@ def __init__(
self,
arch="ViT-H-14",
version="laion2b_s32b_b79k",
device="cuda",
device=None,
max_length=77,
freeze=True,
layer="last",
):
super().__init__()
device = device or get_default_device_name()
assert layer in self.LAYERS
model, _, _ = open_clip.create_model_and_transforms(
arch, device=torch.device("cpu"), pretrained=version
Expand Down Expand Up @@ -591,7 +599,7 @@ def __init__(
self,
arch="ViT-H-14",
version="laion2b_s32b_b79k",
device="cuda",
device=None,
max_length=77,
freeze=True,
antialias=True,
Expand All @@ -602,6 +610,7 @@ def __init__(
output_tokens=False,
):
super().__init__()
device = device or get_default_device_name()
model, _, _ = open_clip.create_model_and_transforms(
arch,
device=torch.device("cpu"),
Expand Down Expand Up @@ -747,11 +756,12 @@ def __init__(
self,
clip_version="openai/clip-vit-large-patch14",
t5_version="google/t5-v1_1-xl",
device="cuda",
device=None,
clip_max_length=77,
t5_max_length=77,
):
super().__init__()
device = device or get_default_device_name()
self.clip_encoder = FrozenCLIPEmbedder(
clip_version, device, max_length=clip_max_length
)
Expand Down
Loading