Commit 5bb126bd authored by AUTOMATIC's avatar AUTOMATIC

add split attention layer optimization from https://github.com/basujindal/stable-diffusion/pull/117

parent 407fc1fe
...@@ -3,8 +3,43 @@ import sys ...@@ -3,8 +3,43 @@ import sys
import traceback import traceback
import torch import torch
import numpy as np import numpy as np
from torch import einsum
from modules.shared import opts, device from modules.shared import opts, device, cmd_opts
from ldm.util import default
from einops import rearrange
import ldm.modules.attention
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
def split_cross_attention_forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
for i in range(0, q.shape[0], 2):
end = i + 2
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
s1 *= self.scale
s2 = s1.softmax(dim=-1)
del s1
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
del s2
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return self.to_out(r2)
class StableDiffusionModelHijack: class StableDiffusionModelHijack:
...@@ -67,6 +102,9 @@ class StableDiffusionModelHijack: ...@@ -67,6 +102,9 @@ class StableDiffusionModelHijack:
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
if cmd_opts.opt_split_attention:
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack): def __init__(self, wrapped, hijack):
...@@ -205,4 +243,8 @@ class EmbeddingsWithFixes(torch.nn.Module): ...@@ -205,4 +243,8 @@ class EmbeddingsWithFixes(torch.nn.Module):
return inputs_embeds return inputs_embeds
model_hijack = StableDiffusionModelHijack() model_hijack = StableDiffusionModelHijack()
...@@ -29,6 +29,7 @@ parser.add_argument("--unload-gfpgan", action='store_true', help="unload GFPGAN ...@@ -29,6 +29,7 @@ parser.add_argument("--unload-gfpgan", action='store_true', help="unload GFPGAN
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)") parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
parser.add_argument("--esrgan-models-path", type=str, help="path to directory with ESRGAN models", default=os.path.join(script_path, 'ESRGAN')) parser.add_argument("--esrgan-models-path", type=str, help="path to directory with ESRGAN models", default=os.path.join(script_path, 'ESRGAN'))
parser.add_argument("--opt-split-attention", type=str, help="enable optimization that reduced vram usage by a lot for about 10% decrease in performance", default=os.path.join(script_path, 'ESRGAN'))
cmd_opts = parser.parse_args() cmd_opts = parser.parse_args()
cpu = torch.device("cpu") cpu = torch.device("cpu")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment