Commit d51847c1 authored by AUTOMATIC's avatar AUTOMATIC

fix caching for img2imgalt

parent 91c56c51
from collections import namedtuple
import numpy as np import numpy as np
from tqdm import trange from tqdm import trange
...@@ -56,9 +58,14 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps): ...@@ -56,9 +58,14 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
return x / x.std() return x / x.std()
cache = [None, None, None, None, None]
Cached = namedtuple("Cached", ["noise", "cfg_scale", "steps", "latent", "original_prompt"])
class Script(scripts.Script): class Script(scripts.Script):
def __init__(self):
self.cache = None
def title(self): def title(self):
return "img2img alternative test" return "img2img alternative test"
...@@ -67,7 +74,7 @@ class Script(scripts.Script): ...@@ -67,7 +74,7 @@ class Script(scripts.Script):
def ui(self, is_img2img): def ui(self, is_img2img):
original_prompt = gr.Textbox(label="Original prompt", lines=1) original_prompt = gr.Textbox(label="Original prompt", lines=1)
cfg = gr.Slider(label="Decode CFG scale", minimum=0.1, maximum=3.0, step=0.1, value=1.0) cfg = gr.Slider(label="Decode CFG scale", minimum=0.0, maximum=15.0, step=0.1, value=1.0)
st = gr.Slider(label="Decode steps", minimum=1, maximum=150, step=1, value=50) st = gr.Slider(label="Decode steps", minimum=1, maximum=150, step=1, value=50)
return [original_prompt, cfg, st] return [original_prompt, cfg, st]
...@@ -77,19 +84,18 @@ class Script(scripts.Script): ...@@ -77,19 +84,18 @@ class Script(scripts.Script):
p.batch_count = 1 p.batch_count = 1
def sample_extra(x, conditioning, unconditional_conditioning): def sample_extra(x, conditioning, unconditional_conditioning):
lat = tuple([int(x*10) for x in p.init_latent.cpu().numpy().flatten().tolist()]) lat = (p.init_latent.cpu().numpy() * 10).astype(int)
same_params = self.cache is not None and self.cache.cfg_scale == cfg and self.cache.steps == st and self.cache.original_prompt == original_prompt
same_everything = same_params and self.cache.latent.shape == lat.shape and np.abs(self.cache.latent-lat).sum() < 100
if cache[0] is not None and cache[1] == cfg and cache[2] == st and len(cache[3]) == len(lat) and sum(np.array(cache[3])-np.array(lat)) < 100 and cache[4] == original_prompt: if same_everything:
noise = cache[0] noise = self.cache.noise
else: else:
shared.state.job_count += 1 shared.state.job_count += 1
cond = p.sd_model.get_learned_conditioning(p.batch_size * [original_prompt]) cond = p.sd_model.get_learned_conditioning(p.batch_size * [original_prompt])
noise = find_noise_for_image(p, cond, unconditional_conditioning, cfg, st) noise = find_noise_for_image(p, cond, unconditional_conditioning, cfg, st)
cache[0] = noise self.cache = Cached(noise, cfg, st, lat, original_prompt)
cache[1] = cfg
cache[2] = st
cache[3] = lat
cache[4] = original_prompt
sampler = samplers[p.sampler_index].constructor(p.sd_model) sampler = samplers[p.sampler_index].constructor(p.sd_model)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment