Commit 4df63d2d authored by AUTOMATIC's avatar AUTOMATIC

split samplers into one more files for k-diffusion

parent 27447410
This diff is collapsed.
from collections import namedtuple, deque
from collections import namedtuple
import numpy as np
import torch
from PIL import Image
......@@ -64,6 +64,7 @@ class InterruptedException(BaseException):
# MPS fix for randn in torchsde
# XXX move this to separate file for MPS
def torchsde_randn(size, dtype, device, seed):
if device.type == 'mps':
generator = torch.Generator(devices.cpu).manual_seed(int(seed))
......
import math
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
import numpy as np
import torch
......@@ -7,6 +9,12 @@ from modules.shared import state
from modules import sd_samplers_common, prompt_parser, shared
samplers_data_compvis = [
sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
]
class VanillaStableDiffusionSampler:
def __init__(self, constructor, sd_model):
self.sampler = constructor(sd_model)
......
......@@ -2,18 +2,12 @@ from collections import deque
import torch
import inspect
import k_diffusion.sampling
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_compvis
from modules.shared import opts, state
import modules.shared as shared
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
# imports for functions that previously were here and are used by other modules
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image
samplers_k_diffusion = [
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
('Euler', 'sample_euler', ['k_euler'], {}),
......@@ -40,50 +34,6 @@ samplers_data_k_diffusion = [
if hasattr(k_diffusion.sampling, funcname)
]
all_samplers = [
*samplers_data_k_diffusion,
sd_samplers_common.SamplerData('DDIM', lambda model: sd_samplers_compvis.VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
sd_samplers_common.SamplerData('PLMS', lambda model: sd_samplers_compvis.VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
]
all_samplers_map = {x.name: x for x in all_samplers}
samplers = []
samplers_for_img2img = []
samplers_map = {}
def create_sampler(name, model):
if name is not None:
config = all_samplers_map.get(name, None)
else:
config = all_samplers[0]
assert config is not None, f'bad sampler name: {name}'
sampler = config.constructor(model)
sampler.config = config
return sampler
def set_samplers():
global samplers, samplers_for_img2img
hidden = set(opts.hide_samplers)
hidden_img2img = set(opts.hide_samplers + ['PLMS'])
samplers = [x for x in all_samplers if x.name not in hidden]
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
samplers_map.clear()
for sampler in all_samplers:
samplers_map[sampler.name.lower()] = sampler.name
for alias in sampler.aliases:
samplers_map[alias.lower()] = sampler.name
set_samplers()
sampler_extra_params = {
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
......@@ -92,6 +42,13 @@ sampler_extra_params = {
class CFGDenoiser(torch.nn.Module):
"""
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
instead of one. Originally, the second prompt is just an empty string, but we use non-empty
negative prompt.
"""
def __init__(self, model):
super().__init__()
self.inner_model = model
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment