Commit 59146621 authored by AUTOMATIC's avatar AUTOMATIC

better support for xformers flash attention on older versions of torch

parent 3fa48207
...@@ -24,6 +24,18 @@ See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable ...@@ -24,6 +24,18 @@ See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable
""") """)
already_displayed = {}
def display_once(e: Exception, task):
if task in already_displayed:
return
display(e, task)
already_displayed[task] = 1
def run(code, task): def run(code, task):
try: try:
code() code()
......
...@@ -9,7 +9,7 @@ from torch import einsum ...@@ -9,7 +9,7 @@ from torch import einsum
from ldm.util import default from ldm.util import default
from einops import rearrange from einops import rearrange
from modules import shared from modules import shared, errors
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
from .sub_quadratic_attention import efficient_dot_product_attention from .sub_quadratic_attention import efficient_dot_product_attention
...@@ -279,6 +279,21 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_ ...@@ -279,6 +279,21 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
) )
def get_xformers_flash_attention_op(q, k, v):
if not shared.cmd_opts.xformers_flash_attention:
return None
try:
flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
fw, bw = flash_attention_op
if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
return flash_attention_op
except Exception as e:
errors.display_once(e, "enabling flash attention")
return None
def xformers_attention_forward(self, x, context=None, mask=None): def xformers_attention_forward(self, x, context=None, mask=None):
h = self.heads h = self.heads
q_in = self.to_q(x) q_in = self.to_q(x)
...@@ -291,18 +306,7 @@ def xformers_attention_forward(self, x, context=None, mask=None): ...@@ -291,18 +306,7 @@ def xformers_attention_forward(self, x, context=None, mask=None):
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in del q_in, k_in, v_in
if shared.cmd_opts.xformers_flash_attention: out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
fw, bw = op
if not fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
# print('xformers_attention_forward', q.shape, k.shape, v.shape)
# Flash Attention is not availabe for the input arguments.
# Fallback to default xFormers' backend.
op = None
else:
op = None
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=op)
out = rearrange(out, 'b n h d -> b n (h d)', h=h) out = rearrange(out, 'b n h d -> b n (h d)', h=h)
return self.to_out(out) return self.to_out(out)
...@@ -377,17 +381,7 @@ def xformers_attnblock_forward(self, x): ...@@ -377,17 +381,7 @@ def xformers_attnblock_forward(self, x):
q = q.contiguous() q = q.contiguous()
k = k.contiguous() k = k.contiguous()
v = v.contiguous() v = v.contiguous()
if shared.cmd_opts.xformers_flash_attention: out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
fw, bw = op
if not fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v)):
# print('xformers_attnblock_forward', q.shape, k.shape, v.shape)
# Flash Attention is not availabe for the input arguments.
# Fallback to default xFormers' backend.
op = None
else:
op = None
out = xformers.ops.memory_efficient_attention(q, k, v, op=op)
out = rearrange(out, 'b (h w) c -> b c h w', h=h) out = rearrange(out, 'b (h w) c -> b c h w', h=h)
out = self.proj_out(out) out = self.proj_out(out)
return x + out return x + out
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment