Unverified Commit 243253ff authored by random-thoughtss's avatar random-thoughtss Committed by GitHub

Merge branch 'AUTOMATIC1111:master' into master

parents d9e4e4d7 20a860b5
function extensions_apply(_, _){
disable = []
update = []
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
if(x.name.startsWith("enable_") && ! x.checked)
disable.push(x.name.substr(7))
if(x.name.startsWith("update_") && x.checked)
update.push(x.name.substr(7))
})
restart_reload()
return [JSON.stringify(disable), JSON.stringify(update)]
}
function extensions_check(){
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
x.innerHTML = "Loading..."
})
return []
}
function install_extension_from_index(button, url){
button.disabled = "disabled"
button.value = "Installing..."
textarea = gradioApp().querySelector('#extension_to_install textarea')
textarea.value = url
textarea.dispatchEvent(new Event("input", { bubbles: true }))
gradioApp().querySelector('#install_extension_button').click()
}
...@@ -3,8 +3,21 @@ global_progressbars = {} ...@@ -3,8 +3,21 @@ global_progressbars = {}
galleries = {} galleries = {}
galleryObservers = {} galleryObservers = {}
// this tracks laumnches of window.setTimeout for progressbar to prevent starting a new timeout when the previous is still running
timeoutIds = {}
function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){
var progressbar = gradioApp().getElementById(id_progressbar) // gradio 3.8's enlightened approach allows them to create two nested div elements inside each other with same id
// every time you use gr.HTML(elem_id='xxx'), so we handle this here
var progressbar = gradioApp().querySelector("#"+id_progressbar+" #"+id_progressbar)
var progressbarParent
if(progressbar){
progressbarParent = gradioApp().querySelector("#"+id_progressbar)
} else{
progressbar = gradioApp().getElementById(id_progressbar)
progressbarParent = null
}
var skip = id_skip ? gradioApp().getElementById(id_skip) : null var skip = id_skip ? gradioApp().getElementById(id_skip) : null
var interrupt = gradioApp().getElementById(id_interrupt) var interrupt = gradioApp().getElementById(id_interrupt)
...@@ -26,18 +39,26 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip ...@@ -26,18 +39,26 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip
global_progressbars[id_progressbar] = progressbar global_progressbars[id_progressbar] = progressbar
var mutationObserver = new MutationObserver(function(m){ var mutationObserver = new MutationObserver(function(m){
if(timeoutIds[id_part]) return;
preview = gradioApp().getElementById(id_preview) preview = gradioApp().getElementById(id_preview)
gallery = gradioApp().getElementById(id_gallery) gallery = gradioApp().getElementById(id_gallery)
if(preview != null && gallery != null){ if(preview != null && gallery != null){
preview.style.width = gallery.clientWidth + "px" preview.style.width = gallery.clientWidth + "px"
preview.style.height = gallery.clientHeight + "px" preview.style.height = gallery.clientHeight + "px"
if(progressbarParent) progressbar.style.width = progressbarParent.clientWidth + "px"
//only watch gallery if there is a generation process going on //only watch gallery if there is a generation process going on
check_gallery(id_gallery); check_gallery(id_gallery);
var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0; var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
if(!progressDiv){ if(progressDiv){
timeoutIds[id_part] = window.setTimeout(function() {
timeoutIds[id_part] = null
requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt)
}, 500)
} else{
if (skip) { if (skip) {
skip.style.display = "none" skip.style.display = "none"
} }
...@@ -49,11 +70,8 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip ...@@ -49,11 +70,8 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip
galleries[id_gallery] = null; galleries[id_gallery] = null;
} }
} }
} }
window.setTimeout(function() { requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt) }, 500)
}); });
mutationObserver.observe( progressbar, { childList:true, subtree:true }) mutationObserver.observe( progressbar, { childList:true, subtree:true })
} }
......
...@@ -7,6 +7,7 @@ import shlex ...@@ -7,6 +7,7 @@ import shlex
import platform import platform
dir_repos = "repositories" dir_repos = "repositories"
dir_extensions = "extensions"
python = sys.executable python = sys.executable
git = os.environ.get('GIT', "git") git = os.environ.get('GIT', "git")
index_url = os.environ.get('INDEX_URL', "") index_url = os.environ.get('INDEX_URL', "")
...@@ -16,11 +17,11 @@ def extract_arg(args, name): ...@@ -16,11 +17,11 @@ def extract_arg(args, name):
return [x for x in args if x != name], name in args return [x for x in args if x != name], name in args
def run(command, desc=None, errdesc=None): def run(command, desc=None, errdesc=None, custom_env=None):
if desc is not None: if desc is not None:
print(desc) print(desc)
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
if result.returncode != 0: if result.returncode != 0:
...@@ -101,7 +102,25 @@ def version_check(commit): ...@@ -101,7 +102,25 @@ def version_check(commit):
else: else:
print("Not a git clone, can't perform version check.") print("Not a git clone, can't perform version check.")
except Exception as e: except Exception as e:
print("versipm check failed",e) print("version check failed", e)
def run_extensions_installers():
if not os.path.isdir(dir_extensions):
return
for dirname_extension in os.listdir(dir_extensions):
path_installer = os.path.join(dir_extensions, dirname_extension, "install.py")
if not os.path.isfile(path_installer):
continue
try:
env = os.environ.copy()
env['PYTHONPATH'] = os.path.abspath(".")
print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {dirname_extension}", custom_env=env))
except Exception as e:
print(e, file=sys.stderr)
def prepare_enviroment(): def prepare_enviroment():
...@@ -189,6 +208,8 @@ def prepare_enviroment(): ...@@ -189,6 +208,8 @@ def prepare_enviroment():
run_pip(f"install -r {requirements_file}", "requirements for Web UI") run_pip(f"install -r {requirements_file}", "requirements for Web UI")
run_extensions_installers()
if update_check: if update_check:
version_check(commit) version_check(commit)
......
...@@ -70,7 +70,7 @@ ...@@ -70,7 +70,7 @@
"None": "Nichts", "None": "Nichts",
"Prompt matrix": "Promptmatrix", "Prompt matrix": "Promptmatrix",
"Prompts from file or textbox": "Prompts aus Datei oder Textfeld", "Prompts from file or textbox": "Prompts aus Datei oder Textfeld",
"X/Y plot": "X/Y Graf", "X/Y plot": "X/Y Graph",
"Put variable parts at start of prompt": "Variable teile am start des Prompt setzen", "Put variable parts at start of prompt": "Variable teile am start des Prompt setzen",
"Iterate seed every line": "Iterate seed every line", "Iterate seed every line": "Iterate seed every line",
"List of prompt inputs": "List of prompt inputs", "List of prompt inputs": "List of prompt inputs",
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import base64
import io
import time import time
import uvicorn import uvicorn
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image from gradio.processing_utils import decode_base64_to_file, decode_base64_to_image
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
import modules.shared as shared import modules.shared as shared
from modules import devices
from modules.api.models import * from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers from modules.sd_samplers import all_samplers, sample_to_image, samples_to_image_grid
from modules.extras import run_extras, run_pnginfo from modules.extras import run_extras, run_pnginfo
...@@ -29,6 +30,12 @@ def setUpscalers(req: dict): ...@@ -29,6 +30,12 @@ def setUpscalers(req: dict):
return reqDict return reqDict
def encode_pil_to_base64(image):
buffer = io.BytesIO()
image.save(buffer, format="png")
return base64.b64encode(buffer.getvalue())
class Api: class Api:
def __init__(self, app, queue_lock): def __init__(self, app, queue_lock):
self.router = APIRouter() self.router = APIRouter()
...@@ -40,6 +47,7 @@ class Api: ...@@ -40,6 +47,7 @@ class Api:
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse) self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse) self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index) sampler_index = sampler_to_index(txt2imgreq.sampler_index)
...@@ -170,12 +178,19 @@ class Api: ...@@ -170,12 +178,19 @@ class Api:
progress = min(progress, 1) progress = min(progress, 1)
shared.state.set_current_image()
current_image = None current_image = None
if shared.state.current_image and not req.skip_current_image: if shared.state.current_image and not req.skip_current_image:
current_image = encode_pil_to_base64(shared.state.current_image) current_image = encode_pil_to_base64(shared.state.current_image)
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image) return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
def interruptapi(self):
shared.state.interrupt()
return {}
def launch(self, server_name, port): def launch(self, server_name, port):
self.app.include_router(self.router) self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port) uvicorn.run(self.app, host=server_name, port=port)
...@@ -50,6 +50,7 @@ def mod2normal(state_dict): ...@@ -50,6 +50,7 @@ def mod2normal(state_dict):
def resrgan2normal(state_dict, nb=23): def resrgan2normal(state_dict, nb=23):
# this code is copied from https://github.com/victorca25/iNNfer # this code is copied from https://github.com/victorca25/iNNfer
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict: if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
re8x = 0
crt_net = {} crt_net = {}
items = [] items = []
for k, v in state_dict.items(): for k, v in state_dict.items():
...@@ -75,10 +76,18 @@ def resrgan2normal(state_dict, nb=23): ...@@ -75,10 +76,18 @@ def resrgan2normal(state_dict, nb=23):
crt_net['model.3.bias'] = state_dict['conv_up1.bias'] crt_net['model.3.bias'] = state_dict['conv_up1.bias']
crt_net['model.6.weight'] = state_dict['conv_up2.weight'] crt_net['model.6.weight'] = state_dict['conv_up2.weight']
crt_net['model.6.bias'] = state_dict['conv_up2.bias'] crt_net['model.6.bias'] = state_dict['conv_up2.bias']
crt_net['model.8.weight'] = state_dict['conv_hr.weight']
crt_net['model.8.bias'] = state_dict['conv_hr.bias'] if 'conv_up3.weight' in state_dict:
crt_net['model.10.weight'] = state_dict['conv_last.weight'] # modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
crt_net['model.10.bias'] = state_dict['conv_last.bias'] re8x = 3
crt_net['model.9.weight'] = state_dict['conv_up3.weight']
crt_net['model.9.bias'] = state_dict['conv_up3.bias']
crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
state_dict = crt_net state_dict = crt_net
return state_dict return state_dict
......
import os
import sys
import traceback
import git
from modules import paths, shared
extensions = []
extensions_dir = os.path.join(paths.script_path, "extensions")
def active():
return [x for x in extensions if x.enabled]
class Extension:
def __init__(self, name, path, enabled=True):
self.name = name
self.path = path
self.enabled = enabled
self.status = ''
self.can_update = False
repo = None
try:
if os.path.exists(os.path.join(path, ".git")):
repo = git.Repo(path)
except Exception:
print(f"Error reading github repository info from {path}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
if repo is None or repo.bare:
self.remote = None
else:
self.remote = next(repo.remote().urls, None)
self.status = 'unknown'
def list_files(self, subdir, extension):
from modules import scripts
dirpath = os.path.join(self.path, subdir)
if not os.path.isdir(dirpath):
return []
res = []
for filename in sorted(os.listdir(dirpath)):
res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))
res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
return res
def check_updates(self):
repo = git.Repo(self.path)
for fetch in repo.remote().fetch("--dry-run"):
if fetch.flags != fetch.HEAD_UPTODATE:
self.can_update = True
self.status = "behind"
return
self.can_update = False
self.status = "latest"
def pull(self):
repo = git.Repo(self.path)
repo.remotes.origin.pull()
def list_extensions():
extensions.clear()
if not os.path.isdir(extensions_dir):
return
for dirname in sorted(os.listdir(extensions_dir)):
path = os.path.join(extensions_dir, dirname)
if not os.path.isdir(path):
continue
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions)
extensions.append(extension)
...@@ -141,7 +141,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -141,7 +141,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
upscaling_resize_w, upscaling_resize_h, upscaling_crop) upscaling_resize_w, upscaling_resize_h, upscaling_crop)
cache_key = LruCache.Key(image_hash=hash(np.array(image.getdata()).tobytes()), cache_key = LruCache.Key(image_hash=hash(np.array(image.getdata()).tobytes()),
info_hash=hash(info), info_hash=hash(info),
args_hash=hash(upscale_args)) args_hash=hash((upscale_args, upscale_first)))
cached_entry = cached_images.get(cache_key) cached_entry = cached_images.get(cache_key)
if cached_entry is None: if cached_entry is None:
res = upscale(image, *upscale_args) res = upscale(image, *upscale_args)
......
...@@ -17,6 +17,11 @@ paste_fields = {} ...@@ -17,6 +17,11 @@ paste_fields = {}
bind_list = [] bind_list = []
def reset():
paste_fields.clear()
bind_list.clear()
def quote(text): def quote(text):
if ',' not in str(text): if ',' not in str(text):
return text return text
......
...@@ -510,6 +510,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i ...@@ -510,6 +510,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
if extension.lower() == '.png': if extension.lower() == '.png':
pnginfo_data = PngImagePlugin.PngInfo() pnginfo_data = PngImagePlugin.PngInfo()
if opts.enable_pnginfo:
for k, v in params.pnginfo.items(): for k, v in params.pnginfo.items():
pnginfo_data.add_text(k, str(v)) pnginfo_data.add_text(k, str(v))
......
...@@ -55,6 +55,7 @@ def process_batch(p, input_dir, output_dir, args): ...@@ -55,6 +55,7 @@ def process_batch(p, input_dir, output_dir, args):
filename = f"{left}-{n}{right}" filename = f"{left}-{n}{right}"
if not save_normally: if not save_normally:
os.makedirs(output_dir, exist_ok=True)
processed_image.save(os.path.join(output_dir, filename)) processed_image.save(os.path.join(output_dir, filename))
...@@ -80,6 +81,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro ...@@ -80,6 +81,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
mask = None mask = None
# Use the EXIF orientation of photos taken by smartphones. # Use the EXIF orientation of photos taken by smartphones.
if image is not None:
image = ImageOps.exif_transpose(image) image = ImageOps.exif_transpose(image)
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
...@@ -136,6 +138,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro ...@@ -136,6 +138,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
if processed is None: if processed is None:
processed = process_images(p) processed = process_images(p)
p.close()
shared.total_tqdm.clear() shared.total_tqdm.clear()
generation_info_js = processed.js() generation_info_js = processed.js()
......
...@@ -56,9 +56,9 @@ class InterrogateModels: ...@@ -56,9 +56,9 @@ class InterrogateModels:
import clip import clip
if self.running_on_cpu: if self.running_on_cpu:
model, preprocess = clip.load(clip_model_name, device="cpu") model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path)
else: else:
model, preprocess = clip.load(clip_model_name) model, preprocess = clip.load(clip_model_name, download_root=shared.cmd_opts.clip_models_path)
model.eval() model.eval()
model = model.to(devices.device_interrogate) model = model.to(devices.device_interrogate)
......
...@@ -38,13 +38,18 @@ def setup_for_low_vram(sd_model, use_medvram): ...@@ -38,13 +38,18 @@ def setup_for_low_vram(sd_model, use_medvram):
# see below for register_forward_pre_hook; # see below for register_forward_pre_hook;
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is # first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
# useless here, and we just replace those methods # useless here, and we just replace those methods
def first_stage_model_encode_wrap(self, encoder, x):
send_me_to_gpu(self, None)
return encoder(x)
def first_stage_model_decode_wrap(self, decoder, z): first_stage_model = sd_model.first_stage_model
send_me_to_gpu(self, None) first_stage_model_encode = sd_model.first_stage_model.encode
return decoder(z) first_stage_model_decode = sd_model.first_stage_model.decode
def first_stage_model_encode_wrap(x):
send_me_to_gpu(first_stage_model, None)
return first_stage_model_encode(x)
def first_stage_model_decode_wrap(z):
send_me_to_gpu(first_stage_model, None)
return first_stage_model_decode(z)
# remove three big modules, cond, first_stage, and unet from the model and then # remove three big modules, cond, first_stage, and unet from the model and then
# send the model to GPU. Then put modules back. the modules will be in CPU. # send the model to GPU. Then put modules back. the modules will be in CPU.
...@@ -56,8 +61,8 @@ def setup_for_low_vram(sd_model, use_medvram): ...@@ -56,8 +61,8 @@ def setup_for_low_vram(sd_model, use_medvram):
# register hooks for those the first two models # register hooks for those the first two models
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu) sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu) sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x) sd_model.first_stage_model.encode = first_stage_model_encode_wrap
sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z) sd_model.first_stage_model.decode = first_stage_model_decode_wrap
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
if use_medvram: if use_medvram:
......
...@@ -85,6 +85,9 @@ def cleanup_models(): ...@@ -85,6 +85,9 @@ def cleanup_models():
src_path = os.path.join(root_path, "ESRGAN") src_path = os.path.join(root_path, "ESRGAN")
dest_path = os.path.join(models_path, "ESRGAN") dest_path = os.path.join(models_path, "ESRGAN")
move_files(src_path, dest_path) move_files(src_path, dest_path)
src_path = os.path.join(models_path, "BSRGAN")
dest_path = os.path.join(models_path, "ESRGAN")
move_files(src_path, dest_path, ".pth")
src_path = os.path.join(root_path, "gfpgan") src_path = os.path.join(root_path, "gfpgan")
dest_path = os.path.join(models_path, "GFPGAN") dest_path = os.path.join(models_path, "GFPGAN")
move_files(src_path, dest_path) move_files(src_path, dest_path)
......
...@@ -191,9 +191,13 @@ class StableDiffusionProcessing(): ...@@ -191,9 +191,13 @@ class StableDiffusionProcessing():
def init(self, all_prompts, all_seeds, all_subseeds): def init(self, all_prompts, all_seeds, all_subseeds):
pass pass
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
raise NotImplementedError() raise NotImplementedError()
def close(self):
self.sd_model = None
self.sampler = None
class Processed: class Processed:
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None): def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
...@@ -509,7 +513,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -509,7 +513,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
shared.state.job = f"Batch {n+1} out of {p.n_iter}" shared.state.job = f"Batch {n+1} out of {p.n_iter}"
with devices.autocast(): with devices.autocast():
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength) samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
samples_ddim = samples_ddim.to(devices.dtype_vae) samples_ddim = samples_ddim.to(devices.dtype_vae)
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim) x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
...@@ -637,7 +641,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -637,7 +641,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
if not self.enable_hr: if not self.enable_hr:
...@@ -650,6 +654,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -650,6 +654,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2] samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
"""saves image before applying hires fix, if enabled in options; takes as an arguyment either an image or batch with latent space images"""
def save_intermediate(image, index):
if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
return
if not isinstance(image, Image.Image):
image = sd_samplers.sample_to_image(image, index)
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")
if opts.use_scale_latent_for_hires_fix: if opts.use_scale_latent_for_hires_fix:
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
...@@ -660,6 +674,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -660,6 +674,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
else: else:
image_conditioning = self.txt2img_image_conditioning(samples) image_conditioning = self.txt2img_image_conditioning(samples)
for i in range(samples.shape[0]):
save_intermediate(samples, i)
else: else:
decoded_samples = decode_first_stage(self.sd_model, samples) decoded_samples = decode_first_stage(self.sd_model, samples)
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
...@@ -669,6 +685,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -669,6 +685,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8) x_sample = x_sample.astype(np.uint8)
image = Image.fromarray(x_sample) image = Image.fromarray(x_sample)
save_intermediate(image, i)
image = images.resize_image(0, image, self.width, self.height) image = images.resize_image(0, image, self.width, self.height)
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0) image = np.moveaxis(image, 2, 0)
...@@ -826,8 +845,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): ...@@ -826,8 +845,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask) self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
......
...@@ -32,7 +32,7 @@ class RestrictedUnpickler(pickle.Unpickler): ...@@ -32,7 +32,7 @@ class RestrictedUnpickler(pickle.Unpickler):
return getattr(collections, name) return getattr(collections, name)
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']: if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
return getattr(torch._utils, name) return getattr(torch._utils, name)
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage']: if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage']:
return getattr(torch, name) return getattr(torch, name)
if module == 'torch.nn.modules.container' and name in ['ParameterDict']: if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
return getattr(torch.nn.modules.container, name) return getattr(torch.nn.modules.container, name)
......
...@@ -2,7 +2,10 @@ import sys ...@@ -2,7 +2,10 @@ import sys
import traceback import traceback
from collections import namedtuple from collections import namedtuple
import inspect import inspect
from typing import Optional
from fastapi import FastAPI
from gradio import Blocks
def report_exception(c, job): def report_exception(c, job):
print(f"Error executing callback {job} for {c.script}", file=sys.stderr) print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
...@@ -24,12 +27,32 @@ class ImageSaveParams: ...@@ -24,12 +27,32 @@ class ImageSaveParams:
"""dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'""" """dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
class CFGDenoiserParams:
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps):
self.x = x
"""Latent image representation in the process of being denoised"""
self.image_cond = image_cond
"""Conditioning image"""
self.sigma = sigma
"""Current sigma noise step value"""
self.sampling_step = sampling_step
"""Current Sampling step number"""
self.total_sampling_steps = total_sampling_steps
"""Total number of sampling steps planned"""
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
callbacks_app_started = []
callbacks_model_loaded = [] callbacks_model_loaded = []
callbacks_ui_tabs = [] callbacks_ui_tabs = []
callbacks_ui_settings = [] callbacks_ui_settings = []
callbacks_before_image_saved = [] callbacks_before_image_saved = []
callbacks_image_saved = [] callbacks_image_saved = []
callbacks_cfg_denoiser = []
def clear_callbacks(): def clear_callbacks():
...@@ -38,6 +61,14 @@ def clear_callbacks(): ...@@ -38,6 +61,14 @@ def clear_callbacks():
callbacks_ui_settings.clear() callbacks_ui_settings.clear()
callbacks_before_image_saved.clear() callbacks_before_image_saved.clear()
callbacks_image_saved.clear() callbacks_image_saved.clear()
callbacks_cfg_denoiser.clear()
def app_started_callback(demo: Optional[Blocks], app: FastAPI):
for c in callbacks_app_started:
try:
c.callback(demo, app)
except Exception:
report_exception(c, 'app_started_callback')
def model_loaded_callback(sd_model): def model_loaded_callback(sd_model):
...@@ -69,7 +100,7 @@ def ui_settings_callback(): ...@@ -69,7 +100,7 @@ def ui_settings_callback():
def before_image_saved_callback(params: ImageSaveParams): def before_image_saved_callback(params: ImageSaveParams):
for c in callbacks_image_saved: for c in callbacks_before_image_saved:
try: try:
c.callback(params) c.callback(params)
except Exception: except Exception:
...@@ -84,6 +115,14 @@ def image_saved_callback(params: ImageSaveParams): ...@@ -84,6 +115,14 @@ def image_saved_callback(params: ImageSaveParams):
report_exception(c, 'image_saved_callback') report_exception(c, 'image_saved_callback')
def cfg_denoiser_callback(params: CFGDenoiserParams):
for c in callbacks_cfg_denoiser:
try:
c.callback(params)
except Exception:
report_exception(c, 'cfg_denoiser_callback')
def add_callback(callbacks, fun): def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__] stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file' filename = stack[0].filename if len(stack) > 0 else 'unknown file'
...@@ -91,6 +130,12 @@ def add_callback(callbacks, fun): ...@@ -91,6 +130,12 @@ def add_callback(callbacks, fun):
callbacks.append(ScriptCallback(filename, fun)) callbacks.append(ScriptCallback(filename, fun))
def on_app_started(callback):
"""register a function to be called when the webui started, the gradio `Block` component and
fastapi `FastAPI` object are passed as the arguments"""
add_callback(callbacks_app_started, callback)
def on_model_loaded(callback): def on_model_loaded(callback):
"""register a function to be called when the stable diffusion model is created; the model is """register a function to be called when the stable diffusion model is created; the model is
passed as an argument""" passed as an argument"""
...@@ -130,3 +175,12 @@ def on_image_saved(callback): ...@@ -130,3 +175,12 @@ def on_image_saved(callback):
- params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing. - params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
""" """
add_callback(callbacks_image_saved, callback) add_callback(callbacks_image_saved, callback)
def on_cfg_denoiser(callback):
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
The callback is called with one argument:
- params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
"""
add_callback(callbacks_cfg_denoiser, callback)
...@@ -7,7 +7,7 @@ import modules.ui as ui ...@@ -7,7 +7,7 @@ import modules.ui as ui
import gradio as gr import gradio as gr
from modules.processing import StableDiffusionProcessing from modules.processing import StableDiffusionProcessing
from modules import shared, paths, script_callbacks from modules import shared, paths, script_callbacks, extensions
AlwaysVisible = object() AlwaysVisible = object()
...@@ -107,17 +107,8 @@ def list_scripts(scriptdirname, extension): ...@@ -107,17 +107,8 @@ def list_scripts(scriptdirname, extension):
for filename in sorted(os.listdir(basedir)): for filename in sorted(os.listdir(basedir)):
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename))) scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
extdir = os.path.join(paths.script_path, "extensions") for ext in extensions.active():
if os.path.exists(extdir): scripts_list += ext.list_files(scriptdirname, extension)
for dirname in sorted(os.listdir(extdir)):
dirpath = os.path.join(extdir, dirname)
scriptdirpath = os.path.join(dirpath, scriptdirname)
if not os.path.isdir(scriptdirpath):
continue
for filename in sorted(os.listdir(scriptdirpath)):
scripts_list.append(ScriptFile(dirpath, filename, os.path.join(scriptdirpath, filename)))
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
...@@ -127,11 +118,7 @@ def list_scripts(scriptdirname, extension): ...@@ -127,11 +118,7 @@ def list_scripts(scriptdirname, extension):
def list_files_with_name(filename): def list_files_with_name(filename):
res = [] res = []
dirs = [paths.script_path] dirs = [paths.script_path] + [ext.path for ext in extensions.active()]
extdir = os.path.join(paths.script_path, "extensions")
if os.path.exists(extdir):
dirs += [os.path.join(extdir, d) for d in sorted(os.listdir(extdir))]
for dirpath in dirs: for dirpath in dirs:
if not os.path.isdir(dirpath): if not os.path.isdir(dirpath):
......
...@@ -94,6 +94,10 @@ class StableDiffusionModelHijack: ...@@ -94,6 +94,10 @@ class StableDiffusionModelHijack:
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes: if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
self.layers = None
self.circular_enabled = False
self.clip = None
def apply_circular(self, enable): def apply_circular(self, enable):
if self.circular_enabled == enable: if self.circular_enabled == enable:
return return
......
import collections import collections
import os.path import os.path
import sys import sys
import gc
from collections import namedtuple from collections import namedtuple
import torch import torch
import re import re
...@@ -8,7 +9,7 @@ from omegaconf import OmegaConf ...@@ -8,7 +9,7 @@ from omegaconf import OmegaConf
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from modules import shared, modelloader, devices, script_callbacks from modules import shared, modelloader, devices, script_callbacks, sd_vae
from modules.paths import models_path from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
...@@ -158,14 +159,15 @@ def get_state_dict_from_checkpoint(pl_sd): ...@@ -158,14 +159,15 @@ def get_state_dict_from_checkpoint(pl_sd):
return pl_sd return pl_sd
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} def load_model_weights(model, checkpoint_info, vae_file="auto"):
def load_model_weights(model, checkpoint_info):
checkpoint_file = checkpoint_info.filename checkpoint_file = checkpoint_info.filename
sd_model_hash = checkpoint_info.hash sd_model_hash = checkpoint_info.hash
if checkpoint_info not in checkpoints_loaded: vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
checkpoint_key = checkpoint_info
if checkpoint_key not in checkpoints_loaded:
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
...@@ -181,37 +183,38 @@ def load_model_weights(model, checkpoint_info): ...@@ -181,37 +183,38 @@ def load_model_weights(model, checkpoint_info):
model.to(memory_format=torch.channels_last) model.to(memory_format=torch.channels_last)
if not shared.cmd_opts.no_half: if not shared.cmd_opts.no_half:
vae = model.first_stage_model
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
if shared.cmd_opts.no_half_vae:
model.first_stage_model = None
model.half() model.half()
model.first_stage_model = vae
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt"
if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None:
vae_file = shared.cmd_opts.vae_path
if os.path.exists(vae_file):
print(f"Loading VAE weights from: {vae_file}")
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
model.first_stage_model.load_state_dict(vae_dict)
model.first_stage_model.to(devices.dtype_vae) model.first_stage_model.to(devices.dtype_vae)
if shared.opts.sd_checkpoint_cache > 0: if shared.opts.sd_checkpoint_cache > 0:
checkpoints_loaded[checkpoint_info] = model.state_dict().copy() # if PR #4035 were to get merged, restore base VAE first before caching
checkpoints_loaded[checkpoint_key] = model.state_dict().copy()
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
checkpoints_loaded.popitem(last=False) # LRU checkpoints_loaded.popitem(last=False) # LRU
else: else:
print(f"Loading weights [{sd_model_hash}] from cache") vae_name = sd_vae.get_filename(vae_file)
checkpoints_loaded.move_to_end(checkpoint_info) print(f"Loading weights [{sd_model_hash}] with {vae_name} VAE from cache")
model.load_state_dict(checkpoints_loaded[checkpoint_info]) checkpoints_loaded.move_to_end(checkpoint_key)
model.load_state_dict(checkpoints_loaded[checkpoint_key])
model.sd_model_hash = sd_model_hash model.sd_model_hash = sd_model_hash
model.sd_model_checkpoint = checkpoint_file model.sd_model_checkpoint = checkpoint_file
model.sd_checkpoint_info = checkpoint_info model.sd_checkpoint_info = checkpoint_info
sd_vae.load_vae(model, vae_file)
def load_model(checkpoint_info=None): def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack from modules import lowvram, sd_hijack
...@@ -220,6 +223,12 @@ def load_model(checkpoint_info=None): ...@@ -220,6 +223,12 @@ def load_model(checkpoint_info=None):
if checkpoint_info.config != shared.cmd_opts.config: if checkpoint_info.config != shared.cmd_opts.config:
print(f"Loading config from: {checkpoint_info.config}") print(f"Loading config from: {checkpoint_info.config}")
if shared.sd_model:
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
shared.sd_model = None
gc.collect()
devices.torch_gc()
sd_config = OmegaConf.load(checkpoint_info.config) sd_config = OmegaConf.load(checkpoint_info.config)
if should_hijack_inpainting(checkpoint_info): if should_hijack_inpainting(checkpoint_info):
...@@ -233,6 +242,7 @@ def load_model(checkpoint_info=None): ...@@ -233,6 +242,7 @@ def load_model(checkpoint_info=None):
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml")) checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
do_inpainting_hijack() do_inpainting_hijack()
sd_model = instantiate_from_config(sd_config.model) sd_model = instantiate_from_config(sd_config.model)
load_model_weights(sd_model, checkpoint_info) load_model_weights(sd_model, checkpoint_info)
...@@ -252,14 +262,18 @@ def load_model(checkpoint_info=None): ...@@ -252,14 +262,18 @@ def load_model(checkpoint_info=None):
return sd_model return sd_model
def reload_model_weights(sd_model, info=None): def reload_model_weights(sd_model=None, info=None):
from modules import lowvram, devices, sd_hijack from modules import lowvram, devices, sd_hijack
checkpoint_info = info or select_checkpoint() checkpoint_info = info or select_checkpoint()
if not sd_model:
sd_model = shared.sd_model
if sd_model.sd_model_checkpoint == checkpoint_info.filename: if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return return
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
del sd_model
checkpoints_loaded.clear() checkpoints_loaded.clear()
load_model(checkpoint_info) load_model(checkpoint_info)
return shared.sd_model return shared.sd_model
......
from collections import namedtuple from collections import namedtuple
import numpy as np import numpy as np
from math import floor
import torch import torch
import tqdm import tqdm
from PIL import Image from PIL import Image
...@@ -11,6 +12,7 @@ from modules import prompt_parser, devices, processing, images ...@@ -11,6 +12,7 @@ from modules import prompt_parser, devices, processing, images
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
...@@ -91,8 +93,8 @@ def single_sample_to_image(sample): ...@@ -91,8 +93,8 @@ def single_sample_to_image(sample):
return Image.fromarray(x_sample) return Image.fromarray(x_sample)
def sample_to_image(samples): def sample_to_image(samples, index=0):
return single_sample_to_image(samples[0]) return single_sample_to_image(samples[index])
def samples_to_image_grid(samples): def samples_to_image_grid(samples):
...@@ -205,17 +207,22 @@ class VanillaStableDiffusionSampler: ...@@ -205,17 +207,22 @@ class VanillaStableDiffusionSampler:
self.mask = p.mask if hasattr(p, 'mask') else None self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None self.nmask = p.nmask if hasattr(p, 'nmask') else None
def adjust_steps_if_invalid(self, p, num_steps):
if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
valid_step = 999 / (1000 // num_steps)
if valid_step == floor(valid_step):
return int(valid_step) + 1
return num_steps
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = setup_img2img_steps(p, steps) steps, t_enc = setup_img2img_steps(p, steps)
steps = self.adjust_steps_if_invalid(p, steps)
self.initialize(p) self.initialize(p)
# existing code fails with certain step counts, like 9
try:
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False) self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
except Exception:
self.sampler.make_schedule(ddim_num_steps=steps+1, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise) x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
self.init_latent = x self.init_latent = x
...@@ -239,18 +246,14 @@ class VanillaStableDiffusionSampler: ...@@ -239,18 +246,14 @@ class VanillaStableDiffusionSampler:
self.last_latent = x self.last_latent = x
self.step = 0 self.step = 0
steps = steps or p.steps steps = self.adjust_steps_if_invalid(p, steps or p.steps)
# Wrap the conditioning models with additional image conditioning for inpainting model # Wrap the conditioning models with additional image conditioning for inpainting model
if image_conditioning is not None: if image_conditioning is not None:
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]} conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
# existing code fails with certain step counts, like 9
try:
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0]) samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
except Exception:
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
return samples_ddim return samples_ddim
...@@ -278,6 +281,12 @@ class CFGDenoiser(torch.nn.Module): ...@@ -278,6 +281,12 @@ class CFGDenoiser(torch.nn.Module):
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond]) image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
cfg_denoiser_callback(denoiser_params)
x_in = denoiser_params.x
image_cond_in = denoiser_params.image_cond
sigma_in = denoiser_params.sigma
if tensor.shape[1] == uncond.shape[1]: if tensor.shape[1] == uncond.shape[1]:
cond_in = torch.cat([tensor, uncond]) cond_in = torch.cat([tensor, uncond])
......
import torch
import os
from collections import namedtuple
from modules import shared, devices, script_callbacks
from modules.paths import models_path
import glob
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
vae_dir = "VAE"
vae_path = os.path.abspath(os.path.join(models_path, vae_dir))
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
default_vae_dict = {"auto": "auto", "None": "None"}
default_vae_list = ["auto", "None"]
default_vae_values = [default_vae_dict[x] for x in default_vae_list]
vae_dict = dict(default_vae_dict)
vae_list = list(default_vae_list)
first_load = True
base_vae = None
loaded_vae_file = None
checkpoint_info = None
def get_base_vae(model):
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
return base_vae
return None
def store_base_vae(model):
global base_vae, checkpoint_info
if checkpoint_info != model.sd_checkpoint_info:
base_vae = model.first_stage_model.state_dict().copy()
checkpoint_info = model.sd_checkpoint_info
def delete_base_vae():
global base_vae, checkpoint_info
base_vae = None
checkpoint_info = None
def restore_base_vae(model):
global base_vae, checkpoint_info
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
load_vae_dict(model, base_vae)
delete_base_vae()
def get_filename(filepath):
return os.path.splitext(os.path.basename(filepath))[0]
def refresh_vae_list(vae_path=vae_path, model_path=model_path):
global vae_dict, vae_list
res = {}
candidates = [
*glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True),
*glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True),
*glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True),
*glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True)
]
if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path):
candidates.append(shared.cmd_opts.vae_path)
for filepath in candidates:
name = get_filename(filepath)
res[name] = filepath
vae_list.clear()
vae_list.extend(default_vae_list)
vae_list.extend(list(res.keys()))
vae_dict.clear()
vae_dict.update(res)
vae_dict.update(default_vae_dict)
return vae_list
def resolve_vae(checkpoint_file, vae_file="auto"):
global first_load, vae_dict, vae_list
# if vae_file argument is provided, it takes priority, but not saved
if vae_file and vae_file not in default_vae_list:
if not os.path.isfile(vae_file):
vae_file = "auto"
print("VAE provided as function argument doesn't exist")
# for the first load, if vae-path is provided, it takes priority, saved, and failure is reported
if first_load and shared.cmd_opts.vae_path is not None:
if os.path.isfile(shared.cmd_opts.vae_path):
vae_file = shared.cmd_opts.vae_path
shared.opts.data['sd_vae'] = get_filename(vae_file)
else:
print("VAE provided as command line argument doesn't exist")
# else, we load from settings
if vae_file == "auto" and shared.opts.sd_vae is not None:
# if saved VAE settings isn't recognized, fallback to auto
vae_file = vae_dict.get(shared.opts.sd_vae, "auto")
# if VAE selected but not found, fallback to auto
if vae_file not in default_vae_values and not os.path.isfile(vae_file):
vae_file = "auto"
print("Selected VAE doesn't exist")
# vae-path cmd arg takes priority for auto
if vae_file == "auto" and shared.cmd_opts.vae_path is not None:
if os.path.isfile(shared.cmd_opts.vae_path):
vae_file = shared.cmd_opts.vae_path
print("Using VAE provided as command line argument")
# if still not found, try look for ".vae.pt" beside model
model_path = os.path.splitext(checkpoint_file)[0]
if vae_file == "auto":
vae_file_try = model_path + ".vae.pt"
if os.path.isfile(vae_file_try):
vae_file = vae_file_try
print("Using VAE found beside selected model")
# if still not found, try look for ".vae.ckpt" beside model
if vae_file == "auto":
vae_file_try = model_path + ".vae.ckpt"
if os.path.isfile(vae_file_try):
vae_file = vae_file_try
print("Using VAE found beside selected model")
# No more fallbacks for auto
if vae_file == "auto":
vae_file = None
# Last check, just because
if vae_file and not os.path.exists(vae_file):
vae_file = None
return vae_file
def load_vae(model, vae_file=None):
global first_load, vae_dict, vae_list, loaded_vae_file
# save_settings = False
if vae_file:
print(f"Loading VAE weights from: {vae_file}")
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
load_vae_dict(model, vae_dict_1)
# If vae used is not in dict, update it
# It will be removed on refresh though
vae_opt = get_filename(vae_file)
if vae_opt not in vae_dict:
vae_dict[vae_opt] = vae_file
vae_list.append(vae_opt)
loaded_vae_file = vae_file
"""
# Save current VAE to VAE settings, maybe? will it work?
if save_settings:
if vae_file is None:
vae_opt = "None"
# shared.opts.sd_vae = vae_opt
"""
first_load = False
# don't call this from outside
def load_vae_dict(model, vae_dict_1=None):
if vae_dict_1:
store_base_vae(model)
model.first_stage_model.load_state_dict(vae_dict_1)
else:
restore_base_vae()
model.first_stage_model.to(devices.dtype_vae)
def reload_vae_weights(sd_model=None, vae_file="auto"):
from modules import lowvram, devices, sd_hijack
if not sd_model:
sd_model = shared.sd_model
checkpoint_info = sd_model.sd_checkpoint_info
checkpoint_file = checkpoint_info.filename
vae_file = resolve_vae(checkpoint_file, vae_file=vae_file)
if loaded_vae_file == vae_file:
return
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
else:
sd_model.to(devices.cpu)
sd_hijack.model_hijack.undo_hijack(sd_model)
load_vae(sd_model, vae_file)
sd_hijack.model_hijack.hijack(sd_model)
script_callbacks.model_loaded_callback(sd_model)
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device)
print(f"VAE Weights loaded.")
return sd_model
...@@ -4,6 +4,7 @@ import json ...@@ -4,6 +4,7 @@ import json
import os import os
import sys import sys
from collections import OrderedDict from collections import OrderedDict
import time
import gradio as gr import gradio as gr
import tqdm import tqdm
...@@ -14,7 +15,7 @@ import modules.memmon ...@@ -14,7 +15,7 @@ import modules.memmon
import modules.sd_models import modules.sd_models
import modules.styles import modules.styles
import modules.devices as devices import modules.devices as devices
from modules import sd_samplers, sd_models, localization from modules import sd_samplers, sd_models, localization, sd_vae
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
from modules.paths import models_path, script_path, sd_path from modules.paths import models_path, script_path, sd_path
...@@ -40,7 +41,7 @@ parser.add_argument("--lowram", action='store_true', help="load stable diffusion ...@@ -40,7 +41,7 @@ parser.add_argument("--lowram", action='store_true', help="load stable diffusion
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram") parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)") parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None) parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us") parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer')) parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
...@@ -51,6 +52,7 @@ parser.add_argument("--realesrgan-models-path", type=str, help="Path to director ...@@ -51,6 +52,7 @@ parser.add_argument("--realesrgan-models-path", type=str, help="Path to director
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET')) parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET'))
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR')) parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR'))
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR')) parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR'))
parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers") parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work") parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator") parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator")
...@@ -97,6 +99,8 @@ restricted_opts = { ...@@ -97,6 +99,8 @@ restricted_opts = {
"outdir_save", "outdir_save",
} }
cmd_opts.disable_extension_access = cmd_opts.share or cmd_opts.listen
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \ devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer']) (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'])
...@@ -132,6 +136,8 @@ class State: ...@@ -132,6 +136,8 @@ class State:
current_image = None current_image = None
current_image_sampling_step = 0 current_image_sampling_step = 0
textinfo = None textinfo = None
time_start = None
need_restart = False
def skip(self): def skip(self):
self.skipped = True self.skipped = True
...@@ -168,6 +174,7 @@ class State: ...@@ -168,6 +174,7 @@ class State:
self.skipped = False self.skipped = False
self.interrupted = False self.interrupted = False
self.textinfo = None self.textinfo = None
self.time_start = time.time()
devices.torch_gc() devices.torch_gc()
...@@ -177,6 +184,20 @@ class State: ...@@ -177,6 +184,20 @@ class State:
devices.torch_gc() devices.torch_gc()
"""sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
def set_current_image(self):
if not parallel_processing_allowed:
return
if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and self.current_latent is not None:
if opts.show_progress_grid:
self.current_image = sd_samplers.samples_to_image_grid(self.current_latent)
else:
self.current_image = sd_samplers.sample_to_image(self.current_latent)
self.current_image_sampling_step = self.sampling_step
state = State() state = State()
artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv')) artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
...@@ -234,6 +255,8 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" ...@@ -234,6 +255,8 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"), "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."), "save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."), "save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
"save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}), "jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
"export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"), "export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
...@@ -285,21 +308,22 @@ options_templates.update(options_section(('system', "System"), { ...@@ -285,21 +308,22 @@ options_templates.update(options_section(('system', "System"), {
})) }))
options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training hypernetwork. Saves VRAM."), "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"), "training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
"training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
})) }))
options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": list(sd_vae.vae_list)}, refresh=sd_vae.refresh_vae_list),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."), "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
...@@ -354,6 +378,12 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" ...@@ -354,6 +378,12 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
})) }))
options_templates.update(options_section((None, "Hidden options"), {
"disabled_extensions": OptionInfo([], "Disable those extensions"),
}))
options_templates.update()
class Options: class Options:
data = None data = None
...@@ -365,8 +395,9 @@ class Options: ...@@ -365,8 +395,9 @@ class Options:
def __setattr__(self, key, value): def __setattr__(self, key, value):
if self.data is not None: if self.data is not None:
if key in self.data: if key in self.data or key in self.data_labels:
self.data[key] = value self.data[key] = value
return
return super(Options, self).__setattr__(key, value) return super(Options, self).__setattr__(key, value)
...@@ -407,10 +438,11 @@ class Options: ...@@ -407,10 +438,11 @@ class Options:
if bad_settings > 0: if bad_settings > 0:
print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr) print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr)
def onchange(self, key, func): def onchange(self, key, func, call=True):
item = self.data_labels.get(key) item = self.data_labels.get(key)
item.onchange = func item.onchange = func
if call:
func() func()
def dumpjson(self): def dumpjson(self):
......
...@@ -235,6 +235,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -235,6 +235,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name) log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
unload = shared.opts.unload_models_when_training
if save_embedding_every > 0: if save_embedding_every > 0:
embedding_dir = os.path.join(log_directory, "embeddings") embedding_dir = os.path.join(log_directory, "embeddings")
...@@ -272,6 +273,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -272,6 +273,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"): with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size) ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
if unload:
shared.sd_model.first_stage_model.to(devices.cpu)
embedding.vec.requires_grad = True embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
...@@ -328,6 +331,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -328,6 +331,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
if images_dir is not None and steps_done % create_image_every == 0: if images_dir is not None and steps_done % create_image_every == 0:
forced_filename = f'{embedding_name}-{steps_done}' forced_filename = f'{embedding_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename) last_saved_image = os.path.join(images_dir, forced_filename)
shared.sd_model.first_stage_model.to(devices.device)
p = processing.StableDiffusionProcessingTxt2Img( p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model, sd_model=shared.sd_model,
do_not_save_grid=True, do_not_save_grid=True,
...@@ -355,6 +361,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -355,6 +361,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
processed = processing.process_images(p) processed = processing.process_images(p)
image = processed.images[0] image = processed.images[0]
if unload:
shared.sd_model.first_stage_model.to(devices.cpu)
shared.state.current_image = image shared.state.current_image = image
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
...@@ -400,6 +409,7 @@ Last saved image: {html.escape(last_saved_image)}<br/> ...@@ -400,6 +409,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True) save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
shared.sd_model.first_stage_model.to(devices.device)
return embedding, filename return embedding, filename
......
...@@ -25,7 +25,9 @@ def train_embedding(*args): ...@@ -25,7 +25,9 @@ def train_embedding(*args):
assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible' assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
apply_optimizations = shared.opts.training_xattention_optimizations
try: try:
if not apply_optimizations:
sd_hijack.undo_optimizations() sd_hijack.undo_optimizations()
embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args) embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
...@@ -38,5 +40,6 @@ Embedding saved to {html.escape(filename)} ...@@ -38,5 +40,6 @@ Embedding saved to {html.escape(filename)}
except Exception: except Exception:
raise raise
finally: finally:
if not apply_optimizations:
sd_hijack.apply_optimizations() sd_hijack.apply_optimizations()
...@@ -47,6 +47,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: ...@@ -47,6 +47,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
if processed is None: if processed is None:
processed = process_images(p) processed = process_images(p)
p.close()
shared.total_tqdm.clear() shared.total_tqdm.clear()
generation_info_js = processed.js() generation_info_js = processed.js()
......
...@@ -19,7 +19,7 @@ import numpy as np ...@@ -19,7 +19,7 @@ import numpy as np
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
from modules import sd_hijack, sd_models, localization, script_callbacks from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions
from modules.paths import script_path from modules.paths import script_path
from modules.shared import opts, cmd_opts, restricted_opts from modules.shared import opts, cmd_opts, restricted_opts
...@@ -277,15 +277,7 @@ def check_progress_call(id_part): ...@@ -277,15 +277,7 @@ def check_progress_call(id_part):
preview_visibility = gr_show(False) preview_visibility = gr_show(False)
if opts.show_progress_every_n_steps > 0: if opts.show_progress_every_n_steps > 0:
if shared.parallel_processing_allowed: shared.state.set_current_image()
if shared.state.sampling_step - shared.state.current_image_sampling_step >= opts.show_progress_every_n_steps and shared.state.current_latent is not None:
if opts.show_progress_grid:
shared.state.current_image = modules.sd_samplers.samples_to_image_grid(shared.state.current_latent)
else:
shared.state.current_image = modules.sd_samplers.sample_to_image(shared.state.current_latent)
shared.state.current_image_sampling_step = shared.state.sampling_step
image = shared.state.current_image image = shared.state.current_image
if image is None: if image is None:
...@@ -671,6 +663,9 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -671,6 +663,9 @@ def create_ui(wrap_gradio_gpu_call):
import modules.img2img import modules.img2img
import modules.txt2img import modules.txt2img
reload_javascript()
parameters_copypaste.reset()
with gr.Blocks(analytics_enabled=False) as txt2img_interface: with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
...@@ -1059,7 +1054,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1059,7 +1054,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Tabs(elem_id="extras_resize_mode"): with gr.Tabs(elem_id="extras_resize_mode"):
with gr.TabItem('Scale by'): with gr.TabItem('Scale by'):
upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2) upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4)
with gr.TabItem('Scale to'): with gr.TabItem('Scale to'):
with gr.Group(): with gr.Group():
with gr.Row(): with gr.Row():
...@@ -1511,8 +1506,9 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1511,8 +1506,9 @@ def create_ui(wrap_gradio_gpu_call):
column = None column = None
with gr.Row(elem_id="settings").style(equal_height=False): with gr.Row(elem_id="settings").style(equal_height=False):
for i, (k, item) in enumerate(opts.data_labels.items()): for i, (k, item) in enumerate(opts.data_labels.items()):
section_must_be_skipped = item.section[0] is None
if previous_section != item.section: if previous_section != item.section and not section_must_be_skipped:
if cols_displayed < settings_cols and (items_displayed >= items_per_col or previous_section is None): if cols_displayed < settings_cols and (items_displayed >= items_per_col or previous_section is None):
if column is not None: if column is not None:
column.__exit__() column.__exit__()
...@@ -1531,6 +1527,8 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1531,6 +1527,8 @@ def create_ui(wrap_gradio_gpu_call):
if k in quicksettings_names and not shared.cmd_opts.freeze_settings: if k in quicksettings_names and not shared.cmd_opts.freeze_settings:
quicksettings_list.append((i, k, item)) quicksettings_list.append((i, k, item))
components.append(dummy_component) components.append(dummy_component)
elif section_must_be_skipped:
components.append(dummy_component)
else: else:
component = create_setting_component(k) component = create_setting_component(k)
component_dict[k] = component component_dict[k] = component
...@@ -1566,19 +1564,19 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1566,19 +1564,19 @@ def create_ui(wrap_gradio_gpu_call):
reload_script_bodies.click( reload_script_bodies.click(
fn=reload_scripts, fn=reload_scripts,
inputs=[], inputs=[],
outputs=[], outputs=[]
_js='function(){}'
) )
def request_restart(): def request_restart():
shared.state.interrupt() shared.state.interrupt()
settings_interface.gradio_ref.do_restart = True shared.state.need_restart = True
restart_gradio.click( restart_gradio.click(
fn=request_restart, fn=request_restart,
inputs=[], inputs=[],
outputs=[], outputs=[],
_js='function(){restart_reload()}' _js='restart_reload'
) )
if column is not None: if column is not None:
...@@ -1612,14 +1610,15 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1612,14 +1610,15 @@ def create_ui(wrap_gradio_gpu_call):
interfaces += script_callbacks.ui_tabs_callback() interfaces += script_callbacks.ui_tabs_callback()
interfaces += [(settings_interface, "Settings", "settings")] interfaces += [(settings_interface, "Settings", "settings")]
extensions_interface = ui_extensions.create_ui()
interfaces += [(extensions_interface, "Extensions", "extensions")]
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
with gr.Row(elem_id="quicksettings"): with gr.Row(elem_id="quicksettings"):
for i, k, item in quicksettings_list: for i, k, item in quicksettings_list:
component = create_setting_component(k, is_quicksettings=True) component = create_setting_component(k, is_quicksettings=True)
component_dict[k] = component component_dict[k] = component
settings_interface.gradio_ref = demo
parameters_copypaste.integrate_settings_paste_fields(component_dict) parameters_copypaste.integrate_settings_paste_fields(component_dict)
parameters_copypaste.run_bind() parameters_copypaste.run_bind()
...@@ -1776,4 +1775,3 @@ def load_javascript(raw_response): ...@@ -1776,4 +1775,3 @@ def load_javascript(raw_response):
reload_javascript = partial(load_javascript, gradio.routes.templates.TemplateResponse) reload_javascript = partial(load_javascript, gradio.routes.templates.TemplateResponse)
reload_javascript()
import json
import os.path
import shutil
import sys
import time
import traceback
import git
import gradio as gr
import html
from modules import extensions, shared, paths
available_extensions = {"extensions": []}
def check_access():
assert not shared.cmd_opts.disable_extension_access, "extension access disabed because of commandline flags"
def apply_and_restart(disable_list, update_list):
check_access()
disabled = json.loads(disable_list)
assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}"
update = json.loads(update_list)
assert type(update) == list, f"wrong update_list data for apply_and_restart: {update_list}"
update = set(update)
for ext in extensions.extensions:
if ext.name not in update:
continue
try:
ext.pull()
except Exception:
print(f"Error pulling updates for {ext.name}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
shared.opts.disabled_extensions = disabled
shared.opts.save(shared.config_filename)
shared.state.interrupt()
shared.state.need_restart = True
def check_updates():
check_access()
for ext in extensions.extensions:
if ext.remote is None:
continue
try:
ext.check_updates()
except Exception:
print(f"Error checking updates for {ext.name}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return extension_table()
def extension_table():
code = f"""<!-- {time.time()} -->
<table id="extensions">
<thead>
<tr>
<th><abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr></th>
<th>URL</th>
<th><abbr title="Use checkbox to mark the extension for update; it will be updated when you click apply button">Update</abbr></th>
</tr>
</thead>
<tbody>
"""
for ext in extensions.extensions:
if ext.can_update:
ext_status = f"""<label><input class="gr-check-radio gr-checkbox" name="update_{html.escape(ext.name)}" checked="checked" type="checkbox">{html.escape(ext.status)}</label>"""
else:
ext_status = ext.status
code += f"""
<tr>
<td><label><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
<td><a href="{html.escape(ext.remote or '')}">{html.escape(ext.remote or '')}</a></td>
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
</tr>
"""
code += """
</tbody>
</table>
"""
return code
def normalize_git_url(url):
if url is None:
return ""
url = url.replace(".git", "")
return url
def install_extension_from_url(dirname, url):
check_access()
assert url, 'No URL specified'
if dirname is None or dirname == "":
*parts, last_part = url.split('/')
last_part = normalize_git_url(last_part)
dirname = last_part
target_dir = os.path.join(extensions.extensions_dir, dirname)
assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}'
normalized_url = normalize_git_url(url)
assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed'
tmpdir = os.path.join(paths.script_path, "tmp", dirname)
try:
shutil.rmtree(tmpdir, True)
repo = git.Repo.clone_from(url, tmpdir)
repo.remote().fetch()
os.rename(tmpdir, target_dir)
extensions.list_extensions()
return [extension_table(), html.escape(f"Installed into {target_dir}. Use Installed tab to restart.")]
finally:
shutil.rmtree(tmpdir, True)
def install_extension_from_index(url):
ext_table, message = install_extension_from_url(None, url)
return refresh_available_extensions_from_data(), ext_table, message
def refresh_available_extensions(url):
global available_extensions
import urllib.request
with urllib.request.urlopen(url) as response:
text = response.read()
available_extensions = json.loads(text)
return url, refresh_available_extensions_from_data(), ''
def refresh_available_extensions_from_data():
extlist = available_extensions["extensions"]
installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
code = f"""<!-- {time.time()} -->
<table id="available_extensions">
<thead>
<tr>
<th>Extension</th>
<th>Description</th>
<th>Action</th>
</tr>
</thead>
<tbody>
"""
for ext in extlist:
name = ext.get("name", "noname")
url = ext.get("url", None)
description = ext.get("description", "")
if url is None:
continue
existing = installed_extension_urls.get(normalize_git_url(url), None)
install_code = f"""<input onclick="install_extension_from_index(this, '{html.escape(url)}')" type="button" value="{"Install" if not existing else "Installed"}" {"disabled=disabled" if existing else ""} class="gr-button gr-button-lg gr-button-secondary">"""
code += f"""
<tr>
<td><a href="{html.escape(url)}">{html.escape(name)}</a></td>
<td>{html.escape(description)}</td>
<td>{install_code}</td>
</tr>
"""
code += """
</tbody>
</table>
"""
return code
def create_ui():
import modules.ui
with gr.Blocks(analytics_enabled=False) as ui:
with gr.Tabs(elem_id="tabs_extensions") as tabs:
with gr.TabItem("Installed"):
with gr.Row():
apply = gr.Button(value="Apply and restart UI", variant="primary")
check = gr.Button(value="Check for updates")
extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False)
extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False)
extensions_table = gr.HTML(lambda: extension_table())
apply.click(
fn=apply_and_restart,
_js="extensions_apply",
inputs=[extensions_disabled_list, extensions_update_list],
outputs=[],
)
check.click(
fn=check_updates,
_js="extensions_check",
inputs=[],
outputs=[extensions_table],
)
with gr.TabItem("Available"):
with gr.Row():
refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/wiki/AUTOMATIC1111/stable-diffusion-webui/Extensions-index.md", label="Extension index URL").style(container=False)
extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
install_result = gr.HTML()
available_extensions_table = gr.HTML()
refresh_available_extensions_button.click(
fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update()]),
inputs=[available_extensions_index],
outputs=[available_extensions_index, available_extensions_table, install_result],
)
install_extension_button.click(
fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]),
inputs=[extension_to_install],
outputs=[available_extensions_table, extensions_table, install_result],
)
with gr.TabItem("Install from URL"):
install_url = gr.Text(label="URL for extension's git repository")
install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
install_button = gr.Button(value="Install", variant="primary")
install_result = gr.HTML(elem_id="extension_install_result")
install_button.click(
fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]),
inputs=[install_dirname, install_url],
outputs=[extensions_table, install_result],
)
return ui
...@@ -10,6 +10,7 @@ import modules.shared ...@@ -10,6 +10,7 @@ import modules.shared
from modules import modelloader, shared from modules import modelloader, shared
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST)
from modules.paths import models_path from modules.paths import models_path
...@@ -57,7 +58,7 @@ class Upscaler: ...@@ -57,7 +58,7 @@ class Upscaler:
dest_w = img.width * scale dest_w = img.width * scale
dest_h = img.height * scale dest_h = img.height * scale
for i in range(3): for i in range(3):
if img.width >= dest_w and img.height >= dest_h: if img.width > dest_w and img.height > dest_h:
break break
img = self.do_upscale(img, selected_model) img = self.do_upscale(img, selected_model)
if img.width != dest_w or img.height != dest_h: if img.width != dest_w or img.height != dest_h:
...@@ -120,3 +121,17 @@ class UpscalerLanczos(Upscaler): ...@@ -120,3 +121,17 @@ class UpscalerLanczos(Upscaler):
self.name = "Lanczos" self.name = "Lanczos"
self.scalers = [UpscalerData("Lanczos", None, self)] self.scalers = [UpscalerData("Lanczos", None, self)]
class UpscalerNearest(Upscaler):
scalers = []
def do_upscale(self, img, selected_model=None):
return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=NEAREST)
def load_model(self, _):
pass
def __init__(self, dirname=None):
super().__init__(False)
self.name = "Nearest"
self.scalers = [UpscalerData("Nearest", None, self)]
\ No newline at end of file
...@@ -4,7 +4,7 @@ fairscale==0.4.4 ...@@ -4,7 +4,7 @@ fairscale==0.4.4
fonts fonts
font-roboto font-roboto
gfpgan gfpgan
gradio==3.5 gradio==3.8
invisible-watermark invisible-watermark
numpy numpy
omegaconf omegaconf
...@@ -12,7 +12,7 @@ opencv-python ...@@ -12,7 +12,7 @@ opencv-python
requests requests
piexif piexif
Pillow Pillow
pytorch_lightning pytorch_lightning==1.7.7
realesrgan realesrgan
scikit-image>=0.19 scikit-image>=0.19
timm==0.4.12 timm==0.4.12
...@@ -26,3 +26,4 @@ torchdiffeq ...@@ -26,3 +26,4 @@ torchdiffeq
kornia kornia
lark lark
inflection inflection
GitPython
...@@ -2,7 +2,7 @@ transformers==4.19.2 ...@@ -2,7 +2,7 @@ transformers==4.19.2
diffusers==0.3.0 diffusers==0.3.0
basicsr==1.4.2 basicsr==1.4.2
gfpgan==1.3.8 gfpgan==1.3.8
gradio==3.5 gradio==3.8
numpy==1.23.3 numpy==1.23.3
Pillow==9.2.0 Pillow==9.2.0
realesrgan==0.3.0 realesrgan==0.3.0
...@@ -23,3 +23,4 @@ torchdiffeq==0.2.3 ...@@ -23,3 +23,4 @@ torchdiffeq==0.2.3
kornia==0.6.7 kornia==0.6.7
lark==1.1.2 lark==1.1.2
inflection==0.5.1 inflection==0.5.1
GitPython==3.1.27
...@@ -166,8 +166,7 @@ class Script(scripts.Script): ...@@ -166,8 +166,7 @@ class Script(scripts.Script):
if override_strength: if override_strength:
p.denoising_strength = 1.0 p.denoising_strength = 1.0
def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
lat = (p.init_latent.cpu().numpy() * 10).astype(int) lat = (p.init_latent.cpu().numpy() * 10).astype(int)
same_params = self.cache is not None and self.cache.cfg_scale == cfg and self.cache.steps == st \ same_params = self.cache is not None and self.cache.cfg_scale == cfg and self.cache.steps == st \
......
...@@ -96,6 +96,7 @@ class Script(scripts.Script): ...@@ -96,6 +96,7 @@ class Script(scripts.Script):
def ui(self, is_img2img): def ui(self, is_img2img):
checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False) checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False)
checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False)
prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1) prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1)
file = gr.File(label="Upload prompt inputs", type='bytes') file = gr.File(label="Upload prompt inputs", type='bytes')
...@@ -106,9 +107,9 @@ class Script(scripts.Script): ...@@ -106,9 +107,9 @@ class Script(scripts.Script):
# We don't shrink back to 1, because that causes the control to ignore [enter], and it may # We don't shrink back to 1, because that causes the control to ignore [enter], and it may
# be unclear to the user that shift-enter is needed. # be unclear to the user that shift-enter is needed.
prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt]) prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt])
return [checkbox_iterate, file, prompt_txt] return [checkbox_iterate, checkbox_iterate_batch, file, prompt_txt]
def run(self, p, checkbox_iterate, file, prompt_txt: str): def run(self, p, checkbox_iterate, checkbox_iterate_batch, file, prompt_txt: str):
lines = [x.strip() for x in prompt_txt.splitlines()] lines = [x.strip() for x in prompt_txt.splitlines()]
lines = [x for x in lines if len(x) > 0] lines = [x for x in lines if len(x) > 0]
...@@ -137,7 +138,7 @@ class Script(scripts.Script): ...@@ -137,7 +138,7 @@ class Script(scripts.Script):
jobs.append(args) jobs.append(args)
print(f"Will process {len(lines)} lines in {job_count} jobs.") print(f"Will process {len(lines)} lines in {job_count} jobs.")
if (checkbox_iterate and p.seed == -1): if (checkbox_iterate or checkbox_iterate_batch) and p.seed == -1:
p.seed = int(random.randrange(4294967294)) p.seed = int(random.randrange(4294967294))
state.job_count = job_count state.job_count = job_count
...@@ -153,7 +154,7 @@ class Script(scripts.Script): ...@@ -153,7 +154,7 @@ class Script(scripts.Script):
proc = process_images(copy_p) proc = process_images(copy_p)
images += proc.images images += proc.images
if (checkbox_iterate): if checkbox_iterate:
p.seed = p.seed + (p.batch_size * p.n_iter) p.seed = p.seed + (p.batch_size * p.n_iter)
......
...@@ -260,6 +260,16 @@ input[type="range"]{ ...@@ -260,6 +260,16 @@ input[type="range"]{
#txt2img_negative_prompt, #img2img_negative_prompt{ #txt2img_negative_prompt, #img2img_negative_prompt{
} }
/* gradio 3.8 adds opacity to progressbar which makes it blink; disable it here */
.transition.opacity-20 {
opacity: 1 !important;
}
/* more gradio's garbage cleanup */
.min-h-\[4rem\] {
min-height: unset !important;
}
#txt2img_progressbar, #img2img_progressbar, #ti_progressbar{ #txt2img_progressbar, #img2img_progressbar, #ti_progressbar{
position: absolute; position: absolute;
z-index: 1000; z-index: 1000;
...@@ -491,7 +501,7 @@ input[type="range"]{ ...@@ -491,7 +501,7 @@ input[type="range"]{
padding: 0; padding: 0;
} }
#refresh_sd_model_checkpoint, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization{ #refresh_sd_model_checkpoint, #refresh_sd_vae, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization{
max-width: 2.5em; max-width: 2.5em;
min-width: 2.5em; min-width: 2.5em;
height: 2.4em; height: 2.4em;
...@@ -530,6 +540,29 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h ...@@ -530,6 +540,29 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h
min-height: 480px !important; min-height: 480px !important;
} }
/* Extensions */
#tab_extensions table{
border-collapse: collapse;
}
#tab_extensions table td, #tab_extensions table th{
border: 1px solid #ccc;
padding: 0.25em 0.5em;
}
#tab_extensions table input[type="checkbox"]{
margin-right: 0.5em;
}
#tab_extensions button{
max-width: 16em;
}
#tab_extensions input[disabled="disabled"]{
opacity: 0.5;
}
/* The following handles localization for right-to-left (RTL) languages like Arabic. /* The following handles localization for right-to-left (RTL) languages like Arabic.
The rtl media type will only be activated by the logic in javascript/localization.js. The rtl media type will only be activated by the logic in javascript/localization.js.
If you change anything above, you need to make sure it is RTL compliant by just running If you change anything above, you need to make sure it is RTL compliant by just running
......
...@@ -9,7 +9,7 @@ from fastapi.middleware.gzip import GZipMiddleware ...@@ -9,7 +9,7 @@ from fastapi.middleware.gzip import GZipMiddleware
from modules.paths import script_path from modules.paths import script_path
from modules import devices, sd_samplers, upscaler from modules import devices, sd_samplers, upscaler, extensions
import modules.codeformer_model as codeformer import modules.codeformer_model as codeformer
import modules.extras import modules.extras
import modules.face_restoration import modules.face_restoration
...@@ -21,8 +21,10 @@ import modules.paths ...@@ -21,8 +21,10 @@ import modules.paths
import modules.scripts import modules.scripts
import modules.sd_hijack import modules.sd_hijack
import modules.sd_models import modules.sd_models
import modules.sd_vae
import modules.shared as shared import modules.shared as shared
import modules.txt2img import modules.txt2img
import modules.script_callbacks
import modules.ui import modules.ui
from modules import devices from modules import devices
...@@ -60,6 +62,8 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): ...@@ -60,6 +62,8 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
def initialize(): def initialize():
extensions.list_extensions()
if cmd_opts.ui_debug_mode: if cmd_opts.ui_debug_mode:
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
modules.scripts.load_scripts() modules.scripts.load_scripts()
...@@ -74,8 +78,10 @@ def initialize(): ...@@ -74,8 +78,10 @@ def initialize():
modules.scripts.load_scripts() modules.scripts.load_scripts()
modules.sd_vae.refresh_vae_list()
modules.sd_models.load_model() modules.sd_models.load_model()
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
...@@ -92,15 +98,18 @@ def create_api(app): ...@@ -92,15 +98,18 @@ def create_api(app):
api = Api(app, queue_lock) api = Api(app, queue_lock)
return api return api
def wait_on_server(demo=None): def wait_on_server(demo=None):
while 1: while 1:
time.sleep(0.5) time.sleep(0.5)
if demo and getattr(demo, 'do_restart', False): if shared.state.need_restart:
shared.state.need_restart = False
time.sleep(0.5) time.sleep(0.5)
demo.close() demo.close()
time.sleep(0.5) time.sleep(0.5)
break break
def api_only(): def api_only():
initialize() initialize()
...@@ -108,6 +117,8 @@ def api_only(): ...@@ -108,6 +117,8 @@ def api_only():
app.add_middleware(GZipMiddleware, minimum_size=1000) app.add_middleware(GZipMiddleware, minimum_size=1000)
api = create_api(app) api = create_api(app)
modules.script_callbacks.app_started_callback(None, app)
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861) api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
...@@ -132,14 +143,18 @@ def webui(): ...@@ -132,14 +143,18 @@ def webui():
app.add_middleware(GZipMiddleware, minimum_size=1000) app.add_middleware(GZipMiddleware, minimum_size=1000)
if (launch_api): if launch_api:
create_api(app) create_api(app)
modules.script_callbacks.app_started_callback(demo, app)
wait_on_server(demo) wait_on_server(demo)
sd_samplers.set_samplers() sd_samplers.set_samplers()
print('Reloading Custom Scripts') print('Reloading extensions')
extensions.list_extensions()
print('Reloading custom scripts')
modules.scripts.reload_scripts() modules.scripts.reload_scripts()
print('Reloading modules: modules.ui') print('Reloading modules: modules.ui')
importlib.reload(modules.ui) importlib.reload(modules.ui)
...@@ -148,8 +163,6 @@ def webui(): ...@@ -148,8 +163,6 @@ def webui():
print('Restarting Gradio') print('Restarting Gradio')
task = []
if __name__ == "__main__": if __name__ == "__main__":
if cmd_opts.nowebui: if cmd_opts.nowebui:
api_only() api_only()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment