Commit d7374179 authored by d8ahazard's avatar d8ahazard

Merge remote-tracking branch 'upstream/master' into ModelLoader

parents 0dce0df1 498515e7
......@@ -8,6 +8,8 @@ __pycache__
/tmp
/model.ckpt
/models/**/*
/GFPGANv1.3.pth
/gfpgan/weights/*.pth
/ui-config.json
/outputs
/config.json
......
......@@ -15,6 +15,7 @@ titles = {
"\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed",
"\u{1f3a8}": "Add a random artist to the prompt.",
"\u2199\ufe0f": "Read generation parameters from prompt into user interface.",
"\uD83D\uDCC2": "Open images output directory",
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
......
......@@ -182,4 +182,23 @@ onUiUpdate(function(){
});
json_elem.parentElement.style.display="none"
if (!txt2img_textarea) {
txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea");
txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button"));
}
if (!img2img_textarea) {
img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
}
})
let txt2img_textarea, img2img_textarea = undefined;
let wait_time = 800
let token_timeout;
function update_token_counter(button_id) {
if (token_timeout)
clearTimeout(token_timeout);
token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
}
......@@ -15,11 +15,11 @@ torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
k_diffusion_package = os.environ.get('K_DIFFUSION_PACKAGE', "git+https://github.com/crowsonkb/k-diffusion.git@1a0703dfb7d24d8806267c3e7ccc4caf67fd1331")
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "a7ec1974d4ccb394c2dca275f42cd97490618924")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
......@@ -107,10 +107,7 @@ if not is_installed("torch") or not is_installed("torchvision"):
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch")
if not skip_torch_cuda_test:
run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDINE_ARGS variable to disable this check'")
if not is_installed("k_diffusion.sampling"):
run_pip(f"install {k_diffusion_package}", "k-diffusion")
run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
if not is_installed("gfpgan"):
run_pip(f"install {gfpgan_package}", "gfpgan")
......@@ -119,6 +116,7 @@ os.makedirs(dir_repos, exist_ok=True)
git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash)
git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
git_clone("https://github.com/crowsonkb/k-diffusion.git", repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash)
if os.path.isdir(repo_dir('latent-diffusion')):
......@@ -133,6 +131,9 @@ run_pip(f"install -r {requirements_file}", "requirements for Web UI")
sys.argv += args
if "--exit" in args:
print("Exiting because of --exit argument")
exit(0)
def start_webui():
print(f"Launching Web UI with arguments: {' '.join(sys.argv[1:])}")
......
......@@ -6,13 +6,14 @@ from PIL import Image
import torch
import tqdm
from modules import processing, shared, images, devices
from modules import processing, shared, images, devices, sd_models
from modules.shared import opts
import modules.gfpgan_model
from modules.ui import plaintext_to_html
import modules.codeformer_model
import piexif
import piexif.helper
import gradio as gr
cached_images = {}
......@@ -141,7 +142,7 @@ def run_pnginfo(image):
return '', geninfo, info
def run_modelmerger(modelname_0, modelname_1, interp_method, interp_amount):
def run_modelmerger(primary_model_name, secondary_model_name, interp_method, interp_amount, save_as_half, custom_name):
# Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation)
def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1)
......@@ -151,45 +152,52 @@ def run_modelmerger(modelname_0, modelname_1, interp_method, interp_amount):
alpha = alpha * alpha * (3 - (2 * alpha))
return theta0 + ((theta1 - theta0) * alpha)
if os.path.exists(modelname_0):
model0_filename = modelname_0
modelname_0 = os.path.splitext(os.path.basename(modelname_0))[0]
else:
model0_filename = 'models/' + modelname_0 + '.ckpt'
# Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
def inv_sigmoid(theta0, theta1, alpha):
import math
alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0)
return theta0 + ((theta1 - theta0) * alpha)
if os.path.exists(modelname_1):
model1_filename = modelname_1
modelname_1 = os.path.splitext(os.path.basename(modelname_1))[0]
else:
model1_filename = 'models/' + modelname_1 + '.ckpt'
primary_model_info = sd_models.checkpoints_list[primary_model_name]
secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
print(f"Loading {model0_filename}...")
model_0 = torch.load(model0_filename, map_location='cpu')
print(f"Loading {primary_model_info.filename}...")
primary_model = torch.load(primary_model_info.filename, map_location='cpu')
print(f"Loading {model1_filename}...")
model_1 = torch.load(model1_filename, map_location='cpu')
print(f"Loading {secondary_model_info.filename}...")
secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
theta_0 = model_0['state_dict']
theta_1 = model_1['state_dict']
theta_0 = primary_model['state_dict']
theta_1 = secondary_model['state_dict']
theta_funcs = {
"Weighted Sum": weighted_sum,
"Sigmoid": sigmoid,
"Inverse Sigmoid": inv_sigmoid,
}
theta_func = theta_funcs[interp_method]
print(f"Merging...")
for key in tqdm.tqdm(theta_0.keys()):
if 'model' in key and key in theta_1:
theta_0[key] = theta_func(theta_0[key], theta_1[key], interp_amount)
theta_0[key] = theta_func(theta_0[key], theta_1[key], (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint
if save_as_half:
theta_0[key] = theta_0[key].half()
for key in theta_1.keys():
if 'model' in key and key not in theta_0:
theta_0[key] = theta_1[key]
if save_as_half:
theta_0[key] = theta_0[key].half()
filename = primary_model_info.model_name + '_' + str(round(interp_amount, 2)) + '-' + secondary_model_info.model_name + '_' + str(round((float(1.0) - interp_amount), 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
filename = filename if custom_name == '' else (custom_name + '.ckpt')
output_modelname = os.path.join(shared.cmd_opts.ckpt_dir, filename)
output_modelname = 'models/' + modelname_0 + '-' + modelname_1 + '-merged.ckpt'
print(f"Saving to {output_modelname}...")
torch.save(model_0, output_modelname)
torch.save(primary_model, output_modelname)
sd_models.list_models()
print(f"Checkpoint saved.")
return "Checkpoint saved to " + output_modelname
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(3)]
......@@ -21,6 +21,7 @@ path_dirs = [
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer'),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP'),
(os.path.join(sd_path, '../latent-diffusion'), 'LDSR.py', 'LDSR'),
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion'),
]
paths = {}
......
......@@ -49,7 +49,7 @@ def apply_color_correction(correction, image):
class StableDiffusionProcessing:
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None):
self.sd_model = sd_model
self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids
......@@ -75,11 +75,11 @@ class StableDiffusionProcessing:
self.do_not_save_grid: bool = do_not_save_grid
self.extra_generation_params: dict = extra_generation_params or {}
self.overlay_images = overlay_images
self.eta = eta
self.paste_to = None
self.color_corrections = None
self.denoising_strength: float = 0
self.ddim_eta = opts.ddim_eta
self.ddim_discretize = opts.ddim_discretize
self.s_churn = opts.s_churn
self.s_tmin = opts.s_tmin
......@@ -100,7 +100,7 @@ class StableDiffusionProcessing:
class Processed:
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0):
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
self.images = images_list
self.prompt = p.prompt
self.negative_prompt = p.negative_prompt
......@@ -124,7 +124,7 @@ class Processed:
self.extra_generation_params = p.extra_generation_params
self.index_of_first_image = index_of_first_image
self.ddim_eta = p.ddim_eta
self.eta = p.eta
self.ddim_discretize = p.ddim_discretize
self.s_churn = p.s_churn
self.s_tmin = p.s_tmin
......@@ -139,6 +139,7 @@ class Processed:
self.all_prompts = all_prompts or [self.prompt]
self.all_seeds = all_seeds or [self.seed]
self.all_subseeds = all_subseeds or [self.subseed]
self.infotexts = infotexts or [info]
def js(self):
obj = {
......@@ -165,6 +166,7 @@ class Processed:
"denoising_strength": self.denoising_strength,
"extra_generation_params": self.extra_generation_params,
"index_of_first_image": self.index_of_first_image,
"infotexts": self.infotexts,
}
return json.dumps(obj)
......@@ -269,6 +271,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None),
"Eta": (None if p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
}
generation_params.update(p.extra_generation_params)
......@@ -277,7 +280,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments])
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
def process_images(p: StableDiffusionProcessing) -> Processed:
......@@ -322,6 +325,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if os.path.exists(cmd_opts.embeddings_dir):
model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model)
infotexts = []
output_images = []
precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope)
......@@ -404,6 +408,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if opts.samples_save and not p.do_not_save_samples:
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
infotexts.append(infotext(n, i))
output_images.append(image)
state.nextjob()
......@@ -416,6 +421,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
grid = images.image_grid(output_images, p.batch_size)
if opts.return_grid:
infotexts.insert(0, infotext())
output_images.insert(0, grid)
index_of_first_image = 1
......@@ -423,7 +429,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
devices.torch_gc()
return Processed(p, output_images, all_seeds[0], infotext(), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image)
return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
......
......@@ -126,5 +126,93 @@ def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step):
return res
re_attention = re.compile(r"""
\\\(|
\\\)|
\\\[|
\\]|
\\\\|
\\|
\(|
\[|
:([+-]?[.\d]+)\)|
\)|
]|
[^\\()\[\]:]+|
:
""", re.X)
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs: text and its assoicated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '('
\[ - literal character '['
\) - literal character ')'
\] - literal character ']'
\\ - literal character '\'
anything else - just text
Example:
'a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).'
produces:
[
['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]
]
"""
#get_learned_conditioning_prompt_schedules(["fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"], 100)
res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith('\\'):
res.append([text[1:], 1.0])
elif text == '(':
round_brackets.append(len(res))
elif text == '[':
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif text == ')' and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == ']' and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
res.append([text, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
return res
......@@ -55,7 +55,7 @@ def load_scripts(basedir):
if not os.path.exists(basedir):
return
for filename in os.listdir(basedir):
for filename in sorted(os.listdir(basedir)):
path = os.path.join(basedir, filename)
if not os.path.isfile(path):
......
......@@ -6,6 +6,7 @@ import torch
import numpy as np
from torch import einsum
from modules import prompt_parser
from modules.shared import opts, device, cmd_opts
from ldm.util import default
......@@ -180,6 +181,7 @@ class StableDiffusionModelHijack:
dir_mtime = None
layers = None
circular_enabled = False
clip = None
def load_textual_inversion_embeddings(self, dirname, model):
mt = os.path.getmtime(dirname)
......@@ -210,6 +212,7 @@ class StableDiffusionModelHijack:
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
# diffuser concepts
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
......@@ -235,7 +238,7 @@ class StableDiffusionModelHijack:
print(traceback.format_exc(), file=sys.stderr)
continue
print(f"Loaded a total of {len(self.word_embeddings)} text inversion embeddings.")
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
def hijack(self, m):
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
......@@ -243,6 +246,8 @@ class StableDiffusionModelHijack:
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
self.clip = m.cond_stage_model
if cmd_opts.opt_split_attention_v1:
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
......@@ -259,6 +264,14 @@ class StableDiffusionModelHijack:
self.layers = flatten(m)
def undo_hijack(self, m):
if type(m.cond_stage_model) == FrozenCLIPEmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
def apply_circular(self, enable):
if self.circular_enabled == enable:
return
......@@ -268,6 +281,11 @@ class StableDiffusionModelHijack:
for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
layer.padding_mode = 'circular' if enable else 'zeros'
def tokenize(self, text):
max_length = self.clip.max_length - 2
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
return remade_batch_tokens[0], token_count, max_length
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack):
......@@ -294,14 +312,101 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if mult != 1.0:
self.token_mults[ident] = mult
def forward(self, text):
self.hijack.fixes = []
self.hijack.comments = []
def tokenize_line(self, line, used_custom_terms, hijack_comments):
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
maxlen = self.wrapped.max_length
if opts.enable_emphasis:
parsed = prompt_parser.parse_prompt_attention(line)
else:
parsed = [[line, 1.0]]
tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"]
fixes = []
remade_tokens = []
multipliers = []
for tokens, (text, weight) in zip(tokenized, parsed):
i = 0
while i < len(tokens):
token = tokens[i]
possible_matches = self.hijack.ids_lookup.get(token, None)
if possible_matches is None:
remade_tokens.append(token)
multipliers.append(weight)
else:
found = False
for ids, word in possible_matches:
if tokens[i:i + len(ids)] == ids:
emb_len = int(self.hijack.word_embeddings[word].shape[0])
fixes.append((len(remade_tokens), word))
remade_tokens += [0] * emb_len
multipliers += [weight] * emb_len
i += len(ids) - 1
found = True
used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
break
if not found:
remade_tokens.append(token)
multipliers.append(weight)
i += 1
if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
ovf = remade_tokens[maxlen - 2:]
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
return remade_tokens, fixes, multipliers, token_count
def process_text(self, texts):
used_custom_terms = []
remade_batch_tokens = []
hijack_comments = []
hijack_fixes = []
token_count = 0
cache = {}
batch_multipliers = []
for line in texts:
if line in cache:
remade_tokens, fixes, multipliers = cache[line]
else:
remade_tokens, fixes, multipliers, token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
cache[line] = (remade_tokens, fixes, multipliers)
remade_batch_tokens.append(remade_tokens)
hijack_fixes.append(fixes)
batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def process_text_old(self, text):
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
maxlen = self.wrapped.max_length
used_custom_terms = []
remade_batch_tokens = []
overflowing_words = []
hijack_comments = []
hijack_fixes = []
token_count = 0
cache = {}
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
......@@ -353,9 +458,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
ovf = remade_tokens[maxlen - 2:]
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
self.hijack.comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
......@@ -364,11 +468,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
remade_batch_tokens.append(remade_tokens)
self.hijack.fixes.append(fixes)
hijack_fixes.append(fixes)
batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def forward(self, text):
if opts.use_old_emphasis_implementation:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
else:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
self.hijack.fixes = hijack_fixes
self.hijack.comments = hijack_comments
if len(used_custom_terms) > 0:
self.hijack.comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
tokens = torch.asarray(remade_batch_tokens).to(device)
outputs = self.wrapped.transformer(input_ids=tokens)
......
......@@ -15,8 +15,9 @@ model_dir = "Stable-diffusion"
model_path = os.path.join(models_path, model_dir)
model_name = "sd-v1-4.ckpt"
model_url = "https://drive.yerf.org/wl/?id=EBfTrmcCCUAGaQBXVIj5lJmEhjoP1tgl&mode=grid&download=1"
user_dir = None
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash'])
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
checkpoints_list = {}
try:
......@@ -47,23 +48,56 @@ def setup_model(dirname):
global model_path
global model_name
global model_url
global user_dir
global model_list
user_dir = dirname
if not os.path.exists(model_path):
os.makedirs(model_path)
checkpoints_list.clear()
model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=dirname, download_name=model_name, ext_filter=".ckpt")
list_models()
def checkpoint_tiles():
return sorted([x.title for x in checkpoints_list.values()])
def list_models():
global model_path
global model_url
global model_name
global user_dir
checkpoints_list.clear()
model_list = modelloader.load_models(model_path=model_path,model_url=model_url,command_path= user_dir,
ext_filter=[".ckpt"], download_name=model_name)
print(f"Model list: {model_list}")
model_dir = os.path.abspath(model_path)
def modeltitle(path, h):
abspath = os.path.abspath(path)
if abspath.startswith(model_dir):
name = abspath.replace(model_dir, '')
else:
name = os.path.basename(path)
if name.startswith("\\") or name.startswith("/"):
name = name[1:]
shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
return f'{name} [{h}]', shortname
cmd_ckpt = shared.cmd_opts.ckpt
if os.path.exists(cmd_ckpt):
h = model_hash(cmd_ckpt)
title = modeltitle(cmd_ckpt, h)
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h)
title, model_name = modeltitle(cmd_ckpt, h)
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, model_name)
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
for filename in model_list:
h = model_hash(filename)
title = modeltitle(filename, h)
checkpoints_list[title] = CheckpointInfo(filename, title, h)
checkpoints_list[title] = CheckpointInfo(filename, title, h, model_name)
def model_hash(filename):
......@@ -89,7 +123,7 @@ def select_checkpoint():
if len(checkpoints_list) == 0:
print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
print(f" - directory {os.path.abspath(shared.cmd_opts.stablediffusion_models_path)}", file=sys.stderr)
print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
exit(1)
......@@ -142,7 +176,7 @@ def load_model():
def reload_model_weights(sd_model, info=None):
from modules import lowvram, devices
from modules import lowvram, devices, sd_hijack
checkpoint_info = info or select_checkpoint()
if sd_model.sd_model_checkpint == checkpoint_info.filename:
......@@ -153,8 +187,12 @@ def reload_model_weights(sd_model, info=None):
else:
sd_model.to(devices.cpu)
sd_hijack.model_hijack.undo_hijack(sd_model)
load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash)
sd_hijack.model_hijack.hijack(sd_model)
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device)
......
This diff is collapsed.
......@@ -155,6 +155,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
"use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
}))
options_templates.update(options_section(('saving-paths', "Paths for saving"), {
......@@ -182,7 +183,6 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
"SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
"ldsr_pre_down": OptionInfo(1, "LDSR Pre-process downssample scale. 1 = no down-sampling, 4 = 1/4 scale.", gr.Slider, {"minimum": 1, "maximum": 4, "step": 1}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Radio, lambda: {"choices": [x.name for x in sd_upscalers]}),
}))
......@@ -190,7 +190,6 @@ options_templates.update(options_section(('face-restoration', "Face restoration"
"face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
"code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
"save_selected_only": OptionInfo(False, "When using 'Save' button, only save a single selected image"),
}))
options_templates.update(options_section(('system', "System"), {
......@@ -200,12 +199,13 @@ options_templates.update(options_section(('system', "System"), {
}))
options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Radio, lambda: {"choices": [x.title for x in modules.sd_models.checkpoints_list.values()]}),
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Radio, lambda: {"choices": modules.sd_models.checkpoint_tiles()}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
"enable_emphasis": OptionInfo(True, "Use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"enable_emphasis": OptionInfo(True, "Eemphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
......@@ -231,8 +231,9 @@ options_templates.update(options_section(('ui', "User interface"), {
}))
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
"ddim_eta": OptionInfo(0.0, "DDIM eta", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform','quad']}),
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
......
This diff is collapsed.
......@@ -4,9 +4,8 @@ fairscale==0.4.4
fonts
font-roboto
gfpgan
gradio
gradio==3.4b3
invisible-watermark
git+https://github.com/crowsonkb/k-diffusion.git
numpy
omegaconf
piexif
......@@ -16,5 +15,12 @@ realesrgan
scikit-image>=0.19
git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379
timm==0.4.12
transformers
transformers==4.19.2
torch
einops
jsonmerge
clean-fid
git+https://github.com/openai/CLIP@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
resize-right
torchdiffeq
kornia
......@@ -15,3 +15,10 @@ font-roboto
timm==0.6.7
fairscale==0.4.9
piexif==1.1.3
einops==0.4.1
jsonmerge==1.8.0
clean-fid==0.1.29
git+https://github.com/openai/CLIP@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
resize-right==0.0.2
torchdiffeq==0.2.3
kornia==0.6.7
......@@ -91,8 +91,8 @@ axis_options = [
AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label),
AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label),
AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label),
AxisOption("DDIM Eta", float, apply_field("ddim_eta"), format_value_add_label),
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label),# as it is now all AxisOptionImg2Img items must go after AxisOption ones
AxisOption("Eta", float, apply_field("eta"), format_value_add_label),
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones
]
......@@ -159,6 +159,9 @@ class Script(scripts.Script):
p.batch_size = 1
def process_axis(opt, vals):
if opt.label == 'Nothing':
return [0]
valslist = [x.strip() for x in vals.split(",")]
if opt.type == int:
......
.output-html p {margin: 0 0.5em;}
.row > *,
.row > .gr-form > * {
min-width: min(120px, 100%);
flex: 1 1 0%;
}
.performance {
font-size: 0.85em;
color: #444;
......@@ -43,13 +49,17 @@
margin-right: auto;
}
#random_seed, #random_subseed, #reuse_seed, #reuse_subseed{
#random_seed, #random_subseed, #reuse_seed, #reuse_subseed, #open_folder{
min-width: auto;
flex-grow: 0;
padding-left: 0.25em;
padding-right: 0.25em;
}
#hidden_element{
display: none;
}
#seed_row, #subseed_row{
gap: 0.5rem;
}
......@@ -389,3 +399,7 @@ input[type="range"]{
border-radius: 8px;
display: none;
}
.red {
color: red;
}
import os
import threading
from modules import devices
from modules.paths import script_path
import signal
import threading
import modules.paths
......@@ -44,6 +48,8 @@ def wrap_queued_call(func):
def wrap_gradio_gpu_call(func):
def f(*args, **kwargs):
devices.torch_gc()
shared.state.sampling_step = 0
shared.state.job_count = -1
shared.state.job_no = 0
......@@ -59,6 +65,8 @@ def wrap_gradio_gpu_call(func):
shared.state.job = ""
shared.state.job_count = 0
devices.torch_gc()
return res
return modules.ui.wrap_gradio_call(f)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment