Commit 924e2220 authored by AUTOMATIC's avatar AUTOMATIC

add option to show/hide warnings

removed hiding warnings from LDSR
fixed/reworked few places that produced warnings
parent 889b851a
import os import os
import gc import gc
import time import time
import warnings
import numpy as np import numpy as np
import torch import torch
...@@ -15,8 +14,6 @@ from ldm.models.diffusion.ddim import DDIMSampler ...@@ -15,8 +14,6 @@ from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config, ismap from ldm.util import instantiate_from_config, ismap
from modules import shared, sd_hijack from modules import shared, sd_hijack
warnings.filterwarnings("ignore", category=UserWarning)
cached_ldsr_model: torch.nn.Module = None cached_ldsr_model: torch.nn.Module = None
......
...@@ -11,7 +11,7 @@ ignore_ids_for_localization={ ...@@ -11,7 +11,7 @@ ignore_ids_for_localization={
train_embedding: 'OPTION', train_embedding: 'OPTION',
train_hypernetwork: 'OPTION', train_hypernetwork: 'OPTION',
txt2img_styles: 'OPTION', txt2img_styles: 'OPTION',
img2img_styles 'OPTION', img2img_styles: 'OPTION',
setting_random_artist_categories: 'SPAN', setting_random_artist_categories: 'SPAN',
setting_face_restoration_model: 'SPAN', setting_face_restoration_model: 'SPAN',
setting_realesrgan_enabled_models: 'SPAN', setting_realesrgan_enabled_models: 'SPAN',
......
...@@ -12,7 +12,7 @@ import torch ...@@ -12,7 +12,7 @@ import torch
import tqdm import tqdm
from einops import rearrange, repeat from einops import rearrange, repeat
from ldm.util import default from ldm.util import default
from modules import devices, processing, sd_models, shared, sd_samplers, hashes from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint
from modules.textual_inversion import textual_inversion, logging from modules.textual_inversion import textual_inversion, logging
from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum from torch import einsum
...@@ -575,6 +575,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi ...@@ -575,6 +575,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
pbar = tqdm.tqdm(total=steps - initial_step) pbar = tqdm.tqdm(total=steps - initial_step)
try: try:
sd_hijack_checkpoint.add()
for i in range((steps-initial_step) * gradient_step): for i in range((steps-initial_step) * gradient_step):
if scheduler.finished: if scheduler.finished:
break break
...@@ -724,6 +726,9 @@ Last saved image: {html.escape(last_saved_image)}<br/> ...@@ -724,6 +726,9 @@ Last saved image: {html.escape(last_saved_image)}<br/>
pbar.close() pbar.close()
hypernetwork.eval() hypernetwork.eval()
#report_statistics(loss_dict) #report_statistics(loss_dict)
sd_hijack_checkpoint.remove()
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
hypernetwork.optimizer_name = optimizer_name hypernetwork.optimizer_name = optimizer_name
......
...@@ -69,12 +69,6 @@ def undo_optimizations(): ...@@ -69,12 +69,6 @@ def undo_optimizations():
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
def fix_checkpoint():
ldm.modules.attention.BasicTransformerBlock.forward = sd_hijack_checkpoint.BasicTransformerBlock_forward
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
class StableDiffusionModelHijack: class StableDiffusionModelHijack:
fixes = None fixes = None
comments = [] comments = []
...@@ -106,8 +100,6 @@ class StableDiffusionModelHijack: ...@@ -106,8 +100,6 @@ class StableDiffusionModelHijack:
self.optimization_method = apply_optimizations() self.optimization_method = apply_optimizations()
self.clip = m.cond_stage_model self.clip = m.cond_stage_model
fix_checkpoint()
def flatten(el): def flatten(el):
flattened = [flatten(children) for children in el.children()] flattened = [flatten(children) for children in el.children()]
......
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
import ldm.modules.attention
import ldm.modules.diffusionmodules.openaimodel
def BasicTransformerBlock_forward(self, x, context=None): def BasicTransformerBlock_forward(self, x, context=None):
return checkpoint(self._forward, x, context) return checkpoint(self._forward, x, context)
def AttentionBlock_forward(self, x): def AttentionBlock_forward(self, x):
return checkpoint(self._forward, x) return checkpoint(self._forward, x)
def ResBlock_forward(self, x, emb): def ResBlock_forward(self, x, emb):
return checkpoint(self._forward, x, emb) return checkpoint(self._forward, x, emb)
\ No newline at end of file
stored = []
def add():
if len(stored) != 0:
return
stored.extend([
ldm.modules.attention.BasicTransformerBlock.forward,
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward,
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward
])
ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward
def remove():
if len(stored) == 0:
return
ldm.modules.attention.BasicTransformerBlock.forward = stored[0]
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1]
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2]
stored.clear()
...@@ -369,6 +369,7 @@ options_templates.update(options_section(('face-restoration', "Face restoration" ...@@ -369,6 +369,7 @@ options_templates.update(options_section(('face-restoration', "Face restoration"
})) }))
options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('system', "System"), {
"show_warnings": OptionInfo(False, "Show warnings in console."),
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}), "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"), "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."), "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
......
...@@ -15,7 +15,7 @@ import numpy as np ...@@ -15,7 +15,7 @@ import numpy as np
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint
import modules.textual_inversion.dataset import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.learn_schedule import LearnRateScheduler
...@@ -452,6 +452,8 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st ...@@ -452,6 +452,8 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
pbar = tqdm.tqdm(total=steps - initial_step) pbar = tqdm.tqdm(total=steps - initial_step)
try: try:
sd_hijack_checkpoint.add()
for i in range((steps-initial_step) * gradient_step): for i in range((steps-initial_step) * gradient_step):
if scheduler.finished: if scheduler.finished:
break break
...@@ -617,9 +619,11 @@ Last saved image: {html.escape(last_saved_image)}<br/> ...@@ -617,9 +619,11 @@ Last saved image: {html.escape(last_saved_image)}<br/>
pbar.close() pbar.close()
shared.sd_model.first_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device)
shared.parallel_processing_allowed = old_parallel_processing_allowed shared.parallel_processing_allowed = old_parallel_processing_allowed
sd_hijack_checkpoint.remove()
return embedding, filename return embedding, filename
def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True): def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
old_embedding_name = embedding.name old_embedding_name = embedding.name
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
......
...@@ -11,6 +11,7 @@ import tempfile ...@@ -11,6 +11,7 @@ import tempfile
import time import time
import traceback import traceback
from functools import partial, reduce from functools import partial, reduce
import warnings
import gradio as gr import gradio as gr
import gradio.routes import gradio.routes
...@@ -41,6 +42,8 @@ from modules.textual_inversion import textual_inversion ...@@ -41,6 +42,8 @@ from modules.textual_inversion import textual_inversion
import modules.hypernetworks.ui import modules.hypernetworks.ui
from modules.generation_parameters_copypaste import image_from_url_text from modules.generation_parameters_copypaste import image_from_url_text
warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init() mimetypes.init()
mimetypes.add_type('application/javascript', '.js') mimetypes.add_type('application/javascript', '.js')
...@@ -417,17 +420,16 @@ def apply_setting(key, value): ...@@ -417,17 +420,16 @@ def apply_setting(key, value):
return value return value
def update_generation_info(args): def update_generation_info(generation_info, html_info, img_index):
generation_info, html_info, img_index = args
try: try:
generation_info = json.loads(generation_info) generation_info = json.loads(generation_info)
if img_index < 0 or img_index >= len(generation_info["infotexts"]): if img_index < 0 or img_index >= len(generation_info["infotexts"]):
return html_info return html_info, gr.update()
return plaintext_to_html(generation_info["infotexts"][img_index]) return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update()
except Exception: except Exception:
pass pass
# if the json parse or anything else fails, just return the old html_info # if the json parse or anything else fails, just return the old html_info
return html_info return html_info, gr.update()
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
...@@ -508,10 +510,9 @@ Requested path was: {f} ...@@ -508,10 +510,9 @@ Requested path was: {f}
generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
generation_info_button.click( generation_info_button.click(
fn=update_generation_info, fn=update_generation_info,
_js="(x, y) => [x, y, selected_gallery_index()]", _js="function(x, y, z){ console.log(x, y, z); return [x, y, selected_gallery_index()] }",
inputs=[generation_info, html_info], inputs=[generation_info, html_info, html_info],
outputs=[html_info], outputs=[html_info, html_info],
preprocess=False
) )
save.click( save.click(
...@@ -526,7 +527,8 @@ Requested path was: {f} ...@@ -526,7 +527,8 @@ Requested path was: {f}
outputs=[ outputs=[
download_files, download_files,
html_log, html_log,
] ],
show_progress=False,
) )
save_zip.click( save_zip.click(
...@@ -588,7 +590,7 @@ def create_ui(): ...@@ -588,7 +590,7 @@ def create_ui():
txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False) dummy_component = gr.Label(visible=False)
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)
with gr.Row().style(equal_height=False): with gr.Row().style(equal_height=False):
with gr.Column(variant='compact', elem_id="txt2img_settings"): with gr.Column(variant='compact', elem_id="txt2img_settings"):
...@@ -768,7 +770,7 @@ def create_ui(): ...@@ -768,7 +770,7 @@ def create_ui():
with gr.Blocks(analytics_enabled=False) as img2img_interface: with gr.Blocks(analytics_enabled=False) as img2img_interface:
img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True) img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False)
with FormRow().style(equal_height=False): with FormRow().style(equal_height=False):
with gr.Column(variant='compact', elem_id="img2img_settings"): with gr.Column(variant='compact', elem_id="img2img_settings"):
...@@ -1768,7 +1770,10 @@ def create_ui(): ...@@ -1768,7 +1770,10 @@ def create_ui():
if saved_value is None: if saved_value is None:
ui_settings[key] = getattr(obj, field) ui_settings[key] = getattr(obj, field)
elif condition and not condition(saved_value): elif condition and not condition(saved_value):
print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.') pass
# this warning is generally not useful;
# print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
else: else:
setattr(obj, field, saved_value) setattr(obj, field, saved_value)
if init_field is not None: if init_field is not None:
......
...@@ -116,7 +116,7 @@ class Script(scripts.Script): ...@@ -116,7 +116,7 @@ class Script(scripts.Script):
checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch")) checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch"))
prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=self.elem_id("prompt_txt")) prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=self.elem_id("prompt_txt"))
file = gr.File(label="Upload prompt inputs", type='bytes', elem_id=self.elem_id("file")) file = gr.File(label="Upload prompt inputs", type='binary', elem_id=self.elem_id("file"))
file.change(fn=load_prompt_file, inputs=[file], outputs=[file, prompt_txt, prompt_txt]) file.change(fn=load_prompt_file, inputs=[file], outputs=[file, prompt_txt, prompt_txt])
......
...@@ -299,9 +299,8 @@ input[type="range"]{ ...@@ -299,9 +299,8 @@ input[type="range"]{
} }
/* more gradio's garbage cleanup */ /* more gradio's garbage cleanup */
.min-h-\[4rem\] { .min-h-\[4rem\] { min-height: unset !important; }
min-height: unset !important; .min-h-\[6rem\] { min-height: unset !important; }
}
.progressDiv{ .progressDiv{
position: absolute; position: absolute;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment