Commit 4df63d2d authored by AUTOMATIC's avatar AUTOMATIC

split samplers into one more files for k-diffusion

parent 27447410
This diff is collapsed.
from collections import namedtuple, deque from collections import namedtuple
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
...@@ -64,6 +64,7 @@ class InterruptedException(BaseException): ...@@ -64,6 +64,7 @@ class InterruptedException(BaseException):
# MPS fix for randn in torchsde # MPS fix for randn in torchsde
# XXX move this to separate file for MPS
def torchsde_randn(size, dtype, device, seed): def torchsde_randn(size, dtype, device, seed):
if device.type == 'mps': if device.type == 'mps':
generator = torch.Generator(devices.cpu).manual_seed(int(seed)) generator = torch.Generator(devices.cpu).manual_seed(int(seed))
......
import math import math
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
import numpy as np import numpy as np
import torch import torch
...@@ -7,6 +9,12 @@ from modules.shared import state ...@@ -7,6 +9,12 @@ from modules.shared import state
from modules import sd_samplers_common, prompt_parser, shared from modules import sd_samplers_common, prompt_parser, shared
samplers_data_compvis = [
sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
]
class VanillaStableDiffusionSampler: class VanillaStableDiffusionSampler:
def __init__(self, constructor, sd_model): def __init__(self, constructor, sd_model):
self.sampler = constructor(sd_model) self.sampler = constructor(sd_model)
......
...@@ -2,18 +2,12 @@ from collections import deque ...@@ -2,18 +2,12 @@ from collections import deque
import torch import torch
import inspect import inspect
import k_diffusion.sampling import k_diffusion.sampling
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_compvis from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_compvis
from modules.shared import opts, state from modules.shared import opts, state
import modules.shared as shared import modules.shared as shared
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
# imports for functions that previously were here and are used by other modules
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image
samplers_k_diffusion = [ samplers_k_diffusion = [
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}), ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
('Euler', 'sample_euler', ['k_euler'], {}), ('Euler', 'sample_euler', ['k_euler'], {}),
...@@ -40,50 +34,6 @@ samplers_data_k_diffusion = [ ...@@ -40,50 +34,6 @@ samplers_data_k_diffusion = [
if hasattr(k_diffusion.sampling, funcname) if hasattr(k_diffusion.sampling, funcname)
] ]
all_samplers = [
*samplers_data_k_diffusion,
sd_samplers_common.SamplerData('DDIM', lambda model: sd_samplers_compvis.VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
sd_samplers_common.SamplerData('PLMS', lambda model: sd_samplers_compvis.VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
]
all_samplers_map = {x.name: x for x in all_samplers}
samplers = []
samplers_for_img2img = []
samplers_map = {}
def create_sampler(name, model):
if name is not None:
config = all_samplers_map.get(name, None)
else:
config = all_samplers[0]
assert config is not None, f'bad sampler name: {name}'
sampler = config.constructor(model)
sampler.config = config
return sampler
def set_samplers():
global samplers, samplers_for_img2img
hidden = set(opts.hide_samplers)
hidden_img2img = set(opts.hide_samplers + ['PLMS'])
samplers = [x for x in all_samplers if x.name not in hidden]
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
samplers_map.clear()
for sampler in all_samplers:
samplers_map[sampler.name.lower()] = sampler.name
for alias in sampler.aliases:
samplers_map[alias.lower()] = sampler.name
set_samplers()
sampler_extra_params = { sampler_extra_params = {
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'], 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'], 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
...@@ -92,6 +42,13 @@ sampler_extra_params = { ...@@ -92,6 +42,13 @@ sampler_extra_params = {
class CFGDenoiser(torch.nn.Module): class CFGDenoiser(torch.nn.Module):
"""
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
instead of one. Originally, the second prompt is just an empty string, but we use non-empty
negative prompt.
"""
def __init__(self, model): def __init__(self, model):
super().__init__() super().__init__()
self.inner_model = model self.inner_model = model
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment