Commit 7ce7fb01 authored by AUTOMATIC's avatar AUTOMATIC

fix for live progress breaking lowvram and medvram optimizations

parent 0bfa0d43
...@@ -33,6 +33,9 @@ A browser interface based on Gradio library for Stable Diffusion. ...@@ -33,6 +33,9 @@ A browser interface based on Gradio library for Stable Diffusion.
- Running custom code from UI - Running custom code from UI
- Mouseover hints fo most UI elements - Mouseover hints fo most UI elements
- Possible to change defaults/mix/max/step values for UI elements via text config - Possible to change defaults/mix/max/step values for UI elements via text config
- Random artist button
- Tiling support: UI checkbox to create images that can be tiled like textures
- Progress bar and live image generation preview
## Installing and running ## Installing and running
......
from collections import namedtuple from collections import namedtuple
import numpy as np
import ldm.models.diffusion.ddim
import torch import torch
import tqdm import tqdm
from PIL import Image
import k_diffusion.sampling import k_diffusion.sampling
import ldm.models.diffusion.ddim import ldm.models.diffusion.ddim
...@@ -37,12 +37,28 @@ samplers = [ ...@@ -37,12 +37,28 @@ samplers = [
samplers_for_img2img = [x for x in samplers if x.name != 'PLMS'] samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']
def sample_to_image(samples):
x_sample = shared.sd_model.decode_first_stage(samples[0:1].type(shared.sd_model.dtype))[0]
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
return Image.fromarray(x_sample)
def store_latent(decoded):
state.current_latent = decoded
if opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
if not shared.parallel_processing_allowed:
shared.state.current_image = sample_to_image(decoded)
def p_sample_ddim_hook(sampler_wrapper, x_dec, cond, ts, *args, **kwargs): def p_sample_ddim_hook(sampler_wrapper, x_dec, cond, ts, *args, **kwargs):
if sampler_wrapper.mask is not None: if sampler_wrapper.mask is not None:
img_orig = sampler_wrapper.sampler.model.q_sample(sampler_wrapper.init_latent, ts) img_orig = sampler_wrapper.sampler.model.q_sample(sampler_wrapper.init_latent, ts)
x_dec = img_orig * sampler_wrapper.mask + sampler_wrapper.nmask * x_dec x_dec = img_orig * sampler_wrapper.mask + sampler_wrapper.nmask * x_dec
state.current_latent = x_dec store_latent(x_dec)
return sampler_wrapper.orig_p_sample_ddim(x_dec, cond, ts, *args, **kwargs) return sampler_wrapper.orig_p_sample_ddim(x_dec, cond, ts, *args, **kwargs)
...@@ -144,7 +160,7 @@ class KDiffusionSampler: ...@@ -144,7 +160,7 @@ class KDiffusionSampler:
self.model_wrap_cfg = CFGDenoiser(self.model_wrap) self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
def callback_state(self, d): def callback_state(self, d):
state.current_latent = d["denoised"] store_latent(d["denoised"])
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning): def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
t_enc = int(min(p.denoising_strength, 0.999) * p.steps) t_enc = int(min(p.denoising_strength, 0.999) * p.steps)
......
...@@ -38,7 +38,7 @@ cpu = torch.device("cpu") ...@@ -38,7 +38,7 @@ cpu = torch.device("cpu")
gpu = torch.device("cuda") gpu = torch.device("cuda")
device = gpu if torch.cuda.is_available() else cpu device = gpu if torch.cuda.is_available() else cpu
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram) batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
class State: class State:
interrupted = False interrupted = False
...@@ -49,7 +49,8 @@ class State: ...@@ -49,7 +49,8 @@ class State:
sampling_steps = 0 sampling_steps = 0
current_latent = None current_latent = None
current_image = None current_image = None
current_progress_index = 0 current_image_sampling_step = 0
def interrupt(self): def interrupt(self):
self.interrupted = True self.interrupted = True
...@@ -57,6 +58,7 @@ class State: ...@@ -57,6 +58,7 @@ class State:
def nextjob(self): def nextjob(self):
self.job_no += 1 self.job_no += 1
self.sampling_step = 0 self.sampling_step = 0
self.current_image_sampling_step = 0
state = State() state = State()
...@@ -103,7 +105,7 @@ class Options: ...@@ -103,7 +105,7 @@ class Options:
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
"upscale_at_full_resolution_padding": OptionInfo(16, "Inpainting at full resolution: padding, in pixels, for the masked region.", gr.Slider, {"minimum": 0, "maximum": 128, "step": 4}), "upscale_at_full_resolution_padding": OptionInfo(16, "Inpainting at full resolution: padding, in pixels, for the masked region.", gr.Slider, {"minimum": 0, "maximum": 128, "step": 4}),
"show_progressbar": OptionInfo(True, "Show progressbar"), "show_progressbar": OptionInfo(True, "Show progressbar"),
"show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N progress pudates. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}), "show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
} }
def __init__(self): def __init__(self):
......
...@@ -160,13 +160,11 @@ def check_progress_call(): ...@@ -160,13 +160,11 @@ def check_progress_call():
preview_visibility = gr_show(False) preview_visibility = gr_show(False)
if opts.show_progress_every_n_steps > 0: if opts.show_progress_every_n_steps > 0:
if shared.state.current_progress_index % opts.show_progress_every_n_steps == 0 and shared.state.current_latent is not None: if shared.parallel_processing_allowed:
x_sample = shared.sd_model.decode_first_stage(shared.state.current_latent[0:1].type(shared.sd_model.dtype))[0]
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
shared.state.current_image = Image.fromarray(x_sample)
if shared.state.sampling_step - shared.state.current_image_sampling_step >= opts.show_progress_every_n_steps and shared.state.current_latent is not None:
shared.state.current_image = modules.sd_samplers.sample_to_image(shared.state.current_latent)
shared.state.current_image_sampling_step = shared.state.sampling_step
image = shared.state.current_image image = shared.state.current_image
...@@ -175,8 +173,6 @@ def check_progress_call(): ...@@ -175,8 +173,6 @@ def check_progress_call():
else: else:
preview_visibility = gr_show(True) preview_visibility = gr_show(True)
shared.state.current_progress_index += 1
return f"<span style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image return f"<span style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image
......
...@@ -127,7 +127,7 @@ def wrap_gradio_gpu_call(func): ...@@ -127,7 +127,7 @@ def wrap_gradio_gpu_call(func):
shared.state.job_no = 0 shared.state.job_no = 0
shared.state.current_latent = None shared.state.current_latent = None
shared.state.current_image = None shared.state.current_image = None
shared.state.current_progress_index = 0 shared.state.current_image_sampling_step = 0
with queue_lock: with queue_lock:
res = func(*args, **kwargs) res = func(*args, **kwargs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment