Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 4 additions & 34 deletions sgm/modules/attention.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import math
from inspect import isfunction
from typing import Any, Optional

import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from packaging import version
from torch import nn
from ..util import exists, default


if version.parse(torch.__version__) >= version.parse("2.0.0"):
SDP_IS_AVAILABLE = True
Expand Down Expand Up @@ -53,25 +54,11 @@
from .diffusionmodules.util import checkpoint


def exists(val):
return val is not None


def uniq(arr):
def uniq(arr): # TODO: this seems unused
return {el: True for el in arr}.keys()


def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d


def max_neg_value(t):
return -torch.finfo(t.dtype).max


def init_(tensor):
def init_(tensor): # TODO: this seems unused
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
Expand Down Expand Up @@ -251,23 +238,6 @@ def forward(

q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))

## old
"""
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
del q, k

if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)

# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)

out = einsum('b i j, b j d -> b i d', sim, v)
"""
## new
with sdp_kernel(**BACKEND_MAP[self.backend]):
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
out = F.scaled_dot_product_attention(
Expand Down