Unverified Commit f174fb29 authored by C43H66N12O12S2's avatar C43H66N12O12S2 Committed by GitHub

add xformers attention

parent 2995107f
import math import math
import torch import torch
from torch import einsum from torch import einsum
import xformers.ops
import functorch
xformers._is_functorch_available=True
from ldm.util import default from ldm.util import default
from einops import rearrange from einops import rearrange
...@@ -92,6 +94,41 @@ def split_cross_attention_forward(self, x, context=None, mask=None): ...@@ -92,6 +94,41 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
return self.to_out(r2) return self.to_out(r2)
def _maybe_init(self, x):
"""
Initialize the attention operator, if required We expect the head dimension to be exposed here, meaning that x
: B, Head, Length
"""
if self.attention_op is not None:
return
_, M, K = x.shape
try:
self.attention_op = xformers.ops.AttentionOpDispatch(
dtype=x.dtype,
device=x.device,
k=K,
attn_bias_type=type(None),
has_dropout=False,
kv_len=M,
q_len=M,
).op
except NotImplementedError as err:
raise NotImplementedError(f"Please install xformers with the flash attention / cutlass components.\n{err}")
def xformers_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
k_in = self.to_k(context)
v_in = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
self._maybe_init(q)
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
def cross_attention_attnblock_forward(self, x): def cross_attention_attnblock_forward(self, x):
h_ = x h_ = x
h_ = self.norm(h_) h_ = self.norm(h_)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment