Commit d7374179 authored by d8ahazard's avatar d8ahazard

Merge remote-tracking branch 'upstream/master' into ModelLoader

parents 0dce0df1 498515e7
...@@ -8,6 +8,8 @@ __pycache__ ...@@ -8,6 +8,8 @@ __pycache__
/tmp /tmp
/model.ckpt /model.ckpt
/models/**/* /models/**/*
/GFPGANv1.3.pth
/gfpgan/weights/*.pth
/ui-config.json /ui-config.json
/outputs /outputs
/config.json /config.json
......
...@@ -15,6 +15,7 @@ titles = { ...@@ -15,6 +15,7 @@ titles = {
"\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed", "\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed",
"\u{1f3a8}": "Add a random artist to the prompt.", "\u{1f3a8}": "Add a random artist to the prompt.",
"\u2199\ufe0f": "Read generation parameters from prompt into user interface.", "\u2199\ufe0f": "Read generation parameters from prompt into user interface.",
"\uD83D\uDCC2": "Open images output directory",
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back", "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
......
...@@ -182,4 +182,23 @@ onUiUpdate(function(){ ...@@ -182,4 +182,23 @@ onUiUpdate(function(){
}); });
json_elem.parentElement.style.display="none" json_elem.parentElement.style.display="none"
if (!txt2img_textarea) {
txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea");
txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button"));
}
if (!img2img_textarea) {
img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
}
}) })
let txt2img_textarea, img2img_textarea = undefined;
let wait_time = 800
let token_timeout;
function update_token_counter(button_id) {
if (token_timeout)
clearTimeout(token_timeout);
token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
}
...@@ -15,11 +15,11 @@ torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 ...@@ -15,11 +15,11 @@ torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
commandline_args = os.environ.get('COMMANDLINE_ARGS', "") commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
k_diffusion_package = os.environ.get('K_DIFFUSION_PACKAGE', "git+https://github.com/crowsonkb/k-diffusion.git@1a0703dfb7d24d8806267c3e7ccc4caf67fd1331")
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379") gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc") stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6") taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "a7ec1974d4ccb394c2dca275f42cd97490618924")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
...@@ -107,10 +107,7 @@ if not is_installed("torch") or not is_installed("torchvision"): ...@@ -107,10 +107,7 @@ if not is_installed("torch") or not is_installed("torchvision"):
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch") run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch")
if not skip_torch_cuda_test: if not skip_torch_cuda_test:
run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDINE_ARGS variable to disable this check'") run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
if not is_installed("k_diffusion.sampling"):
run_pip(f"install {k_diffusion_package}", "k-diffusion")
if not is_installed("gfpgan"): if not is_installed("gfpgan"):
run_pip(f"install {gfpgan_package}", "gfpgan") run_pip(f"install {gfpgan_package}", "gfpgan")
...@@ -119,6 +116,7 @@ os.makedirs(dir_repos, exist_ok=True) ...@@ -119,6 +116,7 @@ os.makedirs(dir_repos, exist_ok=True)
git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash) git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash)
git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash) git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
git_clone("https://github.com/crowsonkb/k-diffusion.git", repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash) git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash)
if os.path.isdir(repo_dir('latent-diffusion')): if os.path.isdir(repo_dir('latent-diffusion')):
...@@ -133,6 +131,9 @@ run_pip(f"install -r {requirements_file}", "requirements for Web UI") ...@@ -133,6 +131,9 @@ run_pip(f"install -r {requirements_file}", "requirements for Web UI")
sys.argv += args sys.argv += args
if "--exit" in args:
print("Exiting because of --exit argument")
exit(0)
def start_webui(): def start_webui():
print(f"Launching Web UI with arguments: {' '.join(sys.argv[1:])}") print(f"Launching Web UI with arguments: {' '.join(sys.argv[1:])}")
......
...@@ -6,13 +6,14 @@ from PIL import Image ...@@ -6,13 +6,14 @@ from PIL import Image
import torch import torch
import tqdm import tqdm
from modules import processing, shared, images, devices from modules import processing, shared, images, devices, sd_models
from modules.shared import opts from modules.shared import opts
import modules.gfpgan_model import modules.gfpgan_model
from modules.ui import plaintext_to_html from modules.ui import plaintext_to_html
import modules.codeformer_model import modules.codeformer_model
import piexif import piexif
import piexif.helper import piexif.helper
import gradio as gr
cached_images = {} cached_images = {}
...@@ -141,7 +142,7 @@ def run_pnginfo(image): ...@@ -141,7 +142,7 @@ def run_pnginfo(image):
return '', geninfo, info return '', geninfo, info
def run_modelmerger(modelname_0, modelname_1, interp_method, interp_amount): def run_modelmerger(primary_model_name, secondary_model_name, interp_method, interp_amount, save_as_half, custom_name):
# Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation) # Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation)
def weighted_sum(theta0, theta1, alpha): def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1) return ((1 - alpha) * theta0) + (alpha * theta1)
...@@ -151,45 +152,52 @@ def run_modelmerger(modelname_0, modelname_1, interp_method, interp_amount): ...@@ -151,45 +152,52 @@ def run_modelmerger(modelname_0, modelname_1, interp_method, interp_amount):
alpha = alpha * alpha * (3 - (2 * alpha)) alpha = alpha * alpha * (3 - (2 * alpha))
return theta0 + ((theta1 - theta0) * alpha) return theta0 + ((theta1 - theta0) * alpha)
if os.path.exists(modelname_0): # Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
model0_filename = modelname_0 def inv_sigmoid(theta0, theta1, alpha):
modelname_0 = os.path.splitext(os.path.basename(modelname_0))[0] import math
else: alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0)
model0_filename = 'models/' + modelname_0 + '.ckpt' return theta0 + ((theta1 - theta0) * alpha)
if os.path.exists(modelname_1): primary_model_info = sd_models.checkpoints_list[primary_model_name]
model1_filename = modelname_1 secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
modelname_1 = os.path.splitext(os.path.basename(modelname_1))[0]
else:
model1_filename = 'models/' + modelname_1 + '.ckpt'
print(f"Loading {model0_filename}...") print(f"Loading {primary_model_info.filename}...")
model_0 = torch.load(model0_filename, map_location='cpu') primary_model = torch.load(primary_model_info.filename, map_location='cpu')
print(f"Loading {model1_filename}...") print(f"Loading {secondary_model_info.filename}...")
model_1 = torch.load(model1_filename, map_location='cpu') secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
theta_0 = model_0['state_dict'] theta_0 = primary_model['state_dict']
theta_1 = model_1['state_dict'] theta_1 = secondary_model['state_dict']
theta_funcs = { theta_funcs = {
"Weighted Sum": weighted_sum, "Weighted Sum": weighted_sum,
"Sigmoid": sigmoid, "Sigmoid": sigmoid,
"Inverse Sigmoid": inv_sigmoid,
} }
theta_func = theta_funcs[interp_method] theta_func = theta_funcs[interp_method]
print(f"Merging...") print(f"Merging...")
for key in tqdm.tqdm(theta_0.keys()): for key in tqdm.tqdm(theta_0.keys()):
if 'model' in key and key in theta_1: if 'model' in key and key in theta_1:
theta_0[key] = theta_func(theta_0[key], theta_1[key], interp_amount) theta_0[key] = theta_func(theta_0[key], theta_1[key], (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint
if save_as_half:
theta_0[key] = theta_0[key].half()
for key in theta_1.keys(): for key in theta_1.keys():
if 'model' in key and key not in theta_0: if 'model' in key and key not in theta_0:
theta_0[key] = theta_1[key] theta_0[key] = theta_1[key]
if save_as_half:
theta_0[key] = theta_0[key].half()
filename = primary_model_info.model_name + '_' + str(round(interp_amount, 2)) + '-' + secondary_model_info.model_name + '_' + str(round((float(1.0) - interp_amount), 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
filename = filename if custom_name == '' else (custom_name + '.ckpt')
output_modelname = os.path.join(shared.cmd_opts.ckpt_dir, filename)
output_modelname = 'models/' + modelname_0 + '-' + modelname_1 + '-merged.ckpt'
print(f"Saving to {output_modelname}...") print(f"Saving to {output_modelname}...")
torch.save(model_0, output_modelname) torch.save(primary_model, output_modelname)
sd_models.list_models()
print(f"Checkpoint saved.") print(f"Checkpoint saved.")
return "Checkpoint saved to " + output_modelname return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(3)]
...@@ -124,4 +124,4 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro ...@@ -124,4 +124,4 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
if opts.samples_log_stdout: if opts.samples_log_stdout:
print(generation_info_js) print(generation_info_js)
return processed.images, generation_info_js, plaintext_to_html(processed.info) return processed.images, generation_info_js, plaintext_to_html(processed.info)
\ No newline at end of file
...@@ -21,6 +21,7 @@ path_dirs = [ ...@@ -21,6 +21,7 @@ path_dirs = [
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer'), (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer'),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP'), (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP'),
(os.path.join(sd_path, '../latent-diffusion'), 'LDSR.py', 'LDSR'), (os.path.join(sd_path, '../latent-diffusion'), 'LDSR.py', 'LDSR'),
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion'),
] ]
paths = {} paths = {}
......
...@@ -49,7 +49,7 @@ def apply_color_correction(correction, image): ...@@ -49,7 +49,7 @@ def apply_color_correction(correction, image):
class StableDiffusionProcessing: class StableDiffusionProcessing:
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None): def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None):
self.sd_model = sd_model self.sd_model = sd_model
self.outpath_samples: str = outpath_samples self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids self.outpath_grids: str = outpath_grids
...@@ -75,15 +75,15 @@ class StableDiffusionProcessing: ...@@ -75,15 +75,15 @@ class StableDiffusionProcessing:
self.do_not_save_grid: bool = do_not_save_grid self.do_not_save_grid: bool = do_not_save_grid
self.extra_generation_params: dict = extra_generation_params or {} self.extra_generation_params: dict = extra_generation_params or {}
self.overlay_images = overlay_images self.overlay_images = overlay_images
self.eta = eta
self.paste_to = None self.paste_to = None
self.color_corrections = None self.color_corrections = None
self.denoising_strength: float = 0 self.denoising_strength: float = 0
self.ddim_eta = opts.ddim_eta
self.ddim_discretize = opts.ddim_discretize self.ddim_discretize = opts.ddim_discretize
self.s_churn = opts.s_churn self.s_churn = opts.s_churn
self.s_tmin = opts.s_tmin self.s_tmin = opts.s_tmin
self.s_tmax = float('inf') # not representable as a standard ui option self.s_tmax = float('inf') # not representable as a standard ui option
self.s_noise = opts.s_noise self.s_noise = opts.s_noise
if not seed_enable_extras: if not seed_enable_extras:
...@@ -100,7 +100,7 @@ class StableDiffusionProcessing: ...@@ -100,7 +100,7 @@ class StableDiffusionProcessing:
class Processed: class Processed:
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0): def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
self.images = images_list self.images = images_list
self.prompt = p.prompt self.prompt = p.prompt
self.negative_prompt = p.negative_prompt self.negative_prompt = p.negative_prompt
...@@ -124,7 +124,7 @@ class Processed: ...@@ -124,7 +124,7 @@ class Processed:
self.extra_generation_params = p.extra_generation_params self.extra_generation_params = p.extra_generation_params
self.index_of_first_image = index_of_first_image self.index_of_first_image = index_of_first_image
self.ddim_eta = p.ddim_eta self.eta = p.eta
self.ddim_discretize = p.ddim_discretize self.ddim_discretize = p.ddim_discretize
self.s_churn = p.s_churn self.s_churn = p.s_churn
self.s_tmin = p.s_tmin self.s_tmin = p.s_tmin
...@@ -139,6 +139,7 @@ class Processed: ...@@ -139,6 +139,7 @@ class Processed:
self.all_prompts = all_prompts or [self.prompt] self.all_prompts = all_prompts or [self.prompt]
self.all_seeds = all_seeds or [self.seed] self.all_seeds = all_seeds or [self.seed]
self.all_subseeds = all_subseeds or [self.subseed] self.all_subseeds = all_subseeds or [self.subseed]
self.infotexts = infotexts or [info]
def js(self): def js(self):
obj = { obj = {
...@@ -165,6 +166,7 @@ class Processed: ...@@ -165,6 +166,7 @@ class Processed:
"denoising_strength": self.denoising_strength, "denoising_strength": self.denoising_strength,
"extra_generation_params": self.extra_generation_params, "extra_generation_params": self.extra_generation_params,
"index_of_first_image": self.index_of_first_image, "index_of_first_image": self.index_of_first_image,
"infotexts": self.infotexts,
} }
return json.dumps(obj) return json.dumps(obj)
...@@ -269,6 +271,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration ...@@ -269,6 +271,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None), "Denoising strength": getattr(p, 'denoising_strength', None),
"Eta": (None if p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
} }
generation_params.update(p.extra_generation_params) generation_params.update(p.extra_generation_params)
...@@ -277,7 +280,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration ...@@ -277,7 +280,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else "" negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments]) return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
def process_images(p: StableDiffusionProcessing) -> Processed: def process_images(p: StableDiffusionProcessing) -> Processed:
...@@ -322,6 +325,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -322,6 +325,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if os.path.exists(cmd_opts.embeddings_dir): if os.path.exists(cmd_opts.embeddings_dir):
model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model) model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model)
infotexts = []
output_images = [] output_images = []
precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope) ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope)
...@@ -404,6 +408,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -404,6 +408,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if opts.samples_save and not p.do_not_save_samples: if opts.samples_save and not p.do_not_save_samples:
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p) images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
infotexts.append(infotext(n, i))
output_images.append(image) output_images.append(image)
state.nextjob() state.nextjob()
...@@ -416,6 +421,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -416,6 +421,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
grid = images.image_grid(output_images, p.batch_size) grid = images.image_grid(output_images, p.batch_size)
if opts.return_grid: if opts.return_grid:
infotexts.insert(0, infotext())
output_images.insert(0, grid) output_images.insert(0, grid)
index_of_first_image = 1 index_of_first_image = 1
...@@ -423,7 +429,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -423,7 +429,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
devices.torch_gc() devices.torch_gc()
return Processed(p, output_images, all_seeds[0], infotext(), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image) return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
......
...@@ -126,5 +126,93 @@ def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step): ...@@ -126,5 +126,93 @@ def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step):
return res return res
re_attention = re.compile(r"""
\\\(|
\\\)|
\\\[|
\\]|
\\\\|
\\|
\(|
\[|
:([+-]?[.\d]+)\)|
\)|
]|
[^\\()\[\]:]+|
:
""", re.X)
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs: text and its assoicated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '('
\[ - literal character '['
\) - literal character ')'
\] - literal character ']'
\\ - literal character '\'
anything else - just text
Example:
'a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).'
produces:
[
['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]
]
"""
#get_learned_conditioning_prompt_schedules(["fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"], 100) res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith('\\'):
res.append([text[1:], 1.0])
elif text == '(':
round_brackets.append(len(res))
elif text == '[':
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif text == ')' and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == ']' and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
res.append([text, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
return res
...@@ -55,7 +55,7 @@ def load_scripts(basedir): ...@@ -55,7 +55,7 @@ def load_scripts(basedir):
if not os.path.exists(basedir): if not os.path.exists(basedir):
return return
for filename in os.listdir(basedir): for filename in sorted(os.listdir(basedir)):
path = os.path.join(basedir, filename) path = os.path.join(basedir, filename)
if not os.path.isfile(path): if not os.path.isfile(path):
......
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
import numpy as np import numpy as np
from torch import einsum from torch import einsum
from modules import prompt_parser
from modules.shared import opts, device, cmd_opts from modules.shared import opts, device, cmd_opts
from ldm.util import default from ldm.util import default
...@@ -180,6 +181,7 @@ class StableDiffusionModelHijack: ...@@ -180,6 +181,7 @@ class StableDiffusionModelHijack:
dir_mtime = None dir_mtime = None
layers = None layers = None
circular_enabled = False circular_enabled = False
clip = None
def load_textual_inversion_embeddings(self, dirname, model): def load_textual_inversion_embeddings(self, dirname, model):
mt = os.path.getmtime(dirname) mt = os.path.getmtime(dirname)
...@@ -210,6 +212,7 @@ class StableDiffusionModelHijack: ...@@ -210,6 +212,7 @@ class StableDiffusionModelHijack:
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it' assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1] emb = next(iter(param_dict.items()))[1]
# diffuser concepts
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
assert len(data.keys()) == 1, 'embedding file has multiple terms in it' assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
...@@ -235,7 +238,7 @@ class StableDiffusionModelHijack: ...@@ -235,7 +238,7 @@ class StableDiffusionModelHijack:
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
continue continue
print(f"Loaded a total of {len(self.word_embeddings)} text inversion embeddings.") print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
def hijack(self, m): def hijack(self, m):
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
...@@ -243,6 +246,8 @@ class StableDiffusionModelHijack: ...@@ -243,6 +246,8 @@ class StableDiffusionModelHijack:
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
self.clip = m.cond_stage_model
if cmd_opts.opt_split_attention_v1: if cmd_opts.opt_split_attention_v1:
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
...@@ -259,6 +264,14 @@ class StableDiffusionModelHijack: ...@@ -259,6 +264,14 @@ class StableDiffusionModelHijack:
self.layers = flatten(m) self.layers = flatten(m)
def undo_hijack(self, m):
if type(m.cond_stage_model) == FrozenCLIPEmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
def apply_circular(self, enable): def apply_circular(self, enable):
if self.circular_enabled == enable: if self.circular_enabled == enable:
return return
...@@ -268,6 +281,11 @@ class StableDiffusionModelHijack: ...@@ -268,6 +281,11 @@ class StableDiffusionModelHijack:
for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]: for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
layer.padding_mode = 'circular' if enable else 'zeros' layer.padding_mode = 'circular' if enable else 'zeros'
def tokenize(self, text):
max_length = self.clip.max_length - 2
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
return remade_batch_tokens[0], token_count, max_length
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack): def __init__(self, wrapped, hijack):
...@@ -294,14 +312,101 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -294,14 +312,101 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if mult != 1.0: if mult != 1.0:
self.token_mults[ident] = mult self.token_mults[ident] = mult
def forward(self, text):
self.hijack.fixes = [] def tokenize_line(self, line, used_custom_terms, hijack_comments):
self.hijack.comments = [] id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
maxlen = self.wrapped.max_length
if opts.enable_emphasis:
parsed = prompt_parser.parse_prompt_attention(line)
else:
parsed = [[line, 1.0]]
tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"]
fixes = []
remade_tokens = []
multipliers = []
for tokens, (text, weight) in zip(tokenized, parsed):
i = 0
while i < len(tokens):
token = tokens[i]
possible_matches = self.hijack.ids_lookup.get(token, None)
if possible_matches is None:
remade_tokens.append(token)
multipliers.append(weight)
else:
found = False
for ids, word in possible_matches:
if tokens[i:i + len(ids)] == ids:
emb_len = int(self.hijack.word_embeddings[word].shape[0])
fixes.append((len(remade_tokens), word))
remade_tokens += [0] * emb_len
multipliers += [weight] * emb_len
i += len(ids) - 1
found = True
used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
break
if not found:
remade_tokens.append(token)
multipliers.append(weight)
i += 1
if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
ovf = remade_tokens[maxlen - 2:]
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
return remade_tokens, fixes, multipliers, token_count
def process_text(self, texts):
used_custom_terms = []
remade_batch_tokens = [] remade_batch_tokens = []
hijack_comments = []
hijack_fixes = []
token_count = 0
cache = {}
batch_multipliers = []
for line in texts:
if line in cache:
remade_tokens, fixes, multipliers = cache[line]
else:
remade_tokens, fixes, multipliers, token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
cache[line] = (remade_tokens, fixes, multipliers)
remade_batch_tokens.append(remade_tokens)
hijack_fixes.append(fixes)
batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def process_text_old(self, text):
id_start = self.wrapped.tokenizer.bos_token_id id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id id_end = self.wrapped.tokenizer.eos_token_id
maxlen = self.wrapped.max_length maxlen = self.wrapped.max_length
used_custom_terms = [] used_custom_terms = []
remade_batch_tokens = []
overflowing_words = []
hijack_comments = []
hijack_fixes = []
token_count = 0
cache = {} cache = {}
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"] batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
...@@ -353,9 +458,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -353,9 +458,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
ovf = remade_tokens[maxlen - 2:] ovf = remade_tokens[maxlen - 2:]
overflowing_words = [vocab.get(int(x), "") for x in ovf] overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
self.hijack.comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end] remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
cache[tuple_tokens] = (remade_tokens, fixes, multipliers) cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
...@@ -364,11 +468,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -364,11 +468,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
remade_batch_tokens.append(remade_tokens) remade_batch_tokens.append(remade_tokens)
self.hijack.fixes.append(fixes) hijack_fixes.append(fixes)
batch_multipliers.append(multipliers) batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def forward(self, text):
if opts.use_old_emphasis_implementation:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
else:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
self.hijack.fixes = hijack_fixes
self.hijack.comments = hijack_comments
if len(used_custom_terms) > 0: if len(used_custom_terms) > 0:
self.hijack.comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
tokens = torch.asarray(remade_batch_tokens).to(device) tokens = torch.asarray(remade_batch_tokens).to(device)
outputs = self.wrapped.transformer(input_ids=tokens) outputs = self.wrapped.transformer(input_ids=tokens)
......
...@@ -15,8 +15,9 @@ model_dir = "Stable-diffusion" ...@@ -15,8 +15,9 @@ model_dir = "Stable-diffusion"
model_path = os.path.join(models_path, model_dir) model_path = os.path.join(models_path, model_dir)
model_name = "sd-v1-4.ckpt" model_name = "sd-v1-4.ckpt"
model_url = "https://drive.yerf.org/wl/?id=EBfTrmcCCUAGaQBXVIj5lJmEhjoP1tgl&mode=grid&download=1" model_url = "https://drive.yerf.org/wl/?id=EBfTrmcCCUAGaQBXVIj5lJmEhjoP1tgl&mode=grid&download=1"
user_dir = None
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash']) CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
checkpoints_list = {} checkpoints_list = {}
try: try:
...@@ -47,23 +48,56 @@ def setup_model(dirname): ...@@ -47,23 +48,56 @@ def setup_model(dirname):
global model_path global model_path
global model_name global model_name
global model_url global model_url
global user_dir
global model_list
user_dir = dirname
if not os.path.exists(model_path): if not os.path.exists(model_path):
os.makedirs(model_path) os.makedirs(model_path)
checkpoints_list.clear() checkpoints_list.clear()
model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=dirname, download_name=model_name, ext_filter=".ckpt") list_models()
def checkpoint_tiles():
return sorted([x.title for x in checkpoints_list.values()])
def list_models():
global model_path
global model_url
global model_name
global user_dir
checkpoints_list.clear()
model_list = modelloader.load_models(model_path=model_path,model_url=model_url,command_path= user_dir,
ext_filter=[".ckpt"], download_name=model_name)
print(f"Model list: {model_list}")
model_dir = os.path.abspath(model_path)
def modeltitle(path, h):
abspath = os.path.abspath(path)
if abspath.startswith(model_dir):
name = abspath.replace(model_dir, '')
else:
name = os.path.basename(path)
if name.startswith("\\") or name.startswith("/"):
name = name[1:]
shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
return f'{name} [{h}]', shortname
cmd_ckpt = shared.cmd_opts.ckpt cmd_ckpt = shared.cmd_opts.ckpt
if os.path.exists(cmd_ckpt): if os.path.exists(cmd_ckpt):
h = model_hash(cmd_ckpt) h = model_hash(cmd_ckpt)
title = modeltitle(cmd_ckpt, h) title, model_name = modeltitle(cmd_ckpt, h)
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h) checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, model_name)
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
for filename in model_list: for filename in model_list:
h = model_hash(filename) h = model_hash(filename)
title = modeltitle(filename, h) title = modeltitle(filename, h)
checkpoints_list[title] = CheckpointInfo(filename, title, h) checkpoints_list[title] = CheckpointInfo(filename, title, h, model_name)
def model_hash(filename): def model_hash(filename):
...@@ -89,7 +123,7 @@ def select_checkpoint(): ...@@ -89,7 +123,7 @@ def select_checkpoint():
if len(checkpoints_list) == 0: if len(checkpoints_list) == 0:
print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr) print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr) print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr) print(f" - directory {os.path.abspath(shared.cmd_opts.stablediffusion_models_path)}", file=sys.stderr)
print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr) print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
exit(1) exit(1)
...@@ -142,7 +176,7 @@ def load_model(): ...@@ -142,7 +176,7 @@ def load_model():
def reload_model_weights(sd_model, info=None): def reload_model_weights(sd_model, info=None):
from modules import lowvram, devices from modules import lowvram, devices, sd_hijack
checkpoint_info = info or select_checkpoint() checkpoint_info = info or select_checkpoint()
if sd_model.sd_model_checkpint == checkpoint_info.filename: if sd_model.sd_model_checkpint == checkpoint_info.filename:
...@@ -153,8 +187,12 @@ def reload_model_weights(sd_model, info=None): ...@@ -153,8 +187,12 @@ def reload_model_weights(sd_model, info=None):
else: else:
sd_model.to(devices.cpu) sd_model.to(devices.cpu)
sd_hijack.model_hijack.undo_hijack(sd_model)
load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash) load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash)
sd_hijack.model_hijack.hijack(sd_model)
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device) sd_model.to(devices.device)
......
This diff is collapsed.
...@@ -155,6 +155,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" ...@@ -155,6 +155,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"), "export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
"use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"), "use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
})) }))
options_templates.update(options_section(('saving-paths', "Paths for saving"), { options_templates.update(options_section(('saving-paths', "Paths for saving"), {
...@@ -182,7 +183,6 @@ options_templates.update(options_section(('upscaling', "Upscaling"), { ...@@ -182,7 +183,6 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
"SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}), "SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}), "ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
"ldsr_pre_down": OptionInfo(1, "LDSR Pre-process downssample scale. 1 = no down-sampling, 4 = 1/4 scale.", gr.Slider, {"minimum": 1, "maximum": 4, "step": 1}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Radio, lambda: {"choices": [x.name for x in sd_upscalers]}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Radio, lambda: {"choices": [x.name for x in sd_upscalers]}),
})) }))
...@@ -190,7 +190,6 @@ options_templates.update(options_section(('face-restoration', "Face restoration" ...@@ -190,7 +190,6 @@ options_templates.update(options_section(('face-restoration', "Face restoration"
"face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}), "face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
"code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), "code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"), "face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
"save_selected_only": OptionInfo(False, "When using 'Save' button, only save a single selected image"),
})) }))
options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('system', "System"), {
...@@ -200,12 +199,13 @@ options_templates.update(options_section(('system', "System"), { ...@@ -200,12 +199,13 @@ options_templates.update(options_section(('system', "System"), {
})) }))
options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Radio, lambda: {"choices": [x.title for x in modules.sd_models.checkpoints_list.values()]}), "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Radio, lambda: {"choices": modules.sd_models.checkpoint_tiles()}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."), "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
"enable_emphasis": OptionInfo(True, "Use (text) to make model pay more attention to text and [text] to make it pay less attention"), "enable_emphasis": OptionInfo(True, "Eemphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"filter_nsfw": OptionInfo(False, "Filter NSFW content"), "filter_nsfw": OptionInfo(False, "Filter NSFW content"),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
...@@ -231,8 +231,9 @@ options_templates.update(options_section(('ui', "User interface"), { ...@@ -231,8 +231,9 @@ options_templates.update(options_section(('ui', "User interface"), {
})) }))
options_templates.update(options_section(('sampler-params', "Sampler parameters"), { options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
"ddim_eta": OptionInfo(0.0, "DDIM eta", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform','quad']}), "eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
......
This diff is collapsed.
...@@ -4,9 +4,8 @@ fairscale==0.4.4 ...@@ -4,9 +4,8 @@ fairscale==0.4.4
fonts fonts
font-roboto font-roboto
gfpgan gfpgan
gradio gradio==3.4b3
invisible-watermark invisible-watermark
git+https://github.com/crowsonkb/k-diffusion.git
numpy numpy
omegaconf omegaconf
piexif piexif
...@@ -16,5 +15,12 @@ realesrgan ...@@ -16,5 +15,12 @@ realesrgan
scikit-image>=0.19 scikit-image>=0.19
git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379 git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379
timm==0.4.12 timm==0.4.12
transformers transformers==4.19.2
torch torch
einops
jsonmerge
clean-fid
git+https://github.com/openai/CLIP@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
resize-right
torchdiffeq
kornia
...@@ -14,4 +14,11 @@ fonts ...@@ -14,4 +14,11 @@ fonts
font-roboto font-roboto
timm==0.6.7 timm==0.6.7
fairscale==0.4.9 fairscale==0.4.9
piexif==1.1.3 piexif==1.1.3
\ No newline at end of file einops==0.4.1
jsonmerge==1.8.0
clean-fid==0.1.29
git+https://github.com/openai/CLIP@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
resize-right==0.0.2
torchdiffeq==0.2.3
kornia==0.6.7
...@@ -87,12 +87,12 @@ axis_options = [ ...@@ -87,12 +87,12 @@ axis_options = [
AxisOption("Prompt S/R", str, apply_prompt, format_value), AxisOption("Prompt S/R", str, apply_prompt, format_value),
AxisOption("Sampler", str, apply_sampler, format_value), AxisOption("Sampler", str, apply_sampler, format_value),
AxisOption("Checkpoint name", str, apply_checkpoint, format_value), AxisOption("Checkpoint name", str, apply_checkpoint, format_value),
AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label), AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label),
AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label), AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label),
AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label), AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label),
AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label), AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label),
AxisOption("DDIM Eta", float, apply_field("ddim_eta"), format_value_add_label), AxisOption("Eta", float, apply_field("eta"), format_value_add_label),
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label),# as it is now all AxisOptionImg2Img items must go after AxisOption ones AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones
] ]
...@@ -159,6 +159,9 @@ class Script(scripts.Script): ...@@ -159,6 +159,9 @@ class Script(scripts.Script):
p.batch_size = 1 p.batch_size = 1
def process_axis(opt, vals): def process_axis(opt, vals):
if opt.label == 'Nothing':
return [0]
valslist = [x.strip() for x in vals.split(",")] valslist = [x.strip() for x in vals.split(",")]
if opt.type == int: if opt.type == int:
......
.output-html p {margin: 0 0.5em;} .output-html p {margin: 0 0.5em;}
.row > *,
.row > .gr-form > * {
min-width: min(120px, 100%);
flex: 1 1 0%;
}
.performance { .performance {
font-size: 0.85em; font-size: 0.85em;
color: #444; color: #444;
...@@ -43,13 +49,17 @@ ...@@ -43,13 +49,17 @@
margin-right: auto; margin-right: auto;
} }
#random_seed, #random_subseed, #reuse_seed, #reuse_subseed{ #random_seed, #random_subseed, #reuse_seed, #reuse_subseed, #open_folder{
min-width: auto; min-width: auto;
flex-grow: 0; flex-grow: 0;
padding-left: 0.25em; padding-left: 0.25em;
padding-right: 0.25em; padding-right: 0.25em;
} }
#hidden_element{
display: none;
}
#seed_row, #subseed_row{ #seed_row, #subseed_row{
gap: 0.5rem; gap: 0.5rem;
} }
...@@ -389,3 +399,7 @@ input[type="range"]{ ...@@ -389,3 +399,7 @@ input[type="range"]{
border-radius: 8px; border-radius: 8px;
display: none; display: none;
} }
.red {
color: red;
}
import os import os
import threading
from modules import devices
from modules.paths import script_path
import signal import signal
import threading import threading
import modules.paths import modules.paths
...@@ -44,6 +48,8 @@ def wrap_queued_call(func): ...@@ -44,6 +48,8 @@ def wrap_queued_call(func):
def wrap_gradio_gpu_call(func): def wrap_gradio_gpu_call(func):
def f(*args, **kwargs): def f(*args, **kwargs):
devices.torch_gc()
shared.state.sampling_step = 0 shared.state.sampling_step = 0
shared.state.job_count = -1 shared.state.job_count = -1
shared.state.job_no = 0 shared.state.job_no = 0
...@@ -59,6 +65,8 @@ def wrap_gradio_gpu_call(func): ...@@ -59,6 +65,8 @@ def wrap_gradio_gpu_call(func):
shared.state.job = "" shared.state.job = ""
shared.state.job_count = 0 shared.state.job_count = 0
devices.torch_gc()
return res return res
return modules.ui.wrap_gradio_call(f) return modules.ui.wrap_gradio_call(f)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment