Commit 9597b265 authored by AUTOMATIC's avatar AUTOMATIC

implementation for attention using [] and ()

parent a51bedfb
...@@ -188,3 +188,9 @@ and put it into `embeddings` dir and use Usada Pekora in prompt. ...@@ -188,3 +188,9 @@ and put it into `embeddings` dir and use Usada Pekora in prompt.
A tab with settings, allowing you to use UI to edit more than half of parameters that previously A tab with settings, allowing you to use UI to edit more than half of parameters that previously
were commandline. Settings are saved to config.js file. Settings that remain as commandline were commandline. Settings are saved to config.js file. Settings that remain as commandline
options are ones that are required at startup. options are ones that are required at startup.
### Attention
Using `()` in prompt decreases model's attention to enclosed words, and `[]` increases it. You can combine
multiple modifiers:
![](images/attention-3.jpg)
...@@ -433,15 +433,15 @@ if os.path.exists(cmd_opts.gfpgan_dir): ...@@ -433,15 +433,15 @@ if os.path.exists(cmd_opts.gfpgan_dir):
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
class TextInversionEmbeddings: class StableDiffuionModelHijack:
ids_lookup = {} ids_lookup = {}
word_embeddings = {} word_embeddings = {}
word_embeddings_checksums = {} word_embeddings_checksums = {}
fixes = [] fixes = None
used_custom_terms = [] used_custom_terms = []
dir_mtime = None dir_mtime = None
def load(self, dir, model): def load_textual_inversion_embeddings(self, dir, model):
mt = os.path.getmtime(dir) mt = os.path.getmtime(dir)
if self.dir_mtime is not None and mt <= self.dir_mtime: if self.dir_mtime is not None and mt <= self.dir_mtime:
return return
...@@ -469,6 +469,7 @@ class TextInversionEmbeddings: ...@@ -469,6 +469,7 @@ class TextInversionEmbeddings:
self.word_embeddings_checksums[name] = f'{const_hash(emb)&0xffff:04x}' self.word_embeddings_checksums[name] = f'{const_hash(emb)&0xffff:04x}'
ids = tokenizer([name], add_special_tokens=False)['input_ids'][0] ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]
first_id = ids[0] first_id = ids[0]
if first_id not in self.ids_lookup: if first_id not in self.ids_lookup:
self.ids_lookup[first_id] = [] self.ids_lookup[first_id] = []
...@@ -497,6 +498,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -497,6 +498,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
self.embeddings = embeddings self.embeddings = embeddings
self.tokenizer = wrapped.tokenizer self.tokenizer = wrapped.tokenizer
self.max_length = wrapped.max_length self.max_length = wrapped.max_length
self.token_mults = {}
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens:
mult = 1.0
for c in text:
if c == '[':
mult /= 1.1
if c == ']':
mult *= 1.1
if c == '(':
mult *= 1.1
if c == ')':
mult /= 1.1
if mult != 1.0:
self.token_mults[ident] = mult
def forward(self, text): def forward(self, text):
self.embeddings.fixes = [] self.embeddings.fixes = []
...@@ -508,14 +526,17 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -508,14 +526,17 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
cache = {} cache = {}
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"] batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
batch_multipliers = []
for tokens in batch_tokens: for tokens in batch_tokens:
tuple_tokens = tuple(tokens) tuple_tokens = tuple(tokens)
if tuple_tokens in cache: if tuple_tokens in cache:
remade_tokens, fixes = cache[tuple_tokens] remade_tokens, fixes, multipliers = cache[tuple_tokens]
else: else:
fixes = [] fixes = []
remade_tokens = [] remade_tokens = []
multipliers = []
mult = 1.0
i = 0 i = 0
while i < len(tokens): while i < len(tokens):
...@@ -523,14 +544,19 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -523,14 +544,19 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
possible_matches = self.embeddings.ids_lookup.get(token, None) possible_matches = self.embeddings.ids_lookup.get(token, None)
if possible_matches is None: mult_change = self.token_mults.get(token)
if mult_change is not None:
mult *= mult_change
elif possible_matches is None:
remade_tokens.append(token) remade_tokens.append(token)
multipliers.append(mult)
else: else:
found = False found = False
for ids, word in possible_matches: for ids, word in possible_matches:
if tokens[i:i+len(ids)] == ids: if tokens[i:i+len(ids)] == ids:
fixes.append((len(remade_tokens), word)) fixes.append((len(remade_tokens), word))
remade_tokens.append(777) remade_tokens.append(777)
multipliers.append(mult)
i += len(ids) - 1 i += len(ids) - 1
found = True found = True
self.embeddings.used_custom_terms.append((word, self.embeddings.word_embeddings_checksums[word])) self.embeddings.used_custom_terms.append((word, self.embeddings.word_embeddings_checksums[word]))
...@@ -538,19 +564,32 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -538,19 +564,32 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if not found: if not found:
remade_tokens.append(token) remade_tokens.append(token)
multipliers.append(mult)
i += 1 i += 1
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end] remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
cache[tuple_tokens] = (remade_tokens, fixes) cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
remade_batch_tokens.append(remade_tokens) remade_batch_tokens.append(remade_tokens)
self.embeddings.fixes.append(fixes) self.embeddings.fixes.append(fixes)
batch_multipliers.append(multipliers)
tokens = torch.asarray(remade_batch_tokens).to(self.wrapped.device) tokens = torch.asarray(remade_batch_tokens).to(self.wrapped.device)
outputs = self.wrapped.transformer(input_ids=tokens) outputs = self.wrapped.transformer(input_ids=tokens)
z = outputs.last_hidden_state z = outputs.last_hidden_state
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers = torch.asarray(np.array(batch_multipliers)).to(device)
original_mean = z.mean()
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()
z *= original_mean / new_mean
return z return z
...@@ -562,22 +601,17 @@ class EmbeddingsWithFixes(nn.Module): ...@@ -562,22 +601,17 @@ class EmbeddingsWithFixes(nn.Module):
def forward(self, input_ids): def forward(self, input_ids):
batch_fixes = self.embeddings.fixes batch_fixes = self.embeddings.fixes
self.embeddings.fixes = [] self.embeddings.fixes = None
inputs_embeds = self.wrapped(input_ids) inputs_embeds = self.wrapped(input_ids)
for fixes, tensor in zip(batch_fixes, inputs_embeds): if batch_fixes is not None:
for offset, word in fixes: for fixes, tensor in zip(batch_fixes, inputs_embeds):
tensor[offset] = self.embeddings.word_embeddings[word] for offset, word in fixes:
tensor[offset] = self.embeddings.word_embeddings[word]
return inputs_embeds
def get_learned_conditioning_with_embeddings(model, prompts): return inputs_embeds
if os.path.exists(cmd_opts.embeddings_dir):
text_inversion_embeddings.load(cmd_opts.embeddings_dir, model)
return model.get_learned_conditioning(prompts)
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False, extra_generation_params=None): def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False, extra_generation_params=None):
...@@ -648,7 +682,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, ...@@ -648,7 +682,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index,
return f"{prompt}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments]) return f"{prompt}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments])
if os.path.exists(cmd_opts.embeddings_dir): if os.path.exists(cmd_opts.embeddings_dir):
text_inversion_embeddings.load(cmd_opts.embeddings_dir, model) model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, model)
output_images = [] output_images = []
with torch.no_grad(), autocast("cuda"), model.ema_scope(): with torch.no_grad(), autocast("cuda"), model.ema_scope():
...@@ -661,8 +695,8 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, ...@@ -661,8 +695,8 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index,
uc = model.get_learned_conditioning(len(prompts) * [""]) uc = model.get_learned_conditioning(len(prompts) * [""])
c = model.get_learned_conditioning(prompts) c = model.get_learned_conditioning(prompts)
if len(text_inversion_embeddings.used_custom_terms) > 0: if len(model_hijack.used_custom_terms) > 0:
comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in text_inversion_embeddings.used_custom_terms])) comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in model_hijack.used_custom_terms]))
# we manually generate all input noises because each one should have a specific seed # we manually generate all input noises because each one should have a specific seed
x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds) x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds)
...@@ -1060,10 +1094,9 @@ model = load_model_from_config(config, cmd_opts.ckpt) ...@@ -1060,10 +1094,9 @@ model = load_model_from_config(config, cmd_opts.ckpt)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = (model if cmd_opts.no_half else model.half()).to(device) model = (model if cmd_opts.no_half else model.half()).to(device)
text_inversion_embeddings = TextInversionEmbeddings()
if os.path.exists(cmd_opts.embeddings_dir): model_hijack = StableDiffuionModelHijack()
text_inversion_embeddings.hijack(model) model_hijack.hijack(model)
demo = gr.TabbedInterface( demo = gr.TabbedInterface(
interface_list=[x[0] for x in interfaces], interface_list=[x[0] for x in interfaces],
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment