Commit 9597b265 authored by AUTOMATIC's avatar AUTOMATIC

implementation for attention using [] and ()

parent a51bedfb
......@@ -188,3 +188,9 @@ and put it into `embeddings` dir and use Usada Pekora in prompt.
A tab with settings, allowing you to use UI to edit more than half of parameters that previously
were commandline. Settings are saved to config.js file. Settings that remain as commandline
options are ones that are required at startup.
### Attention
Using `()` in prompt decreases model's attention to enclosed words, and `[]` increases it. You can combine
multiple modifiers:
![](images/attention-3.jpg)
......@@ -433,15 +433,15 @@ if os.path.exists(cmd_opts.gfpgan_dir):
print(traceback.format_exc(), file=sys.stderr)
class TextInversionEmbeddings:
class StableDiffuionModelHijack:
ids_lookup = {}
word_embeddings = {}
word_embeddings_checksums = {}
fixes = []
fixes = None
used_custom_terms = []
dir_mtime = None
def load(self, dir, model):
def load_textual_inversion_embeddings(self, dir, model):
mt = os.path.getmtime(dir)
if self.dir_mtime is not None and mt <= self.dir_mtime:
return
......@@ -469,6 +469,7 @@ class TextInversionEmbeddings:
self.word_embeddings_checksums[name] = f'{const_hash(emb)&0xffff:04x}'
ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]
first_id = ids[0]
if first_id not in self.ids_lookup:
self.ids_lookup[first_id] = []
......@@ -497,6 +498,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
self.embeddings = embeddings
self.tokenizer = wrapped.tokenizer
self.max_length = wrapped.max_length
self.token_mults = {}
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens:
mult = 1.0
for c in text:
if c == '[':
mult /= 1.1
if c == ']':
mult *= 1.1
if c == '(':
mult *= 1.1
if c == ')':
mult /= 1.1
if mult != 1.0:
self.token_mults[ident] = mult
def forward(self, text):
self.embeddings.fixes = []
......@@ -508,14 +526,17 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
cache = {}
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
batch_multipliers = []
for tokens in batch_tokens:
tuple_tokens = tuple(tokens)
if tuple_tokens in cache:
remade_tokens, fixes = cache[tuple_tokens]
remade_tokens, fixes, multipliers = cache[tuple_tokens]
else:
fixes = []
remade_tokens = []
multipliers = []
mult = 1.0
i = 0
while i < len(tokens):
......@@ -523,14 +544,19 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
possible_matches = self.embeddings.ids_lookup.get(token, None)
if possible_matches is None:
mult_change = self.token_mults.get(token)
if mult_change is not None:
mult *= mult_change
elif possible_matches is None:
remade_tokens.append(token)
multipliers.append(mult)
else:
found = False
for ids, word in possible_matches:
if tokens[i:i+len(ids)] == ids:
fixes.append((len(remade_tokens), word))
remade_tokens.append(777)
multipliers.append(mult)
i += len(ids) - 1
found = True
self.embeddings.used_custom_terms.append((word, self.embeddings.word_embeddings_checksums[word]))
......@@ -538,19 +564,32 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if not found:
remade_tokens.append(token)
multipliers.append(mult)
i += 1
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
cache[tuple_tokens] = (remade_tokens, fixes)
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
remade_batch_tokens.append(remade_tokens)
self.embeddings.fixes.append(fixes)
batch_multipliers.append(multipliers)
tokens = torch.asarray(remade_batch_tokens).to(self.wrapped.device)
outputs = self.wrapped.transformer(input_ids=tokens)
z = outputs.last_hidden_state
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers = torch.asarray(np.array(batch_multipliers)).to(device)
original_mean = z.mean()
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()
z *= original_mean / new_mean
return z
......@@ -562,22 +601,17 @@ class EmbeddingsWithFixes(nn.Module):
def forward(self, input_ids):
batch_fixes = self.embeddings.fixes
self.embeddings.fixes = []
self.embeddings.fixes = None
inputs_embeds = self.wrapped(input_ids)
if batch_fixes is not None:
for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, word in fixes:
tensor[offset] = self.embeddings.word_embeddings[word]
return inputs_embeds
def get_learned_conditioning_with_embeddings(model, prompts):
if os.path.exists(cmd_opts.embeddings_dir):
text_inversion_embeddings.load(cmd_opts.embeddings_dir, model)
return model.get_learned_conditioning(prompts)
return inputs_embeds
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False, extra_generation_params=None):
......@@ -648,7 +682,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index,
return f"{prompt}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments])
if os.path.exists(cmd_opts.embeddings_dir):
text_inversion_embeddings.load(cmd_opts.embeddings_dir, model)
model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, model)
output_images = []
with torch.no_grad(), autocast("cuda"), model.ema_scope():
......@@ -661,8 +695,8 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index,
uc = model.get_learned_conditioning(len(prompts) * [""])
c = model.get_learned_conditioning(prompts)
if len(text_inversion_embeddings.used_custom_terms) > 0:
comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in text_inversion_embeddings.used_custom_terms]))
if len(model_hijack.used_custom_terms) > 0:
comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in model_hijack.used_custom_terms]))
# we manually generate all input noises because each one should have a specific seed
x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds)
......@@ -1060,10 +1094,9 @@ model = load_model_from_config(config, cmd_opts.ckpt)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = (model if cmd_opts.no_half else model.half()).to(device)
text_inversion_embeddings = TextInversionEmbeddings()
if os.path.exists(cmd_opts.embeddings_dir):
text_inversion_embeddings.hijack(model)
model_hijack = StableDiffuionModelHijack()
model_hijack.hijack(model)
demo = gr.TabbedInterface(
interface_list=[x[0] for x in interfaces],
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment