Unverified Commit af6fba24 authored by KyuSeok Jung's avatar KyuSeok Jung Committed by GitHub

Merge branch 'master' into master

parents 467cae16 95c6308c
...@@ -3,8 +3,21 @@ global_progressbars = {} ...@@ -3,8 +3,21 @@ global_progressbars = {}
galleries = {} galleries = {}
galleryObservers = {} galleryObservers = {}
// this tracks laumnches of window.setTimeout for progressbar to prevent starting a new timeout when the previous is still running
timeoutIds = {}
function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){
var progressbar = gradioApp().getElementById(id_progressbar) // gradio 3.8's enlightened approach allows them to create two nested div elements inside each other with same id
// every time you use gr.HTML(elem_id='xxx'), so we handle this here
var progressbar = gradioApp().querySelector("#"+id_progressbar+" #"+id_progressbar)
var progressbarParent
if(progressbar){
progressbarParent = gradioApp().querySelector("#"+id_progressbar)
} else{
progressbar = gradioApp().getElementById(id_progressbar)
progressbarParent = null
}
var skip = id_skip ? gradioApp().getElementById(id_skip) : null var skip = id_skip ? gradioApp().getElementById(id_skip) : null
var interrupt = gradioApp().getElementById(id_interrupt) var interrupt = gradioApp().getElementById(id_interrupt)
...@@ -26,18 +39,26 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip ...@@ -26,18 +39,26 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip
global_progressbars[id_progressbar] = progressbar global_progressbars[id_progressbar] = progressbar
var mutationObserver = new MutationObserver(function(m){ var mutationObserver = new MutationObserver(function(m){
if(timeoutIds[id_part]) return;
preview = gradioApp().getElementById(id_preview) preview = gradioApp().getElementById(id_preview)
gallery = gradioApp().getElementById(id_gallery) gallery = gradioApp().getElementById(id_gallery)
if(preview != null && gallery != null){ if(preview != null && gallery != null){
preview.style.width = gallery.clientWidth + "px" preview.style.width = gallery.clientWidth + "px"
preview.style.height = gallery.clientHeight + "px" preview.style.height = gallery.clientHeight + "px"
if(progressbarParent) progressbar.style.width = progressbarParent.clientWidth + "px"
//only watch gallery if there is a generation process going on //only watch gallery if there is a generation process going on
check_gallery(id_gallery); check_gallery(id_gallery);
var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0; var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
if(!progressDiv){ if(progressDiv){
timeoutIds[id_part] = window.setTimeout(function() {
timeoutIds[id_part] = null
requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt)
}, 500)
} else{
if (skip) { if (skip) {
skip.style.display = "none" skip.style.display = "none"
} }
...@@ -47,13 +68,10 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip ...@@ -47,13 +68,10 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip
if (galleryObservers[id_gallery]) { if (galleryObservers[id_gallery]) {
galleryObservers[id_gallery].disconnect(); galleryObservers[id_gallery].disconnect();
galleries[id_gallery] = null; galleries[id_gallery] = null;
} }
} }
} }
window.setTimeout(function() { requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt) }, 500)
}); });
mutationObserver.observe( progressbar, { childList:true, subtree:true }) mutationObserver.observe( progressbar, { childList:true, subtree:true })
} }
......
...@@ -70,7 +70,7 @@ ...@@ -70,7 +70,7 @@
"None": "Nichts", "None": "Nichts",
"Prompt matrix": "Promptmatrix", "Prompt matrix": "Promptmatrix",
"Prompts from file or textbox": "Prompts aus Datei oder Textfeld", "Prompts from file or textbox": "Prompts aus Datei oder Textfeld",
"X/Y plot": "X/Y Graf", "X/Y plot": "X/Y Graph",
"Put variable parts at start of prompt": "Variable teile am start des Prompt setzen", "Put variable parts at start of prompt": "Variable teile am start des Prompt setzen",
"Iterate seed every line": "Iterate seed every line", "Iterate seed every line": "Iterate seed every line",
"List of prompt inputs": "List of prompt inputs", "List of prompt inputs": "List of prompt inputs",
...@@ -455,4 +455,4 @@ ...@@ -455,4 +455,4 @@
"Only applies to inpainting models. Determines how strongly to mask off the original image for inpainting and img2img. 1.0 means fully masked, which is the default behaviour. 0.0 means a fully unmasked conditioning. Lower values will help preserve the overall composition of the image, but will struggle with large changes.": "Gilt nur für Inpainting-Modelle. Legt fest, wie stark das Originalbild für Inpainting und img2img maskiert werden soll. 1.0 bedeutet vollständig maskiert, was das Standardverhalten ist. 0.0 bedeutet eine vollständig unmaskierte Konditionierung. Niedrigere Werte tragen dazu bei, die Gesamtkomposition des Bildes zu erhalten, sind aber bei großen Änderungen problematisch.", "Only applies to inpainting models. Determines how strongly to mask off the original image for inpainting and img2img. 1.0 means fully masked, which is the default behaviour. 0.0 means a fully unmasked conditioning. Lower values will help preserve the overall composition of the image, but will struggle with large changes.": "Gilt nur für Inpainting-Modelle. Legt fest, wie stark das Originalbild für Inpainting und img2img maskiert werden soll. 1.0 bedeutet vollständig maskiert, was das Standardverhalten ist. 0.0 bedeutet eine vollständig unmaskierte Konditionierung. Niedrigere Werte tragen dazu bei, die Gesamtkomposition des Bildes zu erhalten, sind aber bei großen Änderungen problematisch.",
"List of setting names, separated by commas, for settings that should go to the quick access bar at the top, rather than the usual setting tab. See modules/shared.py for setting names. Requires restarting to apply.": "Liste von Einstellungsnamen, getrennt durch Kommas, für Einstellungen, die in der Schnellzugriffsleiste oben erscheinen sollen, anstatt in dem üblichen Einstellungs-Tab. Siehe modules/shared.py für Einstellungsnamen. Erfordert einen Neustart zur Anwendung.", "List of setting names, separated by commas, for settings that should go to the quick access bar at the top, rather than the usual setting tab. See modules/shared.py for setting names. Requires restarting to apply.": "Liste von Einstellungsnamen, getrennt durch Kommas, für Einstellungen, die in der Schnellzugriffsleiste oben erscheinen sollen, anstatt in dem üblichen Einstellungs-Tab. Siehe modules/shared.py für Einstellungsnamen. Erfordert einen Neustart zur Anwendung.",
"If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.": "Wenn dieser Wert ungleich Null ist, wird er zum Seed addiert und zur Initialisierung des RNG für Noise bei der Verwendung von Samplern mit Eta verwendet. Dies kann verwendet werden, um noch mehr Variationen von Bildern zu erzeugen, oder um Bilder von anderer Software zu erzeugen, wenn Sie wissen, was Sie tun." "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.": "Wenn dieser Wert ungleich Null ist, wird er zum Seed addiert und zur Initialisierung des RNG für Noise bei der Verwendung von Samplern mit Eta verwendet. Dies kann verwendet werden, um noch mehr Variationen von Bildern zu erzeugen, oder um Bilder von anderer Software zu erzeugen, wenn Sie wissen, was Sie tun."
} }
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import base64
import io
import time import time
import uvicorn import uvicorn
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image from gradio.processing_utils import decode_base64_to_file, decode_base64_to_image
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
import modules.shared as shared import modules.shared as shared
from modules import devices from modules import devices
...@@ -29,6 +31,12 @@ def setUpscalers(req: dict): ...@@ -29,6 +31,12 @@ def setUpscalers(req: dict):
return reqDict return reqDict
def encode_pil_to_base64(image):
buffer = io.BytesIO()
image.save(buffer, format="png")
return base64.b64encode(buffer.getvalue())
class Api: class Api:
def __init__(self, app, queue_lock): def __init__(self, app, queue_lock):
self.router = APIRouter() self.router = APIRouter()
...@@ -40,6 +48,7 @@ class Api: ...@@ -40,6 +48,7 @@ class Api:
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse) self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse) self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index) sampler_index = sampler_to_index(txt2imgreq.sampler_index)
...@@ -176,6 +185,11 @@ class Api: ...@@ -176,6 +185,11 @@ class Api:
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image) return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
def interruptapi(self):
shared.state.interrupt()
return {}
def launch(self, server_name, port): def launch(self, server_name, port):
self.app.include_router(self.router) self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port) uvicorn.run(self.app, host=server_name, port=port)
...@@ -141,7 +141,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -141,7 +141,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
upscaling_resize_w, upscaling_resize_h, upscaling_crop) upscaling_resize_w, upscaling_resize_h, upscaling_crop)
cache_key = LruCache.Key(image_hash=hash(np.array(image.getdata()).tobytes()), cache_key = LruCache.Key(image_hash=hash(np.array(image.getdata()).tobytes()),
info_hash=hash(info), info_hash=hash(info),
args_hash=hash(upscale_args)) args_hash=hash((upscale_args, upscale_first)))
cached_entry = cached_images.get(cache_key) cached_entry = cached_images.get(cache_key)
if cached_entry is None: if cached_entry is None:
res = upscale(image, *upscale_args) res = upscale(image, *upscale_args)
......
...@@ -510,8 +510,9 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i ...@@ -510,8 +510,9 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
if extension.lower() == '.png': if extension.lower() == '.png':
pnginfo_data = PngImagePlugin.PngInfo() pnginfo_data = PngImagePlugin.PngInfo()
for k, v in params.pnginfo.items(): if opts.enable_pnginfo:
pnginfo_data.add_text(k, str(v)) for k, v in params.pnginfo.items():
pnginfo_data.add_text(k, str(v))
image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo_data) image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
......
...@@ -137,6 +137,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro ...@@ -137,6 +137,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
if processed is None: if processed is None:
processed = process_images(p) processed = process_images(p)
p.close()
shared.total_tqdm.clear() shared.total_tqdm.clear()
generation_info_js = processed.js() generation_info_js = processed.js()
......
...@@ -56,9 +56,9 @@ class InterrogateModels: ...@@ -56,9 +56,9 @@ class InterrogateModels:
import clip import clip
if self.running_on_cpu: if self.running_on_cpu:
model, preprocess = clip.load(clip_model_name, device="cpu") model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path)
else: else:
model, preprocess = clip.load(clip_model_name) model, preprocess = clip.load(clip_model_name, download_root=shared.cmd_opts.clip_models_path)
model.eval() model.eval()
model = model.to(devices.device_interrogate) model = model.to(devices.device_interrogate)
......
...@@ -202,6 +202,10 @@ class StableDiffusionProcessing(): ...@@ -202,6 +202,10 @@ class StableDiffusionProcessing():
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
raise NotImplementedError() raise NotImplementedError()
def close(self):
self.sd_model = None
self.sampler = None
class Processed: class Processed:
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None): def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
...@@ -597,9 +601,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -597,9 +601,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None: if p.scripts is not None:
p.scripts.postprocess(p, res) p.scripts.postprocess(p, res)
p.sd_model = None
p.sampler = None
return res return res
......
...@@ -26,6 +26,24 @@ class ImageSaveParams: ...@@ -26,6 +26,24 @@ class ImageSaveParams:
"""dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'""" """dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
class CFGDenoiserParams:
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps):
self.x = x
"""Latent image representation in the process of being denoised"""
self.image_cond = image_cond
"""Conditioning image"""
self.sigma = sigma
"""Current sigma noise step value"""
self.sampling_step = sampling_step
"""Current Sampling step number"""
self.total_sampling_steps = total_sampling_steps
"""Total number of sampling steps planned"""
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
callbacks_app_started = [] callbacks_app_started = []
callbacks_model_loaded = [] callbacks_model_loaded = []
...@@ -33,6 +51,7 @@ callbacks_ui_tabs = [] ...@@ -33,6 +51,7 @@ callbacks_ui_tabs = []
callbacks_ui_settings = [] callbacks_ui_settings = []
callbacks_before_image_saved = [] callbacks_before_image_saved = []
callbacks_image_saved = [] callbacks_image_saved = []
callbacks_cfg_denoiser = []
def clear_callbacks(): def clear_callbacks():
...@@ -41,7 +60,7 @@ def clear_callbacks(): ...@@ -41,7 +60,7 @@ def clear_callbacks():
callbacks_ui_settings.clear() callbacks_ui_settings.clear()
callbacks_before_image_saved.clear() callbacks_before_image_saved.clear()
callbacks_image_saved.clear() callbacks_image_saved.clear()
callbacks_cfg_denoiser.clear()
def app_started_callback(demo: Blocks, app: FastAPI): def app_started_callback(demo: Blocks, app: FastAPI):
for c in callbacks_app_started: for c in callbacks_app_started:
...@@ -95,6 +114,14 @@ def image_saved_callback(params: ImageSaveParams): ...@@ -95,6 +114,14 @@ def image_saved_callback(params: ImageSaveParams):
report_exception(c, 'image_saved_callback') report_exception(c, 'image_saved_callback')
def cfg_denoiser_callback(params: CFGDenoiserParams):
for c in callbacks_cfg_denoiser:
try:
c.callback(params)
except Exception:
report_exception(c, 'cfg_denoiser_callback')
def add_callback(callbacks, fun): def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__] stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file' filename = stack[0].filename if len(stack) > 0 else 'unknown file'
...@@ -147,3 +174,12 @@ def on_image_saved(callback): ...@@ -147,3 +174,12 @@ def on_image_saved(callback):
- params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing. - params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
""" """
add_callback(callbacks_image_saved, callback) add_callback(callbacks_image_saved, callback)
def on_cfg_denoiser(callback):
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
The callback is called with one argument:
- params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
"""
add_callback(callbacks_cfg_denoiser, callback)
...@@ -12,6 +12,7 @@ from modules import prompt_parser, devices, processing, images ...@@ -12,6 +12,7 @@ from modules import prompt_parser, devices, processing, images
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
...@@ -280,6 +281,12 @@ class CFGDenoiser(torch.nn.Module): ...@@ -280,6 +281,12 @@ class CFGDenoiser(torch.nn.Module):
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond]) image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
cfg_denoiser_callback(denoiser_params)
x_in = denoiser_params.x
image_cond_in = denoiser_params.image_cond
sigma_in = denoiser_params.sigma
if tensor.shape[1] == uncond.shape[1]: if tensor.shape[1] == uncond.shape[1]:
cond_in = torch.cat([tensor, uncond]) cond_in = torch.cat([tensor, uncond])
......
...@@ -51,6 +51,7 @@ parser.add_argument("--realesrgan-models-path", type=str, help="Path to director ...@@ -51,6 +51,7 @@ parser.add_argument("--realesrgan-models-path", type=str, help="Path to director
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET')) parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET'))
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR')) parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR'))
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR')) parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR'))
parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers") parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work") parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator") parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator")
...@@ -288,11 +289,12 @@ options_templates.update(options_section(('system', "System"), { ...@@ -288,11 +289,12 @@ options_templates.update(options_section(('system', "System"), {
})) }))
options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training hypernetwork. Saves VRAM."), "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"), "training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
"training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
})) }))
options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('sd', "Stable Diffusion"), {
......
...@@ -235,6 +235,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -235,6 +235,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name) log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
unload = shared.opts.unload_models_when_training
if save_embedding_every > 0: if save_embedding_every > 0:
embedding_dir = os.path.join(log_directory, "embeddings") embedding_dir = os.path.join(log_directory, "embeddings")
...@@ -272,6 +273,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -272,6 +273,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"): with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, shuffle_tags=shuffle_tags, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size) ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, shuffle_tags=shuffle_tags, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
if unload:
shared.sd_model.first_stage_model.to(devices.cpu)
embedding.vec.requires_grad = True embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
...@@ -328,6 +331,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -328,6 +331,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
if images_dir is not None and steps_done % create_image_every == 0: if images_dir is not None and steps_done % create_image_every == 0:
forced_filename = f'{embedding_name}-{steps_done}' forced_filename = f'{embedding_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename) last_saved_image = os.path.join(images_dir, forced_filename)
shared.sd_model.first_stage_model.to(devices.device)
p = processing.StableDiffusionProcessingTxt2Img( p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model, sd_model=shared.sd_model,
do_not_save_grid=True, do_not_save_grid=True,
...@@ -355,6 +361,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -355,6 +361,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
processed = processing.process_images(p) processed = processing.process_images(p)
image = processed.images[0] image = processed.images[0]
if unload:
shared.sd_model.first_stage_model.to(devices.cpu)
shared.state.current_image = image shared.state.current_image = image
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
...@@ -400,6 +409,7 @@ Last saved image: {html.escape(last_saved_image)}<br/> ...@@ -400,6 +409,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True) save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
shared.sd_model.first_stage_model.to(devices.device)
return embedding, filename return embedding, filename
......
...@@ -25,8 +25,10 @@ def train_embedding(*args): ...@@ -25,8 +25,10 @@ def train_embedding(*args):
assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible' assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
apply_optimizations = shared.opts.training_xattention_optimizations
try: try:
sd_hijack.undo_optimizations() if not apply_optimizations:
sd_hijack.undo_optimizations()
embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args) embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
...@@ -38,5 +40,6 @@ Embedding saved to {html.escape(filename)} ...@@ -38,5 +40,6 @@ Embedding saved to {html.escape(filename)}
except Exception: except Exception:
raise raise
finally: finally:
sd_hijack.apply_optimizations() if not apply_optimizations:
sd_hijack.apply_optimizations()
...@@ -47,6 +47,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: ...@@ -47,6 +47,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
if processed is None: if processed is None:
processed = process_images(p) processed = process_images(p)
p.close()
shared.total_tqdm.clear() shared.total_tqdm.clear()
generation_info_js = processed.js() generation_info_js = processed.js()
......
...@@ -671,6 +671,8 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -671,6 +671,8 @@ def create_ui(wrap_gradio_gpu_call):
import modules.img2img import modules.img2img
import modules.txt2img import modules.txt2img
reload_javascript()
parameters_copypaste.reset() parameters_copypaste.reset()
with gr.Blocks(analytics_enabled=False) as txt2img_interface: with gr.Blocks(analytics_enabled=False) as txt2img_interface:
...@@ -1573,8 +1575,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1573,8 +1575,7 @@ def create_ui(wrap_gradio_gpu_call):
reload_script_bodies.click( reload_script_bodies.click(
fn=reload_scripts, fn=reload_scripts,
inputs=[], inputs=[],
outputs=[], outputs=[]
_js='function(){}'
) )
def request_restart(): def request_restart():
...@@ -1586,7 +1587,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1586,7 +1587,7 @@ def create_ui(wrap_gradio_gpu_call):
fn=request_restart, fn=request_restart,
inputs=[], inputs=[],
outputs=[], outputs=[],
_js='function(){restart_reload()}' _js='restart_reload'
) )
if column is not None: if column is not None:
...@@ -1785,4 +1786,3 @@ def load_javascript(raw_response): ...@@ -1785,4 +1786,3 @@ def load_javascript(raw_response):
reload_javascript = partial(load_javascript, gradio.routes.templates.TemplateResponse) reload_javascript = partial(load_javascript, gradio.routes.templates.TemplateResponse)
reload_javascript()
...@@ -4,7 +4,7 @@ fairscale==0.4.4 ...@@ -4,7 +4,7 @@ fairscale==0.4.4
fonts fonts
font-roboto font-roboto
gfpgan gfpgan
gradio==3.5 gradio==3.8
invisible-watermark invisible-watermark
numpy numpy
omegaconf omegaconf
...@@ -12,7 +12,7 @@ opencv-python ...@@ -12,7 +12,7 @@ opencv-python
requests requests
piexif piexif
Pillow Pillow
pytorch_lightning pytorch_lightning==1.7.7
realesrgan realesrgan
scikit-image>=0.19 scikit-image>=0.19
timm==0.4.12 timm==0.4.12
......
...@@ -2,7 +2,7 @@ transformers==4.19.2 ...@@ -2,7 +2,7 @@ transformers==4.19.2
diffusers==0.3.0 diffusers==0.3.0
basicsr==1.4.2 basicsr==1.4.2
gfpgan==1.3.8 gfpgan==1.3.8
gradio==3.5 gradio==3.8
numpy==1.23.3 numpy==1.23.3
Pillow==9.2.0 Pillow==9.2.0
realesrgan==0.3.0 realesrgan==0.3.0
......
...@@ -96,6 +96,7 @@ class Script(scripts.Script): ...@@ -96,6 +96,7 @@ class Script(scripts.Script):
def ui(self, is_img2img): def ui(self, is_img2img):
checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False) checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False)
checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False)
prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1) prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1)
file = gr.File(label="Upload prompt inputs", type='bytes') file = gr.File(label="Upload prompt inputs", type='bytes')
...@@ -106,9 +107,9 @@ class Script(scripts.Script): ...@@ -106,9 +107,9 @@ class Script(scripts.Script):
# We don't shrink back to 1, because that causes the control to ignore [enter], and it may # We don't shrink back to 1, because that causes the control to ignore [enter], and it may
# be unclear to the user that shift-enter is needed. # be unclear to the user that shift-enter is needed.
prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt]) prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt])
return [checkbox_iterate, file, prompt_txt] return [checkbox_iterate, checkbox_iterate_batch, file, prompt_txt]
def run(self, p, checkbox_iterate, file, prompt_txt: str): def run(self, p, checkbox_iterate, checkbox_iterate_batch, file, prompt_txt: str):
lines = [x.strip() for x in prompt_txt.splitlines()] lines = [x.strip() for x in prompt_txt.splitlines()]
lines = [x for x in lines if len(x) > 0] lines = [x for x in lines if len(x) > 0]
...@@ -137,7 +138,7 @@ class Script(scripts.Script): ...@@ -137,7 +138,7 @@ class Script(scripts.Script):
jobs.append(args) jobs.append(args)
print(f"Will process {len(lines)} lines in {job_count} jobs.") print(f"Will process {len(lines)} lines in {job_count} jobs.")
if (checkbox_iterate and p.seed == -1): if (checkbox_iterate or checkbox_iterate_batch) and p.seed == -1:
p.seed = int(random.randrange(4294967294)) p.seed = int(random.randrange(4294967294))
state.job_count = job_count state.job_count = job_count
...@@ -153,7 +154,7 @@ class Script(scripts.Script): ...@@ -153,7 +154,7 @@ class Script(scripts.Script):
proc = process_images(copy_p) proc = process_images(copy_p)
images += proc.images images += proc.images
if (checkbox_iterate): if checkbox_iterate:
p.seed = p.seed + (p.batch_size * p.n_iter) p.seed = p.seed + (p.batch_size * p.n_iter)
......
...@@ -260,6 +260,16 @@ input[type="range"]{ ...@@ -260,6 +260,16 @@ input[type="range"]{
#txt2img_negative_prompt, #img2img_negative_prompt{ #txt2img_negative_prompt, #img2img_negative_prompt{
} }
/* gradio 3.8 adds opacity to progressbar which makes it blink; disable it here */
.transition.opacity-20 {
opacity: 1 !important;
}
/* more gradio's garbage cleanup */
.min-h-\[4rem\] {
min-height: unset !important;
}
#txt2img_progressbar, #img2img_progressbar, #ti_progressbar{ #txt2img_progressbar, #img2img_progressbar, #ti_progressbar{
position: absolute; position: absolute;
z-index: 1000; z-index: 1000;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment