Commit 94853395 authored by AUTOMATIC's avatar AUTOMATIC

replace duplicate code with a function

parent 5e2627a1
...@@ -64,21 +64,26 @@ def load_hypernetwork(filename): ...@@ -64,21 +64,26 @@ def load_hypernetwork(filename):
shared.loaded_hypernetwork = None shared.loaded_hypernetwork = None
def apply_hypernetwork(hypernetwork, context):
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
if hypernetwork_layers is None:
return context, context
context_k = hypernetwork_layers[0](context)
context_v = hypernetwork_layers[1](context)
return context_k, context_v
def attention_CrossAttention_forward(self, x, context=None, mask=None): def attention_CrossAttention_forward(self, x, context=None, mask=None):
h = self.heads h = self.heads
q = self.to_q(x) q = self.to_q(x)
context = default(context, x) context = default(context, x)
hypernetwork = shared.loaded_hypernetwork context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context)
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) k = self.to_k(context_k)
v = self.to_v(context_v)
if hypernetwork_layers is not None:
k = self.to_k(hypernetwork_layers[0](context))
v = self.to_v(hypernetwork_layers[1](context))
else:
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
......
...@@ -8,7 +8,8 @@ from torch import einsum ...@@ -8,7 +8,8 @@ from torch import einsum
from ldm.util import default from ldm.util import default
from einops import rearrange from einops import rearrange
from modules import shared from modules import shared, hypernetwork
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
try: try:
...@@ -26,16 +27,10 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): ...@@ -26,16 +27,10 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
q_in = self.to_q(x) q_in = self.to_q(x)
context = default(context, x) context = default(context, x)
hypernetwork = shared.loaded_hypernetwork context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
if hypernetwork_layers is not None: del context, context_k, context_v, x
k_in = self.to_k(hypernetwork_layers[0](context))
v_in = self.to_v(hypernetwork_layers[1](context))
else:
k_in = self.to_k(context)
v_in = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in del q_in, k_in, v_in
...@@ -59,22 +54,16 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): ...@@ -59,22 +54,16 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
return self.to_out(r2) return self.to_out(r2)
# taken from https://github.com/Doggettx/stable-diffusion # taken from https://github.com/Doggettx/stable-diffusion and modified
def split_cross_attention_forward(self, x, context=None, mask=None): def split_cross_attention_forward(self, x, context=None, mask=None):
h = self.heads h = self.heads
q_in = self.to_q(x) q_in = self.to_q(x)
context = default(context, x) context = default(context, x)
hypernetwork = shared.loaded_hypernetwork context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
if hypernetwork_layers is not None:
k_in = self.to_k(hypernetwork_layers[0](context))
v_in = self.to_v(hypernetwork_layers[1](context))
else:
k_in = self.to_k(context)
v_in = self.to_v(context)
k_in *= self.scale k_in *= self.scale
...@@ -130,14 +119,11 @@ def xformers_attention_forward(self, x, context=None, mask=None): ...@@ -130,14 +119,11 @@ def xformers_attention_forward(self, x, context=None, mask=None):
h = self.heads h = self.heads
q_in = self.to_q(x) q_in = self.to_q(x)
context = default(context, x) context = default(context, x)
hypernetwork = shared.loaded_hypernetwork
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
if hypernetwork_layers is not None: k_in = self.to_k(context_k)
k_in = self.to_k(hypernetwork_layers[0](context)) v_in = self.to_v(context_v)
v_in = self.to_v(hypernetwork_layers[1](context))
else:
k_in = self.to_k(context)
v_in = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in del q_in, k_in, v_in
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment