Commit ce691115 authored by AUTOMATIC's avatar AUTOMATIC

Add support Stable Diffusion 2.0

parent 828438b4
...@@ -84,26 +84,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web ...@@ -84,26 +84,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- API - API
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML. - Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML.
- via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients)) - via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))
- [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions
## Where are Aesthetic Gradients?!?!
Aesthetic Gradients are now an extension. You can install it using git:
```commandline
git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients extensions/aesthetic-gradients
```
After running this command, make sure that you have `aesthetic-gradients` dir in webui's `extensions` directory and restart
the UI. The interface for Aesthetic Gradients should appear exactly the same as it was.
## Where is History/Image browser?!?!
Image browser is now an extension. You can install it using git:
```commandline
git clone https://github.com/yfszzx/stable-diffusion-webui-images-browser extensions/images-browser
```
After running this command, make sure that you have `images-browser` dir in webui's `extensions` directory and restart
the UI. The interface for Image browser should appear exactly the same as it was.
## Installation and Running ## Installation and Running
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
......
...@@ -134,18 +134,19 @@ def prepare_enviroment(): ...@@ -134,18 +134,19 @@ def prepare_enviroment():
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379") gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1") clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
xformers_windows_package = os.environ.get('XFORMERS_WINDOWS_PACKAGE', 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl') xformers_windows_package = os.environ.get('XFORMERS_WINDOWS_PACKAGE', 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl')
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/CompVis/stable-diffusion.git") stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git") taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git') codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc") stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "47b6b607fdd31875c9279cd2f4f16b92e4ea958e")
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6") taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "60e5042ca0da89c14d1dd59d73883280f8fce991") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "5b3af030dd83e0297272d861c19477735d0317ec")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
...@@ -179,6 +180,9 @@ def prepare_enviroment(): ...@@ -179,6 +180,9 @@ def prepare_enviroment():
if not is_installed("clip"): if not is_installed("clip"):
run_pip(f"install {clip_package}", "clip") run_pip(f"install {clip_package}", "clip")
if not is_installed("open_clip"):
run_pip(f"install {openclip_package}", "open_clip")
if (not is_installed("xformers") or reinstall_xformers) and xformers: if (not is_installed("xformers") or reinstall_xformers) and xformers:
if platform.system() == "Windows": if platform.system() == "Windows":
if platform.python_version().startswith("3.10"): if platform.python_version().startswith("3.10"):
...@@ -196,7 +200,7 @@ def prepare_enviroment(): ...@@ -196,7 +200,7 @@ def prepare_enviroment():
os.makedirs(dir_repos, exist_ok=True) os.makedirs(dir_repos, exist_ok=True)
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash) git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash) git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
......
...@@ -9,7 +9,7 @@ sys.path.insert(0, script_path) ...@@ -9,7 +9,7 @@ sys.path.insert(0, script_path)
# search for directory of stable diffusion in following places # search for directory of stable diffusion in following places
sd_path = None sd_path = None
possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), '.', os.path.dirname(script_path)] possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion-stability-ai'), '.', os.path.dirname(script_path)]
for possible_sd_path in possible_sd_paths: for possible_sd_path in possible_sd_paths:
if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')): if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
sd_path = os.path.abspath(possible_sd_path) sd_path = os.path.abspath(possible_sd_path)
......
This diff is collapsed.
This diff is collapsed.
...@@ -199,8 +199,8 @@ def sample_plms(self, ...@@ -199,8 +199,8 @@ def sample_plms(self,
@torch.no_grad() @torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
def get_model_output(x, t): def get_model_output(x, t):
...@@ -249,6 +249,8 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F ...@@ -249,6 +249,8 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised: if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
# direction pointing to x_t # direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
...@@ -321,12 +323,16 @@ def should_hijack_inpainting(checkpoint_info): ...@@ -321,12 +323,16 @@ def should_hijack_inpainting(checkpoint_info):
def do_inpainting_hijack(): def do_inpainting_hijack():
ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning # most of this stuff seems to no longer be needed because it is already included into SD2.0
# LatentInpaintDiffusion remains because SD2.0's LatentInpaintDiffusion can't be loaded without specifying a checkpoint
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
# this file should be cleaned up later if weverything tuens out to work fine
# ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim # ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim # ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms # ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
import open_clip.tokenizer
import torch
from modules import sd_hijack_clip, devices
from modules.shared import opts
tokenizer = open_clip.tokenizer._tokenizer
class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
def __init__(self, wrapped, hijack):
super().__init__(wrapped, hijack)
self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0]
self.id_start = tokenizer.encoder["<start_of_text>"]
self.id_end = tokenizer.encoder["<end_of_text>"]
self.id_pad = 0
def tokenize(self, texts):
assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
tokenized = [tokenizer.encode(text) for text in texts]
return tokenized
def encode_with_transformers(self, tokens):
# set self.wrapped.layer_idx here according to opts.CLIP_stop_at_last_layers
z = self.wrapped.encode_with_transformer(tokens)
return z
def encode_embedding_init_text(self, init_text, nvpt):
ids = tokenizer.encode(init_text)
ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
return embedded
...@@ -127,7 +127,8 @@ class InterruptedException(BaseException): ...@@ -127,7 +127,8 @@ class InterruptedException(BaseException):
class VanillaStableDiffusionSampler: class VanillaStableDiffusionSampler:
def __init__(self, constructor, sd_model): def __init__(self, constructor, sd_model):
self.sampler = constructor(sd_model) self.sampler = constructor(sd_model)
self.orig_p_sample_ddim = self.sampler.p_sample_ddim if hasattr(self.sampler, 'p_sample_ddim') else self.sampler.p_sample_plms self.is_plms = hasattr(self.sampler, 'p_sample_plms')
self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim
self.mask = None self.mask = None
self.nmask = None self.nmask = None
self.init_latent = None self.init_latent = None
...@@ -218,7 +219,6 @@ class VanillaStableDiffusionSampler: ...@@ -218,7 +219,6 @@ class VanillaStableDiffusionSampler:
self.mask = p.mask if hasattr(p, 'mask') else None self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None self.nmask = p.nmask if hasattr(p, 'nmask') else None
def adjust_steps_if_invalid(self, p, num_steps): def adjust_steps_if_invalid(self, p, num_steps):
if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'): if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
valid_step = 999 / (1000 // num_steps) valid_step = 999 / (1000 // num_steps)
...@@ -227,7 +227,6 @@ class VanillaStableDiffusionSampler: ...@@ -227,7 +227,6 @@ class VanillaStableDiffusionSampler:
return num_steps return num_steps
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = setup_img2img_steps(p, steps) steps, t_enc = setup_img2img_steps(p, steps)
steps = self.adjust_steps_if_invalid(p, steps) steps = self.adjust_steps_if_invalid(p, steps)
...@@ -260,9 +259,10 @@ class VanillaStableDiffusionSampler: ...@@ -260,9 +259,10 @@ class VanillaStableDiffusionSampler:
steps = self.adjust_steps_if_invalid(p, steps or p.steps) steps = self.adjust_steps_if_invalid(p, steps or p.steps)
# Wrap the conditioning models with additional image conditioning for inpainting model # Wrap the conditioning models with additional image conditioning for inpainting model
# dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
if image_conditioning is not None: if image_conditioning is not None:
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]} conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0]) samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
...@@ -350,7 +350,9 @@ class TorchHijack: ...@@ -350,7 +350,9 @@ class TorchHijack:
class KDiffusionSampler: class KDiffusionSampler:
def __init__(self, funcname, sd_model): def __init__(self, funcname, sd_model):
self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization) denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
self.funcname = funcname self.funcname = funcname
self.func = getattr(k_diffusion.sampling, self.funcname) self.func = getattr(k_diffusion.sampling, self.funcname)
self.extra_params = sampler_extra_params.get(funcname, []) self.extra_params = sampler_extra_params.get(funcname, [])
......
...@@ -11,17 +11,15 @@ import tqdm ...@@ -11,17 +11,15 @@ import tqdm
import modules.artists import modules.artists
import modules.interrogate import modules.interrogate
import modules.memmon import modules.memmon
import modules.sd_models
import modules.styles import modules.styles
import modules.devices as devices import modules.devices as devices
from modules import sd_samplers, sd_models, localization, sd_vae, extensions, script_loading from modules import localization, sd_vae, extensions, script_loading
from modules.hypernetworks import hypernetwork
from modules.paths import models_path, script_path, sd_path from modules.paths import models_path, script_path, sd_path
sd_model_file = os.path.join(script_path, 'model.ckpt') sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file default_sd_model_file = sd_model_file
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",) parser.add_argument("--config", type=str, default=os.path.join(script_path, "v1-inference.yaml"), help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints") parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
...@@ -121,10 +119,12 @@ xformers_available = False ...@@ -121,10 +119,12 @@ xformers_available = False
config_filename = cmd_opts.ui_settings_file config_filename = cmd_opts.ui_settings_file
os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) hypernetworks = {}
loaded_hypernetwork = None loaded_hypernetwork = None
def reload_hypernetworks(): def reload_hypernetworks():
from modules.hypernetworks import hypernetwork
global hypernetworks global hypernetworks
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
...@@ -206,10 +206,11 @@ class State: ...@@ -206,10 +206,11 @@ class State:
if self.current_latent is None: if self.current_latent is None:
return return
import modules.sd_samplers
if opts.show_progress_grid: if opts.show_progress_grid:
self.current_image = sd_samplers.samples_to_image_grid(self.current_latent) self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent)
else: else:
self.current_image = sd_samplers.sample_to_image(self.current_latent) self.current_image = modules.sd_samplers.sample_to_image(self.current_latent)
self.current_image_sampling_step = self.sampling_step self.current_image_sampling_step = self.sampling_step
...@@ -248,6 +249,21 @@ def options_section(section_identifier, options_dict): ...@@ -248,6 +249,21 @@ def options_section(section_identifier, options_dict):
return options_dict return options_dict
def list_checkpoint_tiles():
import modules.sd_models
return modules.sd_models.checkpoint_tiles()
def refresh_checkpoints():
import modules.sd_models
return modules.sd_models.list_models()
def list_samplers():
import modules.sd_samplers
return modules.sd_samplers.all_samplers
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config} hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
options_templates = {} options_templates = {}
...@@ -333,7 +349,7 @@ options_templates.update(options_section(('training', "Training"), { ...@@ -333,7 +349,7 @@ options_templates.update(options_section(('training', "Training"), {
})) }))
options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list), "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
"sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), "sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
...@@ -385,7 +401,7 @@ options_templates.update(options_section(('ui', "User interface"), { ...@@ -385,7 +401,7 @@ options_templates.update(options_section(('ui', "User interface"), {
})) }))
options_templates.update(options_section(('sampler-params', "Sampler parameters"), { options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
"hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in sd_samplers.all_samplers]}), "hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}),
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}), "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
......
...@@ -64,7 +64,8 @@ class EmbeddingDatabase: ...@@ -64,7 +64,8 @@ class EmbeddingDatabase:
self.word_embeddings[embedding.name] = embedding self.word_embeddings[embedding.name] = embedding
ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0] # TODO changing between clip and open clip changes tokenization, which will cause embeddings to stop working
ids = model.cond_stage_model.tokenize([embedding.name])[0]
first_id = ids[0] first_id = ids[0]
if first_id not in self.ids_lookup: if first_id not in self.ids_lookup:
...@@ -155,13 +156,11 @@ class EmbeddingDatabase: ...@@ -155,13 +156,11 @@ class EmbeddingDatabase:
def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
cond_model = shared.sd_model.cond_stage_model cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
with devices.autocast(): with devices.autocast():
cond_model([""]) # will send cond model to GPU if lowvram/medvram is active cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"] embedded = cond_model.encode_embedding_init_text(init_text, num_vectors_per_token)
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device) vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
for i in range(num_vectors_per_token): for i in range(num_vectors_per_token):
......
...@@ -478,9 +478,7 @@ def create_toprow(is_img2img): ...@@ -478,9 +478,7 @@ def create_toprow(is_img2img):
if is_img2img: if is_img2img:
with gr.Column(scale=1, elem_id="interrogate_col"): with gr.Column(scale=1, elem_id="interrogate_col"):
button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
if cmd_opts.deepdanbooru:
button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
with gr.Column(scale=1): with gr.Column(scale=1):
with gr.Row(): with gr.Row():
...@@ -1004,11 +1002,10 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1004,11 +1002,10 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[img2img_prompt], outputs=[img2img_prompt],
) )
if cmd_opts.deepdanbooru: img2img_deepbooru.click(
img2img_deepbooru.click( fn=interrogate_deepbooru,
fn=interrogate_deepbooru, inputs=[init_img],
inputs=[init_img], outputs=[img2img_prompt],
outputs=[img2img_prompt],
) )
......
...@@ -28,3 +28,4 @@ kornia ...@@ -28,3 +28,4 @@ kornia
lark lark
inflection inflection
GitPython GitPython
torchsde
...@@ -25,3 +25,4 @@ kornia==0.6.7 ...@@ -25,3 +25,4 @@ kornia==0.6.7
lark==1.1.2 lark==1.1.2
inflection==0.5.1 inflection==0.5.1
GitPython==3.1.27 GitPython==3.1.27
torchsde==0.2.5
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
...@@ -10,7 +10,7 @@ from fastapi.middleware.gzip import GZipMiddleware ...@@ -10,7 +10,7 @@ from fastapi.middleware.gzip import GZipMiddleware
from modules.paths import script_path from modules.paths import script_path
from modules import devices, sd_samplers, upscaler, extensions, localization from modules import shared, devices, sd_samplers, upscaler, extensions, localization
import modules.codeformer_model as codeformer import modules.codeformer_model as codeformer
import modules.extras import modules.extras
import modules.face_restoration import modules.face_restoration
...@@ -23,7 +23,6 @@ import modules.scripts ...@@ -23,7 +23,6 @@ import modules.scripts
import modules.sd_hijack import modules.sd_hijack
import modules.sd_models import modules.sd_models
import modules.sd_vae import modules.sd_vae
import modules.shared as shared
import modules.txt2img import modules.txt2img
import modules.script_callbacks import modules.script_callbacks
...@@ -86,7 +85,7 @@ def initialize(): ...@@ -86,7 +85,7 @@ def initialize():
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: shared.reload_hypernetworks()))
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None: if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment