Commit 3bfe2bb5 authored by brkirch's avatar brkirch

Merge remote-tracking branch 'upstream/master' into sub-quad_attn_opt

parents f6ab5a39 5f4fa942
...@@ -9,11 +9,19 @@ function dropReplaceImage( imgWrap, files ) { ...@@ -9,11 +9,19 @@ function dropReplaceImage( imgWrap, files ) {
return; return;
} }
const tmpFile = files[0];
imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click(); imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click();
const callback = () => { const callback = () => {
const fileInput = imgWrap.querySelector('input[type="file"]'); const fileInput = imgWrap.querySelector('input[type="file"]');
if ( fileInput ) { if ( fileInput ) {
fileInput.files = files; if ( files.length === 0 ) {
files = new DataTransfer();
files.items.add(tmpFile);
fileInput.files = files.files;
} else {
fileInput.files = files;
}
fileInput.dispatchEvent(new Event('change')); fileInput.dispatchEvent(new Event('change'));
} }
}; };
......
...@@ -81,9 +81,6 @@ titles = { ...@@ -81,9 +81,6 @@ titles = {
"vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).", "vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).",
"Highres. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition",
"Scale latent": "Uscale the image in latent space. Alternative is to produce the full image from latent representation, upscale that, and then move it back to latent space.",
"Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.", "Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.",
"Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.", "Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.",
...@@ -100,7 +97,13 @@ titles = { ...@@ -100,7 +97,13 @@ titles = {
"Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.", "Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.",
"Approx NN": "Cheap neural network approximation. Very fast compared to VAE, but produces pictures with 4 times smaller horizontal/vertical resoluton and lower quality.", "Approx NN": "Cheap neural network approximation. Very fast compared to VAE, but produces pictures with 4 times smaller horizontal/vertical resoluton and lower quality.",
"Approx cheap": "Very cheap approximation. Very fast compared to VAE, but produces pictures with 8 times smaller horizontal/vertical resoluton and extremely low quality." "Approx cheap": "Very cheap approximation. Very fast compared to VAE, but produces pictures with 8 times smaller horizontal/vertical resoluton and extremely low quality.",
"Hires. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition",
"Hires steps": "Number of sampling steps for upscaled picture. If 0, uses same as for original.",
"Upscale by": "Adjusts the size of the image by multiplying the original width and height by the selected value. Ignored if either Resize width to or Resize height to are non-zero.",
"Resize width to": "Resizes image to this width. If 0, width is inferred from either of two nearby sliders.",
"Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders."
} }
......
...@@ -148,8 +148,8 @@ function showGalleryImage() { ...@@ -148,8 +148,8 @@ function showGalleryImage() {
if(e && e.parentElement.tagName == 'DIV'){ if(e && e.parentElement.tagName == 'DIV'){
e.style.cursor='pointer' e.style.cursor='pointer'
e.style.userSelect='none' e.style.userSelect='none'
e.addEventListener('click', function (evt) { e.addEventListener('mousedown', function (evt) {
if(!opts.js_modal_lightbox) return; if(!opts.js_modal_lightbox || evt.button != 0) return;
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed) modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
showModal(evt) showModal(evt)
}, true); }, true);
......
// various functions for interation with ui.py not large enough to warrant putting them in separate files // various functions for interaction with ui.py not large enough to warrant putting them in separate files
function set_theme(theme){ function set_theme(theme){
gradioURL = window.location.href gradioURL = window.location.href
......
import base64 import base64
import io import io
import time import time
import datetime
import uvicorn import uvicorn
from threading import Lock from threading import Lock
from io import BytesIO from io import BytesIO
from gradio.processing_utils import decode_base64_to_file from gradio.processing_utils import decode_base64_to_file
from fastapi import APIRouter, Depends, FastAPI, HTTPException from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
from fastapi.security import HTTPBasic, HTTPBasicCredentials from fastapi.security import HTTPBasic, HTTPBasicCredentials
from secrets import compare_digest from secrets import compare_digest
...@@ -18,7 +19,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_ ...@@ -18,7 +19,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
from modules.textual_inversion.preprocess import preprocess from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image from PIL import PngImagePlugin,Image
from modules.sd_models import checkpoints_list from modules.sd_models import checkpoints_list, find_checkpoint_config
from modules.realesrgan_model import get_realesrgan_models from modules.realesrgan_model import get_realesrgan_models
from modules import devices from modules import devices
from typing import List from typing import List
...@@ -67,6 +68,27 @@ def encode_pil_to_base64(image): ...@@ -67,6 +68,27 @@ def encode_pil_to_base64(image):
bytes_data = output_bytes.getvalue() bytes_data = output_bytes.getvalue()
return base64.b64encode(bytes_data) return base64.b64encode(bytes_data)
def api_middleware(app: FastAPI):
@app.middleware("http")
async def log_and_time(req: Request, call_next):
ts = time.time()
res: Response = await call_next(req)
duration = str(round(time.time() - ts, 4))
res.headers["X-Process-Time"] = duration
endpoint = req.scope.get('path', 'err')
if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
code = res.status_code,
ver = req.scope.get('http_version', '0.0'),
cli = req.scope.get('client', ('0:0.0.0', 0))[0],
prot = req.scope.get('scheme', 'err'),
method = req.scope.get('method', 'err'),
endpoint = endpoint,
duration = duration,
))
return res
class Api: class Api:
def __init__(self, app: FastAPI, queue_lock: Lock): def __init__(self, app: FastAPI, queue_lock: Lock):
...@@ -79,6 +101,7 @@ class Api: ...@@ -79,6 +101,7 @@ class Api:
self.router = APIRouter() self.router = APIRouter()
self.app = app self.app = app
self.queue_lock = queue_lock self.queue_lock = queue_lock
api_middleware(self.app)
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse) self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse) self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
...@@ -303,7 +326,7 @@ class Api: ...@@ -303,7 +326,7 @@ class Api:
return upscalers return upscalers
def get_sd_models(self): def get_sd_models(self):
return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": x.config} for x in checkpoints_list.values()] return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
def get_hypernetworks(self): def get_hypernetworks(self):
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks] return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
......
...@@ -2,9 +2,30 @@ import sys ...@@ -2,9 +2,30 @@ import sys
import traceback import traceback
def print_error_explanation(message):
lines = message.strip().split("\n")
max_len = max([len(x) for x in lines])
print('=' * max_len, file=sys.stderr)
for line in lines:
print(line, file=sys.stderr)
print('=' * max_len, file=sys.stderr)
def display(e: Exception, task):
print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
message = str(e)
if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
print_error_explanation("""
The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its connfig file.
See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
""")
def run(code, task): def run(code, task):
try: try:
code() code()
except Exception as e: except Exception as e:
print(f"{task}: {type(e).__name__}", file=sys.stderr) display(task, e)
print(traceback.format_exc(), file=sys.stderr)
...@@ -19,8 +19,6 @@ from modules.shared import opts ...@@ -19,8 +19,6 @@ from modules.shared import opts
import modules.gfpgan_model import modules.gfpgan_model
from modules.ui import plaintext_to_html from modules.ui import plaintext_to_html
import modules.codeformer_model import modules.codeformer_model
import piexif
import piexif.helper
import gradio as gr import gradio as gr
import safetensors.torch import safetensors.torch
...@@ -58,6 +56,9 @@ cached_images: LruCache = LruCache(max_size=5) ...@@ -58,6 +56,9 @@ cached_images: LruCache = LruCache(max_size=5)
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
devices.torch_gc() devices.torch_gc()
shared.state.begin()
shared.state.job = 'extras'
imageArr = [] imageArr = []
# Also keep track of original file names # Also keep track of original file names
imageNameArr = [] imageNameArr = []
...@@ -94,6 +95,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -94,6 +95,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
# Extra operation definitions # Extra operation definitions
def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]: def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
shared.state.job = 'extras-gfpgan'
restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
res = Image.fromarray(restored_img) res = Image.fromarray(restored_img)
...@@ -104,6 +106,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -104,6 +106,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
return (res, info) return (res, info)
def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]: def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
shared.state.job = 'extras-codeformer'
restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight) restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
res = Image.fromarray(restored_img) res = Image.fromarray(restored_img)
...@@ -114,6 +117,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -114,6 +117,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
return (res, info) return (res, info)
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
shared.state.job = 'extras-upscale'
upscaler = shared.sd_upscalers[scaler_index] upscaler = shared.sd_upscalers[scaler_index]
res = upscaler.scaler.upscale(image, resize, upscaler.data_path) res = upscaler.scaler.upscale(image, resize, upscaler.data_path)
if mode == 1 and crop: if mode == 1 and crop:
...@@ -180,6 +184,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -180,6 +184,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
for image, image_name in zip(imageArr, imageNameArr): for image, image_name in zip(imageArr, imageNameArr):
if image is None: if image is None:
return outputs, "Please select an input image.", '' return outputs, "Please select an input image.", ''
shared.state.textinfo = f'Processing image {image_name}'
existing_pnginfo = image.info or {} existing_pnginfo = image.info or {}
image = image.convert("RGB") image = image.convert("RGB")
...@@ -193,6 +200,10 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -193,6 +200,10 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
else: else:
basename = '' basename = ''
if opts.enable_pnginfo: # append info before save
image.info = existing_pnginfo
image.info["extras"] = info
if save_output: if save_output:
# Add upscaler name as a suffix. # Add upscaler name as a suffix.
suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else "" suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else ""
...@@ -203,10 +214,6 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -203,10 +214,6 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix) no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix)
if opts.enable_pnginfo:
image.info = existing_pnginfo
image.info["extras"] = info
if extras_mode != 2 or show_extras_results : if extras_mode != 2 or show_extras_results :
outputs.append(image) outputs.append(image)
...@@ -242,6 +249,9 @@ def run_pnginfo(image): ...@@ -242,6 +249,9 @@ def run_pnginfo(image):
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format): def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
shared.state.begin()
shared.state.job = 'model-merge'
def weighted_sum(theta0, theta1, alpha): def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1) return ((1 - alpha) * theta0) + (alpha * theta1)
...@@ -263,8 +273,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam ...@@ -263,8 +273,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
theta_func1, theta_func2 = theta_funcs[interp_method] theta_func1, theta_func2 = theta_funcs[interp_method]
if theta_func1 and not tertiary_model_info: if theta_func1 and not tertiary_model_info:
shared.state.textinfo = "Failed: Interpolation method requires a tertiary model."
shared.state.end()
return ["Failed: Interpolation method requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)] return ["Failed: Interpolation method requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
shared.state.textinfo = f"Loading {secondary_model_info.filename}..."
print(f"Loading {secondary_model_info.filename}...") print(f"Loading {secondary_model_info.filename}...")
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
...@@ -281,6 +294,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam ...@@ -281,6 +294,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
theta_1[key] = torch.zeros_like(theta_1[key]) theta_1[key] = torch.zeros_like(theta_1[key])
del theta_2 del theta_2
shared.state.textinfo = f"Loading {primary_model_info.filename}..."
print(f"Loading {primary_model_info.filename}...") print(f"Loading {primary_model_info.filename}...")
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu') theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
...@@ -291,6 +305,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam ...@@ -291,6 +305,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
a = theta_0[key] a = theta_0[key]
b = theta_1[key] b = theta_1[key]
shared.state.textinfo = f'Merging layer {key}'
# this enables merging an inpainting model (A) with another one (B); # this enables merging an inpainting model (A) with another one (B);
# where normal model would have 4 channels, for latenst space, inpainting model would # where normal model would have 4 channels, for latenst space, inpainting model would
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9 # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
...@@ -303,8 +318,6 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam ...@@ -303,8 +318,6 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier) theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
result_is_inpainting_model = True result_is_inpainting_model = True
else: else:
assert a.shape == b.shape, f'Incompatible shapes for layer {key}: A is {a.shape}, and B is {b.shape}'
theta_0[key] = theta_func2(a, b, multiplier) theta_0[key] = theta_func2(a, b, multiplier)
if save_as_half: if save_as_half:
...@@ -332,6 +345,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam ...@@ -332,6 +345,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
output_modelname = os.path.join(ckpt_dir, filename) output_modelname = os.path.join(ckpt_dir, filename)
shared.state.textinfo = f"Saving to {output_modelname}..."
print(f"Saving to {output_modelname}...") print(f"Saving to {output_modelname}...")
_, extension = os.path.splitext(output_modelname) _, extension = os.path.splitext(output_modelname)
...@@ -343,4 +357,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam ...@@ -343,4 +357,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
sd_models.list_models() sd_models.list_models()
print("Checkpoint saved.") print("Checkpoint saved.")
shared.state.textinfo = "Checkpoint saved to " + output_modelname
shared.state.end()
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)] return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
...@@ -212,11 +212,10 @@ def restore_old_hires_fix_params(res): ...@@ -212,11 +212,10 @@ def restore_old_hires_fix_params(res):
firstpass_width = math.ceil(scale * width / 64) * 64 firstpass_width = math.ceil(scale * width / 64) * 64
firstpass_height = math.ceil(scale * height / 64) * 64 firstpass_height = math.ceil(scale * height / 64) * 64
hr_scale = width / firstpass_width if firstpass_width > 0 else height / firstpass_height
res['Size-1'] = firstpass_width res['Size-1'] = firstpass_width
res['Size-2'] = firstpass_height res['Size-2'] = firstpass_height
res['Hires upscale'] = hr_scale res['Hires resize-1'] = width
res['Hires resize-2'] = height
def parse_generation_parameters(x: str): def parse_generation_parameters(x: str):
...@@ -276,6 +275,10 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model ...@@ -276,6 +275,10 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
hypernet_hash = res.get("Hypernet hash", None) hypernet_hash = res.get("Hypernet hash", None)
res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash) res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash)
if "Hires resize-1" not in res:
res["Hires resize-1"] = 0
res["Hires resize-2"] = 0
restore_old_hires_fix_params(res) restore_old_hires_fix_params(res)
return res return res
......
...@@ -402,10 +402,8 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, ...@@ -402,10 +402,8 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
shared.reload_hypernetworks() shared.reload_hypernetworks()
return fn
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem. # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images from modules import images
...@@ -417,6 +415,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, ...@@ -417,6 +415,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
shared.loaded_hypernetwork = Hypernetwork() shared.loaded_hypernetwork = Hypernetwork()
shared.loaded_hypernetwork.load(path) shared.loaded_hypernetwork.load(path)
shared.state.job = "train-hypernetwork"
shared.state.textinfo = "Initializing hypernetwork training..." shared.state.textinfo = "Initializing hypernetwork training..."
shared.state.job_count = steps shared.state.job_count = steps
...@@ -447,6 +446,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, ...@@ -447,6 +446,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
return hypernetwork, filename return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, initial_step) scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
if clip_grad:
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
# dataset loading may take a while, so input validations and early returns should be done before this # dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
...@@ -465,7 +468,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, ...@@ -465,7 +468,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
shared.parallel_processing_allowed = False shared.parallel_processing_allowed = False
shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu)
weights = hypernetwork.weights() weights = hypernetwork.weights()
hypernetwork.train_mode() hypernetwork.train_mode()
...@@ -524,6 +527,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, ...@@ -524,6 +527,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
if shared.state.interrupted: if shared.state.interrupted:
break break
if clip_grad:
clip_grad_sched.step(hypernetwork.step)
with devices.autocast(): with devices.autocast():
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
if tag_drop_out != 0 or shuffle_tags: if tag_drop_out != 0 or shuffle_tags:
...@@ -538,14 +544,14 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, ...@@ -538,14 +544,14 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
_loss_step += loss.item() _loss_step += loss.item()
scaler.scale(loss).backward() scaler.scale(loss).backward()
# go back until we reach gradient accumulation steps # go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0: if (j + 1) % gradient_step != 0:
continue continue
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}")
# scaler.unscale_(optimizer) if clip_grad:
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}") clip_grad(weights, clip_grad_sched.learn_rate)
# torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0)
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
hypernetwork.step += 1 hypernetwork.step += 1
......
...@@ -136,7 +136,8 @@ class InterrogateModels: ...@@ -136,7 +136,8 @@ class InterrogateModels:
def interrogate(self, pil_image): def interrogate(self, pil_image):
res = "" res = ""
shared.state.begin()
shared.state.job = 'interrogate'
try: try:
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
...@@ -177,5 +178,6 @@ class InterrogateModels: ...@@ -177,5 +178,6 @@ class InterrogateModels:
res += "<error>" res += "<error>"
self.unload() self.unload()
shared.state.end()
return res return res
This diff is collapsed.
...@@ -33,25 +33,34 @@ def apply_optimizations(): ...@@ -33,25 +33,34 @@ def apply_optimizations():
ldm.modules.diffusionmodules.model.nonlinearity = silu ldm.modules.diffusionmodules.model.nonlinearity = silu
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
optimization_method = None
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
print("Applying xformers cross attention optimization.") print("Applying xformers cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
optimization_method = 'xformers'
elif cmd_opts.opt_sub_quad_attention: elif cmd_opts.opt_sub_quad_attention:
print("Applying sub-quadratic cross attention optimization.") print("Applying sub-quadratic cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
optimization_method = 'sub-quadratic'
elif cmd_opts.opt_split_attention_v1: elif cmd_opts.opt_split_attention_v1:
print("Applying v1 cross attention optimization.") print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
optimization_method = 'V1'
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()): elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
print("Applying cross attention optimization (InvokeAI).") print("Applying cross attention optimization (InvokeAI).")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
optimization_method = 'InvokeAI'
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
print("Applying cross attention optimization (Doggettx).") print("Applying cross attention optimization (Doggettx).")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
optimization_method = 'Doggettx'
return optimization_method
def undo_optimizations(): def undo_optimizations():
...@@ -72,6 +81,7 @@ class StableDiffusionModelHijack: ...@@ -72,6 +81,7 @@ class StableDiffusionModelHijack:
layers = None layers = None
circular_enabled = False circular_enabled = False
clip = None clip = None
optimization_method = None
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir) embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
...@@ -91,7 +101,7 @@ class StableDiffusionModelHijack: ...@@ -91,7 +101,7 @@ class StableDiffusionModelHijack:
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self) m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
apply_optimizations() self.optimization_method = apply_optimizations()
self.clip = m.cond_stage_model self.clip = m.cond_stage_model
......
...@@ -97,8 +97,11 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F ...@@ -97,8 +97,11 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
def should_hijack_inpainting(checkpoint_info): def should_hijack_inpainting(checkpoint_info):
from modules import sd_models
ckpt_basename = os.path.basename(checkpoint_info.filename).lower() ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
cfg_basename = os.path.basename(checkpoint_info.config).lower() cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower()
return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename
......
...@@ -20,7 +20,7 @@ from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inp ...@@ -20,7 +20,7 @@ from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inp
model_dir = "Stable-diffusion" model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir)) model_path = os.path.abspath(os.path.join(models_path, model_dir))
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config']) CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
checkpoints_list = {} checkpoints_list = {}
checkpoints_loaded = collections.OrderedDict() checkpoints_loaded = collections.OrderedDict()
...@@ -48,6 +48,14 @@ def checkpoint_tiles(): ...@@ -48,6 +48,14 @@ def checkpoint_tiles():
return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key) return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)
def find_checkpoint_config(info):
config = os.path.splitext(info.filename)[0] + ".yaml"
if os.path.exists(config):
return config
return shared.cmd_opts.config
def list_models(): def list_models():
checkpoints_list.clear() checkpoints_list.clear()
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"]) model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"])
...@@ -73,7 +81,7 @@ def list_models(): ...@@ -73,7 +81,7 @@ def list_models():
if os.path.exists(cmd_ckpt): if os.path.exists(cmd_ckpt):
h = model_hash(cmd_ckpt) h = model_hash(cmd_ckpt)
title, short_model_name = modeltitle(cmd_ckpt, h) title, short_model_name = modeltitle(cmd_ckpt, h)
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config) checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
shared.opts.data['sd_model_checkpoint'] = title shared.opts.data['sd_model_checkpoint'] = title
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
...@@ -81,12 +89,7 @@ def list_models(): ...@@ -81,12 +89,7 @@ def list_models():
h = model_hash(filename) h = model_hash(filename)
title, short_model_name = modeltitle(filename, h) title, short_model_name = modeltitle(filename, h)
basename, _ = os.path.splitext(filename) checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
config = basename + ".yaml"
if not os.path.exists(config):
config = shared.cmd_opts.config
checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config)
def get_closet_checkpoint_match(searchString): def get_closet_checkpoint_match(searchString):
...@@ -168,7 +171,10 @@ def get_state_dict_from_checkpoint(pl_sd): ...@@ -168,7 +171,10 @@ def get_state_dict_from_checkpoint(pl_sd):
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None): def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
_, extension = os.path.splitext(checkpoint_file) _, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors": if extension.lower() == ".safetensors":
pl_sd = safetensors.torch.load_file(checkpoint_file, device=map_location or shared.weight_load_location) device = map_location or shared.weight_load_location
if device is None:
device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu"
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
else: else:
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location) pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
...@@ -278,12 +284,14 @@ def enable_midas_autodownload(): ...@@ -278,12 +284,14 @@ def enable_midas_autodownload():
midas.api.load_model = load_model_wrapper midas.api.load_model = load_model_wrapper
def load_model(checkpoint_info=None): def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint() checkpoint_info = checkpoint_info or select_checkpoint()
checkpoint_config = find_checkpoint_config(checkpoint_info)
if checkpoint_info.config != shared.cmd_opts.config: if checkpoint_config != shared.cmd_opts.config:
print(f"Loading config from: {checkpoint_info.config}") print(f"Loading config from: {checkpoint_config}")
if shared.sd_model: if shared.sd_model:
sd_hijack.model_hijack.undo_hijack(shared.sd_model) sd_hijack.model_hijack.undo_hijack(shared.sd_model)
...@@ -291,7 +299,7 @@ def load_model(checkpoint_info=None): ...@@ -291,7 +299,7 @@ def load_model(checkpoint_info=None):
gc.collect() gc.collect()
devices.torch_gc() devices.torch_gc()
sd_config = OmegaConf.load(checkpoint_info.config) sd_config = OmegaConf.load(checkpoint_config)
if should_hijack_inpainting(checkpoint_info): if should_hijack_inpainting(checkpoint_info):
# Hardcoded config for now... # Hardcoded config for now...
...@@ -300,9 +308,6 @@ def load_model(checkpoint_info=None): ...@@ -300,9 +308,6 @@ def load_model(checkpoint_info=None):
sd_config.model.params.unet_config.params.in_channels = 9 sd_config.model.params.unet_config.params.in_channels = 9
sd_config.model.params.finetune_keys = None sd_config.model.params.finetune_keys = None
# Create a "fake" config with a different name so that we know to unload it when switching models.
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
if not hasattr(sd_config.model.params, "use_ema"): if not hasattr(sd_config.model.params, "use_ema"):
sd_config.model.params.use_ema = False sd_config.model.params.use_ema = False
...@@ -312,6 +317,7 @@ def load_model(checkpoint_info=None): ...@@ -312,6 +317,7 @@ def load_model(checkpoint_info=None):
sd_config.model.params.unet_config.params.use_fp16 = False sd_config.model.params.unet_config.params.use_fp16 = False
sd_model = instantiate_from_config(sd_config.model) sd_model = instantiate_from_config(sd_config.model)
load_model_weights(sd_model, checkpoint_info) load_model_weights(sd_model, checkpoint_info)
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
...@@ -336,14 +342,17 @@ def load_model(checkpoint_info=None): ...@@ -336,14 +342,17 @@ def load_model(checkpoint_info=None):
def reload_model_weights(sd_model=None, info=None): def reload_model_weights(sd_model=None, info=None):
from modules import lowvram, devices, sd_hijack from modules import lowvram, devices, sd_hijack
checkpoint_info = info or select_checkpoint() checkpoint_info = info or select_checkpoint()
if not sd_model: if not sd_model:
sd_model = shared.sd_model sd_model = shared.sd_model
current_checkpoint_info = sd_model.sd_checkpoint_info
checkpoint_config = find_checkpoint_config(current_checkpoint_info)
if sd_model.sd_model_checkpoint == checkpoint_info.filename: if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return return
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): if checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
del sd_model del sd_model
checkpoints_loaded.clear() checkpoints_loaded.clear()
load_model(checkpoint_info) load_model(checkpoint_info)
...@@ -356,13 +365,19 @@ def reload_model_weights(sd_model=None, info=None): ...@@ -356,13 +365,19 @@ def reload_model_weights(sd_model=None, info=None):
sd_hijack.model_hijack.undo_hijack(sd_model) sd_hijack.model_hijack.undo_hijack(sd_model)
load_model_weights(sd_model, checkpoint_info) try:
load_model_weights(sd_model, checkpoint_info)
sd_hijack.model_hijack.hijack(sd_model) except Exception as e:
script_callbacks.model_loaded_callback(sd_model) print("Failed to load checkpoint, restoring previous")
load_model_weights(sd_model, current_checkpoint_info)
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: raise
sd_model.to(devices.device) finally:
sd_hijack.model_hijack.hijack(sd_model)
script_callbacks.model_loaded_callback(sd_model)
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device)
print("Weights loaded.") print("Weights loaded.")
return sd_model return sd_model
...@@ -97,8 +97,9 @@ sampler_extra_params = { ...@@ -97,8 +97,9 @@ sampler_extra_params = {
def setup_img2img_steps(p, steps=None): def setup_img2img_steps(p, steps=None):
if opts.img2img_fix_steps or steps is not None: if opts.img2img_fix_steps or steps is not None:
steps = int((steps or p.steps) / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0 requested_steps = (steps or p.steps)
t_enc = p.steps - 1 steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
t_enc = requested_steps - 1
else: else:
steps = p.steps steps = p.steps
t_enc = int(min(p.denoising_strength, 0.999) * steps) t_enc = int(min(p.denoising_strength, 0.999) * steps)
......
...@@ -14,7 +14,7 @@ import modules.interrogate ...@@ -14,7 +14,7 @@ import modules.interrogate
import modules.memmon import modules.memmon
import modules.styles import modules.styles
import modules.devices as devices import modules.devices as devices
from modules import localization, sd_vae, extensions, script_loading from modules import localization, sd_vae, extensions, script_loading, errors
from modules.paths import models_path, script_path, sd_path from modules.paths import models_path, script_path, sd_path
...@@ -86,6 +86,7 @@ parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencode ...@@ -86,6 +86,7 @@ parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencode
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)") parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
parser.add_argument("--api-log", action='store_true', help="use api-log=True to enable logging of all API requests")
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui") parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui")
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
...@@ -156,6 +157,7 @@ class State: ...@@ -156,6 +157,7 @@ class State:
job = "" job = ""
job_no = 0 job_no = 0
job_count = 0 job_count = 0
processing_has_refined_job_count = False
job_timestamp = '0' job_timestamp = '0'
sampling_step = 0 sampling_step = 0
sampling_steps = 0 sampling_steps = 0
...@@ -186,6 +188,7 @@ class State: ...@@ -186,6 +188,7 @@ class State:
"interrupted": self.interrupted, "interrupted": self.interrupted,
"job": self.job, "job": self.job,
"job_count": self.job_count, "job_count": self.job_count,
"job_timestamp": self.job_timestamp,
"job_no": self.job_no, "job_no": self.job_no,
"sampling_step": self.sampling_step, "sampling_step": self.sampling_step,
"sampling_steps": self.sampling_steps, "sampling_steps": self.sampling_steps,
...@@ -196,6 +199,7 @@ class State: ...@@ -196,6 +199,7 @@ class State:
def begin(self): def begin(self):
self.sampling_step = 0 self.sampling_step = 0
self.job_count = -1 self.job_count = -1
self.processing_has_refined_job_count = False
self.job_no = 0 self.job_no = 0
self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
self.current_latent = None self.current_latent = None
...@@ -216,12 +220,13 @@ class State: ...@@ -216,12 +220,13 @@ class State:
"""sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this""" """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
def set_current_image(self): def set_current_image(self):
if not parallel_processing_allowed:
return
if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0: if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0:
self.do_set_current_image() self.do_set_current_image()
def do_set_current_image(self): def do_set_current_image(self):
if not parallel_processing_allowed:
return
if self.current_latent is None: if self.current_latent is None:
return return
...@@ -233,6 +238,7 @@ class State: ...@@ -233,6 +238,7 @@ class State:
self.current_image_sampling_step = self.sampling_step self.current_image_sampling_step = self.sampling_step
state = State() state = State()
artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv')) artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
...@@ -359,7 +365,7 @@ options_templates.update(options_section(('system', "System"), { ...@@ -359,7 +365,7 @@ options_templates.update(options_section(('system', "System"), {
options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."), "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."), "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
...@@ -498,7 +504,12 @@ class Options: ...@@ -498,7 +504,12 @@ class Options:
return False return False
if self.data_labels[key].onchange is not None: if self.data_labels[key].onchange is not None:
self.data_labels[key].onchange() try:
self.data_labels[key].onchange()
except Exception as e:
errors.display(e, f"changing setting {key} to {value}")
setattr(self, key, oldval)
return False
return True return True
...@@ -563,8 +574,11 @@ if os.path.exists(config_filename): ...@@ -563,8 +574,11 @@ if os.path.exists(config_filename):
latent_upscale_default_mode = "Latent" latent_upscale_default_mode = "Latent"
latent_upscale_modes = { latent_upscale_modes = {
"Latent": "bilinear", "Latent": {"mode": "bilinear", "antialias": False},
"Latent (nearest)": "nearest", "Latent (antialiased)": {"mode": "bilinear", "antialias": True},
"Latent (bicubic)": {"mode": "bicubic", "antialias": False},
"Latent (bicubic antialiased)": {"mode": "bicubic", "antialias": True},
"Latent (nearest)": {"mode": "nearest", "antialias": False},
} }
sd_upscalers = [] sd_upscalers = []
...@@ -600,7 +614,7 @@ class TotalTQDM: ...@@ -600,7 +614,7 @@ class TotalTQDM:
return return
if self._tqdm is None: if self._tqdm is None:
self.reset() self.reset()
self._tqdm.total=new_total self._tqdm.total = new_total
def clear(self): def clear(self):
if self._tqdm is not None: if self._tqdm is not None:
......
...@@ -58,14 +58,19 @@ class LearnRateScheduler: ...@@ -58,14 +58,19 @@ class LearnRateScheduler:
self.finished = False self.finished = False
def apply(self, optimizer, step_number): def step(self, step_number):
if step_number < self.end_step: if step_number < self.end_step:
return return False
try: try:
(self.learn_rate, self.end_step) = next(self.schedules) (self.learn_rate, self.end_step) = next(self.schedules)
except Exception: except StopIteration:
self.finished = True self.finished = True
return False
return True
def apply(self, optimizer, step_number):
if not self.step(step_number):
return return
if self.verbose: if self.verbose:
......
...@@ -124,6 +124,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre ...@@ -124,6 +124,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
files = listfiles(src) files = listfiles(src)
shared.state.job = "preprocess"
shared.state.textinfo = "Preprocessing..." shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files) shared.state.job_count = len(files)
......
...@@ -28,6 +28,7 @@ class Embedding: ...@@ -28,6 +28,7 @@ class Embedding:
self.cached_checksum = None self.cached_checksum = None
self.sd_checkpoint = None self.sd_checkpoint = None
self.sd_checkpoint_name = None self.sd_checkpoint_name = None
self.optimizer_state_dict = None
def save(self, filename): def save(self, filename):
embedding_data = { embedding_data = {
...@@ -41,6 +42,13 @@ class Embedding: ...@@ -41,6 +42,13 @@ class Embedding:
torch.save(embedding_data, filename) torch.save(embedding_data, filename)
if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None:
optimizer_saved_dict = {
'hash': self.checksum(),
'optimizer_state_dict': self.optimizer_state_dict,
}
torch.save(optimizer_saved_dict, filename + '.optim')
def checksum(self): def checksum(self):
if self.cached_checksum is not None: if self.cached_checksum is not None:
return self.cached_checksum return self.cached_checksum
...@@ -95,9 +103,10 @@ class EmbeddingDatabase: ...@@ -95,9 +103,10 @@ class EmbeddingDatabase:
self.expected_shape = self.get_expected_shape() self.expected_shape = self.get_expected_shape()
def process_file(path, filename): def process_file(path, filename):
name = os.path.splitext(filename)[0] name, ext = os.path.splitext(filename)
ext = ext.upper()
if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']: if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
embed_image = Image.open(path) embed_image = Image.open(path)
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text: if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
data = embedding_from_b64(embed_image.text['sd-ti-embedding']) data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
...@@ -105,8 +114,10 @@ class EmbeddingDatabase: ...@@ -105,8 +114,10 @@ class EmbeddingDatabase:
else: else:
data = extract_image_data_embed(embed_image) data = extract_image_data_embed(embed_image)
name = data.get('name', name) name = data.get('name', name)
else: elif ext in ['.BIN', '.PT']:
data = torch.load(path, map_location="cpu") data = torch.load(path, map_location="cpu")
else:
return
# textual inversion embeddings # textual inversion embeddings
if 'string_to_param' in data: if 'string_to_param' in data:
...@@ -240,11 +251,12 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat ...@@ -240,11 +251,12 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
if save_model_every or create_image_every: if save_model_every or create_image_every:
assert log_directory, "Log directory is empty" assert log_directory, "Log directory is empty"
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0 save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0 create_image_every = create_image_every or 0
validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding") validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
shared.state.job = "train-embedding"
shared.state.textinfo = "Initializing textual inversion training..." shared.state.textinfo = "Initializing textual inversion training..."
shared.state.job_count = steps shared.state.job_count = steps
...@@ -282,6 +294,11 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ ...@@ -282,6 +294,11 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
return embedding, filename return embedding, filename
scheduler = LearnRateScheduler(learn_rate, steps, initial_step) scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
None
if clip_grad:
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
# dataset loading may take a while, so input validations and early returns should be done before this # dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
old_parallel_processing_allowed = shared.parallel_processing_allowed old_parallel_processing_allowed = shared.parallel_processing_allowed
...@@ -300,6 +317,19 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ ...@@ -300,6 +317,19 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
embedding.vec.requires_grad = True embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0) optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
if shared.opts.save_optimizer_state:
optimizer_state_dict = None
if os.path.exists(filename + '.optim'):
optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu')
if embedding.checksum() == optimizer_saved_dict.get('hash', None):
optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
if optimizer_state_dict is not None:
optimizer.load_state_dict(optimizer_state_dict)
print("Loaded existing optimizer from checkpoint")
else:
print("No saved optimizer exists in checkpoint")
scaler = torch.cuda.amp.GradScaler() scaler = torch.cuda.amp.GradScaler()
batch_size = ds.batch_size batch_size = ds.batch_size
...@@ -315,6 +345,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ ...@@ -315,6 +345,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
forced_filename = "<none>" forced_filename = "<none>"
embedding_yet_to_be_embedded = False embedding_yet_to_be_embedded = False
is_training_inpainting_model = shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}
img_c = None
pbar = tqdm.tqdm(total=steps - initial_step) pbar = tqdm.tqdm(total=steps - initial_step)
try: try:
for i in range((steps-initial_step) * gradient_step): for i in range((steps-initial_step) * gradient_step):
...@@ -332,14 +365,22 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ ...@@ -332,14 +365,22 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
if shared.state.interrupted: if shared.state.interrupted:
break break
if clip_grad:
clip_grad_sched.step(embedding.step)
with devices.autocast(): with devices.autocast():
# c = stack_conds(batch.cond).to(devices.device)
# mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory)
# print(mask)
# c[:, 1:1+embedding.vec.shape[0]] = embedding.vec.to(devices.device, non_blocking=pin_memory)
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
c = shared.sd_model.cond_stage_model(batch.cond_text) c = shared.sd_model.cond_stage_model(batch.cond_text)
loss = shared.sd_model(x, c)[0] / gradient_step
if is_training_inpainting_model:
if img_c is None:
img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)
cond = {"c_concat": [img_c], "c_crossattn": [c]}
else:
cond = c
loss = shared.sd_model(x, cond)[0] / gradient_step
del x del x
_loss_step += loss.item() _loss_step += loss.item()
...@@ -348,6 +389,10 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ ...@@ -348,6 +389,10 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
# go back until we reach gradient accumulation steps # go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0: if (j + 1) % gradient_step != 0:
continue continue
if clip_grad:
clip_grad(embedding.vec, clip_grad_sched.learn_rate)
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
embedding.step += 1 embedding.step += 1
...@@ -366,9 +411,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ ...@@ -366,9 +411,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
# Before saving, change name to match current checkpoint. # Before saving, change name to match current checkpoint.
embedding_name_every = f'{embedding_name}-{steps_done}' embedding_name_every = f'{embedding_name}-{steps_done}'
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt') last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
#if shared.opts.save_optimizer_state: save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
#embedding.optimizer_state_dict = optimizer.state_dict()
save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
embedding_yet_to_be_embedded = True embedding_yet_to_be_embedded = True
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, { write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
...@@ -458,7 +501,7 @@ Last saved image: {html.escape(last_saved_image)}<br/> ...@@ -458,7 +501,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>
""" """
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True) save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
except Exception: except Exception:
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
pass pass
...@@ -470,7 +513,7 @@ Last saved image: {html.escape(last_saved_image)}<br/> ...@@ -470,7 +513,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
return embedding, filename return embedding, filename
def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True): def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
old_embedding_name = embedding.name old_embedding_name = embedding.name
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
...@@ -481,6 +524,7 @@ def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cache ...@@ -481,6 +524,7 @@ def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cache
if remove_cached_checksum: if remove_cached_checksum:
embedding.cached_checksum = None embedding.cached_checksum = None
embedding.name = embedding_name embedding.name = embedding_name
embedding.optimizer_state_dict = optimizer.state_dict()
embedding.save(filename) embedding.save(filename)
except: except:
embedding.sd_checkpoint = old_sd_checkpoint embedding.sd_checkpoint = old_sd_checkpoint
......
...@@ -8,7 +8,7 @@ import modules.processing as processing ...@@ -8,7 +8,7 @@ import modules.processing as processing
from modules.ui import plaintext_to_html from modules.ui import plaintext_to_html
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, *args): def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args):
p = StableDiffusionProcessingTxt2Img( p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model, sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
...@@ -35,6 +35,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: ...@@ -35,6 +35,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
denoising_strength=denoising_strength if enable_hr else None, denoising_strength=denoising_strength if enable_hr else None,
hr_scale=hr_scale, hr_scale=hr_scale,
hr_upscaler=hr_upscaler, hr_upscaler=hr_upscaler,
hr_second_pass_steps=hr_second_pass_steps,
hr_resize_x=hr_resize_x,
hr_resize_y=hr_resize_y,
) )
p.scripts = modules.scripts.scripts_txt2img p.scripts = modules.scripts.scripts_txt2img
......
This diff is collapsed.
...@@ -5,7 +5,7 @@ basicsr==1.4.2 ...@@ -5,7 +5,7 @@ basicsr==1.4.2
gfpgan==1.3.8 gfpgan==1.3.8
gradio==3.15.0 gradio==3.15.0
numpy==1.23.3 numpy==1.23.3
Pillow==9.2.0 Pillow==9.4.0
realesrgan==0.3.0 realesrgan==0.3.0
torch torch
omegaconf==2.2.3 omegaconf==2.2.3
...@@ -26,5 +26,5 @@ lark==1.1.2 ...@@ -26,5 +26,5 @@ lark==1.1.2
inflection==0.5.1 inflection==0.5.1
GitPython==3.1.27 GitPython==3.1.27
torchsde==0.2.5 torchsde==0.2.5
safetensors==0.2.5 safetensors==0.2.7
httpcore<=0.15 httpcore<=0.15
...@@ -4,7 +4,7 @@ function gradioApp() { ...@@ -4,7 +4,7 @@ function gradioApp() {
} }
function get_uiCurrentTab() { function get_uiCurrentTab() {
return gradioApp().querySelector('.tabs button:not(.border-transparent)') return gradioApp().querySelector('#tabs button:not(.border-transparent)')
} }
function get_uiCurrentTabContent() { function get_uiCurrentTabContent() {
......
...@@ -10,7 +10,7 @@ import numpy as np ...@@ -10,7 +10,7 @@ import numpy as np
import modules.scripts as scripts import modules.scripts as scripts
import gradio as gr import gradio as gr
from modules import images, paths, sd_samplers from modules import images, paths, sd_samplers, processing
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
...@@ -285,6 +285,7 @@ re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d ...@@ -285,6 +285,7 @@ re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d
re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*") re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*")
re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*") re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*")
class Script(scripts.Script): class Script(scripts.Script):
def title(self): def title(self):
return "X/Y plot" return "X/Y plot"
...@@ -381,7 +382,7 @@ class Script(scripts.Script): ...@@ -381,7 +382,7 @@ class Script(scripts.Script):
ys = process_axis(y_opt, y_values) ys = process_axis(y_opt, y_values)
def fix_axis_seeds(axis_opt, axis_list): def fix_axis_seeds(axis_opt, axis_list):
if axis_opt.label in ['Seed','Var. seed']: if axis_opt.label in ['Seed', 'Var. seed']:
return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list] return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list]
else: else:
return axis_list return axis_list
...@@ -403,12 +404,33 @@ class Script(scripts.Script): ...@@ -403,12 +404,33 @@ class Script(scripts.Script):
print(f"X/Y plot will create {len(xs) * len(ys) * p.n_iter} images on a {len(xs)}x{len(ys)} grid. (Total steps to process: {total_steps * p.n_iter})") print(f"X/Y plot will create {len(xs) * len(ys) * p.n_iter} images on a {len(xs)}x{len(ys)} grid. (Total steps to process: {total_steps * p.n_iter})")
shared.total_tqdm.updateTotal(total_steps * p.n_iter) shared.total_tqdm.updateTotal(total_steps * p.n_iter)
grid_infotext = [None]
def cell(x, y): def cell(x, y):
pc = copy(p) pc = copy(p)
x_opt.apply(pc, x, xs) x_opt.apply(pc, x, xs)
y_opt.apply(pc, y, ys) y_opt.apply(pc, y, ys)
return process_images(pc) res = process_images(pc)
if grid_infotext[0] is None:
pc.extra_generation_params = copy(pc.extra_generation_params)
if x_opt.label != 'Nothing':
pc.extra_generation_params["X Type"] = x_opt.label
pc.extra_generation_params["X Values"] = x_values
if x_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
pc.extra_generation_params["Fixed X Values"] = ", ".join([str(x) for x in xs])
if y_opt.label != 'Nothing':
pc.extra_generation_params["Y Type"] = y_opt.label
pc.extra_generation_params["Y Values"] = y_values
if y_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
pc.extra_generation_params["Fixed Y Values"] = ", ".join([str(y) for y in ys])
grid_infotext[0] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds)
return res
with SharedSettingsStackHelper(): with SharedSettingsStackHelper():
processed = draw_xy_grid( processed = draw_xy_grid(
...@@ -423,6 +445,6 @@ class Script(scripts.Script): ...@@ -423,6 +445,6 @@ class Script(scripts.Script):
) )
if opts.grid_save: if opts.grid_save:
images.save_image(processed.images[0], p.outpath_grids, "xy_grid", extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p) images.save_image(processed.images[0], p.outpath_grids, "xy_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
return processed return processed
...@@ -611,7 +611,7 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h ...@@ -611,7 +611,7 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h
padding-top: 0.9em; padding-top: 0.9em;
} }
#img2img_settings div.gr-form .gr-form, #txt2img_settings div.gr-form .gr-form{ #img2img_settings div.gr-form .gr-form, #txt2img_settings div.gr-form .gr-form, #train_tabs div.gr-form .gr-form{
border: none; border: none;
padding-bottom: 0.5em; padding-bottom: 0.5em;
} }
......
...@@ -9,7 +9,7 @@ from fastapi import FastAPI ...@@ -9,7 +9,7 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware from fastapi.middleware.gzip import GZipMiddleware
from modules import import_hook from modules import import_hook, errors
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
from modules.paths import script_path from modules.paths import script_path
...@@ -61,7 +61,15 @@ def initialize(): ...@@ -61,7 +61,15 @@ def initialize():
modelloader.load_upscalers() modelloader.load_upscalers()
modules.sd_vae.refresh_vae_list() modules.sd_vae.refresh_vae_list()
modules.sd_models.load_model()
try:
modules.sd_models.load_model()
except Exception as e:
errors.display(e, "loading stable diffusion model")
print("", file=sys.stderr)
print("Stable diffusion model failed to load, exiting", file=sys.stderr)
exit(1)
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
...@@ -92,11 +100,11 @@ def initialize(): ...@@ -92,11 +100,11 @@ def initialize():
def setup_cors(app): def setup_cors(app):
if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex: if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex:
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*']) app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
elif cmd_opts.cors_allow_origins: elif cmd_opts.cors_allow_origins:
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*']) app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
elif cmd_opts.cors_allow_origins_regex: elif cmd_opts.cors_allow_origins_regex:
app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*']) app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
def create_api(app): def create_api(app):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment