Commit fb3b5648 authored by Muhammad Rizqi Nur's avatar Muhammad Rizqi Nur

Merge branch 'master' into fix-ckpt-cache

parents bf7a6998 172c4bc0
function extensions_apply(_, _){
disable = []
update = []
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
if(x.name.startsWith("enable_") && ! x.checked)
disable.push(x.name.substr(7))
if(x.name.startsWith("update_") && x.checked)
update.push(x.name.substr(7))
})
restart_reload()
return [JSON.stringify(disable), JSON.stringify(update)]
}
function extensions_check(){
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
x.innerHTML = "Loading..."
})
return []
}
function install_extension_from_index(button, url){
button.disabled = "disabled"
button.value = "Installing..."
textarea = gradioApp().querySelector('#extension_to_install textarea')
textarea.value = url
textarea.dispatchEvent(new Event("input", { bubbles: true }))
gradioApp().querySelector('#install_extension_button').click()
}
......@@ -3,8 +3,21 @@ global_progressbars = {}
galleries = {}
galleryObservers = {}
// this tracks laumnches of window.setTimeout for progressbar to prevent starting a new timeout when the previous is still running
timeoutIds = {}
function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){
var progressbar = gradioApp().getElementById(id_progressbar)
// gradio 3.8's enlightened approach allows them to create two nested div elements inside each other with same id
// every time you use gr.HTML(elem_id='xxx'), so we handle this here
var progressbar = gradioApp().querySelector("#"+id_progressbar+" #"+id_progressbar)
var progressbarParent
if(progressbar){
progressbarParent = gradioApp().querySelector("#"+id_progressbar)
} else{
progressbar = gradioApp().getElementById(id_progressbar)
progressbarParent = null
}
var skip = id_skip ? gradioApp().getElementById(id_skip) : null
var interrupt = gradioApp().getElementById(id_interrupt)
......@@ -26,18 +39,26 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip
global_progressbars[id_progressbar] = progressbar
var mutationObserver = new MutationObserver(function(m){
if(timeoutIds[id_part]) return;
preview = gradioApp().getElementById(id_preview)
gallery = gradioApp().getElementById(id_gallery)
if(preview != null && gallery != null){
preview.style.width = gallery.clientWidth + "px"
preview.style.height = gallery.clientHeight + "px"
if(progressbarParent) progressbar.style.width = progressbarParent.clientWidth + "px"
//only watch gallery if there is a generation process going on
check_gallery(id_gallery);
var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
if(!progressDiv){
if(progressDiv){
timeoutIds[id_part] = window.setTimeout(function() {
timeoutIds[id_part] = null
requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt)
}, 500)
} else{
if (skip) {
skip.style.display = "none"
}
......@@ -49,11 +70,8 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip
galleries[id_gallery] = null;
}
}
}
window.setTimeout(function() { requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt) }, 500)
});
mutationObserver.observe( progressbar, { childList:true, subtree:true })
}
......
......@@ -7,6 +7,7 @@ import shlex
import platform
dir_repos = "repositories"
dir_extensions = "extensions"
python = sys.executable
git = os.environ.get('GIT', "git")
index_url = os.environ.get('INDEX_URL', "")
......@@ -16,11 +17,11 @@ def extract_arg(args, name):
return [x for x in args if x != name], name in args
def run(command, desc=None, errdesc=None):
def run(command, desc=None, errdesc=None, custom_env=None):
if desc is not None:
print(desc)
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
if result.returncode != 0:
......@@ -101,7 +102,25 @@ def version_check(commit):
else:
print("Not a git clone, can't perform version check.")
except Exception as e:
print("versipm check failed",e)
print("version check failed", e)
def run_extensions_installers():
if not os.path.isdir(dir_extensions):
return
for dirname_extension in os.listdir(dir_extensions):
path_installer = os.path.join(dir_extensions, dirname_extension, "install.py")
if not os.path.isfile(path_installer):
continue
try:
env = os.environ.copy()
env['PYTHONPATH'] = os.path.abspath(".")
print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {dirname_extension}", custom_env=env))
except Exception as e:
print(e, file=sys.stderr)
def prepare_enviroment():
......@@ -189,6 +208,8 @@ def prepare_enviroment():
run_pip(f"install -r {requirements_file}", "requirements for Web UI")
run_extensions_installers()
if update_check:
version_check(commit)
......
......@@ -70,7 +70,7 @@
"None": "Nichts",
"Prompt matrix": "Promptmatrix",
"Prompts from file or textbox": "Prompts aus Datei oder Textfeld",
"X/Y plot": "X/Y Graf",
"X/Y plot": "X/Y Graph",
"Put variable parts at start of prompt": "Variable teile am start des Prompt setzen",
"Iterate seed every line": "Iterate seed every line",
"List of prompt inputs": "List of prompt inputs",
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import base64
import io
import time
import uvicorn
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
from gradio.processing_utils import decode_base64_to_file, decode_base64_to_image
from fastapi import APIRouter, Depends, HTTPException
import modules.shared as shared
from modules import devices
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers
from modules.sd_samplers import all_samplers, sample_to_image, samples_to_image_grid
from modules.extras import run_extras, run_pnginfo
......@@ -29,6 +30,12 @@ def setUpscalers(req: dict):
return reqDict
def encode_pil_to_base64(image):
buffer = io.BytesIO()
image.save(buffer, format="png")
return base64.b64encode(buffer.getvalue())
class Api:
def __init__(self, app, queue_lock):
self.router = APIRouter()
......@@ -40,6 +47,7 @@ class Api:
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
......@@ -170,12 +178,19 @@ class Api:
progress = min(progress, 1)
shared.state.set_current_image()
current_image = None
if shared.state.current_image and not req.skip_current_image:
current_image = encode_pil_to_base64(shared.state.current_image)
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
def interruptapi(self):
shared.state.interrupt()
return {}
def launch(self, server_name, port):
self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port)
......@@ -50,6 +50,7 @@ def mod2normal(state_dict):
def resrgan2normal(state_dict, nb=23):
# this code is copied from https://github.com/victorca25/iNNfer
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
re8x = 0
crt_net = {}
items = []
for k, v in state_dict.items():
......@@ -75,10 +76,18 @@ def resrgan2normal(state_dict, nb=23):
crt_net['model.3.bias'] = state_dict['conv_up1.bias']
crt_net['model.6.weight'] = state_dict['conv_up2.weight']
crt_net['model.6.bias'] = state_dict['conv_up2.bias']
crt_net['model.8.weight'] = state_dict['conv_hr.weight']
crt_net['model.8.bias'] = state_dict['conv_hr.bias']
crt_net['model.10.weight'] = state_dict['conv_last.weight']
crt_net['model.10.bias'] = state_dict['conv_last.bias']
if 'conv_up3.weight' in state_dict:
# modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
re8x = 3
crt_net['model.9.weight'] = state_dict['conv_up3.weight']
crt_net['model.9.bias'] = state_dict['conv_up3.bias']
crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
state_dict = crt_net
return state_dict
......
import os
import sys
import traceback
import git
from modules import paths, shared
extensions = []
extensions_dir = os.path.join(paths.script_path, "extensions")
def active():
return [x for x in extensions if x.enabled]
class Extension:
def __init__(self, name, path, enabled=True):
self.name = name
self.path = path
self.enabled = enabled
self.status = ''
self.can_update = False
repo = None
try:
if os.path.exists(os.path.join(path, ".git")):
repo = git.Repo(path)
except Exception:
print(f"Error reading github repository info from {path}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
if repo is None or repo.bare:
self.remote = None
else:
self.remote = next(repo.remote().urls, None)
self.status = 'unknown'
def list_files(self, subdir, extension):
from modules import scripts
dirpath = os.path.join(self.path, subdir)
if not os.path.isdir(dirpath):
return []
res = []
for filename in sorted(os.listdir(dirpath)):
res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))
res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
return res
def check_updates(self):
repo = git.Repo(self.path)
for fetch in repo.remote().fetch("--dry-run"):
if fetch.flags != fetch.HEAD_UPTODATE:
self.can_update = True
self.status = "behind"
return
self.can_update = False
self.status = "latest"
def pull(self):
repo = git.Repo(self.path)
repo.remotes.origin.pull()
def list_extensions():
extensions.clear()
if not os.path.isdir(extensions_dir):
return
for dirname in sorted(os.listdir(extensions_dir)):
path = os.path.join(extensions_dir, dirname)
if not os.path.isdir(path):
continue
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions)
extensions.append(extension)
......@@ -141,7 +141,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
upscaling_resize_w, upscaling_resize_h, upscaling_crop)
cache_key = LruCache.Key(image_hash=hash(np.array(image.getdata()).tobytes()),
info_hash=hash(info),
args_hash=hash(upscale_args))
args_hash=hash((upscale_args, upscale_first)))
cached_entry = cached_images.get(cache_key)
if cached_entry is None:
res = upscale(image, *upscale_args)
......
......@@ -17,6 +17,11 @@ paste_fields = {}
bind_list = []
def reset():
paste_fields.clear()
bind_list.clear()
def quote(text):
if ',' not in str(text):
return text
......
......@@ -510,6 +510,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
if extension.lower() == '.png':
pnginfo_data = PngImagePlugin.PngInfo()
if opts.enable_pnginfo:
for k, v in params.pnginfo.items():
pnginfo_data.add_text(k, str(v))
......
......@@ -55,6 +55,7 @@ def process_batch(p, input_dir, output_dir, args):
filename = f"{left}-{n}{right}"
if not save_normally:
os.makedirs(output_dir, exist_ok=True)
processed_image.save(os.path.join(output_dir, filename))
......@@ -80,6 +81,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
mask = None
# Use the EXIF orientation of photos taken by smartphones.
if image is not None:
image = ImageOps.exif_transpose(image)
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
......@@ -136,6 +138,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
if processed is None:
processed = process_images(p)
p.close()
shared.total_tqdm.clear()
generation_info_js = processed.js()
......
......@@ -56,9 +56,9 @@ class InterrogateModels:
import clip
if self.running_on_cpu:
model, preprocess = clip.load(clip_model_name, device="cpu")
model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path)
else:
model, preprocess = clip.load(clip_model_name)
model, preprocess = clip.load(clip_model_name, download_root=shared.cmd_opts.clip_models_path)
model.eval()
model = model.to(devices.device_interrogate)
......
......@@ -38,13 +38,18 @@ def setup_for_low_vram(sd_model, use_medvram):
# see below for register_forward_pre_hook;
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
# useless here, and we just replace those methods
def first_stage_model_encode_wrap(self, encoder, x):
send_me_to_gpu(self, None)
return encoder(x)
def first_stage_model_decode_wrap(self, decoder, z):
send_me_to_gpu(self, None)
return decoder(z)
first_stage_model = sd_model.first_stage_model
first_stage_model_encode = sd_model.first_stage_model.encode
first_stage_model_decode = sd_model.first_stage_model.decode
def first_stage_model_encode_wrap(x):
send_me_to_gpu(first_stage_model, None)
return first_stage_model_encode(x)
def first_stage_model_decode_wrap(z):
send_me_to_gpu(first_stage_model, None)
return first_stage_model_decode(z)
# remove three big modules, cond, first_stage, and unet from the model and then
# send the model to GPU. Then put modules back. the modules will be in CPU.
......@@ -56,8 +61,8 @@ def setup_for_low_vram(sd_model, use_medvram):
# register hooks for those the first two models
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x)
sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z)
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
if use_medvram:
......
......@@ -85,6 +85,9 @@ def cleanup_models():
src_path = os.path.join(root_path, "ESRGAN")
dest_path = os.path.join(models_path, "ESRGAN")
move_files(src_path, dest_path)
src_path = os.path.join(models_path, "BSRGAN")
dest_path = os.path.join(models_path, "ESRGAN")
move_files(src_path, dest_path, ".pth")
src_path = os.path.join(root_path, "gfpgan")
dest_path = os.path.join(models_path, "GFPGAN")
move_files(src_path, dest_path)
......
......@@ -199,9 +199,13 @@ class StableDiffusionProcessing():
def init(self, all_prompts, all_seeds, all_subseeds):
pass
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
raise NotImplementedError()
def close(self):
self.sd_model = None
self.sampler = None
class Processed:
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
......@@ -517,7 +521,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
with devices.autocast():
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
samples_ddim = samples_ddim.to(devices.dtype_vae)
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
......@@ -645,7 +649,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
if not self.enable_hr:
......@@ -658,9 +662,21 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
"""saves image before applying hires fix, if enabled in options; takes as an arguyment either an image or batch with latent space images"""
def save_intermediate(image, index):
if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
return
if not isinstance(image, Image.Image):
image = sd_samplers.sample_to_image(image, index)
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")
if opts.use_scale_latent_for_hires_fix:
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
for i in range(samples.shape[0]):
save_intermediate(samples, i)
else:
decoded_samples = decode_first_stage(self.sd_model, samples)
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
......@@ -670,6 +686,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
image = Image.fromarray(x_sample)
save_intermediate(image, i)
image = images.resize_image(0, image, self.width, self.height)
image = np.array(image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0)
......@@ -827,8 +846,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
......
......@@ -32,7 +32,7 @@ class RestrictedUnpickler(pickle.Unpickler):
return getattr(collections, name)
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
return getattr(torch._utils, name)
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage']:
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage']:
return getattr(torch, name)
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
return getattr(torch.nn.modules.container, name)
......
......@@ -2,7 +2,10 @@ import sys
import traceback
from collections import namedtuple
import inspect
from typing import Optional
from fastapi import FastAPI
from gradio import Blocks
def report_exception(c, job):
print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
......@@ -24,12 +27,32 @@ class ImageSaveParams:
"""dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
class CFGDenoiserParams:
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps):
self.x = x
"""Latent image representation in the process of being denoised"""
self.image_cond = image_cond
"""Conditioning image"""
self.sigma = sigma
"""Current sigma noise step value"""
self.sampling_step = sampling_step
"""Current Sampling step number"""
self.total_sampling_steps = total_sampling_steps
"""Total number of sampling steps planned"""
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
callbacks_app_started = []
callbacks_model_loaded = []
callbacks_ui_tabs = []
callbacks_ui_settings = []
callbacks_before_image_saved = []
callbacks_image_saved = []
callbacks_cfg_denoiser = []
def clear_callbacks():
......@@ -38,6 +61,14 @@ def clear_callbacks():
callbacks_ui_settings.clear()
callbacks_before_image_saved.clear()
callbacks_image_saved.clear()
callbacks_cfg_denoiser.clear()
def app_started_callback(demo: Optional[Blocks], app: FastAPI):
for c in callbacks_app_started:
try:
c.callback(demo, app)
except Exception:
report_exception(c, 'app_started_callback')
def model_loaded_callback(sd_model):
......@@ -69,7 +100,7 @@ def ui_settings_callback():
def before_image_saved_callback(params: ImageSaveParams):
for c in callbacks_image_saved:
for c in callbacks_before_image_saved:
try:
c.callback(params)
except Exception:
......@@ -84,6 +115,14 @@ def image_saved_callback(params: ImageSaveParams):
report_exception(c, 'image_saved_callback')
def cfg_denoiser_callback(params: CFGDenoiserParams):
for c in callbacks_cfg_denoiser:
try:
c.callback(params)
except Exception:
report_exception(c, 'cfg_denoiser_callback')
def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
......@@ -91,6 +130,12 @@ def add_callback(callbacks, fun):
callbacks.append(ScriptCallback(filename, fun))
def on_app_started(callback):
"""register a function to be called when the webui started, the gradio `Block` component and
fastapi `FastAPI` object are passed as the arguments"""
add_callback(callbacks_app_started, callback)
def on_model_loaded(callback):
"""register a function to be called when the stable diffusion model is created; the model is
passed as an argument"""
......@@ -130,3 +175,12 @@ def on_image_saved(callback):
- params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
"""
add_callback(callbacks_image_saved, callback)
def on_cfg_denoiser(callback):
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
The callback is called with one argument:
- params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
"""
add_callback(callbacks_cfg_denoiser, callback)
......@@ -7,7 +7,7 @@ import modules.ui as ui
import gradio as gr
from modules.processing import StableDiffusionProcessing
from modules import shared, paths, script_callbacks
from modules import shared, paths, script_callbacks, extensions
AlwaysVisible = object()
......@@ -107,17 +107,8 @@ def list_scripts(scriptdirname, extension):
for filename in sorted(os.listdir(basedir)):
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
extdir = os.path.join(paths.script_path, "extensions")
if os.path.exists(extdir):
for dirname in sorted(os.listdir(extdir)):
dirpath = os.path.join(extdir, dirname)
scriptdirpath = os.path.join(dirpath, scriptdirname)
if not os.path.isdir(scriptdirpath):
continue
for filename in sorted(os.listdir(scriptdirpath)):
scripts_list.append(ScriptFile(dirpath, filename, os.path.join(scriptdirpath, filename)))
for ext in extensions.active():
scripts_list += ext.list_files(scriptdirname, extension)
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
......@@ -127,11 +118,7 @@ def list_scripts(scriptdirname, extension):
def list_files_with_name(filename):
res = []
dirs = [paths.script_path]
extdir = os.path.join(paths.script_path, "extensions")
if os.path.exists(extdir):
dirs += [os.path.join(extdir, d) for d in sorted(os.listdir(extdir))]
dirs = [paths.script_path] + [ext.path for ext in extensions.active()]
for dirpath in dirs:
if not os.path.isdir(dirpath):
......
......@@ -94,6 +94,10 @@ class StableDiffusionModelHijack:
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
self.layers = None
self.circular_enabled = False
self.clip = None
def apply_circular(self, enable):
if self.circular_enabled == enable:
return
......
import collections
import os.path
import sys
import gc
from collections import namedtuple
import torch
import re
......@@ -8,7 +9,7 @@ from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
from modules import shared, modelloader, devices, script_callbacks
from modules import shared, modelloader, devices, script_callbacks, sd_vae
from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
......@@ -158,14 +159,12 @@ def get_state_dict_from_checkpoint(pl_sd):
return pl_sd
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
def load_model_weights(model, checkpoint_info):
def load_model_weights(model, checkpoint_info, vae_file="auto"):
checkpoint_file = checkpoint_info.filename
sd_model_hash = checkpoint_info.hash
if shared.opts.sd_checkpoint_cache > 0 and hasattr(model, "sd_checkpoint_info"):
sd_vae.restore_base_vae(model)
checkpoints_loaded[model.sd_checkpoint_info] = model.state_dict().copy()
if checkpoint_info not in checkpoints_loaded:
......@@ -184,25 +183,23 @@ def load_model_weights(model, checkpoint_info):
model.to(memory_format=torch.channels_last)
if not shared.cmd_opts.no_half:
vae = model.first_stage_model
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
if shared.cmd_opts.no_half_vae:
model.first_stage_model = None
model.half()
model.first_stage_model = vae
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt"
if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None:
vae_file = shared.cmd_opts.vae_path
if os.path.exists(vae_file):
print(f"Loading VAE weights from: {vae_file}")
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
model.first_stage_model.load_state_dict(vae_dict)
model.first_stage_model.to(devices.dtype_vae)
else:
print(f"Loading weights [{sd_model_hash}] from cache")
vae_name = sd_vae.get_filename(vae_file)
print(f"Loading weights [{sd_model_hash}] with {vae_name} VAE from cache")
model.load_state_dict(checkpoints_loaded[checkpoint_info])
if shared.opts.sd_checkpoint_cache > 0:
......@@ -213,6 +210,8 @@ def load_model_weights(model, checkpoint_info):
model.sd_model_checkpoint = checkpoint_file
model.sd_checkpoint_info = checkpoint_info
sd_vae.load_vae(model, vae_file)
def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack
......@@ -221,6 +220,12 @@ def load_model(checkpoint_info=None):
if checkpoint_info.config != shared.cmd_opts.config:
print(f"Loading config from: {checkpoint_info.config}")
if shared.sd_model:
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
shared.sd_model = None
gc.collect()
devices.torch_gc()
sd_config = OmegaConf.load(checkpoint_info.config)
if should_hijack_inpainting(checkpoint_info):
......@@ -234,6 +239,7 @@ def load_model(checkpoint_info=None):
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
do_inpainting_hijack()
sd_model = instantiate_from_config(sd_config.model)
load_model_weights(sd_model, checkpoint_info)
......@@ -253,14 +259,18 @@ def load_model(checkpoint_info=None):
return sd_model
def reload_model_weights(sd_model, info=None):
def reload_model_weights(sd_model=None, info=None):
from modules import lowvram, devices, sd_hijack
checkpoint_info = info or select_checkpoint()
if not sd_model:
sd_model = shared.sd_model
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
del sd_model
checkpoints_loaded.clear()
load_model(checkpoint_info)
return shared.sd_model
......
from collections import namedtuple
import numpy as np
from math import floor
import torch
import tqdm
from PIL import Image
......@@ -11,6 +12,7 @@ from modules import prompt_parser, devices, processing, images
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
......@@ -91,8 +93,8 @@ def single_sample_to_image(sample):
return Image.fromarray(x_sample)
def sample_to_image(samples):
return single_sample_to_image(samples[0])
def sample_to_image(samples, index=0):
return single_sample_to_image(samples[index])
def samples_to_image_grid(samples):
......@@ -205,17 +207,22 @@ class VanillaStableDiffusionSampler:
self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None
def adjust_steps_if_invalid(self, p, num_steps):
if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
valid_step = 999 / (1000 // num_steps)
if valid_step == floor(valid_step):
return int(valid_step) + 1
return num_steps
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = setup_img2img_steps(p, steps)
steps = self.adjust_steps_if_invalid(p, steps)
self.initialize(p)
# existing code fails with certain step counts, like 9
try:
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
except Exception:
self.sampler.make_schedule(ddim_num_steps=steps+1, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
self.init_latent = x
......@@ -239,18 +246,14 @@ class VanillaStableDiffusionSampler:
self.last_latent = x
self.step = 0
steps = steps or p.steps
steps = self.adjust_steps_if_invalid(p, steps or p.steps)
# Wrap the conditioning models with additional image conditioning for inpainting model
if image_conditioning is not None:
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
# existing code fails with certain step counts, like 9
try:
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
except Exception:
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
return samples_ddim
......@@ -278,6 +281,12 @@ class CFGDenoiser(torch.nn.Module):
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
cfg_denoiser_callback(denoiser_params)
x_in = denoiser_params.x
image_cond_in = denoiser_params.image_cond
sigma_in = denoiser_params.sigma
if tensor.shape[1] == uncond.shape[1]:
cond_in = torch.cat([tensor, uncond])
......
import torch
import os
from collections import namedtuple
from modules import shared, devices, script_callbacks
from modules.paths import models_path
import glob
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
vae_dir = "VAE"
vae_path = os.path.abspath(os.path.join(models_path, vae_dir))
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
default_vae_dict = {"auto": "auto", "None": "None"}
default_vae_list = ["auto", "None"]
default_vae_values = [default_vae_dict[x] for x in default_vae_list]
vae_dict = dict(default_vae_dict)
vae_list = list(default_vae_list)
first_load = True
base_vae = None
loaded_vae_file = None
checkpoint_info = None
def get_base_vae(model):
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
return base_vae
return None
def store_base_vae(model):
global base_vae, checkpoint_info
if checkpoint_info != model.sd_checkpoint_info:
base_vae = model.first_stage_model.state_dict().copy()
checkpoint_info = model.sd_checkpoint_info
def delete_base_vae():
global base_vae, checkpoint_info
base_vae = None
checkpoint_info = None
def restore_base_vae(model):
global base_vae, checkpoint_info
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
load_vae_dict(model, base_vae)
delete_base_vae()
def get_filename(filepath):
return os.path.splitext(os.path.basename(filepath))[0]
def refresh_vae_list(vae_path=vae_path, model_path=model_path):
global vae_dict, vae_list
res = {}
candidates = [
*glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True),
*glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True),
*glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True),
*glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True)
]
if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path):
candidates.append(shared.cmd_opts.vae_path)
for filepath in candidates:
name = get_filename(filepath)
res[name] = filepath
vae_list.clear()
vae_list.extend(default_vae_list)
vae_list.extend(list(res.keys()))
vae_dict.clear()
vae_dict.update(res)
vae_dict.update(default_vae_dict)
return vae_list
def resolve_vae(checkpoint_file, vae_file="auto"):
global first_load, vae_dict, vae_list
# if vae_file argument is provided, it takes priority, but not saved
if vae_file and vae_file not in default_vae_list:
if not os.path.isfile(vae_file):
vae_file = "auto"
print("VAE provided as function argument doesn't exist")
# for the first load, if vae-path is provided, it takes priority, saved, and failure is reported
if first_load and shared.cmd_opts.vae_path is not None:
if os.path.isfile(shared.cmd_opts.vae_path):
vae_file = shared.cmd_opts.vae_path
shared.opts.data['sd_vae'] = get_filename(vae_file)
else:
print("VAE provided as command line argument doesn't exist")
# else, we load from settings
if vae_file == "auto" and shared.opts.sd_vae is not None:
# if saved VAE settings isn't recognized, fallback to auto
vae_file = vae_dict.get(shared.opts.sd_vae, "auto")
# if VAE selected but not found, fallback to auto
if vae_file not in default_vae_values and not os.path.isfile(vae_file):
vae_file = "auto"
print("Selected VAE doesn't exist")
# vae-path cmd arg takes priority for auto
if vae_file == "auto" and shared.cmd_opts.vae_path is not None:
if os.path.isfile(shared.cmd_opts.vae_path):
vae_file = shared.cmd_opts.vae_path
print("Using VAE provided as command line argument")
# if still not found, try look for ".vae.pt" beside model
model_path = os.path.splitext(checkpoint_file)[0]
if vae_file == "auto":
vae_file_try = model_path + ".vae.pt"
if os.path.isfile(vae_file_try):
vae_file = vae_file_try
print("Using VAE found beside selected model")
# if still not found, try look for ".vae.ckpt" beside model
if vae_file == "auto":
vae_file_try = model_path + ".vae.ckpt"
if os.path.isfile(vae_file_try):
vae_file = vae_file_try
print("Using VAE found beside selected model")
# No more fallbacks for auto
if vae_file == "auto":
vae_file = None
# Last check, just because
if vae_file and not os.path.exists(vae_file):
vae_file = None
return vae_file
def load_vae(model, vae_file=None):
global first_load, vae_dict, vae_list, loaded_vae_file
# save_settings = False
if vae_file:
print(f"Loading VAE weights from: {vae_file}")
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
load_vae_dict(model, vae_dict_1)
# If vae used is not in dict, update it
# It will be removed on refresh though
vae_opt = get_filename(vae_file)
if vae_opt not in vae_dict:
vae_dict[vae_opt] = vae_file
vae_list.append(vae_opt)
loaded_vae_file = vae_file
"""
# Save current VAE to VAE settings, maybe? will it work?
if save_settings:
if vae_file is None:
vae_opt = "None"
# shared.opts.sd_vae = vae_opt
"""
first_load = False
# don't call this from outside
def load_vae_dict(model, vae_dict_1=None):
if vae_dict_1:
store_base_vae(model)
model.first_stage_model.load_state_dict(vae_dict_1)
else:
restore_base_vae()
model.first_stage_model.to(devices.dtype_vae)
def reload_vae_weights(sd_model=None, vae_file="auto"):
from modules import lowvram, devices, sd_hijack
if not sd_model:
sd_model = shared.sd_model
checkpoint_info = sd_model.sd_checkpoint_info
checkpoint_file = checkpoint_info.filename
vae_file = resolve_vae(checkpoint_file, vae_file=vae_file)
if loaded_vae_file == vae_file:
return
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
else:
sd_model.to(devices.cpu)
sd_hijack.model_hijack.undo_hijack(sd_model)
load_vae(sd_model, vae_file)
sd_hijack.model_hijack.hijack(sd_model)
script_callbacks.model_loaded_callback(sd_model)
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device)
print(f"VAE Weights loaded.")
return sd_model
......@@ -4,6 +4,7 @@ import json
import os
import sys
from collections import OrderedDict
import time
import gradio as gr
import tqdm
......@@ -14,7 +15,7 @@ import modules.memmon
import modules.sd_models
import modules.styles
import modules.devices as devices
from modules import sd_samplers, sd_models, localization
from modules import sd_samplers, sd_models, localization, sd_vae
from modules.hypernetworks import hypernetwork
from modules.paths import models_path, script_path, sd_path
......@@ -40,7 +41,7 @@ parser.add_argument("--lowram", action='store_true', help="load stable diffusion
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
......@@ -51,6 +52,7 @@ parser.add_argument("--realesrgan-models-path", type=str, help="Path to director
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET'))
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR'))
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR'))
parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator")
......@@ -97,6 +99,8 @@ restricted_opts = {
"outdir_save",
}
cmd_opts.disable_extension_access = cmd_opts.share or cmd_opts.listen
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'])
......@@ -132,6 +136,8 @@ class State:
current_image = None
current_image_sampling_step = 0
textinfo = None
time_start = None
need_restart = False
def skip(self):
self.skipped = True
......@@ -168,6 +174,7 @@ class State:
self.skipped = False
self.interrupted = False
self.textinfo = None
self.time_start = time.time()
devices.torch_gc()
......@@ -177,6 +184,20 @@ class State:
devices.torch_gc()
"""sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
def set_current_image(self):
if not parallel_processing_allowed:
return
if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and self.current_latent is not None:
if opts.show_progress_grid:
self.current_image = sd_samplers.samples_to_image_grid(self.current_latent)
else:
self.current_image = sd_samplers.sample_to_image(self.current_latent)
self.current_image_sampling_step = self.sampling_step
state = State()
artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
......@@ -234,6 +255,8 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
"save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
"export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
......@@ -285,21 +308,22 @@ options_templates.update(options_section(('system', "System"), {
}))
options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training hypernetwork. Saves VRAM."),
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
"training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
}))
options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": list(sd_vae.vae_list)}, refresh=sd_vae.refresh_vae_list),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
......@@ -354,6 +378,12 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
}))
options_templates.update(options_section((None, "Hidden options"), {
"disabled_extensions": OptionInfo([], "Disable those extensions"),
}))
options_templates.update()
class Options:
data = None
......@@ -365,8 +395,9 @@ class Options:
def __setattr__(self, key, value):
if self.data is not None:
if key in self.data:
if key in self.data or key in self.data_labels:
self.data[key] = value
return
return super(Options, self).__setattr__(key, value)
......@@ -407,10 +438,11 @@ class Options:
if bad_settings > 0:
print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr)
def onchange(self, key, func):
def onchange(self, key, func, call=True):
item = self.data_labels.get(key)
item.onchange = func
if call:
func()
def dumpjson(self):
......
......@@ -235,6 +235,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
unload = shared.opts.unload_models_when_training
if save_embedding_every > 0:
embedding_dir = os.path.join(log_directory, "embeddings")
......@@ -272,6 +273,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
if unload:
shared.sd_model.first_stage_model.to(devices.cpu)
embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
......@@ -328,6 +331,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
if images_dir is not None and steps_done % create_image_every == 0:
forced_filename = f'{embedding_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename)
shared.sd_model.first_stage_model.to(devices.device)
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
do_not_save_grid=True,
......@@ -355,6 +361,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
processed = processing.process_images(p)
image = processed.images[0]
if unload:
shared.sd_model.first_stage_model.to(devices.cpu)
shared.state.current_image = image
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
......@@ -400,6 +409,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
shared.sd_model.first_stage_model.to(devices.device)
return embedding, filename
......
......@@ -25,7 +25,9 @@ def train_embedding(*args):
assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
apply_optimizations = shared.opts.training_xattention_optimizations
try:
if not apply_optimizations:
sd_hijack.undo_optimizations()
embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
......@@ -38,5 +40,6 @@ Embedding saved to {html.escape(filename)}
except Exception:
raise
finally:
if not apply_optimizations:
sd_hijack.apply_optimizations()
......@@ -47,6 +47,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
if processed is None:
processed = process_images(p)
p.close()
shared.total_tqdm.clear()
generation_info_js = processed.js()
......
......@@ -19,7 +19,7 @@ import numpy as np
from PIL import Image, PngImagePlugin
from modules import sd_hijack, sd_models, localization, script_callbacks
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions
from modules.paths import script_path
from modules.shared import opts, cmd_opts, restricted_opts
......@@ -277,15 +277,7 @@ def check_progress_call(id_part):
preview_visibility = gr_show(False)
if opts.show_progress_every_n_steps > 0:
if shared.parallel_processing_allowed:
if shared.state.sampling_step - shared.state.current_image_sampling_step >= opts.show_progress_every_n_steps and shared.state.current_latent is not None:
if opts.show_progress_grid:
shared.state.current_image = modules.sd_samplers.samples_to_image_grid(shared.state.current_latent)
else:
shared.state.current_image = modules.sd_samplers.sample_to_image(shared.state.current_latent)
shared.state.current_image_sampling_step = shared.state.sampling_step
shared.state.set_current_image()
image = shared.state.current_image
if image is None:
......@@ -671,6 +663,9 @@ def create_ui(wrap_gradio_gpu_call):
import modules.img2img
import modules.txt2img
reload_javascript()
parameters_copypaste.reset()
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
......@@ -1059,7 +1054,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Tabs(elem_id="extras_resize_mode"):
with gr.TabItem('Scale by'):
upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2)
upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4)
with gr.TabItem('Scale to'):
with gr.Group():
with gr.Row():
......@@ -1511,8 +1506,9 @@ def create_ui(wrap_gradio_gpu_call):
column = None
with gr.Row(elem_id="settings").style(equal_height=False):
for i, (k, item) in enumerate(opts.data_labels.items()):
section_must_be_skipped = item.section[0] is None
if previous_section != item.section:
if previous_section != item.section and not section_must_be_skipped:
if cols_displayed < settings_cols and (items_displayed >= items_per_col or previous_section is None):
if column is not None:
column.__exit__()
......@@ -1531,6 +1527,8 @@ def create_ui(wrap_gradio_gpu_call):
if k in quicksettings_names and not shared.cmd_opts.freeze_settings:
quicksettings_list.append((i, k, item))
components.append(dummy_component)
elif section_must_be_skipped:
components.append(dummy_component)
else:
component = create_setting_component(k)
component_dict[k] = component
......@@ -1566,19 +1564,19 @@ def create_ui(wrap_gradio_gpu_call):
reload_script_bodies.click(
fn=reload_scripts,
inputs=[],
outputs=[],
_js='function(){}'
outputs=[]
)
def request_restart():
shared.state.interrupt()
settings_interface.gradio_ref.do_restart = True
shared.state.need_restart = True
restart_gradio.click(
fn=request_restart,
inputs=[],
outputs=[],
_js='function(){restart_reload()}'
_js='restart_reload'
)
if column is not None:
......@@ -1612,14 +1610,15 @@ def create_ui(wrap_gradio_gpu_call):
interfaces += script_callbacks.ui_tabs_callback()
interfaces += [(settings_interface, "Settings", "settings")]
extensions_interface = ui_extensions.create_ui()
interfaces += [(extensions_interface, "Extensions", "extensions")]
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
with gr.Row(elem_id="quicksettings"):
for i, k, item in quicksettings_list:
component = create_setting_component(k, is_quicksettings=True)
component_dict[k] = component
settings_interface.gradio_ref = demo
parameters_copypaste.integrate_settings_paste_fields(component_dict)
parameters_copypaste.run_bind()
......@@ -1776,4 +1775,3 @@ def load_javascript(raw_response):
reload_javascript = partial(load_javascript, gradio.routes.templates.TemplateResponse)
reload_javascript()
import json
import os.path
import shutil
import sys
import time
import traceback
import git
import gradio as gr
import html
from modules import extensions, shared, paths
available_extensions = {"extensions": []}
def check_access():
assert not shared.cmd_opts.disable_extension_access, "extension access disabed because of commandline flags"
def apply_and_restart(disable_list, update_list):
check_access()
disabled = json.loads(disable_list)
assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}"
update = json.loads(update_list)
assert type(update) == list, f"wrong update_list data for apply_and_restart: {update_list}"
update = set(update)
for ext in extensions.extensions:
if ext.name not in update:
continue
try:
ext.pull()
except Exception:
print(f"Error pulling updates for {ext.name}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
shared.opts.disabled_extensions = disabled
shared.opts.save(shared.config_filename)
shared.state.interrupt()
shared.state.need_restart = True
def check_updates():
check_access()
for ext in extensions.extensions:
if ext.remote is None:
continue
try:
ext.check_updates()
except Exception:
print(f"Error checking updates for {ext.name}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return extension_table()
def extension_table():
code = f"""<!-- {time.time()} -->
<table id="extensions">
<thead>
<tr>
<th><abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr></th>
<th>URL</th>
<th><abbr title="Use checkbox to mark the extension for update; it will be updated when you click apply button">Update</abbr></th>
</tr>
</thead>
<tbody>
"""
for ext in extensions.extensions:
if ext.can_update:
ext_status = f"""<label><input class="gr-check-radio gr-checkbox" name="update_{html.escape(ext.name)}" checked="checked" type="checkbox">{html.escape(ext.status)}</label>"""
else:
ext_status = ext.status
code += f"""
<tr>
<td><label><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
<td><a href="{html.escape(ext.remote or '')}">{html.escape(ext.remote or '')}</a></td>
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
</tr>
"""
code += """
</tbody>
</table>
"""
return code
def normalize_git_url(url):
if url is None:
return ""
url = url.replace(".git", "")
return url
def install_extension_from_url(dirname, url):
check_access()
assert url, 'No URL specified'
if dirname is None or dirname == "":
*parts, last_part = url.split('/')
last_part = normalize_git_url(last_part)
dirname = last_part
target_dir = os.path.join(extensions.extensions_dir, dirname)
assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}'
normalized_url = normalize_git_url(url)
assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed'
tmpdir = os.path.join(paths.script_path, "tmp", dirname)
try:
shutil.rmtree(tmpdir, True)
repo = git.Repo.clone_from(url, tmpdir)
repo.remote().fetch()
os.rename(tmpdir, target_dir)
extensions.list_extensions()
return [extension_table(), html.escape(f"Installed into {target_dir}. Use Installed tab to restart.")]
finally:
shutil.rmtree(tmpdir, True)
def install_extension_from_index(url):
ext_table, message = install_extension_from_url(None, url)
return refresh_available_extensions_from_data(), ext_table, message
def refresh_available_extensions(url):
global available_extensions
import urllib.request
with urllib.request.urlopen(url) as response:
text = response.read()
available_extensions = json.loads(text)
return url, refresh_available_extensions_from_data(), ''
def refresh_available_extensions_from_data():
extlist = available_extensions["extensions"]
installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
code = f"""<!-- {time.time()} -->
<table id="available_extensions">
<thead>
<tr>
<th>Extension</th>
<th>Description</th>
<th>Action</th>
</tr>
</thead>
<tbody>
"""
for ext in extlist:
name = ext.get("name", "noname")
url = ext.get("url", None)
description = ext.get("description", "")
if url is None:
continue
existing = installed_extension_urls.get(normalize_git_url(url), None)
install_code = f"""<input onclick="install_extension_from_index(this, '{html.escape(url)}')" type="button" value="{"Install" if not existing else "Installed"}" {"disabled=disabled" if existing else ""} class="gr-button gr-button-lg gr-button-secondary">"""
code += f"""
<tr>
<td><a href="{html.escape(url)}">{html.escape(name)}</a></td>
<td>{html.escape(description)}</td>
<td>{install_code}</td>
</tr>
"""
code += """
</tbody>
</table>
"""
return code
def create_ui():
import modules.ui
with gr.Blocks(analytics_enabled=False) as ui:
with gr.Tabs(elem_id="tabs_extensions") as tabs:
with gr.TabItem("Installed"):
with gr.Row():
apply = gr.Button(value="Apply and restart UI", variant="primary")
check = gr.Button(value="Check for updates")
extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False)
extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False)
extensions_table = gr.HTML(lambda: extension_table())
apply.click(
fn=apply_and_restart,
_js="extensions_apply",
inputs=[extensions_disabled_list, extensions_update_list],
outputs=[],
)
check.click(
fn=check_updates,
_js="extensions_check",
inputs=[],
outputs=[extensions_table],
)
with gr.TabItem("Available"):
with gr.Row():
refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/wiki/AUTOMATIC1111/stable-diffusion-webui/Extensions-index.md", label="Extension index URL").style(container=False)
extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
install_result = gr.HTML()
available_extensions_table = gr.HTML()
refresh_available_extensions_button.click(
fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update()]),
inputs=[available_extensions_index],
outputs=[available_extensions_index, available_extensions_table, install_result],
)
install_extension_button.click(
fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]),
inputs=[extension_to_install],
outputs=[available_extensions_table, extensions_table, install_result],
)
with gr.TabItem("Install from URL"):
install_url = gr.Text(label="URL for extension's git repository")
install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
install_button = gr.Button(value="Install", variant="primary")
install_result = gr.HTML(elem_id="extension_install_result")
install_button.click(
fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]),
inputs=[install_dirname, install_url],
outputs=[extensions_table, install_result],
)
return ui
......@@ -10,6 +10,7 @@ import modules.shared
from modules import modelloader, shared
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST)
from modules.paths import models_path
......@@ -57,7 +58,7 @@ class Upscaler:
dest_w = img.width * scale
dest_h = img.height * scale
for i in range(3):
if img.width >= dest_w and img.height >= dest_h:
if img.width > dest_w and img.height > dest_h:
break
img = self.do_upscale(img, selected_model)
if img.width != dest_w or img.height != dest_h:
......@@ -120,3 +121,17 @@ class UpscalerLanczos(Upscaler):
self.name = "Lanczos"
self.scalers = [UpscalerData("Lanczos", None, self)]
class UpscalerNearest(Upscaler):
scalers = []
def do_upscale(self, img, selected_model=None):
return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=NEAREST)
def load_model(self, _):
pass
def __init__(self, dirname=None):
super().__init__(False)
self.name = "Nearest"
self.scalers = [UpscalerData("Nearest", None, self)]
\ No newline at end of file
......@@ -4,7 +4,7 @@ fairscale==0.4.4
fonts
font-roboto
gfpgan
gradio==3.5
gradio==3.8
invisible-watermark
numpy
omegaconf
......@@ -12,7 +12,7 @@ opencv-python
requests
piexif
Pillow
pytorch_lightning
pytorch_lightning==1.7.7
realesrgan
scikit-image>=0.19
timm==0.4.12
......@@ -26,3 +26,4 @@ torchdiffeq
kornia
lark
inflection
GitPython
......@@ -2,7 +2,7 @@ transformers==4.19.2
diffusers==0.3.0
basicsr==1.4.2
gfpgan==1.3.8
gradio==3.5
gradio==3.8
numpy==1.23.3
Pillow==9.2.0
realesrgan==0.3.0
......@@ -23,3 +23,4 @@ torchdiffeq==0.2.3
kornia==0.6.7
lark==1.1.2
inflection==0.5.1
GitPython==3.1.27
......@@ -166,8 +166,7 @@ class Script(scripts.Script):
if override_strength:
p.denoising_strength = 1.0
def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
lat = (p.init_latent.cpu().numpy() * 10).astype(int)
same_params = self.cache is not None and self.cache.cfg_scale == cfg and self.cache.steps == st \
......
......@@ -96,6 +96,7 @@ class Script(scripts.Script):
def ui(self, is_img2img):
checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False)
checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False)
prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1)
file = gr.File(label="Upload prompt inputs", type='bytes')
......@@ -106,9 +107,9 @@ class Script(scripts.Script):
# We don't shrink back to 1, because that causes the control to ignore [enter], and it may
# be unclear to the user that shift-enter is needed.
prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt])
return [checkbox_iterate, file, prompt_txt]
return [checkbox_iterate, checkbox_iterate_batch, file, prompt_txt]
def run(self, p, checkbox_iterate, file, prompt_txt: str):
def run(self, p, checkbox_iterate, checkbox_iterate_batch, file, prompt_txt: str):
lines = [x.strip() for x in prompt_txt.splitlines()]
lines = [x for x in lines if len(x) > 0]
......@@ -137,7 +138,7 @@ class Script(scripts.Script):
jobs.append(args)
print(f"Will process {len(lines)} lines in {job_count} jobs.")
if (checkbox_iterate and p.seed == -1):
if (checkbox_iterate or checkbox_iterate_batch) and p.seed == -1:
p.seed = int(random.randrange(4294967294))
state.job_count = job_count
......@@ -153,7 +154,7 @@ class Script(scripts.Script):
proc = process_images(copy_p)
images += proc.images
if (checkbox_iterate):
if checkbox_iterate:
p.seed = p.seed + (p.batch_size * p.n_iter)
......
......@@ -260,6 +260,16 @@ input[type="range"]{
#txt2img_negative_prompt, #img2img_negative_prompt{
}
/* gradio 3.8 adds opacity to progressbar which makes it blink; disable it here */
.transition.opacity-20 {
opacity: 1 !important;
}
/* more gradio's garbage cleanup */
.min-h-\[4rem\] {
min-height: unset !important;
}
#txt2img_progressbar, #img2img_progressbar, #ti_progressbar{
position: absolute;
z-index: 1000;
......@@ -491,7 +501,7 @@ input[type="range"]{
padding: 0;
}
#refresh_sd_model_checkpoint, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization{
#refresh_sd_model_checkpoint, #refresh_sd_vae, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization{
max-width: 2.5em;
min-width: 2.5em;
height: 2.4em;
......@@ -530,6 +540,29 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h
min-height: 480px !important;
}
/* Extensions */
#tab_extensions table{
border-collapse: collapse;
}
#tab_extensions table td, #tab_extensions table th{
border: 1px solid #ccc;
padding: 0.25em 0.5em;
}
#tab_extensions table input[type="checkbox"]{
margin-right: 0.5em;
}
#tab_extensions button{
max-width: 16em;
}
#tab_extensions input[disabled="disabled"]{
opacity: 0.5;
}
/* The following handles localization for right-to-left (RTL) languages like Arabic.
The rtl media type will only be activated by the logic in javascript/localization.js.
If you change anything above, you need to make sure it is RTL compliant by just running
......
......@@ -9,7 +9,7 @@ from fastapi.middleware.gzip import GZipMiddleware
from modules.paths import script_path
from modules import devices, sd_samplers, upscaler
from modules import devices, sd_samplers, upscaler, extensions
import modules.codeformer_model as codeformer
import modules.extras
import modules.face_restoration
......@@ -21,8 +21,10 @@ import modules.paths
import modules.scripts
import modules.sd_hijack
import modules.sd_models
import modules.sd_vae
import modules.shared as shared
import modules.txt2img
import modules.script_callbacks
import modules.ui
from modules import devices
......@@ -60,6 +62,8 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
def initialize():
extensions.list_extensions()
if cmd_opts.ui_debug_mode:
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
modules.scripts.load_scripts()
......@@ -74,8 +78,10 @@ def initialize():
modules.scripts.load_scripts()
modules.sd_vae.refresh_vae_list()
modules.sd_models.load_model()
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
......@@ -92,15 +98,18 @@ def create_api(app):
api = Api(app, queue_lock)
return api
def wait_on_server(demo=None):
while 1:
time.sleep(0.5)
if demo and getattr(demo, 'do_restart', False):
if shared.state.need_restart:
shared.state.need_restart = False
time.sleep(0.5)
demo.close()
time.sleep(0.5)
break
def api_only():
initialize()
......@@ -108,6 +117,8 @@ def api_only():
app.add_middleware(GZipMiddleware, minimum_size=1000)
api = create_api(app)
modules.script_callbacks.app_started_callback(None, app)
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
......@@ -132,14 +143,18 @@ def webui():
app.add_middleware(GZipMiddleware, minimum_size=1000)
if (launch_api):
if launch_api:
create_api(app)
modules.script_callbacks.app_started_callback(demo, app)
wait_on_server(demo)
sd_samplers.set_samplers()
print('Reloading Custom Scripts')
print('Reloading extensions')
extensions.list_extensions()
print('Reloading custom scripts')
modules.scripts.reload_scripts()
print('Reloading modules: modules.ui')
importlib.reload(modules.ui)
......@@ -148,8 +163,6 @@ def webui():
print('Restarting Gradio')
task = []
if __name__ == "__main__":
if cmd_opts.nowebui:
api_only()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment