Commit 5756d517 authored by d8ahazard's avatar d8ahazard

Merge remote-tracking branch 'upstream/master' into ModelLoader

parents 11875f58 ada901ed
...@@ -21,3 +21,5 @@ __pycache__ ...@@ -21,3 +21,5 @@ __pycache__
/interrogate /interrogate
/user.css /user.css
/.idea /.idea
notification.mp3
/SwinIR
...@@ -68,13 +68,19 @@ window.addEventListener('paste', e => { ...@@ -68,13 +68,19 @@ window.addEventListener('paste', e => {
if ( ! isValidImageList( files ) ) { if ( ! isValidImageList( files ) ) {
return; return;
} }
[...gradioApp().querySelectorAll('input[type=file][accept="image/x-png,image/gif,image/jpeg"]')]
.filter(input => !input.matches('.\\!hidden input[type=file]')) const visibleImageFields = [...gradioApp().querySelectorAll('[data-testid="image"]')]
.forEach(input => { .filter(el => uiElementIsVisible(el));
input.files = files; if ( ! visibleImageFields.length ) {
input.dispatchEvent(new Event('change')) return;
}); }
[...gradioApp().querySelectorAll('[data-testid="image"]')]
.filter(imgWrap => !imgWrap.closest('.\\!hidden')) const firstFreeImageField = visibleImageFields
.forEach(imgWrap => dropReplaceImage( imgWrap, files )); .filter(el => el.querySelector('input[type=file]'))?.[0];
dropReplaceImage(
firstFreeImageField ?
firstFreeImageField :
visibleImageFields[visibleImageFields.length - 1]
, files );
}); });
...@@ -25,6 +25,9 @@ onUiUpdate(function(){ ...@@ -25,6 +25,9 @@ onUiUpdate(function(){
lastHeadImg = headImg; lastHeadImg = headImg;
// play notification sound if available
gradioApp().querySelector('#audio_notification audio')?.play();
if (document.hasFocus()) return; if (document.hasFocus()) return;
// Multiple copies of the images are in the DOM when one is selected. Dedup with a Set to get the real number generated. // Multiple copies of the images are in the DOM when one is selected. Dedup with a Set to get the real number generated.
......
// various functions for interation with ui.py not large enough to warrant putting them in separate files // various functions for interation with ui.py not large enough to warrant putting them in separate files
function selected_gallery_index(){ function selected_gallery_index(){
var gr = gradioApp() var buttons = gradioApp().querySelectorAll('[style="display: block;"].tabitem .gallery-item')
var buttons = gradioApp().querySelectorAll(".gallery-item") var button = gradioApp().querySelector('[style="display: block;"].tabitem .gallery-item.\\!ring-2')
var button = gr.querySelector(".gallery-item.\\!ring-2")
var result = -1 var result = -1
buttons.forEach(function(v, i){ if(v==button) { result = i } }) buttons.forEach(function(v, i){ if(v==button) { result = i } })
......
...@@ -3,6 +3,9 @@ import os ...@@ -3,6 +3,9 @@ import os
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import torch
import tqdm
from modules import processing, shared, images, devices from modules import processing, shared, images, devices
from modules.shared import opts from modules.shared import opts
import modules.gfpgan_model import modules.gfpgan_model
...@@ -137,3 +140,57 @@ def run_pnginfo(image): ...@@ -137,3 +140,57 @@ def run_pnginfo(image):
info = f"<div><p>{message}<p></div>" info = f"<div><p>{message}<p></div>"
return '', geninfo, info return '', geninfo, info
def run_modelmerger(modelname_0, modelname_1, interp_method, interp_amount):
# Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation)
def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1)
# Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
def sigmoid(theta0, theta1, alpha):
alpha = alpha * alpha * (3 - (2 * alpha))
return theta0 + ((theta1 - theta0) * alpha)
if os.path.exists(modelname_0):
model0_filename = modelname_0
modelname_0 = os.path.splitext(os.path.basename(modelname_0))[0]
else:
model0_filename = 'models/' + modelname_0 + '.ckpt'
if os.path.exists(modelname_1):
model1_filename = modelname_1
modelname_1 = os.path.splitext(os.path.basename(modelname_1))[0]
else:
model1_filename = 'models/' + modelname_1 + '.ckpt'
print(f"Loading {model0_filename}...")
model_0 = torch.load(model0_filename, map_location='cpu')
print(f"Loading {model1_filename}...")
model_1 = torch.load(model1_filename, map_location='cpu')
theta_0 = model_0['state_dict']
theta_1 = model_1['state_dict']
theta_funcs = {
"Weighted Sum": weighted_sum,
"Sigmoid": sigmoid,
}
theta_func = theta_funcs[interp_method]
print(f"Merging...")
for key in tqdm.tqdm(theta_0.keys()):
if 'model' in key and key in theta_1:
theta_0[key] = theta_func(theta_0[key], theta_1[key], interp_amount)
for key in theta_1.keys():
if 'model' in key and key not in theta_0:
theta_0[key] = theta_1[key]
output_modelname = 'models/' + modelname_0 + '-' + modelname_1 + '-merged.ckpt'
print(f"Saving to {output_modelname}...")
torch.save(model_0, output_modelname)
print(f"Checkpoint saved.")
return "Checkpoint saved to " + output_modelname
...@@ -79,6 +79,13 @@ class StableDiffusionProcessing: ...@@ -79,6 +79,13 @@ class StableDiffusionProcessing:
self.color_corrections = None self.color_corrections = None
self.denoising_strength: float = 0 self.denoising_strength: float = 0
self.ddim_eta = opts.ddim_eta
self.ddim_discretize = opts.ddim_discretize
self.s_churn = opts.s_churn
self.s_tmin = opts.s_tmin
self.s_tmax = float('inf') # not representable as a standard ui option
self.s_noise = opts.s_noise
if not seed_enable_extras: if not seed_enable_extras:
self.subseed = -1 self.subseed = -1
self.subseed_strength = 0 self.subseed_strength = 0
...@@ -117,6 +124,13 @@ class Processed: ...@@ -117,6 +124,13 @@ class Processed:
self.extra_generation_params = p.extra_generation_params self.extra_generation_params = p.extra_generation_params
self.index_of_first_image = index_of_first_image self.index_of_first_image = index_of_first_image
self.ddim_eta = p.ddim_eta
self.ddim_discretize = p.ddim_discretize
self.s_churn = p.s_churn
self.s_tmin = p.s_tmin
self.s_tmax = p.s_tmax
self.s_noise = p.s_noise
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0] self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0] self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) self.seed = int(self.seed if type(self.seed) != list else self.seed[0])
...@@ -406,7 +420,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -406,7 +420,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
index_of_first_image = 1 index_of_first_image = 1
if opts.grid_save: if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p) images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
devices.torch_gc() devices.torch_gc()
return Processed(p, output_images, all_seeds[0], infotext(), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image) return Processed(p, output_images, all_seeds[0], infotext(), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image)
......
...@@ -37,6 +37,11 @@ samplers = [ ...@@ -37,6 +37,11 @@ samplers = [
] ]
samplers_for_img2img = [x for x in samplers if x.name != 'PLMS'] samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']
sampler_extra_params = {
'sample_euler':['s_churn','s_tmin','s_tmax','s_noise'],
'sample_heun' :['s_churn','s_tmin','s_tmax','s_noise'],
'sample_dpm_2':['s_churn','s_tmin','s_tmax','s_noise'],
}
def setup_img2img_steps(p, steps=None): def setup_img2img_steps(p, steps=None):
if opts.img2img_fix_steps or steps is not None: if opts.img2img_fix_steps or steps is not None:
...@@ -120,9 +125,9 @@ class VanillaStableDiffusionSampler: ...@@ -120,9 +125,9 @@ class VanillaStableDiffusionSampler:
# existing code fails with cetain step counts, like 9 # existing code fails with cetain step counts, like 9
try: try:
self.sampler.make_schedule(ddim_num_steps=steps, verbose=False) self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=p.ddim_eta, ddim_discretize=p.ddim_discretize, verbose=False)
except Exception: except Exception:
self.sampler.make_schedule(ddim_num_steps=steps+1, verbose=False) self.sampler.make_schedule(ddim_num_steps=steps+1,ddim_eta=p.ddim_eta, ddim_discretize=p.ddim_discretize, verbose=False)
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise) x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
...@@ -149,9 +154,9 @@ class VanillaStableDiffusionSampler: ...@@ -149,9 +154,9 @@ class VanillaStableDiffusionSampler:
# existing code fails with cetin step counts, like 9 # existing code fails with cetin step counts, like 9
try: try:
samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x) samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=p.ddim_eta)
except Exception: except Exception:
samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x) samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=p.ddim_eta)
return samples_ddim return samples_ddim
...@@ -224,6 +229,7 @@ class KDiffusionSampler: ...@@ -224,6 +229,7 @@ class KDiffusionSampler:
self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization) self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization)
self.funcname = funcname self.funcname = funcname
self.func = getattr(k_diffusion.sampling, self.funcname) self.func = getattr(k_diffusion.sampling, self.funcname)
self.extra_params = sampler_extra_params.get(funcname,[])
self.model_wrap_cfg = CFGDenoiser(self.model_wrap) self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
self.sampler_noises = None self.sampler_noises = None
self.sampler_noise_index = 0 self.sampler_noise_index = 0
...@@ -269,7 +275,12 @@ class KDiffusionSampler: ...@@ -269,7 +275,12 @@ class KDiffusionSampler:
if self.sampler_noises is not None: if self.sampler_noises is not None:
k_diffusion.sampling.torch = TorchHijack(self) k_diffusion.sampling.torch = TorchHijack(self)
return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state) extra_params_kwargs = {}
for val in self.extra_params:
if hasattr(p,val):
extra_params_kwargs[val] = getattr(p,val)
return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None): def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
steps = steps or p.steps steps = steps or p.steps
...@@ -286,7 +297,12 @@ class KDiffusionSampler: ...@@ -286,7 +297,12 @@ class KDiffusionSampler:
if self.sampler_noises is not None: if self.sampler_noises is not None:
k_diffusion.sampling.torch = TorchHijack(self) k_diffusion.sampling.torch = TorchHijack(self)
samples = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state) extra_params_kwargs = {}
for val in self.extra_params:
if hasattr(p,val):
extra_params_kwargs[val] = getattr(p,val)
samples = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
return samples return samples
...@@ -76,7 +76,7 @@ class State: ...@@ -76,7 +76,7 @@ class State:
job = "" job = ""
job_no = 0 job_no = 0
job_count = 0 job_count = 0
job_timestamp = 0 job_timestamp = '0'
sampling_step = 0 sampling_step = 0
sampling_steps = 0 sampling_steps = 0
current_latent = None current_latent = None
...@@ -90,6 +90,7 @@ class State: ...@@ -90,6 +90,7 @@ class State:
self.job_no += 1 self.job_no += 1
self.sampling_step = 0 self.sampling_step = 0
self.current_image_sampling_step = 0 self.current_image_sampling_step = 0
def get_job_timestamp(self): def get_job_timestamp(self):
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") return datetime.datetime.now().strftime("%Y%m%d%H%M%S")
...@@ -229,6 +230,13 @@ options_templates.update(options_section(('ui', "User interface"), { ...@@ -229,6 +230,13 @@ options_templates.update(options_section(('ui', "User interface"), {
"js_modal_lightbox_initialy_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "js_modal_lightbox_initialy_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
})) }))
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
"ddim_eta": OptionInfo(0.0, "DDIM eta", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform','quad']}),
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
}))
class Options: class Options:
data = None data = None
......
...@@ -49,6 +49,7 @@ sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None ...@@ -49,6 +49,7 @@ sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None
css_hide_progressbar = """ css_hide_progressbar = """
.wrap .m-12 svg { display:none!important; } .wrap .m-12 svg { display:none!important; }
.wrap .m-12::before { content:"Loading..." }
.progress-bar { display:none!important; } .progress-bar { display:none!important; }
.meta-text { display:none!important; } .meta-text { display:none!important; }
""" """
...@@ -393,7 +394,7 @@ def setup_progressbar(progressbar, preview, id_part): ...@@ -393,7 +394,7 @@ def setup_progressbar(progressbar, preview, id_part):
) )
def create_ui(txt2img, img2img, run_extras, run_pnginfo): def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
with gr.Blocks(analytics_enabled=False) as txt2img_interface: with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste = create_toprow(is_img2img=False) txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False) dummy_component = gr.Label(visible=False)
...@@ -564,13 +565,13 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): ...@@ -564,13 +565,13 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
with gr.TabItem('Inpaint', id='inpaint'): with gr.TabItem('Inpaint', id='inpaint'):
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA") init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA")
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False) init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False) init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4)
with gr.Row(): with gr.Row():
mask_mode = gr.Radio(label="Mask mode", show_label=False, choices=["Draw mask", "Upload mask"], type="index", value="Draw mask") mask_mode = gr.Radio(label="Mask mode", show_label=False, choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
inpainting_mask_invert = gr.Radio(label='Masking mode', show_label=False, choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index") inpainting_mask_invert = gr.Radio(label='Masking mode', show_label=False, choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index")
inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index") inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index")
...@@ -853,6 +854,33 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): ...@@ -853,6 +854,33 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
outputs=[html, generation_info, html2], outputs=[html, generation_info, html2],
) )
with gr.Blocks() as modelmerger_interface:
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>/models</b> directory.</p>")
modelname_0 = gr.Textbox(elem_id="modelmerger_modelname_0", label="Model Name (to)")
modelname_1 = gr.Textbox(elem_id="modelmerger_modelname_1", label="Model Name (from)")
interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid"], value="Weighted Sum", label="Interpolation Method")
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation Amount', value=0.3)
submit = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
with gr.Column(variant='panel'):
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
submit.click(
fn=run_modelmerger,
inputs=[
modelname_0,
modelname_1,
interp_method,
interp_amount
],
outputs=[
submit_result,
]
)
def create_setting_component(key): def create_setting_component(key):
def fun(): def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key].default return opts.data[key] if key in opts.data else opts.data_labels[key].default
...@@ -950,6 +978,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): ...@@ -950,6 +978,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
(img2img_interface, "img2img", "img2img"), (img2img_interface, "img2img", "img2img"),
(extras_interface, "Extras", "extras"), (extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"), (pnginfo_interface, "PNG Info", "pnginfo"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
(settings_interface, "Settings", "settings"), (settings_interface, "Settings", "settings"),
] ]
...@@ -971,6 +1000,9 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): ...@@ -971,6 +1000,9 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
with gr.TabItem(label, id=ifid): with gr.TabItem(label, id=ifid):
interface.render() interface.render()
if os.path.exists(os.path.join(script_path, "notification.mp3")):
audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
settings_submit.click( settings_submit.click(
fn=run_settings, fn=run_settings,
......
...@@ -39,3 +39,24 @@ document.addEventListener("DOMContentLoaded", function() { ...@@ -39,3 +39,24 @@ document.addEventListener("DOMContentLoaded", function() {
}); });
mutationObserver.observe( gradioApp(), { childList:true, subtree:true }) mutationObserver.observe( gradioApp(), { childList:true, subtree:true })
}); });
/**
* checks that a UI element is not in another hidden element or tab content
*/
function uiElementIsVisible(el) {
let isVisible = !el.closest('.\\!hidden');
if ( ! isVisible ) {
return false;
}
while( isVisible = el.closest('.tabitem')?.style.display !== 'none' ) {
if ( ! isVisible ) {
return false;
} else if ( el.parentElement ) {
el = el.parentElement
} else {
break;
}
}
return isVisible;
}
\ No newline at end of file
...@@ -59,7 +59,55 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps): ...@@ -59,7 +59,55 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
return x / x.std() return x / x.std()
Cached = namedtuple("Cached", ["noise", "cfg_scale", "steps", "latent", "original_prompt", "original_negative_prompt"]) Cached = namedtuple("Cached", ["noise", "cfg_scale", "steps", "latent", "original_prompt", "original_negative_prompt", "sigma_adjustment"])
# Based on changes suggested by briansemrau in https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/736
def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
x = p.init_latent
s_in = x.new_ones([x.shape[0]])
dnw = K.external.CompVisDenoiser(shared.sd_model)
sigmas = dnw.get_sigmas(steps).flip(0)
shared.state.sampling_steps = steps
for i in trange(1, len(sigmas)):
shared.state.sampling_step += 1
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigmas[i - 1] * s_in] * 2)
cond_in = torch.cat([uncond, cond])
c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)]
if i == 1:
t = dnw.sigma_to_t(torch.cat([sigmas[i] * s_in] * 2))
else:
t = dnw.sigma_to_t(sigma_in)
eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)
denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2)
denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cfg_scale
if i == 1:
d = (x - denoised) / (2 * sigmas[i])
else:
d = (x - denoised) / sigmas[i - 1]
dt = sigmas[i] - sigmas[i - 1]
x = x + d * dt
sd_samplers.store_latent(x)
# This shouldn't be necessary, but solved some VRAM issues
del x_in, sigma_in, cond_in, c_out, c_in, t,
del eps, denoised_uncond, denoised_cond, denoised, d, dt
shared.state.nextjob()
return x / sigmas[-1]
class Script(scripts.Script): class Script(scripts.Script):
...@@ -78,9 +126,10 @@ class Script(scripts.Script): ...@@ -78,9 +126,10 @@ class Script(scripts.Script):
cfg = gr.Slider(label="Decode CFG scale", minimum=0.0, maximum=15.0, step=0.1, value=1.0) cfg = gr.Slider(label="Decode CFG scale", minimum=0.0, maximum=15.0, step=0.1, value=1.0)
st = gr.Slider(label="Decode steps", minimum=1, maximum=150, step=1, value=50) st = gr.Slider(label="Decode steps", minimum=1, maximum=150, step=1, value=50)
randomness = gr.Slider(label="Randomness", minimum=0.0, maximum=1.0, step=0.01, value=0.0) randomness = gr.Slider(label="Randomness", minimum=0.0, maximum=1.0, step=0.01, value=0.0)
return [original_prompt, original_negative_prompt, cfg, st, randomness] sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False)
return [original_prompt, original_negative_prompt, cfg, st, randomness, sigma_adjustment]
def run(self, p, original_prompt, original_negative_prompt, cfg, st, randomness): def run(self, p, original_prompt, original_negative_prompt, cfg, st, randomness, sigma_adjustment):
p.batch_size = 1 p.batch_size = 1
p.batch_count = 1 p.batch_count = 1
...@@ -88,7 +137,10 @@ class Script(scripts.Script): ...@@ -88,7 +137,10 @@ class Script(scripts.Script):
def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
lat = (p.init_latent.cpu().numpy() * 10).astype(int) lat = (p.init_latent.cpu().numpy() * 10).astype(int)
same_params = self.cache is not None and self.cache.cfg_scale == cfg and self.cache.steps == st and self.cache.original_prompt == original_prompt and self.cache.original_negative_prompt == original_negative_prompt same_params = self.cache is not None and self.cache.cfg_scale == cfg and self.cache.steps == st \
and self.cache.original_prompt == original_prompt \
and self.cache.original_negative_prompt == original_negative_prompt \
and self.cache.sigma_adjustment == sigma_adjustment
same_everything = same_params and self.cache.latent.shape == lat.shape and np.abs(self.cache.latent-lat).sum() < 100 same_everything = same_params and self.cache.latent.shape == lat.shape and np.abs(self.cache.latent-lat).sum() < 100
if same_everything: if same_everything:
...@@ -97,8 +149,11 @@ class Script(scripts.Script): ...@@ -97,8 +149,11 @@ class Script(scripts.Script):
shared.state.job_count += 1 shared.state.job_count += 1
cond = p.sd_model.get_learned_conditioning(p.batch_size * [original_prompt]) cond = p.sd_model.get_learned_conditioning(p.batch_size * [original_prompt])
uncond = p.sd_model.get_learned_conditioning(p.batch_size * [original_negative_prompt]) uncond = p.sd_model.get_learned_conditioning(p.batch_size * [original_negative_prompt])
if sigma_adjustment:
rec_noise = find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg, st)
else:
rec_noise = find_noise_for_image(p, cond, uncond, cfg, st) rec_noise = find_noise_for_image(p, cond, uncond, cfg, st)
self.cache = Cached(rec_noise, cfg, st, lat, original_prompt, original_negative_prompt) self.cache = Cached(rec_noise, cfg, st, lat, original_prompt, original_negative_prompt, sigma_adjustment)
rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], [p.seed + x + 1 for x in range(p.init_latent.shape[0])]) rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], [p.seed + x + 1 for x in range(p.init_latent.shape[0])])
...@@ -121,6 +176,7 @@ class Script(scripts.Script): ...@@ -121,6 +176,7 @@ class Script(scripts.Script):
p.extra_generation_params["Decode CFG scale"] = cfg p.extra_generation_params["Decode CFG scale"] = cfg
p.extra_generation_params["Decode steps"] = st p.extra_generation_params["Decode steps"] = st
p.extra_generation_params["Randomness"] = randomness p.extra_generation_params["Randomness"] = randomness
p.extra_generation_params["Sigma Adjustment"] = sigma_adjustment
processed = processing.process_images(p) processed = processing.process_images(p)
......
...@@ -2,6 +2,7 @@ from collections import namedtuple ...@@ -2,6 +2,7 @@ from collections import namedtuple
from copy import copy from copy import copy
import random import random
from PIL import Image
import numpy as np import numpy as np
import modules.scripts as scripts import modules.scripts as scripts
...@@ -86,7 +87,12 @@ axis_options = [ ...@@ -86,7 +87,12 @@ axis_options = [
AxisOption("Prompt S/R", str, apply_prompt, format_value), AxisOption("Prompt S/R", str, apply_prompt, format_value),
AxisOption("Sampler", str, apply_sampler, format_value), AxisOption("Sampler", str, apply_sampler, format_value),
AxisOption("Checkpoint name", str, apply_checkpoint, format_value), AxisOption("Checkpoint name", str, apply_checkpoint, format_value),
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label),
AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label),
AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label),
AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label),
AxisOption("DDIM Eta", float, apply_field("ddim_eta"), format_value_add_label),
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label),# as it is now all AxisOptionImg2Img items must go after AxisOption ones
] ]
...@@ -108,7 +114,10 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend): ...@@ -108,7 +114,10 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend):
if first_pocessed is None: if first_pocessed is None:
first_pocessed = processed first_pocessed = processed
try:
res.append(processed.images[0]) res.append(processed.images[0])
except:
res.append(Image.new(res[0].mode, res[0].size))
grid = images.image_grid(res, rows=len(ys)) grid = images.image_grid(res, rows=len(ys))
if draw_legend: if draw_legend:
......
...@@ -85,7 +85,8 @@ def webui(): ...@@ -85,7 +85,8 @@ def webui():
txt2img=wrap_gradio_gpu_call(modules.txt2img.txt2img), txt2img=wrap_gradio_gpu_call(modules.txt2img.txt2img),
img2img=wrap_gradio_gpu_call(modules.img2img.img2img), img2img=wrap_gradio_gpu_call(modules.img2img.img2img),
run_extras=wrap_gradio_gpu_call(modules.extras.run_extras), run_extras=wrap_gradio_gpu_call(modules.extras.run_extras),
run_pnginfo=modules.extras.run_pnginfo run_pnginfo=modules.extras.run_pnginfo,
run_modelmerger=modules.extras.run_modelmerger
) )
demo.launch( demo.launch(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment