Unverified Commit eeb1de43 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge branch 'master' into gradient-clipping

parents d85c2cb2 b7deea47
name: Run basic features tests on CPU with empty SD model
on:
- push
- pull_request
jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Checkout Code
uses: actions/checkout@v3
- name: Set up Python 3.10
uses: actions/setup-python@v4
with:
python-version: 3.10.6
- uses: actions/cache@v3
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
restore-keys: ${{ runner.os }}-pip-
- name: Run tests
run: python launch.py --tests basic_features --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
- name: Upload main app stdout-stderr
uses: actions/upload-artifact@v3
if: always()
with:
name: stdout-stderr
path: |
test/stdout.txt
test/stderr.txt
__pycache__ __pycache__
*.ckpt *.ckpt
*.safetensors
*.pth *.pth
/ESRGAN/* /ESRGAN/*
/SwinIR/* /SwinIR/*
......
...@@ -70,7 +70,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web ...@@ -70,7 +70,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- separate prompts using uppercase `AND` - separate prompts using uppercase `AND`
- also supports weights for prompts: `a cat :1.2 AND a dog AND a penguin :2.2` - also supports weights for prompts: `a cat :1.2 AND a dog AND a penguin :2.2`
- No token limit for prompts (original stable diffusion lets you use up to 75 tokens) - No token limit for prompts (original stable diffusion lets you use up to 75 tokens)
- DeepDanbooru integration, creates danbooru style tags for anime prompts (add --deepdanbooru to commandline args) - DeepDanbooru integration, creates danbooru style tags for anime prompts
- [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args) - [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args)
- via extension: [History tab](https://github.com/yfszzx/stable-diffusion-webui-images-browser): view, direct and delete images conveniently within the UI - via extension: [History tab](https://github.com/yfszzx/stable-diffusion-webui-images-browser): view, direct and delete images conveniently within the UI
- Generate forever option - Generate forever option
...@@ -83,27 +83,8 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web ...@@ -83,27 +83,8 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- Estimated completion time in progress bar - Estimated completion time in progress bar
- API - API
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML. - Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML.
- via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients)) - via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embeds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))
- [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions
## Where are Aesthetic Gradients?!?!
Aesthetic Gradients are now an extension. You can install it using git:
```commandline
git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients extensions/aesthetic-gradients
```
After running this command, make sure that you have `aesthetic-gradients` dir in webui's `extensions` directory and restart
the UI. The interface for Aesthetic Gradients should appear exactly the same as it was.
## Where is History/Image browser?!?!
Image browser is now an extension. You can install it using git:
```commandline
git clone https://github.com/yfszzx/stable-diffusion-webui-images-browser extensions/images-browser
```
After running this command, make sure that you have `images-browser` dir in webui's `extensions` directory and restart
the UI. The interface for Image browser should appear exactly the same as it was.
## Installation and Running ## Installation and Running
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
...@@ -146,6 +127,8 @@ Here's how to add code to this repo: [Contributing](https://github.com/AUTOMATIC ...@@ -146,6 +127,8 @@ Here's how to add code to this repo: [Contributing](https://github.com/AUTOMATIC
The documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki). The documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki).
## Credits ## Credits
Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
- Stable Diffusion - https://github.com/CompVis/stable-diffusion, https://github.com/CompVis/taming-transformers - Stable Diffusion - https://github.com/CompVis/stable-diffusion, https://github.com/CompVis/taming-transformers
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git - k-diffusion - https://github.com/crowsonkb/k-diffusion.git
- GFPGAN - https://github.com/TencentARC/GFPGAN.git - GFPGAN - https://github.com/TencentARC/GFPGAN.git
...@@ -154,6 +137,7 @@ The documentation was moved from this README over to the project's [wiki](https: ...@@ -154,6 +137,7 @@ The documentation was moved from this README over to the project's [wiki](https:
- SwinIR - https://github.com/JingyunLiang/SwinIR - SwinIR - https://github.com/JingyunLiang/SwinIR
- Swin2SR - https://github.com/mv-lab/swin2sr - Swin2SR - https://github.com/mv-lab/swin2sr
- LDSR - https://github.com/Hafiidz/latent-diffusion - LDSR - https://github.com/Hafiidz/latent-diffusion
- MiDaS - https://github.com/isl-org/MiDaS
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion - Ideas for optimizations - https://github.com/basujindal/stable-diffusion
- Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing. - Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
- Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion) - Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion)
......
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: modules.xlmr.BertSeriesModelWithTransformation
params:
name: "XLMR-Large"
\ No newline at end of file
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
import os
import gc import gc
import time import time
import warnings import warnings
...@@ -8,27 +9,49 @@ import torchvision ...@@ -8,27 +9,49 @@ import torchvision
from PIL import Image from PIL import Image
from einops import rearrange, repeat from einops import rearrange, repeat
from omegaconf import OmegaConf from omegaconf import OmegaConf
import safetensors.torch
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config, ismap from ldm.util import instantiate_from_config, ismap
from modules import shared, sd_hijack
warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=UserWarning)
cached_ldsr_model: torch.nn.Module = None
# Create LDSR Class # Create LDSR Class
class LDSR: class LDSR:
def load_model_from_config(self, half_attention): def load_model_from_config(self, half_attention):
global cached_ldsr_model
if shared.opts.ldsr_cached and cached_ldsr_model is not None:
print("Loading model from cache")
model: torch.nn.Module = cached_ldsr_model
else:
print(f"Loading model from {self.modelPath}") print(f"Loading model from {self.modelPath}")
_, extension = os.path.splitext(self.modelPath)
if extension.lower() == ".safetensors":
pl_sd = safetensors.torch.load_file(self.modelPath, device="cpu")
else:
pl_sd = torch.load(self.modelPath, map_location="cpu") pl_sd = torch.load(self.modelPath, map_location="cpu")
sd = pl_sd["state_dict"] sd = pl_sd["state_dict"] if "state_dict" in pl_sd else pl_sd
config = OmegaConf.load(self.yamlPath) config = OmegaConf.load(self.yamlPath)
model = instantiate_from_config(config.model) config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1"
model: torch.nn.Module = instantiate_from_config(config.model)
model.load_state_dict(sd, strict=False) model.load_state_dict(sd, strict=False)
model.cuda() model = model.to(shared.device)
if half_attention: if half_attention:
model = model.half() model = model.half()
if shared.cmd_opts.opt_channelslast:
model = model.to(memory_format=torch.channels_last)
sd_hijack.model_hijack.hijack(model) # apply optimization
model.eval() model.eval()
if shared.opts.ldsr_cached:
cached_ldsr_model = model
return {"model": model} return {"model": model}
def __init__(self, model_path, yaml_path): def __init__(self, model_path, yaml_path):
...@@ -93,6 +116,7 @@ class LDSR: ...@@ -93,6 +116,7 @@ class LDSR:
down_sample_method = 'Lanczos' down_sample_method = 'Lanczos'
gc.collect() gc.collect()
if torch.cuda.is_available:
torch.cuda.empty_cache() torch.cuda.empty_cache()
im_og = image im_og = image
...@@ -130,7 +154,9 @@ class LDSR: ...@@ -130,7 +154,9 @@ class LDSR:
del model del model
gc.collect() gc.collect()
if torch.cuda.is_available:
torch.cuda.empty_cache() torch.cuda.empty_cache()
return a return a
...@@ -145,7 +171,7 @@ def get_cond(selected_path): ...@@ -145,7 +171,7 @@ def get_cond(selected_path):
c = rearrange(c, '1 c h w -> 1 h w c') c = rearrange(c, '1 c h w -> 1 h w c')
c = 2. * c - 1. c = 2. * c - 1.
c = c.to(torch.device("cuda")) c = c.to(shared.device)
example["LR_image"] = c example["LR_image"] = c
example["image"] = c_up example["image"] = c_up
......
import os
from modules import paths
def preload(parser):
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(paths.models_path, 'LDSR'))
...@@ -5,8 +5,9 @@ import traceback ...@@ -5,8 +5,9 @@ import traceback
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
from modules.ldsr_model_arch import LDSR from ldsr_model_arch import LDSR
from modules import shared from modules import shared, script_callbacks
import sd_hijack_autoencoder, sd_hijack_ddpm_v1
class UpscalerLDSR(Upscaler): class UpscalerLDSR(Upscaler):
...@@ -24,6 +25,7 @@ class UpscalerLDSR(Upscaler): ...@@ -24,6 +25,7 @@ class UpscalerLDSR(Upscaler):
yaml_path = os.path.join(self.model_path, "project.yaml") yaml_path = os.path.join(self.model_path, "project.yaml")
old_model_path = os.path.join(self.model_path, "model.pth") old_model_path = os.path.join(self.model_path, "model.pth")
new_model_path = os.path.join(self.model_path, "model.ckpt") new_model_path = os.path.join(self.model_path, "model.ckpt")
safetensors_model_path = os.path.join(self.model_path, "model.safetensors")
if os.path.exists(yaml_path): if os.path.exists(yaml_path):
statinfo = os.stat(yaml_path) statinfo = os.stat(yaml_path)
if statinfo.st_size >= 10485760: if statinfo.st_size >= 10485760:
...@@ -32,6 +34,9 @@ class UpscalerLDSR(Upscaler): ...@@ -32,6 +34,9 @@ class UpscalerLDSR(Upscaler):
if os.path.exists(old_model_path): if os.path.exists(old_model_path):
print("Renaming model from model.pth to model.ckpt") print("Renaming model from model.pth to model.ckpt")
os.rename(old_model_path, new_model_path) os.rename(old_model_path, new_model_path)
if os.path.exists(safetensors_model_path):
model = safetensors_model_path
else:
model = load_file_from_url(url=self.model_url, model_dir=self.model_path, model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
file_name="model.ckpt", progress=True) file_name="model.ckpt", progress=True)
yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path, yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path,
...@@ -52,3 +57,13 @@ class UpscalerLDSR(Upscaler): ...@@ -52,3 +57,13 @@ class UpscalerLDSR(Upscaler):
return img return img
ddim_steps = shared.opts.ldsr_steps ddim_steps = shared.opts.ldsr_steps
return ldsr.super_resolution(img, ddim_steps, self.scale) return ldsr.super_resolution(img, ddim_steps, self.scale)
def on_ui_settings():
import gradio as gr
shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling")))
shared.opts.add_option("ldsr_cached", shared.OptionInfo(False, "Cache LDSR model in memory", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")))
script_callbacks.on_ui_settings(on_ui_settings)
# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from contextlib import contextmanager
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.util import instantiate_from_config
import ldm.models.autoencoder
class VQModel(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
batch_resize_range=None,
scheduler_config=None,
lr_g_factor=1.0,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
use_ema=False
):
super().__init__()
self.embed_dim = embed_dim
self.n_embed = n_embed
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
remap=remap,
sane_index_shape=sane_index_shape)
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.batch_resize_range = batch_resize_range
if self.batch_resize_range is not None:
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.scheduler_config = scheduler_config
self.lr_g_factor = lr_g_factor
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys: {missing}")
print(f"Unexpected Keys: {unexpected}")
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
quant, emb_loss, info = self.quantize(h)
return quant, emb_loss, info
def encode_to_prequant(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
return h
def decode(self, quant):
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
def decode_code(self, code_b):
quant_b = self.quantize.embed_code(code_b)
dec = self.decode(quant_b)
return dec
def forward(self, input, return_pred_indices=False):
quant, diff, (_,_,ind) = self.encode(input)
dec = self.decode(quant)
if return_pred_indices:
return dec, diff, ind
return dec, diff
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
if self.batch_resize_range is not None:
lower_size = self.batch_resize_range[0]
upper_size = self.batch_resize_range[1]
if self.global_step <= 4:
# do the first few batches with max size to avoid later oom
new_resize = upper_size
else:
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
if new_resize != x.shape[2]:
x = F.interpolate(x, size=new_resize, mode="bicubic")
x = x.detach()
return x
def training_step(self, batch, batch_idx, optimizer_idx):
# https://github.com/pytorch/pytorch/issues/37142
# try not to fool the heuristics
x = self.get_input(batch, self.image_key)
xrec, qloss, ind = self(x, return_pred_indices=True)
if optimizer_idx == 0:
# autoencode
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train",
predicted_indices=ind)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return aeloss
if optimizer_idx == 1:
# discriminator
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return discloss
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
return log_dict
def _validation_step(self, batch, batch_idx, suffix=""):
x = self.get_input(batch, self.image_key)
xrec, qloss, ind = self(x, return_pred_indices=True)
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
self.global_step,
last_layer=self.get_last_layer(),
split="val"+suffix,
predicted_indices=ind
)
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
self.global_step,
last_layer=self.get_last_layer(),
split="val"+suffix,
predicted_indices=ind
)
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
self.log(f"val{suffix}/rec_loss", rec_loss,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
self.log(f"val{suffix}/aeloss", aeloss,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
if version.parse(pl.__version__) >= version.parse('1.4.0'):
del log_dict_ae[f"val{suffix}/rec_loss"]
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr_d = self.learning_rate
lr_g = self.lr_g_factor*self.learning_rate
print("lr_d", lr_d)
print("lr_g", lr_g)
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
list(self.decoder.parameters())+
list(self.quantize.parameters())+
list(self.quant_conv.parameters())+
list(self.post_quant_conv.parameters()),
lr=lr_g, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr_d, betas=(0.5, 0.9))
if self.scheduler_config is not None:
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
scheduler = [
{
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
'interval': 'step',
'frequency': 1
},
{
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
'interval': 'step',
'frequency': 1
},
]
return [opt_ae, opt_disc], scheduler
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if only_inputs:
log["inputs"] = x
return log
xrec, _ = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["inputs"] = x
log["reconstructions"] = xrec
if plot_ema:
with self.ema_scope():
xrec_ema, _ = self(x)
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
log["reconstructions_ema"] = xrec_ema
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
return x
class VQModelInterface(VQModel):
def __init__(self, embed_dim, *args, **kwargs):
super().__init__(embed_dim=embed_dim, *args, **kwargs)
self.embed_dim = embed_dim
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
return h
def decode(self, h, force_not_quantize=False):
# also go through quantization layer
if not force_not_quantize:
quant, emb_loss, info = self.quantize(h)
else:
quant = h
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
setattr(ldm.models.autoencoder, "VQModel", VQModel)
setattr(ldm.models.autoencoder, "VQModelInterface", VQModelInterface)
# This script is copied from the compvis/stable-diffusion repo (aka the SD V1 repo)
# Original filename: ldm/models/diffusion/ddpm.py
# The purpose to reinstate the old DDPM logic which works with VQ, whereas the V2 one doesn't
# Some models such as LDSR require VQ to work correctly
# The classes are suffixed with "V1" and added back to the "ldm.models.diffusion.ddpm" module
import torch
import torch.nn as nn
import numpy as np
import pytorch_lightning as pl
from torch.optim.lr_scheduler import LambdaLR
from einops import rearrange, repeat
from contextlib import contextmanager
from functools import partial
from tqdm import tqdm
from torchvision.utils import make_grid
from pytorch_lightning.utilities.distributed import rank_zero_only
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
from ldm.modules.ema import LitEma
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
from ldm.models.diffusion.ddim import DDIMSampler
import ldm.models.diffusion.ddpm
__conditioning_keys__ = {'concat': 'c_concat',
'crossattn': 'c_crossattn',
'adm': 'y'}
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
def uniform_on_device(r1, r2, shape, device):
return (r1 - r2) * torch.rand(*shape, device=device) + r2
class DDPMV1(pl.LightningModule):
# classic DDPM with Gaussian diffusion, in image space
def __init__(self,
unet_config,
timesteps=1000,
beta_schedule="linear",
loss_type="l2",
ckpt_path=None,
ignore_keys=[],
load_only_unet=False,
monitor="val/loss",
use_ema=True,
first_stage_key="image",
image_size=256,
channels=3,
log_every_t=100,
clip_denoised=True,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
given_betas=None,
original_elbo_weight=0.,
v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
l_simple_weight=1.,
conditioning_key=None,
parameterization="eps", # all assuming fixed variance schedules
scheduler_config=None,
use_positional_encodings=False,
learn_logvar=False,
logvar_init=0.,
):
super().__init__()
assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
self.parameterization = parameterization
print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
self.cond_stage_model = None
self.clip_denoised = clip_denoised
self.log_every_t = log_every_t
self.first_stage_key = first_stage_key
self.image_size = image_size # try conv?
self.channels = channels
self.use_positional_encodings = use_positional_encodings
self.model = DiffusionWrapperV1(unet_config, conditioning_key)
count_params(self.model, verbose=True)
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self.model)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
self.use_scheduler = scheduler_config is not None
if self.use_scheduler:
self.scheduler_config = scheduler_config
self.v_posterior = v_posterior
self.original_elbo_weight = original_elbo_weight
self.l_simple_weight = l_simple_weight
if monitor is not None:
self.monitor = monitor
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
self.loss_type = loss_type
self.learn_logvar = learn_logvar
self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
if self.learn_logvar:
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if exists(given_betas):
betas = given_betas
else:
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer('betas', to_torch(betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
1. - alphas_cumprod) + self.v_posterior * betas
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer('posterior_variance', to_torch(posterior_variance))
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
self.register_buffer('posterior_mean_coef1', to_torch(
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
self.register_buffer('posterior_mean_coef2', to_torch(
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
if self.parameterization == "eps":
lvlb_weights = self.betas ** 2 / (
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
elif self.parameterization == "x0":
lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
else:
raise NotImplementedError("mu not supported")
# TODO how to choose this term
lvlb_weights[0] = lvlb_weights[1]
self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
assert not torch.isnan(self.lvlb_weights).all()
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.model.parameters())
self.model_ema.copy_to(self.model)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.model.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
def q_mean_variance(self, x_start, t):
"""
Get the distribution q(x_t | x_0).
:param x_start: the [N x C x ...] tensor of noiseless inputs.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
"""
mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
return mean, variance, log_variance
def predict_start_from_noise(self, x_t, t, noise):
return (
extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, clip_denoised: bool):
model_out = self.model(x, t)
if self.parameterization == "eps":
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
elif self.parameterization == "x0":
x_recon = model_out
if clip_denoised:
x_recon.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, shape, return_intermediates=False):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device=device)
intermediates = [img]
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
clip_denoised=self.clip_denoised)
if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
intermediates.append(img)
if return_intermediates:
return img, intermediates
return img
@torch.no_grad()
def sample(self, batch_size=16, return_intermediates=False):
image_size = self.image_size
channels = self.channels
return self.p_sample_loop((batch_size, channels, image_size, image_size),
return_intermediates=return_intermediates)
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
def get_loss(self, pred, target, mean=True):
if self.loss_type == 'l1':
loss = (target - pred).abs()
if mean:
loss = loss.mean()
elif self.loss_type == 'l2':
if mean:
loss = torch.nn.functional.mse_loss(target, pred)
else:
loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
else:
raise NotImplementedError("unknown loss type '{loss_type}'")
return loss
def p_losses(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
model_out = self.model(x_noisy, t)
loss_dict = {}
if self.parameterization == "eps":
target = noise
elif self.parameterization == "x0":
target = x_start
else:
raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
log_prefix = 'train' if self.training else 'val'
loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
loss_simple = loss.mean() * self.l_simple_weight
loss_vlb = (self.lvlb_weights[t] * loss).mean()
loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
loss = loss_simple + self.original_elbo_weight * loss_vlb
loss_dict.update({f'{log_prefix}/loss': loss})
return loss, loss_dict
def forward(self, x, *args, **kwargs):
# b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
# assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
return self.p_losses(x, t, *args, **kwargs)
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = rearrange(x, 'b h w c -> b c h w')
x = x.to(memory_format=torch.contiguous_format).float()
return x
def shared_step(self, batch):
x = self.get_input(batch, self.first_stage_key)
loss, loss_dict = self(x)
return loss, loss_dict
def training_step(self, batch, batch_idx):
loss, loss_dict = self.shared_step(batch)
self.log_dict(loss_dict, prog_bar=True,
logger=True, on_step=True, on_epoch=True)
self.log("global_step", self.global_step,
prog_bar=True, logger=True, on_step=True, on_epoch=False)
if self.use_scheduler:
lr = self.optimizers().param_groups[0]['lr']
self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
return loss
@torch.no_grad()
def validation_step(self, batch, batch_idx):
_, loss_dict_no_ema = self.shared_step(batch)
with self.ema_scope():
_, loss_dict_ema = self.shared_step(batch)
loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self.model)
def _get_rows_from_list(self, samples):
n_imgs_per_row = len(samples)
denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
return denoise_grid
@torch.no_grad()
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
log = dict()
x = self.get_input(batch, self.first_stage_key)
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
x = x.to(self.device)[:N]
log["inputs"] = x
# get diffusion row
diffusion_row = list()
x_start = x[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
t = t.to(self.device).long()
noise = torch.randn_like(x_start)
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
diffusion_row.append(x_noisy)
log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
if sample:
# get denoise row
with self.ema_scope("Plotting"):
samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
log["samples"] = samples
log["denoise_row"] = self._get_rows_from_list(denoise_row)
if return_keys:
if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
return log
else:
return {key: log[key] for key in return_keys}
return log
def configure_optimizers(self):
lr = self.learning_rate
params = list(self.model.parameters())
if self.learn_logvar:
params = params + [self.logvar]
opt = torch.optim.AdamW(params, lr=lr)
return opt
class LatentDiffusionV1(DDPMV1):
"""main class"""
def __init__(self,
first_stage_config,
cond_stage_config,
num_timesteps_cond=None,
cond_stage_key="image",
cond_stage_trainable=False,
concat_mode=True,
cond_stage_forward=None,
conditioning_key=None,
scale_factor=1.0,
scale_by_std=False,
*args, **kwargs):
self.num_timesteps_cond = default(num_timesteps_cond, 1)
self.scale_by_std = scale_by_std
assert self.num_timesteps_cond <= kwargs['timesteps']
# for backwards compatibility after implementation of DiffusionWrapper
if conditioning_key is None:
conditioning_key = 'concat' if concat_mode else 'crossattn'
if cond_stage_config == '__is_unconditional__':
conditioning_key = None
ckpt_path = kwargs.pop("ckpt_path", None)
ignore_keys = kwargs.pop("ignore_keys", [])
super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
self.concat_mode = concat_mode
self.cond_stage_trainable = cond_stage_trainable
self.cond_stage_key = cond_stage_key
try:
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
except:
self.num_downs = 0
if not scale_by_std:
self.scale_factor = scale_factor
else:
self.register_buffer('scale_factor', torch.tensor(scale_factor))
self.instantiate_first_stage(first_stage_config)
self.instantiate_cond_stage(cond_stage_config)
self.cond_stage_forward = cond_stage_forward
self.clip_denoised = False
self.bbox_tokenizer = None
self.restarted_from_ckpt = False
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys)
self.restarted_from_ckpt = True
def make_cond_schedule(self, ):
self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
self.cond_ids[:self.num_timesteps_cond] = ids
@rank_zero_only
@torch.no_grad()
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
# only for very first batch
if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
# set rescale weight to 1./std of encodings
print("### USING STD-RESCALING ###")
x = super().get_input(batch, self.first_stage_key)
x = x.to(self.device)
encoder_posterior = self.encode_first_stage(x)
z = self.get_first_stage_encoding(encoder_posterior).detach()
del self.scale_factor
self.register_buffer('scale_factor', 1. / z.flatten().std())
print(f"setting self.scale_factor to {self.scale_factor}")
print("### USING STD-RESCALING ###")
def register_schedule(self,
given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
self.shorten_cond_schedule = self.num_timesteps_cond > 1
if self.shorten_cond_schedule:
self.make_cond_schedule()
def instantiate_first_stage(self, config):
model = instantiate_from_config(config)
self.first_stage_model = model.eval()
self.first_stage_model.train = disabled_train
for param in self.first_stage_model.parameters():
param.requires_grad = False
def instantiate_cond_stage(self, config):
if not self.cond_stage_trainable:
if config == "__is_first_stage__":
print("Using first stage also as cond stage.")
self.cond_stage_model = self.first_stage_model
elif config == "__is_unconditional__":
print(f"Training {self.__class__.__name__} as an unconditional model.")
self.cond_stage_model = None
# self.be_unconditional = True
else:
model = instantiate_from_config(config)
self.cond_stage_model = model.eval()
self.cond_stage_model.train = disabled_train
for param in self.cond_stage_model.parameters():
param.requires_grad = False
else:
assert config != '__is_first_stage__'
assert config != '__is_unconditional__'
model = instantiate_from_config(config)
self.cond_stage_model = model
def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
denoise_row = []
for zd in tqdm(samples, desc=desc):
denoise_row.append(self.decode_first_stage(zd.to(self.device),
force_not_quantize=force_no_decoder_quantization))
n_imgs_per_row = len(denoise_row)
denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
return denoise_grid
def get_first_stage_encoding(self, encoder_posterior):
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
z = encoder_posterior.sample()
elif isinstance(encoder_posterior, torch.Tensor):
z = encoder_posterior
else:
raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
return self.scale_factor * z
def get_learned_conditioning(self, c):
if self.cond_stage_forward is None:
if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
c = self.cond_stage_model.encode(c)
if isinstance(c, DiagonalGaussianDistribution):
c = c.mode()
else:
c = self.cond_stage_model(c)
else:
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
return c
def meshgrid(self, h, w):
y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
arr = torch.cat([y, x], dim=-1)
return arr
def delta_border(self, h, w):
"""
:param h: height
:param w: width
:return: normalized distance to image border,
wtith min distance = 0 at border and max dist = 0.5 at image center
"""
lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
arr = self.meshgrid(h, w) / lower_right_corner
dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
return edge_dist
def get_weighting(self, h, w, Ly, Lx, device):
weighting = self.delta_border(h, w)
weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
self.split_input_params["clip_max_weight"], )
weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
if self.split_input_params["tie_braker"]:
L_weighting = self.delta_border(Ly, Lx)
L_weighting = torch.clip(L_weighting,
self.split_input_params["clip_min_tie_weight"],
self.split_input_params["clip_max_tie_weight"])
L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
weighting = weighting * L_weighting
return weighting
def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
"""
:param x: img of size (bs, c, h, w)
:return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
"""
bs, nc, h, w = x.shape
# number of crops in image
Ly = (h - kernel_size[0]) // stride[0] + 1
Lx = (w - kernel_size[1]) // stride[1] + 1
if uf == 1 and df == 1:
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
unfold = torch.nn.Unfold(**fold_params)
fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
elif uf > 1 and df == 1:
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
unfold = torch.nn.Unfold(**fold_params)
fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
dilation=1, padding=0,
stride=(stride[0] * uf, stride[1] * uf))
fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
elif df > 1 and uf == 1:
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
unfold = torch.nn.Unfold(**fold_params)
fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
dilation=1, padding=0,
stride=(stride[0] // df, stride[1] // df))
fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
else:
raise NotImplementedError
return fold, unfold, normalization, weighting
@torch.no_grad()
def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
cond_key=None, return_original_cond=False, bs=None):
x = super().get_input(batch, k)
if bs is not None:
x = x[:bs]
x = x.to(self.device)
encoder_posterior = self.encode_first_stage(x)
z = self.get_first_stage_encoding(encoder_posterior).detach()
if self.model.conditioning_key is not None:
if cond_key is None:
cond_key = self.cond_stage_key
if cond_key != self.first_stage_key:
if cond_key in ['caption', 'coordinates_bbox']:
xc = batch[cond_key]
elif cond_key == 'class_label':
xc = batch
else:
xc = super().get_input(batch, cond_key).to(self.device)
else:
xc = x
if not self.cond_stage_trainable or force_c_encode:
if isinstance(xc, dict) or isinstance(xc, list):
# import pudb; pudb.set_trace()
c = self.get_learned_conditioning(xc)
else:
c = self.get_learned_conditioning(xc.to(self.device))
else:
c = xc
if bs is not None:
c = c[:bs]
if self.use_positional_encodings:
pos_x, pos_y = self.compute_latent_shifts(batch)
ckey = __conditioning_keys__[self.model.conditioning_key]
c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
else:
c = None
xc = None
if self.use_positional_encodings:
pos_x, pos_y = self.compute_latent_shifts(batch)
c = {'pos_x': pos_x, 'pos_y': pos_y}
out = [z, c]
if return_first_stage_outputs:
xrec = self.decode_first_stage(z)
out.extend([x, xrec])
if return_original_cond:
out.append(xc)
return out
@torch.no_grad()
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
if predict_cids:
if z.dim() == 4:
z = torch.argmax(z.exp(), dim=1).long()
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
z = rearrange(z, 'b h w c -> b c h w').contiguous()
z = 1. / self.scale_factor * z
if hasattr(self, "split_input_params"):
if self.split_input_params["patch_distributed_vq"]:
ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64)
uf = self.split_input_params["vqf"]
bs, nc, h, w = z.shape
if ks[0] > h or ks[1] > w:
ks = (min(ks[0], h), min(ks[1], w))
print("reducing Kernel")
if stride[0] > h or stride[1] > w:
stride = (min(stride[0], h), min(stride[1], w))
print("reducing stride")
fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
z = unfold(z) # (bn, nc * prod(**ks), L)
# 1. Reshape to img shape
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
# 2. apply model loop over last dim
if isinstance(self.first_stage_model, VQModelInterface):
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
force_not_quantize=predict_cids or force_not_quantize)
for i in range(z.shape[-1])]
else:
output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
for i in range(z.shape[-1])]
o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
o = o * weighting
# Reverse 1. reshape to img shape
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
# stitch crops together
decoded = fold(o)
decoded = decoded / normalization # norm is shape (1, 1, h, w)
return decoded
else:
if isinstance(self.first_stage_model, VQModelInterface):
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
else:
return self.first_stage_model.decode(z)
else:
if isinstance(self.first_stage_model, VQModelInterface):
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
else:
return self.first_stage_model.decode(z)
# same as above but without decorator
def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
if predict_cids:
if z.dim() == 4:
z = torch.argmax(z.exp(), dim=1).long()
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
z = rearrange(z, 'b h w c -> b c h w').contiguous()
z = 1. / self.scale_factor * z
if hasattr(self, "split_input_params"):
if self.split_input_params["patch_distributed_vq"]:
ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64)
uf = self.split_input_params["vqf"]
bs, nc, h, w = z.shape
if ks[0] > h or ks[1] > w:
ks = (min(ks[0], h), min(ks[1], w))
print("reducing Kernel")
if stride[0] > h or stride[1] > w:
stride = (min(stride[0], h), min(stride[1], w))
print("reducing stride")
fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
z = unfold(z) # (bn, nc * prod(**ks), L)
# 1. Reshape to img shape
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
# 2. apply model loop over last dim
if isinstance(self.first_stage_model, VQModelInterface):
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
force_not_quantize=predict_cids or force_not_quantize)
for i in range(z.shape[-1])]
else:
output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
for i in range(z.shape[-1])]
o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
o = o * weighting
# Reverse 1. reshape to img shape
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
# stitch crops together
decoded = fold(o)
decoded = decoded / normalization # norm is shape (1, 1, h, w)
return decoded
else:
if isinstance(self.first_stage_model, VQModelInterface):
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
else:
return self.first_stage_model.decode(z)
else:
if isinstance(self.first_stage_model, VQModelInterface):
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
else:
return self.first_stage_model.decode(z)
@torch.no_grad()
def encode_first_stage(self, x):
if hasattr(self, "split_input_params"):
if self.split_input_params["patch_distributed_vq"]:
ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64)
df = self.split_input_params["vqf"]
self.split_input_params['original_image_size'] = x.shape[-2:]
bs, nc, h, w = x.shape
if ks[0] > h or ks[1] > w:
ks = (min(ks[0], h), min(ks[1], w))
print("reducing Kernel")
if stride[0] > h or stride[1] > w:
stride = (min(stride[0], h), min(stride[1], w))
print("reducing stride")
fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
z = unfold(x) # (bn, nc * prod(**ks), L)
# Reshape to img shape
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
for i in range(z.shape[-1])]
o = torch.stack(output_list, axis=-1)
o = o * weighting
# Reverse reshape to img shape
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
# stitch crops together
decoded = fold(o)
decoded = decoded / normalization
return decoded
else:
return self.first_stage_model.encode(x)
else:
return self.first_stage_model.encode(x)
def shared_step(self, batch, **kwargs):
x, c = self.get_input(batch, self.first_stage_key)
loss = self(x, c)
return loss
def forward(self, x, c, *args, **kwargs):
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
if self.model.conditioning_key is not None:
assert c is not None
if self.cond_stage_trainable:
c = self.get_learned_conditioning(c)
if self.shorten_cond_schedule: # TODO: drop this option
tc = self.cond_ids[t].to(self.device)
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
return self.p_losses(x, c, t, *args, **kwargs)
def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
def rescale_bbox(bbox):
x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
w = min(bbox[2] / crop_coordinates[2], 1 - x0)
h = min(bbox[3] / crop_coordinates[3], 1 - y0)
return x0, y0, w, h
return [rescale_bbox(b) for b in bboxes]
def apply_model(self, x_noisy, t, cond, return_ids=False):
if isinstance(cond, dict):
# hybrid case, cond is exptected to be a dict
pass
else:
if not isinstance(cond, list):
cond = [cond]
key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
cond = {key: cond}
if hasattr(self, "split_input_params"):
assert len(cond) == 1 # todo can only deal with one conditioning atm
assert not return_ids
ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64)
h, w = x_noisy.shape[-2:]
fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
# Reshape to img shape
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
if self.cond_stage_key in ["image", "LR_image", "segmentation",
'bbox_img'] and self.model.conditioning_key: # todo check for completeness
c_key = next(iter(cond.keys())) # get key
c = next(iter(cond.values())) # get value
assert (len(c) == 1) # todo extend to list with more than one elem
c = c[0] # get element
c = unfold(c)
c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
elif self.cond_stage_key == 'coordinates_bbox':
assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
# assuming padding of unfold is always 0 and its dilation is always 1
n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
full_img_h, full_img_w = self.split_input_params['original_image_size']
# as we are operating on latents, we need the factor from the original image size to the
# spatial latent size to properly rescale the crops for regenerating the bbox annotations
num_downs = self.first_stage_model.encoder.num_resolutions - 1
rescale_latent = 2 ** (num_downs)
# get top left postions of patches as conforming for the bbbox tokenizer, therefore we
# need to rescale the tl patch coordinates to be in between (0,1)
tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
for patch_nr in range(z.shape[-1])]
# patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
patch_limits = [(x_tl, y_tl,
rescale_latent * ks[0] / full_img_w,
rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
# patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
# tokenize crop coordinates for the bounding boxes of the respective patches
patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
print(patch_limits_tknzd[0].shape)
# cut tknzd crop position from conditioning
assert isinstance(cond, dict), 'cond must be dict to be fed into model'
cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
print(cut_cond.shape)
adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
print(adapted_cond.shape)
adapted_cond = self.get_learned_conditioning(adapted_cond)
print(adapted_cond.shape)
adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
print(adapted_cond.shape)
cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
else:
cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
# apply model by loop over crops
output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
assert not isinstance(output_list[0],
tuple) # todo cant deal with multiple model outputs check this never happens
o = torch.stack(output_list, axis=-1)
o = o * weighting
# Reverse reshape to img shape
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
# stitch crops together
x_recon = fold(o) / normalization
else:
x_recon = self.model(x_noisy, t, **cond)
if isinstance(x_recon, tuple) and not return_ids:
return x_recon[0]
else:
return x_recon
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
def _prior_bpd(self, x_start):
"""
Get the prior KL term for the variational lower-bound, measured in
bits-per-dim.
This term can't be optimized, as it only depends on the encoder.
:param x_start: the [N x C x ...] tensor of inputs.
:return: a batch of [N] KL values (in bits), one per batch element.
"""
batch_size = x_start.shape[0]
t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
return mean_flat(kl_prior) / np.log(2.0)
def p_losses(self, x_start, cond, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
model_output = self.apply_model(x_noisy, t, cond)
loss_dict = {}
prefix = 'train' if self.training else 'val'
if self.parameterization == "x0":
target = x_start
elif self.parameterization == "eps":
target = noise
else:
raise NotImplementedError()
loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
logvar_t = self.logvar[t].to(self.device)
loss = loss_simple / torch.exp(logvar_t) + logvar_t
# loss = loss_simple / torch.exp(self.logvar) + self.logvar
if self.learn_logvar:
loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
loss_dict.update({'logvar': self.logvar.data.mean()})
loss = self.l_simple_weight * loss.mean()
loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
loss += (self.original_elbo_weight * loss_vlb)
loss_dict.update({f'{prefix}/loss': loss})
return loss, loss_dict
def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
return_x0=False, score_corrector=None, corrector_kwargs=None):
t_in = t
model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
if score_corrector is not None:
assert self.parameterization == "eps"
model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
if return_codebook_ids:
model_out, logits = model_out
if self.parameterization == "eps":
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
elif self.parameterization == "x0":
x_recon = model_out
else:
raise NotImplementedError()
if clip_denoised:
x_recon.clamp_(-1., 1.)
if quantize_denoised:
x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
if return_codebook_ids:
return model_mean, posterior_variance, posterior_log_variance, logits
elif return_x0:
return model_mean, posterior_variance, posterior_log_variance, x_recon
else:
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
return_codebook_ids=False, quantize_denoised=False, return_x0=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
b, *_, device = *x.shape, x.device
outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
return_codebook_ids=return_codebook_ids,
quantize_denoised=quantize_denoised,
return_x0=return_x0,
score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
if return_codebook_ids:
raise DeprecationWarning("Support dropped.")
model_mean, _, model_log_variance, logits = outputs
elif return_x0:
model_mean, _, model_log_variance, x0 = outputs
else:
model_mean, _, model_log_variance = outputs
noise = noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
if return_codebook_ids:
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
if return_x0:
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
else:
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
log_every_t=None):
if not log_every_t:
log_every_t = self.log_every_t
timesteps = self.num_timesteps
if batch_size is not None:
b = batch_size if batch_size is not None else shape[0]
shape = [batch_size] + list(shape)
else:
b = batch_size = shape[0]
if x_T is None:
img = torch.randn(shape, device=self.device)
else:
img = x_T
intermediates = []
if cond is not None:
if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
if start_T is not None:
timesteps = min(timesteps, start_T)
iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
total=timesteps) if verbose else reversed(
range(0, timesteps))
if type(temperature) == float:
temperature = [temperature] * timesteps
for i in iterator:
ts = torch.full((b,), i, device=self.device, dtype=torch.long)
if self.shorten_cond_schedule:
assert self.model.conditioning_key != 'hybrid'
tc = self.cond_ids[ts].to(cond.device)
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
img, x0_partial = self.p_sample(img, cond, ts,
clip_denoised=self.clip_denoised,
quantize_denoised=quantize_denoised, return_x0=True,
temperature=temperature[i], noise_dropout=noise_dropout,
score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
if mask is not None:
assert x0 is not None
img_orig = self.q_sample(x0, ts)
img = img_orig * mask + (1. - mask) * img
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(x0_partial)
if callback: callback(i)
if img_callback: img_callback(img, i)
return img, intermediates
@torch.no_grad()
def p_sample_loop(self, cond, shape, return_intermediates=False,
x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, start_T=None,
log_every_t=None):
if not log_every_t:
log_every_t = self.log_every_t
device = self.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
intermediates = [img]
if timesteps is None:
timesteps = self.num_timesteps
if start_T is not None:
timesteps = min(timesteps, start_T)
iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
range(0, timesteps))
if mask is not None:
assert x0 is not None
assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
for i in iterator:
ts = torch.full((b,), i, device=device, dtype=torch.long)
if self.shorten_cond_schedule:
assert self.model.conditioning_key != 'hybrid'
tc = self.cond_ids[ts].to(cond.device)
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
img = self.p_sample(img, cond, ts,
clip_denoised=self.clip_denoised,
quantize_denoised=quantize_denoised)
if mask is not None:
img_orig = self.q_sample(x0, ts)
img = img_orig * mask + (1. - mask) * img
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(img)
if callback: callback(i)
if img_callback: img_callback(img, i)
if return_intermediates:
return img, intermediates
return img
@torch.no_grad()
def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
verbose=True, timesteps=None, quantize_denoised=False,
mask=None, x0=None, shape=None,**kwargs):
if shape is None:
shape = (batch_size, self.channels, self.image_size, self.image_size)
if cond is not None:
if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
return self.p_sample_loop(cond,
shape,
return_intermediates=return_intermediates, x_T=x_T,
verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
mask=mask, x0=x0)
@torch.no_grad()
def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
if ddim:
ddim_sampler = DDIMSampler(self)
shape = (self.channels, self.image_size, self.image_size)
samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
shape,cond,verbose=False,**kwargs)
else:
samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
return_intermediates=True,**kwargs)
return samples, intermediates
@torch.no_grad()
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
plot_diffusion_rows=True, **kwargs):
use_ddim = ddim_steps is not None
log = dict()
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
return_first_stage_outputs=True,
force_c_encode=True,
return_original_cond=True,
bs=N)
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
log["inputs"] = x
log["reconstruction"] = xrec
if self.model.conditioning_key is not None:
if hasattr(self.cond_stage_model, "decode"):
xc = self.cond_stage_model.decode(c)
log["conditioning"] = xc
elif self.cond_stage_key in ["caption"]:
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
log["conditioning"] = xc
elif self.cond_stage_key == 'class_label':
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
log['conditioning'] = xc
elif isimage(xc):
log["conditioning"] = xc
if ismap(xc):
log["original_conditioning"] = self.to_rgb(xc)
if plot_diffusion_rows:
# get diffusion row
diffusion_row = list()
z_start = z[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
t = t.to(self.device).long()
noise = torch.randn_like(z_start)
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
diffusion_row.append(self.decode_first_stage(z_noisy))
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
log["diffusion_row"] = diffusion_grid
if sample:
# get denoise row
with self.ema_scope("Plotting"):
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
ddim_steps=ddim_steps,eta=ddim_eta)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
x_samples = self.decode_first_stage(samples)
log["samples"] = x_samples
if plot_denoise_rows:
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
log["denoise_row"] = denoise_grid
if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
self.first_stage_model, IdentityFirstStage):
# also display when quantizing x0 while sampling
with self.ema_scope("Plotting Quantized Denoised"):
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
ddim_steps=ddim_steps,eta=ddim_eta,
quantize_denoised=True)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
# quantize_denoised=True)
x_samples = self.decode_first_stage(samples.to(self.device))
log["samples_x0_quantized"] = x_samples
if inpaint:
# make a simple center square
b, h, w = z.shape[0], z.shape[2], z.shape[3]
mask = torch.ones(N, h, w).to(self.device)
# zeros will be filled in
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
mask = mask[:, None, ...]
with self.ema_scope("Plotting Inpaint"):
samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
x_samples = self.decode_first_stage(samples.to(self.device))
log["samples_inpainting"] = x_samples
log["mask"] = mask
# outpaint
with self.ema_scope("Plotting Outpaint"):
samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
x_samples = self.decode_first_stage(samples.to(self.device))
log["samples_outpainting"] = x_samples
if plot_progressive_rows:
with self.ema_scope("Plotting Progressives"):
img, progressives = self.progressive_denoising(c,
shape=(self.channels, self.image_size, self.image_size),
batch_size=N)
prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
log["progressive_row"] = prog_row
if return_keys:
if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
return log
else:
return {key: log[key] for key in return_keys}
return log
def configure_optimizers(self):
lr = self.learning_rate
params = list(self.model.parameters())
if self.cond_stage_trainable:
print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
params = params + list(self.cond_stage_model.parameters())
if self.learn_logvar:
print('Diffusion model optimizing logvar')
params.append(self.logvar)
opt = torch.optim.AdamW(params, lr=lr)
if self.use_scheduler:
assert 'target' in self.scheduler_config
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
scheduler = [
{
'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
'interval': 'step',
'frequency': 1
}]
return [opt], scheduler
return opt
@torch.no_grad()
def to_rgb(self, x):
x = x.float()
if not hasattr(self, "colorize"):
self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
x = nn.functional.conv2d(x, weight=self.colorize)
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
return x
class DiffusionWrapperV1(pl.LightningModule):
def __init__(self, diff_model_config, conditioning_key):
super().__init__()
self.diffusion_model = instantiate_from_config(diff_model_config)
self.conditioning_key = conditioning_key
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
if self.conditioning_key is None:
out = self.diffusion_model(x, t)
elif self.conditioning_key == 'concat':
xc = torch.cat([x] + c_concat, dim=1)
out = self.diffusion_model(xc, t)
elif self.conditioning_key == 'crossattn':
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(x, t, context=cc)
elif self.conditioning_key == 'hybrid':
xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(xc, t, context=cc)
elif self.conditioning_key == 'adm':
cc = c_crossattn[0]
out = self.diffusion_model(x, t, y=cc)
else:
raise NotImplementedError()
return out
class Layout2ImgDiffusionV1(LatentDiffusionV1):
# TODO: move all layout-specific hacks to this class
def __init__(self, cond_stage_key, *args, **kwargs):
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
def log_images(self, batch, N=8, *args, **kwargs):
logs = super().log_images(batch=batch, N=N, *args, **kwargs)
key = 'train' if self.training else 'validation'
dset = self.trainer.datamodule.datasets[key]
mapper = dset.conditional_builders[self.cond_stage_key]
bbox_imgs = []
map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
for tknzd_bbox in batch[self.cond_stage_key][:N]:
bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
bbox_imgs.append(bboximg)
cond_img = torch.stack(bbox_imgs, dim=0)
logs['bbox_image'] = cond_img
return logs
setattr(ldm.models.diffusion.ddpm, "DDPMV1", DDPMV1)
setattr(ldm.models.diffusion.ddpm, "LatentDiffusionV1", LatentDiffusionV1)
setattr(ldm.models.diffusion.ddpm, "DiffusionWrapperV1", DiffusionWrapperV1)
setattr(ldm.models.diffusion.ddpm, "Layout2ImgDiffusionV1", Layout2ImgDiffusionV1)
import os
from modules import paths
def preload(parser):
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(paths.models_path, 'ScuNET'))
...@@ -9,7 +9,7 @@ from basicsr.utils.download_util import load_file_from_url ...@@ -9,7 +9,7 @@ from basicsr.utils.download_util import load_file_from_url
import modules.upscaler import modules.upscaler
from modules import devices, modelloader from modules import devices, modelloader
from modules.scunet_model_arch import SCUNet as net from scunet_model_arch import SCUNet as net
class UpscalerScuNET(modules.upscaler.Upscaler): class UpscalerScuNET(modules.upscaler.Upscaler):
...@@ -49,12 +49,12 @@ class UpscalerScuNET(modules.upscaler.Upscaler): ...@@ -49,12 +49,12 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
if model is None: if model is None:
return img return img
device = devices.device_scunet device = devices.get_device_for('scunet')
img = np.array(img) img = np.array(img)
img = img[:, :, ::-1] img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255 img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float() img = torch.from_numpy(img).float()
img = devices.mps_contiguous_to(img.unsqueeze(0), device) img = img.unsqueeze(0).to(device)
with torch.no_grad(): with torch.no_grad():
output = model(img) output = model(img)
...@@ -66,7 +66,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): ...@@ -66,7 +66,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
return PIL.Image.fromarray(output, 'RGB') return PIL.Image.fromarray(output, 'RGB')
def load_model(self, path: str): def load_model(self, path: str):
device = devices.device_scunet device = devices.get_device_for('scunet')
if "http" in path: if "http" in path:
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name, filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
progress=True) progress=True)
......
import os
from modules import paths
def preload(parser):
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(paths.models_path, 'SwinIR'))
...@@ -7,15 +7,14 @@ from PIL import Image ...@@ -7,15 +7,14 @@ from PIL import Image
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm from tqdm import tqdm
from modules import modelloader, devices from modules import modelloader, devices, script_callbacks, shared
from modules.shared import cmd_opts, opts from modules.shared import cmd_opts, opts
from modules.swinir_model_arch import SwinIR as net from swinir_model_arch import SwinIR as net
from modules.swinir_model_arch_v2 import Swin2SR as net2 from swinir_model_arch_v2 import Swin2SR as net2
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
precision_scope = (
torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext device_swinir = devices.get_device_for('swinir')
)
class UpscalerSwinIR(Upscaler): class UpscalerSwinIR(Upscaler):
...@@ -42,7 +41,7 @@ class UpscalerSwinIR(Upscaler): ...@@ -42,7 +41,7 @@ class UpscalerSwinIR(Upscaler):
model = self.load_model(model_file) model = self.load_model(model_file)
if model is None: if model is None:
return img return img
model = model.to(devices.device_swinir) model = model.to(device_swinir, dtype=devices.dtype)
img = upscale(img, model) img = upscale(img, model)
try: try:
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -94,25 +93,27 @@ class UpscalerSwinIR(Upscaler): ...@@ -94,25 +93,27 @@ class UpscalerSwinIR(Upscaler):
model.load_state_dict(pretrained_model[params], strict=True) model.load_state_dict(pretrained_model[params], strict=True)
else: else:
model.load_state_dict(pretrained_model, strict=True) model.load_state_dict(pretrained_model, strict=True)
if not cmd_opts.no_half:
model = model.half()
return model return model
def upscale( def upscale(
img, img,
model, model,
tile=opts.SWIN_tile, tile=None,
tile_overlap=opts.SWIN_tile_overlap, tile_overlap=None,
window_size=8, window_size=8,
scale=4, scale=4,
): ):
tile = tile or opts.SWIN_tile
tile_overlap = tile_overlap or opts.SWIN_tile_overlap
img = np.array(img) img = np.array(img)
img = img[:, :, ::-1] img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255 img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float() img = torch.from_numpy(img).float()
img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_swinir) img = img.unsqueeze(0).to(device_swinir, dtype=devices.dtype)
with torch.no_grad(), precision_scope("cuda"): with torch.no_grad(), devices.autocast():
_, _, h_old, w_old = img.size() _, _, h_old, w_old = img.size()
h_pad = (h_old // window_size + 1) * window_size - h_old h_pad = (h_old // window_size + 1) * window_size - h_old
w_pad = (w_old // window_size + 1) * window_size - w_old w_pad = (w_old // window_size + 1) * window_size - w_old
...@@ -139,8 +140,8 @@ def inference(img, model, tile, tile_overlap, window_size, scale): ...@@ -139,8 +140,8 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
stride = tile - tile_overlap stride = tile - tile_overlap
h_idx_list = list(range(0, h - tile, stride)) + [h - tile] h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile] w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=devices.device_swinir).type_as(img) E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device_swinir).type_as(img)
W = torch.zeros_like(E, dtype=torch.half, device=devices.device_swinir) W = torch.zeros_like(E, dtype=devices.dtype, device=device_swinir)
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
for h_idx in h_idx_list: for h_idx in h_idx_list:
...@@ -159,3 +160,13 @@ def inference(img, model, tile, tile_overlap, window_size, scale): ...@@ -159,3 +160,13 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
output = E.div_(W) output = E.div_(W)
return output return output
def on_ui_settings():
import gradio as gr
shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
script_callbacks.on_ui_settings(on_ui_settings)
// Stable Diffusion WebUI - Bracket checker
// Version 1.0
// By Hingashi no Florin/Bwin4L
// Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs.
// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
function checkBrackets(evt) {
textArea = evt.target;
tabName = evt.target.parentElement.parentElement.id.split("_")[0];
counterElt = document.querySelector('gradio-app').shadowRoot.querySelector('#' + tabName + '_token_counter');
promptName = evt.target.parentElement.parentElement.id.includes('neg') ? ' negative' : '';
errorStringParen = '(' + tabName + promptName + ' prompt) - Different number of opening and closing parentheses detected.\n';
errorStringSquare = '[' + tabName + promptName + ' prompt] - Different number of opening and closing square brackets detected.\n';
errorStringCurly = '{' + tabName + promptName + ' prompt} - Different number of opening and closing curly brackets detected.\n';
openBracketRegExp = /\(/g;
closeBracketRegExp = /\)/g;
openSquareBracketRegExp = /\[/g;
closeSquareBracketRegExp = /\]/g;
openCurlyBracketRegExp = /\{/g;
closeCurlyBracketRegExp = /\}/g;
totalOpenBracketMatches = 0;
totalCloseBracketMatches = 0;
totalOpenSquareBracketMatches = 0;
totalCloseSquareBracketMatches = 0;
totalOpenCurlyBracketMatches = 0;
totalCloseCurlyBracketMatches = 0;
openBracketMatches = textArea.value.match(openBracketRegExp);
if(openBracketMatches) {
totalOpenBracketMatches = openBracketMatches.length;
}
closeBracketMatches = textArea.value.match(closeBracketRegExp);
if(closeBracketMatches) {
totalCloseBracketMatches = closeBracketMatches.length;
}
openSquareBracketMatches = textArea.value.match(openSquareBracketRegExp);
if(openSquareBracketMatches) {
totalOpenSquareBracketMatches = openSquareBracketMatches.length;
}
closeSquareBracketMatches = textArea.value.match(closeSquareBracketRegExp);
if(closeSquareBracketMatches) {
totalCloseSquareBracketMatches = closeSquareBracketMatches.length;
}
openCurlyBracketMatches = textArea.value.match(openCurlyBracketRegExp);
if(openCurlyBracketMatches) {
totalOpenCurlyBracketMatches = openCurlyBracketMatches.length;
}
closeCurlyBracketMatches = textArea.value.match(closeCurlyBracketRegExp);
if(closeCurlyBracketMatches) {
totalCloseCurlyBracketMatches = closeCurlyBracketMatches.length;
}
if(totalOpenBracketMatches != totalCloseBracketMatches) {
if(!counterElt.title.includes(errorStringParen)) {
counterElt.title += errorStringParen;
}
} else {
counterElt.title = counterElt.title.replace(errorStringParen, '');
}
if(totalOpenSquareBracketMatches != totalCloseSquareBracketMatches) {
if(!counterElt.title.includes(errorStringSquare)) {
counterElt.title += errorStringSquare;
}
} else {
counterElt.title = counterElt.title.replace(errorStringSquare, '');
}
if(totalOpenCurlyBracketMatches != totalCloseCurlyBracketMatches) {
if(!counterElt.title.includes(errorStringCurly)) {
counterElt.title += errorStringCurly;
}
} else {
counterElt.title = counterElt.title.replace(errorStringCurly, '');
}
if(counterElt.title != '') {
counterElt.style = 'color: #FF5555;';
} else {
counterElt.style = '';
}
}
var shadowRootLoaded = setInterval(function() {
var shadowTextArea = document.querySelector('gradio-app').shadowRoot.querySelectorAll('#txt2img_prompt > label > textarea');
if(shadowTextArea.length < 1) {
return false;
}
clearInterval(shadowRootLoaded);
document.querySelector('gradio-app').shadowRoot.querySelector('#txt2img_prompt').onkeyup = checkBrackets;
document.querySelector('gradio-app').shadowRoot.querySelector('#txt2img_neg_prompt').onkeyup = checkBrackets;
document.querySelector('gradio-app').shadowRoot.querySelector('#img2img_prompt').onkeyup = checkBrackets;
document.querySelector('gradio-app').shadowRoot.querySelector('#img2img_neg_prompt').onkeyup = checkBrackets;
}, 1000);
import random
from modules import script_callbacks, shared
import gradio as gr
art_symbol = '\U0001f3a8' # 🎨
global_prompt = None
related_ids = {"txt2img_prompt", "txt2img_clear_prompt", "img2img_prompt", "img2img_clear_prompt" }
def roll_artist(prompt):
allowed_cats = set([x for x in shared.artist_db.categories() if len(shared.opts.random_artist_categories)==0 or x in shared.opts.random_artist_categories])
artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats])
return prompt + ", " + artist.name if prompt != '' else artist.name
def add_roll_button(prompt):
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
roll.click(
fn=roll_artist,
_js="update_txt2img_tokens",
inputs=[
prompt,
],
outputs=[
prompt,
]
)
def after_component(component, **kwargs):
global global_prompt
elem_id = kwargs.get('elem_id', None)
if elem_id not in related_ids:
return
if elem_id == "txt2img_prompt":
global_prompt = component
elif elem_id == "txt2img_clear_prompt":
add_roll_button(global_prompt)
elif elem_id == "img2img_prompt":
global_prompt = component
elif elem_id == "img2img_clear_prompt":
add_roll_button(global_prompt)
script_callbacks.on_after_component(after_component)
<div>
<a href="/docs">API</a>
 • 
<a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui">Github</a>
 • 
<a href="https://gradio.app">Gradio</a>
 • 
<a href="/" onclick="javascript:gradioApp().getElementById('settings_restart_gradio').click(); return false">Reload UI</a>
</div>
<style>
#licenses h2 {font-size: 1.2em; font-weight: bold; margin-bottom: 0.2em;}
#licenses small {font-size: 0.95em; opacity: 0.85;}
#licenses pre { margin: 1em 0 2em 0;}
</style>
<h2><a href="https://github.com/sczhou/CodeFormer/blob/master/LICENSE">CodeFormer</a></h2>
<small>Parts of CodeFormer code had to be copied to be compatible with GFPGAN.</small>
<pre>
S-Lab License 1.0
Copyright 2022 S-Lab
Redistribution and use for non-commercial purpose in source and
binary forms, with or without modification, are permitted provided
that the following conditions are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
In the event that redistribution and/or use for commercial purpose in
source or binary forms, with or without modification is required,
please contact the contributor(s) of the work.
</pre>
<h2><a href="https://github.com/victorca25/iNNfer/blob/main/LICENSE">ESRGAN</a></h2>
<small>Code for architecture and reading models copied.</small>
<pre>
MIT License
Copyright (c) 2021 victorca25
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
</pre>
<h2><a href="https://github.com/xinntao/Real-ESRGAN/blob/master/LICENSE">Real-ESRGAN</a></h2>
<small>Some code is copied to support ESRGAN models.</small>
<pre>
BSD 3-Clause License
Copyright (c) 2021, Xintao Wang
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
</pre>
<h2><a href="https://github.com/invoke-ai/InvokeAI/blob/main/LICENSE">InvokeAI</a></h2>
<small>Some code for compatibility with OSX is taken from lstein's repository.</small>
<pre>
MIT License
Copyright (c) 2022 InvokeAI Team
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
</pre>
<h2><a href="https://github.com/Hafiidz/latent-diffusion/blob/main/LICENSE">LDSR</a></h2>
<small>Code added by contirubtors, most likely copied from this repository.</small>
<pre>
MIT License
Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
</pre>
<h2><a href="https://github.com/pharmapsychotic/clip-interrogator/blob/main/LICENSE">CLIP Interrogator</a></h2>
<small>Some small amounts of code borrowed and reworked.</small>
<pre>
MIT License
Copyright (c) 2022 pharmapsychotic
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
</pre>
<h2><a href="https://github.com/JingyunLiang/SwinIR/blob/main/LICENSE">SwinIR</a></h2>
<small>Code added by contirubtors, most likely copied from this repository.</small>
<pre>
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [2021] [SwinIR Authors]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
</pre>
...@@ -61,15 +61,15 @@ contextMenuInit = function(){ ...@@ -61,15 +61,15 @@ contextMenuInit = function(){
} }
function appendContextMenuOption(targetEmementSelector,entryName,entryFunction){ function appendContextMenuOption(targetElementSelector,entryName,entryFunction){
currentItems = menuSpecs.get(targetEmementSelector) currentItems = menuSpecs.get(targetElementSelector)
if(!currentItems){ if(!currentItems){
currentItems = [] currentItems = []
menuSpecs.set(targetEmementSelector,currentItems); menuSpecs.set(targetElementSelector,currentItems);
} }
let newItem = {'id':targetEmementSelector+'_'+uid(), let newItem = {'id':targetElementSelector+'_'+uid(),
'name':entryName, 'name':entryName,
'func':entryFunction, 'func':entryFunction,
'isNew':true} 'isNew':true}
......
...@@ -9,11 +9,19 @@ function dropReplaceImage( imgWrap, files ) { ...@@ -9,11 +9,19 @@ function dropReplaceImage( imgWrap, files ) {
return; return;
} }
const tmpFile = files[0];
imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click(); imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click();
const callback = () => { const callback = () => {
const fileInput = imgWrap.querySelector('input[type="file"]'); const fileInput = imgWrap.querySelector('input[type="file"]');
if ( fileInput ) { if ( fileInput ) {
if ( files.length === 0 ) {
files = new DataTransfer();
files.items.add(tmpFile);
fileInput.files = files.files;
} else {
fileInput.files = files; fileInput.files = files;
}
fileInput.dispatchEvent(new Event('change')); fileInput.dispatchEvent(new Event('change'));
} }
}; };
......
// attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes
let txt2img_gallery, img2img_gallery, modal = undefined;
onUiUpdate(function(){
if (!txt2img_gallery) {
txt2img_gallery = attachGalleryListeners("txt2img")
}
if (!img2img_gallery) {
img2img_gallery = attachGalleryListeners("img2img")
}
if (!modal) {
modal = gradioApp().getElementById('lightboxModal')
modalObserver.observe(modal, { attributes : true, attributeFilter : ['style'] });
}
});
let modalObserver = new MutationObserver(function(mutations) {
mutations.forEach(function(mutationRecord) {
let selectedTab = gradioApp().querySelector('#tabs div button.bg-white')?.innerText
if (mutationRecord.target.style.display === 'none' && selectedTab === 'txt2img' || selectedTab === 'img2img')
gradioApp().getElementById(selectedTab+"_generation_info_button").click()
});
});
function attachGalleryListeners(tab_name) {
gallery = gradioApp().querySelector('#'+tab_name+'_gallery')
gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click());
gallery?.addEventListener('keydown', (e) => {
if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow
gradioApp().getElementById(tab_name+"_generation_info_button").click()
});
return gallery;
}
...@@ -6,6 +6,7 @@ titles = { ...@@ -6,6 +6,7 @@ titles = {
"GFPGAN": "Restore low quality faces using GFPGAN neural network", "GFPGAN": "Restore low quality faces using GFPGAN neural network",
"Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps to higher than 30-40 does not help", "Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps to higher than 30-40 does not help",
"DDIM": "Denoising Diffusion Implicit Models - best at inpainting", "DDIM": "Denoising Diffusion Implicit Models - best at inpainting",
"DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution",
"Batch count": "How many batches of images to create", "Batch count": "How many batches of images to create",
"Batch size": "How many image to create in a single batch", "Batch size": "How many image to create in a single batch",
...@@ -17,6 +18,7 @@ titles = { ...@@ -17,6 +18,7 @@ titles = {
"\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.", "\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
"\u{1f4c2}": "Open images output directory", "\u{1f4c2}": "Open images output directory",
"\u{1f4be}": "Save style", "\u{1f4be}": "Save style",
"\U0001F5D1": "Clear prompt",
"\u{1f4cb}": "Apply selected styles to current prompt", "\u{1f4cb}": "Apply selected styles to current prompt",
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
...@@ -62,8 +64,8 @@ titles = { ...@@ -62,8 +64,8 @@ titles = {
"Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.", "Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.", "Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.", "Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle", "Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
"Loopback": "Process an image, use it as an input, repeat.", "Loopback": "Process an image, use it as an input, repeat.",
...@@ -94,6 +96,11 @@ titles = { ...@@ -94,6 +96,11 @@ titles = {
"Add difference": "Result = A + (B - C) * M", "Add difference": "Result = A + (B - C) * M",
"Learning rate": "how fast should the training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.", "Learning rate": "how fast should the training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
"Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.",
"Approx NN": "Cheap neural network approximation. Very fast compared to VAE, but produces pictures with 4 times smaller horizontal/vertical resoluton and lower quality.",
"Approx cheap": "Very cheap approximation. Very fast compared to VAE, but produces pictures with 8 times smaller horizontal/vertical resoluton and extremely low quality."
} }
......
...@@ -15,7 +15,7 @@ onUiUpdate(function(){ ...@@ -15,7 +15,7 @@ onUiUpdate(function(){
} }
} }
const galleryPreviews = gradioApp().querySelectorAll('img.h-full.w-full.overflow-hidden'); const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] img.h-full.w-full.overflow-hidden');
if (galleryPreviews == null) return; if (galleryPreviews == null) return;
......
...@@ -3,7 +3,7 @@ global_progressbars = {} ...@@ -3,7 +3,7 @@ global_progressbars = {}
galleries = {} galleries = {}
galleryObservers = {} galleryObservers = {}
// this tracks laumnches of window.setTimeout for progressbar to prevent starting a new timeout when the previous is still running // this tracks launches of window.setTimeout for progressbar to prevent starting a new timeout when the previous is still running
timeoutIds = {} timeoutIds = {}
function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){
...@@ -23,7 +23,7 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip ...@@ -23,7 +23,7 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip
if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){ if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){
if(progressbar.innerText){ if(progressbar.innerText){
let newtitle = 'Stable Diffusion - ' + progressbar.innerText let newtitle = '[' + progressbar.innerText.trim() + '] Stable Diffusion';
if(document.title != newtitle){ if(document.title != newtitle){
document.title = newtitle; document.title = newtitle;
} }
...@@ -92,14 +92,26 @@ function check_gallery(id_gallery){ ...@@ -92,14 +92,26 @@ function check_gallery(id_gallery){
if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) { if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) {
// automatically re-open previously selected index (if exists) // automatically re-open previously selected index (if exists)
activeElement = gradioApp().activeElement; activeElement = gradioApp().activeElement;
let scrollX = window.scrollX;
let scrollY = window.scrollY;
galleryButtons[prevSelectedIndex].click(); galleryButtons[prevSelectedIndex].click();
showGalleryImage(); showGalleryImage();
// When the gallery button is clicked, it gains focus and scrolls itself into view
// We need to scroll back to the previous position
setTimeout(function (){
window.scrollTo(scrollX, scrollY);
}, 50);
if(activeElement){ if(activeElement){
// i fought this for about an hour; i don't know why the focus is lost or why this helps recover it // i fought this for about an hour; i don't know why the focus is lost or why this helps recover it
// if somenoe has a better solution please by all means // if someone has a better solution please by all means
setTimeout(function() { activeElement.focus() }, 1); setTimeout(function (){
activeElement.focus({
preventScroll: true // Refocus the element that was focused before the gallery was opened without scrolling to it
})
}, 1);
} }
} }
}) })
......
// various functions for interation with ui.py not large enough to warrant putting them in separate files // various functions for interaction with ui.py not large enough to warrant putting them in separate files
function set_theme(theme){ function set_theme(theme){
gradioURL = window.location.href gradioURL = window.location.href
...@@ -8,8 +8,8 @@ function set_theme(theme){ ...@@ -8,8 +8,8 @@ function set_theme(theme){
} }
function selected_gallery_index(){ function selected_gallery_index(){
var buttons = gradioApp().querySelectorAll('[style="display: block;"].tabitem .gallery-item') var buttons = gradioApp().querySelectorAll('[style="display: block;"].tabitem div[id$=_gallery] .gallery-item')
var button = gradioApp().querySelector('[style="display: block;"].tabitem .gallery-item.\\!ring-2') var button = gradioApp().querySelector('[style="display: block;"].tabitem div[id$=_gallery] .gallery-item.\\!ring-2')
var result = -1 var result = -1
buttons.forEach(function(v, i){ if(v==button) { result = i } }) buttons.forEach(function(v, i){ if(v==button) { result = i } })
...@@ -19,7 +19,7 @@ function selected_gallery_index(){ ...@@ -19,7 +19,7 @@ function selected_gallery_index(){
function extract_image_from_gallery(gallery){ function extract_image_from_gallery(gallery){
if(gallery.length == 1){ if(gallery.length == 1){
return gallery[0] return [gallery[0]]
} }
index = selected_gallery_index() index = selected_gallery_index()
...@@ -28,7 +28,7 @@ function extract_image_from_gallery(gallery){ ...@@ -28,7 +28,7 @@ function extract_image_from_gallery(gallery){
return [null] return [null]
} }
return gallery[index]; return [gallery[index]];
} }
function args_to_array(args){ function args_to_array(args){
...@@ -100,7 +100,7 @@ function create_submit_args(args){ ...@@ -100,7 +100,7 @@ function create_submit_args(args){
// As it is currently, txt2img and img2img send back the previous output args (txt2img_gallery, generation_info, html_info) whenever you generate a new image. // As it is currently, txt2img and img2img send back the previous output args (txt2img_gallery, generation_info, html_info) whenever you generate a new image.
// This can lead to uploading a huge gallery of previously generated images, which leads to an unnecessary delay between submitting and beginning to generate. // This can lead to uploading a huge gallery of previously generated images, which leads to an unnecessary delay between submitting and beginning to generate.
// I don't know why gradio is seding outputs along with inputs, but we can prevent sending the image gallery here, which seems to be an issue for some. // I don't know why gradio is sending outputs along with inputs, but we can prevent sending the image gallery here, which seems to be an issue for some.
// If gradio at some point stops sending outputs, this may break something // If gradio at some point stops sending outputs, this may break something
if(Array.isArray(res[res.length - 3])){ if(Array.isArray(res[res.length - 3])){
res[res.length - 3] = null res[res.length - 3] = null
...@@ -131,6 +131,15 @@ function ask_for_style_name(_, prompt_text, negative_prompt_text) { ...@@ -131,6 +131,15 @@ function ask_for_style_name(_, prompt_text, negative_prompt_text) {
return [name_, prompt_text, negative_prompt_text] return [name_, prompt_text, negative_prompt_text]
} }
function confirm_clear_prompt(prompt, negative_prompt) {
if(confirm("Delete prompt?")) {
prompt = ""
negative_prompt = ""
}
return [prompt, negative_prompt]
}
opts = {} opts = {}
...@@ -179,6 +188,17 @@ onUiUpdate(function(){ ...@@ -179,6 +188,17 @@ onUiUpdate(function(){
img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea"); img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button")); img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
} }
show_all_pages = gradioApp().getElementById('settings_show_all_pages')
settings_tabs = gradioApp().querySelector('#settings div')
if(show_all_pages && settings_tabs){
settings_tabs.appendChild(show_all_pages)
show_all_pages.onclick = function(){
gradioApp().querySelectorAll('#settings > div').forEach(function(elem){
elem.style.display = "block";
})
}
}
}) })
let txt2img_textarea, img2img_textarea = undefined; let txt2img_textarea, img2img_textarea = undefined;
......
...@@ -5,6 +5,8 @@ import sys ...@@ -5,6 +5,8 @@ import sys
import importlib.util import importlib.util
import shlex import shlex
import platform import platform
import argparse
import json
dir_repos = "repositories" dir_repos = "repositories"
dir_extensions = "extensions" dir_extensions = "extensions"
...@@ -17,6 +19,19 @@ def extract_arg(args, name): ...@@ -17,6 +19,19 @@ def extract_arg(args, name):
return [x for x in args if x != name], name in args return [x for x in args if x != name], name in args
def extract_opt(args, name):
opt = None
is_present = False
if name in args:
is_present = True
idx = args.index(name)
del args[idx]
if idx < len(args) and args[idx][0] != "-":
opt = args[idx]
del args[idx]
return args, is_present, opt
def run(command, desc=None, errdesc=None, custom_env=None): def run(command, desc=None, errdesc=None, custom_env=None):
if desc is not None: if desc is not None:
print(desc) print(desc)
...@@ -105,56 +120,78 @@ def version_check(commit): ...@@ -105,56 +120,78 @@ def version_check(commit):
print("version check failed", e) print("version check failed", e)
def run_extensions_installers(): def run_extension_installer(extension_dir):
if not os.path.isdir(dir_extensions): path_installer = os.path.join(extension_dir, "install.py")
return
for dirname_extension in os.listdir(dir_extensions):
path_installer = os.path.join(dir_extensions, dirname_extension, "install.py")
if not os.path.isfile(path_installer): if not os.path.isfile(path_installer):
continue return
try: try:
env = os.environ.copy() env = os.environ.copy()
env['PYTHONPATH'] = os.path.abspath(".") env['PYTHONPATH'] = os.path.abspath(".")
print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {dirname_extension}", custom_env=env)) print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
except Exception as e:
print(e, file=sys.stderr)
def list_extensions(settings_file):
settings = {}
try:
if os.path.isfile(settings_file):
with open(settings_file, "r", encoding="utf8") as file:
settings = json.load(file)
except Exception as e: except Exception as e:
print(e, file=sys.stderr) print(e, file=sys.stderr)
disabled_extensions = set(settings.get('disabled_extensions', []))
def prepare_enviroment(): return [x for x in os.listdir(dir_extensions) if x not in disabled_extensions]
def run_extensions_installers(settings_file):
if not os.path.isdir(dir_extensions):
return
for dirname_extension in list_extensions(settings_file):
run_extension_installer(os.path.join(dir_extensions, dirname_extension))
def prepare_environment():
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113") torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
commandline_args = os.environ.get('COMMANDLINE_ARGS', "") commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379") gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1") clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
deepdanbooru_package = os.environ.get('DEEPDANBOORU_PACKAGE', "git+https://github.com/KichangKim/DeepDanbooru.git@d91a2963bf87c6a770d74894667e9ffa9f6de7ff") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
xformers_windows_package = os.environ.get('XFORMERS_WINDOWS_PACKAGE', 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl') xformers_windows_package = os.environ.get('XFORMERS_WINDOWS_PACKAGE', 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl')
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/CompVis/stable-diffusion.git") stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
taming_transformers_repo = os.environ.get('TAMING_REANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git") taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
codeformer_repo = os.environ.get('CODEFORMET_REPO', 'https://github.com/sczhou/CodeFormer.git') codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc") stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "47b6b607fdd31875c9279cd2f4f16b92e4ea958e")
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6") taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "60e5042ca0da89c14d1dd59d73883280f8fce991") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "5b3af030dd83e0297272d861c19477735d0317ec")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
sys.argv += shlex.split(commandline_args) sys.argv += shlex.split(commandline_args)
test_argv = [x for x in sys.argv if x != '--tests']
parser = argparse.ArgumentParser()
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default='config.json')
args, _ = parser.parse_known_args(sys.argv)
sys.argv, _ = extract_arg(sys.argv, '-f')
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test') sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers') sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
sys.argv, update_check = extract_arg(sys.argv, '--update-check') sys.argv, update_check = extract_arg(sys.argv, '--update-check')
sys.argv, run_tests = extract_arg(sys.argv, '--tests') sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests')
xformers = '--xformers' in sys.argv xformers = '--xformers' in sys.argv
deepdanbooru = '--deepdanbooru' in sys.argv
ngrok = '--ngrok' in sys.argv ngrok = '--ngrok' in sys.argv
try: try:
...@@ -177,6 +214,9 @@ def prepare_enviroment(): ...@@ -177,6 +214,9 @@ def prepare_enviroment():
if not is_installed("clip"): if not is_installed("clip"):
run_pip(f"install {clip_package}", "clip") run_pip(f"install {clip_package}", "clip")
if not is_installed("open_clip"):
run_pip(f"install {openclip_package}", "open_clip")
if (not is_installed("xformers") or reinstall_xformers) and xformers: if (not is_installed("xformers") or reinstall_xformers) and xformers:
if platform.system() == "Windows": if platform.system() == "Windows":
if platform.python_version().startswith("3.10"): if platform.python_version().startswith("3.10"):
...@@ -189,15 +229,12 @@ def prepare_enviroment(): ...@@ -189,15 +229,12 @@ def prepare_enviroment():
elif platform.system() == "Linux": elif platform.system() == "Linux":
run_pip("install xformers", "xformers") run_pip("install xformers", "xformers")
if not is_installed("deepdanbooru") and deepdanbooru:
run_pip(f"install {deepdanbooru_package}#egg=deepdanbooru[tensorflow] tensorflow==2.10.0 tensorflow-io==0.27.0", "deepdanbooru")
if not is_installed("pyngrok") and ngrok: if not is_installed("pyngrok") and ngrok:
run_pip("install pyngrok", "ngrok") run_pip("install pyngrok", "ngrok")
os.makedirs(dir_repos, exist_ok=True) os.makedirs(dir_repos, exist_ok=True)
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash) git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash) git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
...@@ -208,7 +245,7 @@ def prepare_enviroment(): ...@@ -208,7 +245,7 @@ def prepare_enviroment():
run_pip(f"install -r {requirements_file}", "requirements for Web UI") run_pip(f"install -r {requirements_file}", "requirements for Web UI")
run_extensions_installers() run_extensions_installers(settings_file=args.ui_settings_file)
if update_check: if update_check:
version_check(commit) version_check(commit)
...@@ -218,24 +255,30 @@ def prepare_enviroment(): ...@@ -218,24 +255,30 @@ def prepare_enviroment():
exit(0) exit(0)
if run_tests: if run_tests:
tests(test_argv) exitcode = tests(test_dir)
exit(0) exit(exitcode)
def tests(argv): def tests(test_dir):
if "--api" not in argv: if "--api" not in sys.argv:
argv.append("--api") sys.argv.append("--api")
if "--ckpt" not in sys.argv:
sys.argv.append("--ckpt")
sys.argv.append("./test/test_files/empty.pt")
if "--skip-torch-cuda-test" not in sys.argv:
sys.argv.append("--skip-torch-cuda-test")
print(f"Launching Web UI in another process for testing with arguments: {' '.join(argv[1:])}") print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}")
with open('test/stdout.txt', "w", encoding="utf8") as stdout, open('test/stderr.txt', "w", encoding="utf8") as stderr: with open('test/stdout.txt', "w", encoding="utf8") as stdout, open('test/stderr.txt', "w", encoding="utf8") as stderr:
proc = subprocess.Popen([sys.executable, *argv], stdout=stdout, stderr=stderr) proc = subprocess.Popen([sys.executable, *sys.argv], stdout=stdout, stderr=stderr)
import test.server_poll import test.server_poll
test.server_poll.run_tests() exitcode = test.server_poll.run_tests(proc, test_dir)
print(f"Stopping Web UI process with id {proc.pid}") print(f"Stopping Web UI process with id {proc.pid}")
proc.kill() proc.kill()
return exitcode
def start(): def start():
...@@ -248,5 +291,5 @@ def start(): ...@@ -248,5 +291,5 @@ def start():
if __name__ == "__main__": if __name__ == "__main__":
prepare_enviroment() prepare_environment()
start() start()
import base64 import base64
import io import io
import time import time
import datetime
import uvicorn import uvicorn
from threading import Lock from threading import Lock
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image from io import BytesIO
from fastapi import APIRouter, Depends, FastAPI, HTTPException from gradio.processing_utils import decode_base64_to_file
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from secrets import compare_digest
import modules.shared as shared import modules.shared as shared
from modules import sd_samplers, deepbooru, sd_hijack
from modules.api.models import * from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers
from modules.extras import run_extras, run_pnginfo from modules.extras import run_extras, run_pnginfo
from PIL import PngImagePlugin from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
from modules.sd_models import checkpoints_list from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image
from modules.sd_models import checkpoints_list, find_checkpoint_config
from modules.realesrgan_model import get_realesrgan_models from modules.realesrgan_model import get_realesrgan_models
from modules import devices
from typing import List from typing import List
def upscaler_to_index(name: str): def upscaler_to_index(name: str):
...@@ -22,8 +31,12 @@ def upscaler_to_index(name: str): ...@@ -22,8 +31,12 @@ def upscaler_to_index(name: str):
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}") raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) def validate_sampler_name(name):
config = sd_samplers.all_samplers_map.get(name, None)
if config is None:
raise HTTPException(status_code=404, detail="Sampler not found")
return name
def setUpscalers(req: dict): def setUpscalers(req: dict):
reqDict = vars(req) reqDict = vars(req)
...@@ -33,6 +46,10 @@ def setUpscalers(req: dict): ...@@ -33,6 +46,10 @@ def setUpscalers(req: dict):
reqDict.pop('upscaler_2') reqDict.pop('upscaler_2')
return reqDict return reqDict
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";")[1].split(",")[1]
return Image.open(BytesIO(base64.b64decode(encoding)))
def encode_pil_to_base64(image): def encode_pil_to_base64(image):
with io.BytesIO() as output_bytes: with io.BytesIO() as output_bytes:
...@@ -51,67 +68,104 @@ def encode_pil_to_base64(image): ...@@ -51,67 +68,104 @@ def encode_pil_to_base64(image):
bytes_data = output_bytes.getvalue() bytes_data = output_bytes.getvalue()
return base64.b64encode(bytes_data) return base64.b64encode(bytes_data)
def api_middleware(app: FastAPI):
@app.middleware("http")
async def log_and_time(req: Request, call_next):
ts = time.time()
res: Response = await call_next(req)
duration = str(round(time.time() - ts, 4))
res.headers["X-Process-Time"] = duration
endpoint = req.scope.get('path', 'err')
if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
code = res.status_code,
ver = req.scope.get('http_version', '0.0'),
cli = req.scope.get('client', ('0:0.0.0', 0))[0],
prot = req.scope.get('scheme', 'err'),
method = req.scope.get('method', 'err'),
endpoint = endpoint,
duration = duration,
))
return res
class Api: class Api:
def __init__(self, app: FastAPI, queue_lock: Lock): def __init__(self, app: FastAPI, queue_lock: Lock):
if shared.cmd_opts.api_auth:
self.credentials = dict()
for auth in shared.cmd_opts.api_auth.split(","):
user, password = auth.split(":")
self.credentials[user] = password
self.router = APIRouter() self.router = APIRouter()
self.app = app self.app = app
self.queue_lock = queue_lock self.queue_lock = queue_lock
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse) api_middleware(self.app)
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse) self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse) self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse) self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
self.app.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"]) self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"]) self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
self.app.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel) self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
self.app.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
self.app.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel) self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
self.app.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem]) self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
self.app.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem]) self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
self.app.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem]) self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
self.app.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem]) self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
self.app.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem]) self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
self.app.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem]) self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
self.app.add_api_route("/sdapi/v1/prompt-styles", self.get_promp_styles, methods=["GET"], response_model=List[PromptStyleItem]) self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
self.app.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str]) self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
self.app.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem]) self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse)
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
def add_api_route(self, path: str, endpoint, **kwargs):
if shared.cmd_opts.api_auth:
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
return self.app.add_api_route(path, endpoint, **kwargs)
def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())):
if credentials.username in self.credentials:
if compare_digest(credentials.password, self.credentials[credentials.username]):
return True
raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
if sampler_index is None:
raise HTTPException(status_code=404, detail="Sampler not found")
populate = txt2imgreq.copy(update={ # Override __init__ params populate = txt2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model, "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
"sampler_index": sampler_index[0],
"do_not_save_samples": True, "do_not_save_samples": True,
"do_not_save_grid": True "do_not_save_grid": True
} }
) )
p = StableDiffusionProcessingTxt2Img(**vars(populate)) if populate.sampler_name:
# Override object param populate.sampler_index = None # prevent a warning later on
shared.state.begin()
with self.queue_lock: with self.queue_lock:
processed = process_images(p) p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **vars(populate))
shared.state.begin()
processed = process_images(p)
shared.state.end() shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images)) b64images = list(map(encode_pil_to_base64, processed.images))
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI): def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
sampler_index = sampler_to_index(img2imgreq.sampler_index)
if sampler_index is None:
raise HTTPException(status_code=404, detail="Sampler not found")
init_images = img2imgreq.init_images init_images = img2imgreq.init_images
if init_images is None: if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found") raise HTTPException(status_code=404, detail="Init image not found")
...@@ -120,34 +174,30 @@ class Api: ...@@ -120,34 +174,30 @@ class Api:
if mask: if mask:
mask = decode_base64_to_image(mask) mask = decode_base64_to_image(mask)
populate = img2imgreq.copy(update={ # Override __init__ params populate = img2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model, "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
"sampler_index": sampler_index[0],
"do_not_save_samples": True, "do_not_save_samples": True,
"do_not_save_grid": True, "do_not_save_grid": True,
"mask": mask "mask": mask
} }
) )
p = StableDiffusionProcessingImg2Img(**vars(populate)) if populate.sampler_name:
populate.sampler_index = None # prevent a warning later on
imgs = [] args = vars(populate)
for img in init_images: args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
img = decode_base64_to_image(img)
imgs = [img] * p.batch_size
p.init_images = imgs with self.queue_lock:
p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
p.init_images = [decode_base64_to_image(x) for x in init_images]
shared.state.begin() shared.state.begin()
with self.queue_lock:
processed = process_images(p) processed = process_images(p)
shared.state.end() shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images)) b64images = list(map(encode_pil_to_base64, processed.images))
if (not img2imgreq.include_init_images): if not img2imgreq.include_init_images:
img2imgreq.init_images = None img2imgreq.init_images = None
img2imgreq.mask = None img2imgreq.mask = None
...@@ -159,7 +209,7 @@ class Api: ...@@ -159,7 +209,7 @@ class Api:
reqDict['image'] = decode_base64_to_image(reqDict['image']) reqDict['image'] = decode_base64_to_image(reqDict['image'])
with self.queue_lock: with self.queue_lock:
result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", **reqDict) result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1]) return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
...@@ -175,7 +225,7 @@ class Api: ...@@ -175,7 +225,7 @@ class Api:
reqDict.pop('imageList') reqDict.pop('imageList')
with self.queue_lock: with self.queue_lock:
result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", **reqDict) result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
...@@ -220,11 +270,17 @@ class Api: ...@@ -220,11 +270,17 @@ class Api:
if image_b64 is None: if image_b64 is None:
raise HTTPException(status_code=404, detail="Image not found") raise HTTPException(status_code=404, detail="Image not found")
img = self.__base64_to_image(image_b64) img = decode_base64_to_image(image_b64)
img = img.convert('RGB')
# Override object param # Override object param
with self.queue_lock: with self.queue_lock:
if interrogatereq.model == "clip":
processed = shared.interrogator.interrogate(img) processed = shared.interrogator.interrogate(img)
elif interrogatereq.model == "deepdanbooru":
processed = deepbooru.model.tag(img)
else:
raise HTTPException(status_code=404, detail="Model not found")
return InterrogateResponse(caption=processed) return InterrogateResponse(caption=processed)
...@@ -233,6 +289,9 @@ class Api: ...@@ -233,6 +289,9 @@ class Api:
return {} return {}
def skip(self):
shared.state.skip()
def get_config(self): def get_config(self):
options = {} options = {}
for key in shared.opts.data.keys(): for key in shared.opts.data.keys():
...@@ -244,14 +303,9 @@ class Api: ...@@ -244,14 +303,9 @@ class Api:
return options return options
def set_config(self, req: OptionsModel): def set_config(self, req: Dict[str, Any]):
# currently req has all options fields even if you send a dict like { "send_seed": false }, which means it will for k, v in req.items():
# overwrite all options with default values. shared.opts.set(k, v)
raise RuntimeError('Setting options via API is not supported')
reqDict = vars(req)
for o in reqDict:
setattr(shared.opts, o, reqDict[o])
shared.opts.save(shared.config_filename) shared.opts.save(shared.config_filename)
return return
...@@ -260,7 +314,7 @@ class Api: ...@@ -260,7 +314,7 @@ class Api:
return vars(shared.cmd_opts) return vars(shared.cmd_opts)
def get_samplers(self): def get_samplers(self):
return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in all_samplers] return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
def get_upscalers(self): def get_upscalers(self):
upscalers = [] upscalers = []
...@@ -272,7 +326,7 @@ class Api: ...@@ -272,7 +326,7 @@ class Api:
return upscalers return upscalers
def get_sd_models(self): def get_sd_models(self):
return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": x.config} for x in checkpoints_list.values()] return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
def get_hypernetworks(self): def get_hypernetworks(self):
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks] return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
...@@ -283,11 +337,11 @@ class Api: ...@@ -283,11 +337,11 @@ class Api:
def get_realesrgan_models(self): def get_realesrgan_models(self):
return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)] return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
def get_promp_styles(self): def get_prompt_styles(self):
styleList = [] styleList = []
for k in shared.prompt_styles.styles: for k in shared.prompt_styles.styles:
style = shared.prompt_styles.styles[k] style = shared.prompt_styles.styles[k]
styleList.append({"name":style[0], "prompt": style[1], "negative_prompr": style[2]}) styleList.append({"name":style[0], "prompt": style[1], "negative_prompt": style[2]})
return styleList return styleList
...@@ -297,6 +351,112 @@ class Api: ...@@ -297,6 +351,112 @@ class Api:
def get_artists(self): def get_artists(self):
return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists] return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
def get_embeddings(self):
db = sd_hijack.model_hijack.embedding_db
def convert_embedding(embedding):
return {
"step": embedding.step,
"sd_checkpoint": embedding.sd_checkpoint,
"sd_checkpoint_name": embedding.sd_checkpoint_name,
"shape": embedding.shape,
"vectors": embedding.vectors,
}
def convert_embeddings(embeddings):
return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
return {
"loaded": convert_embeddings(db.word_embeddings),
"skipped": convert_embeddings(db.skipped_embeddings),
}
def refresh_checkpoints(self):
shared.refresh_checkpoints()
def create_embedding(self, args: dict):
try:
shared.state.begin()
filename = create_embedding(**args) # create empty embedding
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
shared.state.end()
return CreateResponse(info = "create embedding filename: {filename}".format(filename = filename))
except AssertionError as e:
shared.state.end()
return TrainResponse(info = "create embedding error: {error}".format(error = e))
def create_hypernetwork(self, args: dict):
try:
shared.state.begin()
filename = create_hypernetwork(**args) # create empty embedding
shared.state.end()
return CreateResponse(info = "create hypernetwork filename: {filename}".format(filename = filename))
except AssertionError as e:
shared.state.end()
return TrainResponse(info = "create hypernetwork error: {error}".format(error = e))
def preprocess(self, args: dict):
try:
shared.state.begin()
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
shared.state.end()
return PreprocessResponse(info = 'preprocess complete')
except KeyError as e:
shared.state.end()
return PreprocessResponse(info = "preprocess error: invalid token: {error}".format(error = e))
except AssertionError as e:
shared.state.end()
return PreprocessResponse(info = "preprocess error: {error}".format(error = e))
except FileNotFoundError as e:
shared.state.end()
return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e))
def train_embedding(self, args: dict):
try:
shared.state.begin()
apply_optimizations = shared.opts.training_xattention_optimizations
error = None
filename = ''
if not apply_optimizations:
sd_hijack.undo_optimizations()
try:
embedding, filename = train_embedding(**args) # can take a long time to complete
except Exception as e:
error = e
finally:
if not apply_optimizations:
sd_hijack.apply_optimizations()
shared.state.end()
return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
except AssertionError as msg:
shared.state.end()
return TrainResponse(info = "train embedding error: {msg}".format(msg = msg))
def train_hypernetwork(self, args: dict):
try:
shared.state.begin()
initial_hypernetwork = shared.loaded_hypernetwork
apply_optimizations = shared.opts.training_xattention_optimizations
error = None
filename = ''
if not apply_optimizations:
sd_hijack.undo_optimizations()
try:
hypernetwork, filename = train_hypernetwork(*args)
except Exception as e:
error = e
finally:
shared.loaded_hypernetwork = initial_hypernetwork
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
if not apply_optimizations:
sd_hijack.apply_optimizations()
shared.state.end()
return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
except AssertionError as msg:
shared.state.end()
return TrainResponse(info = "train embedding error: {error}".format(error = error))
def launch(self, server_name, port): def launch(self, server_name, port):
self.app.include_router(self.router) self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port) uvicorn.run(self.app, host=server_name, port=port)
...@@ -128,7 +128,7 @@ class ExtrasBaseRequest(BaseModel): ...@@ -128,7 +128,7 @@ class ExtrasBaseRequest(BaseModel):
upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.") upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.")
upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.") upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.") upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the choosen size?") upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the chosen size?")
upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}") upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}") upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.") extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
...@@ -170,14 +170,24 @@ class ProgressResponse(BaseModel): ...@@ -170,14 +170,24 @@ class ProgressResponse(BaseModel):
class InterrogateRequest(BaseModel): class InterrogateRequest(BaseModel):
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.") image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
model: str = Field(default="clip", title="Model", description="The interrogate model used.")
class InterrogateResponse(BaseModel): class InterrogateResponse(BaseModel):
caption: str = Field(default=None, title="Caption", description="The generated caption for the image.") caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
class TrainResponse(BaseModel):
info: str = Field(title="Train info", description="Response string from train embedding or hypernetwork task.")
class CreateResponse(BaseModel):
info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.")
class PreprocessResponse(BaseModel):
info: str = Field(title="Preprocess info", description="Response string from preprocessing task.")
fields = {} fields = {}
for key, value in opts.data.items(): for key, metadata in opts.data_labels.items():
metadata = opts.data_labels.get(key) value = opts.data.get(key)
optType = opts.typemap.get(type(value), type(value)) optType = opts.typemap.get(type(metadata.default), type(value))
if (metadata is not None): if (metadata is not None):
fields.update({key: (Optional[optType], Field( fields.update({key: (Optional[optType], Field(
...@@ -239,3 +249,13 @@ class ArtistItem(BaseModel): ...@@ -239,3 +249,13 @@ class ArtistItem(BaseModel):
score: float = Field(title="Score") score: float = Field(title="Score")
category: str = Field(title="Category") category: str = Field(title="Category")
class EmbeddingItem(BaseModel):
step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
shape: int = Field(title="Shape", description="The length of each individual vector in the embedding")
vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
class EmbeddingsResponse(BaseModel):
loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
\ No newline at end of file
import html
import sys
import threading
import traceback
import time
from modules import shared
queue_lock = threading.Lock()
def wrap_queued_call(func):
def f(*args, **kwargs):
with queue_lock:
res = func(*args, **kwargs)
return res
return f
def wrap_gradio_gpu_call(func, extra_outputs=None):
def f(*args, **kwargs):
shared.state.begin()
with queue_lock:
res = func(*args, **kwargs)
shared.state.end()
return res
return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
if run_memmon:
shared.mem_mon.monitor()
t = time.perf_counter()
try:
res = list(func(*args, **kwargs))
except Exception as e:
# When printing out our debug argument list, do not print out more than a MB of text
max_debug_str_len = 131072 # (1024*1024)/8
print("Error completing request", file=sys.stderr)
argStr = f"Arguments: {str(args)} {str(kwargs)}"
print(argStr[:max_debug_str_len], file=sys.stderr)
if len(argStr) > max_debug_str_len:
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
shared.state.job = ""
shared.state.job_count = 0
if extra_outputs_array is None:
extra_outputs_array = [None, '']
res = extra_outputs_array + [f"<div class='error'>{html.escape(type(e).__name__+': '+str(e))}</div>"]
shared.state.skipped = False
shared.state.interrupted = False
shared.state.job_count = 0
if not add_stats:
return tuple(res)
elapsed = time.perf_counter() - t
elapsed_m = int(elapsed // 60)
elapsed_s = elapsed % 60
elapsed_text = f"{elapsed_s:.2f}s"
if elapsed_m > 0:
elapsed_text = f"{elapsed_m}m "+elapsed_text
if run_memmon:
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
active_peak = mem_stats['active_peak']
reserved_peak = mem_stats['reserved_peak']
sys_peak = mem_stats['system_peak']
sys_total = mem_stats['total']
sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
else:
vram_html = ''
# last item is always HTML
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
return tuple(res)
return f
...@@ -382,7 +382,7 @@ class VQAutoEncoder(nn.Module): ...@@ -382,7 +382,7 @@ class VQAutoEncoder(nn.Module):
self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
logger.info(f'vqgan is loaded from: {model_path} [params]') logger.info(f'vqgan is loaded from: {model_path} [params]')
else: else:
raise ValueError(f'Wrong params!') raise ValueError('Wrong params!')
def forward(self, x): def forward(self, x):
...@@ -431,7 +431,7 @@ class VQGANDiscriminator(nn.Module): ...@@ -431,7 +431,7 @@ class VQGANDiscriminator(nn.Module):
elif 'params' in chkpt: elif 'params' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
else: else:
raise ValueError(f'Wrong params!') raise ValueError('Wrong params!')
def forward(self, x): def forward(self, x):
return self.main(x) return self.main(x)
\ No newline at end of file
...@@ -36,6 +36,7 @@ def setup_model(dirname): ...@@ -36,6 +36,7 @@ def setup_model(dirname):
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
from basicsr.utils import imwrite, img2tensor, tensor2img from basicsr.utils import imwrite, img2tensor, tensor2img
from facelib.utils.face_restoration_helper import FaceRestoreHelper from facelib.utils.face_restoration_helper import FaceRestoreHelper
from facelib.detection.retinaface import retinaface
from modules.shared import cmd_opts from modules.shared import cmd_opts
net_class = CodeFormer net_class = CodeFormer
...@@ -65,6 +66,8 @@ def setup_model(dirname): ...@@ -65,6 +66,8 @@ def setup_model(dirname):
net.load_state_dict(checkpoint) net.load_state_dict(checkpoint)
net.eval() net.eval()
if hasattr(retinaface, 'device'):
retinaface.device = devices.device_codeformer
face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer) face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer)
self.net = net self.net = net
......
import os.path import os
from concurrent.futures import ProcessPoolExecutor
import multiprocessing
import time
import re import re
import torch
from PIL import Image
import numpy as np
from modules import modelloader, paths, deepbooru_model, devices, images, shared
re_special = re.compile(r'([\\()])') re_special = re.compile(r'([\\()])')
def get_deepbooru_tags(pil_image):
""" class DeepDanbooru:
This method is for running only one image at a time for simple use. Used to the img2img interrogate. def __init__(self):
""" self.model = None
from modules import shared # prevents circular reference
def load(self):
try: if self.model is not None:
create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, create_deepbooru_opts()) return
return get_tags_from_process(pil_image)
finally: files = modelloader.load_models(
release_process() model_path=os.path.join(paths.models_path, "torch_deepdanbooru"),
model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt',
ext_filter=[".pt"],
OPT_INCLUDE_RANKS = "include_ranks" download_name='model-resnet_custom_v3.pt',
def create_deepbooru_opts():
from modules import shared
return {
"use_spaces": shared.opts.deepbooru_use_spaces,
"use_escape": shared.opts.deepbooru_escape,
"alpha_sort": shared.opts.deepbooru_sort_alpha,
OPT_INCLUDE_RANKS: shared.opts.interrogate_return_ranks,
}
def deepbooru_process(queue, deepbooru_process_return, threshold, deepbooru_opts):
model, tags = get_deepbooru_tags_model()
while True: # while process is running, keep monitoring queue for new image
pil_image = queue.get()
if pil_image == "QUIT":
break
else:
deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts)
def create_deepbooru_process(threshold, deepbooru_opts):
"""
Creates deepbooru process. A queue is created to send images into the process. This enables multiple images
to be processed in a row without reloading the model or creating a new process. To return the data, a shared
dictionary is created to hold the tags created. To wait for tags to be returned, a value of -1 is assigned
to the dictionary and the method adding the image to the queue should wait for this value to be updated with
the tags.
"""
from modules import shared # prevents circular reference
context = multiprocessing.get_context("spawn")
shared.deepbooru_process_manager = context.Manager()
shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue()
shared.deepbooru_process_return = shared.deepbooru_process_manager.dict()
shared.deepbooru_process_return["value"] = -1
shared.deepbooru_process = context.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts))
shared.deepbooru_process.start()
def get_tags_from_process(image):
from modules import shared
shared.deepbooru_process_return["value"] = -1
shared.deepbooru_process_queue.put(image)
while shared.deepbooru_process_return["value"] == -1:
time.sleep(0.2)
caption = shared.deepbooru_process_return["value"]
shared.deepbooru_process_return["value"] = -1
return caption
def release_process():
"""
Stops the deepbooru process to return used memory
"""
from modules import shared # prevents circular reference
shared.deepbooru_process_queue.put("QUIT")
shared.deepbooru_process.join()
shared.deepbooru_process_queue = None
shared.deepbooru_process = None
shared.deepbooru_process_return = None
shared.deepbooru_process_manager = None
def get_deepbooru_tags_model():
import deepdanbooru as dd
import tensorflow as tf
import numpy as np
this_folder = os.path.dirname(__file__)
model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru'))
if not os.path.exists(os.path.join(model_path, 'project.json')):
# there is no point importing these every time
import zipfile
from basicsr.utils.download_util import load_file_from_url
load_file_from_url(
r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip",
model_path)
with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref:
zip_ref.extractall(model_path)
os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"))
tags = dd.project.load_tags_from_project(model_path)
model = dd.project.load_model_from_project(
model_path, compile_model=False
)
return model, tags
def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts):
import deepdanbooru as dd
import tensorflow as tf
import numpy as np
alpha_sort = deepbooru_opts['alpha_sort']
use_spaces = deepbooru_opts['use_spaces']
use_escape = deepbooru_opts['use_escape']
include_ranks = deepbooru_opts['include_ranks']
width = model.input_shape[2]
height = model.input_shape[1]
image = np.array(pil_image)
image = tf.image.resize(
image,
size=(height, width),
method=tf.image.ResizeMethod.AREA,
preserve_aspect_ratio=True,
) )
image = image.numpy() # EagerTensor to np.array
image = dd.image.transform_and_pad_image(image, width, height)
image = image / 255.0
image_shape = image.shape
image = image.reshape((1, image_shape[0], image_shape[1], image_shape[2]))
y = model.predict(image)[0] self.model = deepbooru_model.DeepDanbooruModel()
self.model.load_state_dict(torch.load(files[0], map_location="cpu"))
self.model.eval()
self.model.to(devices.cpu, devices.dtype)
def start(self):
self.load()
self.model.to(devices.device)
def stop(self):
if not shared.opts.interrogate_keep_models_in_memory:
self.model.to(devices.cpu)
devices.torch_gc()
def tag(self, pil_image):
self.start()
res = self.tag_multi(pil_image)
self.stop()
result_dict = {} return res
for i, tag in enumerate(tags): def tag_multi(self, pil_image, force_disable_ranks=False):
result_dict[tag] = y[i] threshold = shared.opts.interrogate_deepbooru_score_threshold
use_spaces = shared.opts.deepbooru_use_spaces
use_escape = shared.opts.deepbooru_escape
alpha_sort = shared.opts.deepbooru_sort_alpha
include_ranks = shared.opts.interrogate_return_ranks and not force_disable_ranks
pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512)
a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
with torch.no_grad(), devices.autocast():
x = torch.from_numpy(a).to(devices.device)
y = self.model(x)[0].detach().cpu().numpy()
probability_dict = {}
for tag, probability in zip(self.model.tags, y):
if probability < threshold:
continue
unsorted_tags_in_theshold = []
result_tags_print = []
for tag in tags:
if result_dict[tag] >= threshold:
if tag.startswith("rating:"): if tag.startswith("rating:"):
continue continue
unsorted_tags_in_theshold.append((result_dict[tag], tag))
result_tags_print.append(f'{result_dict[tag]} {tag}')
# sort tags probability_dict[tag] = probability
result_tags_out = []
sort_ndx = 0
if alpha_sort: if alpha_sort:
sort_ndx = 1 tags = sorted(probability_dict)
else:
tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]
res = []
# sort by reverse by likelihood and normal for alpha, and format tag text as requested filtertags = set([x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")])
unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort))
for weight, tag in unsorted_tags_in_theshold: for tag in [x for x in tags if x not in filtertags]:
probability = probability_dict[tag]
tag_outformat = tag tag_outformat = tag
if use_spaces: if use_spaces:
tag_outformat = tag_outformat.replace('_', ' ') tag_outformat = tag_outformat.replace('_', ' ')
if use_escape: if use_escape:
tag_outformat = re.sub(re_special, r'\\\1', tag_outformat) tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
if include_ranks: if include_ranks:
tag_outformat = f"({tag_outformat}:{weight:.3f})" tag_outformat = f"({tag_outformat}:{probability:.3f})"
res.append(tag_outformat)
result_tags_out.append(tag_outformat) return ", ".join(res)
print('\n'.join(sorted(result_tags_print, reverse=True)))
return ', '.join(result_tags_out) model = DeepDanbooru()
import torch
import torch.nn as nn
import torch.nn.functional as F
# see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
class DeepDanbooruModel(nn.Module):
def __init__(self):
super(DeepDanbooruModel, self).__init__()
self.tags = []
self.n_Conv_0 = nn.Conv2d(kernel_size=(7, 7), in_channels=3, out_channels=64, stride=(2, 2))
self.n_MaxPool_0 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2))
self.n_Conv_1 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
self.n_Conv_2 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=64)
self.n_Conv_3 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
self.n_Conv_4 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
self.n_Conv_5 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
self.n_Conv_6 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
self.n_Conv_7 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
self.n_Conv_8 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
self.n_Conv_9 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
self.n_Conv_10 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
self.n_Conv_11 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=512, stride=(2, 2))
self.n_Conv_12 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=128)
self.n_Conv_13 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128, stride=(2, 2))
self.n_Conv_14 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
self.n_Conv_15 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
self.n_Conv_16 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
self.n_Conv_17 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
self.n_Conv_18 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
self.n_Conv_19 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
self.n_Conv_20 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
self.n_Conv_21 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
self.n_Conv_22 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
self.n_Conv_23 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
self.n_Conv_24 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
self.n_Conv_25 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
self.n_Conv_26 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
self.n_Conv_27 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
self.n_Conv_28 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
self.n_Conv_29 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
self.n_Conv_30 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
self.n_Conv_31 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
self.n_Conv_32 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
self.n_Conv_33 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
self.n_Conv_34 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
self.n_Conv_35 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
self.n_Conv_36 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=1024, stride=(2, 2))
self.n_Conv_37 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=256)
self.n_Conv_38 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
self.n_Conv_39 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_40 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_41 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_42 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_43 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_44 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_45 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_46 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_47 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_48 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_49 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_50 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_51 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_52 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_53 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_54 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_55 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_56 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_57 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_58 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_59 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_60 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_61 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_62 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_63 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_64 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_65 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_66 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_67 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_68 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_69 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_70 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_71 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_72 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_73 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_74 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_75 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_76 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_77 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_78 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_79 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_80 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_81 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_82 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_83 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_84 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_85 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_86 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_87 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_88 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_89 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_90 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_91 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_92 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_93 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_94 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_95 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_96 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_97 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_98 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
self.n_Conv_99 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_100 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=1024, stride=(2, 2))
self.n_Conv_101 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_102 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_103 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_104 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_105 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_106 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_107 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_108 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_109 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_110 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_111 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_112 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_113 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_114 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_115 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_116 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_117 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_118 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_119 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_120 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_121 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_122 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_123 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_124 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_125 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_126 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_127 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_128 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_129 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_130 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_131 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_132 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_133 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_134 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_135 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_136 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_137 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_138 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_139 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_140 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_141 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_142 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_143 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_144 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_145 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_146 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_147 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_148 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_149 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_150 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_151 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_152 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_153 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_154 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_155 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_156 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_157 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_158 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=2048, stride=(2, 2))
self.n_Conv_159 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=512)
self.n_Conv_160 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512, stride=(2, 2))
self.n_Conv_161 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
self.n_Conv_162 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
self.n_Conv_163 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
self.n_Conv_164 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
self.n_Conv_165 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
self.n_Conv_166 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
self.n_Conv_167 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
self.n_Conv_168 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=4096, stride=(2, 2))
self.n_Conv_169 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=1024)
self.n_Conv_170 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024, stride=(2, 2))
self.n_Conv_171 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
self.n_Conv_172 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
self.n_Conv_173 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
self.n_Conv_174 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
self.n_Conv_175 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
self.n_Conv_176 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
self.n_Conv_177 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
self.n_Conv_178 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=9176, bias=False)
def forward(self, *inputs):
t_358, = inputs
t_359 = t_358.permute(*[0, 3, 1, 2])
t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
t_360 = self.n_Conv_0(t_359_padded)
t_361 = F.relu(t_360)
t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
t_362 = self.n_MaxPool_0(t_361)
t_363 = self.n_Conv_1(t_362)
t_364 = self.n_Conv_2(t_362)
t_365 = F.relu(t_364)
t_365_padded = F.pad(t_365, [1, 1, 1, 1], value=0)
t_366 = self.n_Conv_3(t_365_padded)
t_367 = F.relu(t_366)
t_368 = self.n_Conv_4(t_367)
t_369 = torch.add(t_368, t_363)
t_370 = F.relu(t_369)
t_371 = self.n_Conv_5(t_370)
t_372 = F.relu(t_371)
t_372_padded = F.pad(t_372, [1, 1, 1, 1], value=0)
t_373 = self.n_Conv_6(t_372_padded)
t_374 = F.relu(t_373)
t_375 = self.n_Conv_7(t_374)
t_376 = torch.add(t_375, t_370)
t_377 = F.relu(t_376)
t_378 = self.n_Conv_8(t_377)
t_379 = F.relu(t_378)
t_379_padded = F.pad(t_379, [1, 1, 1, 1], value=0)
t_380 = self.n_Conv_9(t_379_padded)
t_381 = F.relu(t_380)
t_382 = self.n_Conv_10(t_381)
t_383 = torch.add(t_382, t_377)
t_384 = F.relu(t_383)
t_385 = self.n_Conv_11(t_384)
t_386 = self.n_Conv_12(t_384)
t_387 = F.relu(t_386)
t_387_padded = F.pad(t_387, [0, 1, 0, 1], value=0)
t_388 = self.n_Conv_13(t_387_padded)
t_389 = F.relu(t_388)
t_390 = self.n_Conv_14(t_389)
t_391 = torch.add(t_390, t_385)
t_392 = F.relu(t_391)
t_393 = self.n_Conv_15(t_392)
t_394 = F.relu(t_393)
t_394_padded = F.pad(t_394, [1, 1, 1, 1], value=0)
t_395 = self.n_Conv_16(t_394_padded)
t_396 = F.relu(t_395)
t_397 = self.n_Conv_17(t_396)
t_398 = torch.add(t_397, t_392)
t_399 = F.relu(t_398)
t_400 = self.n_Conv_18(t_399)
t_401 = F.relu(t_400)
t_401_padded = F.pad(t_401, [1, 1, 1, 1], value=0)
t_402 = self.n_Conv_19(t_401_padded)
t_403 = F.relu(t_402)
t_404 = self.n_Conv_20(t_403)
t_405 = torch.add(t_404, t_399)
t_406 = F.relu(t_405)
t_407 = self.n_Conv_21(t_406)
t_408 = F.relu(t_407)
t_408_padded = F.pad(t_408, [1, 1, 1, 1], value=0)
t_409 = self.n_Conv_22(t_408_padded)
t_410 = F.relu(t_409)
t_411 = self.n_Conv_23(t_410)
t_412 = torch.add(t_411, t_406)
t_413 = F.relu(t_412)
t_414 = self.n_Conv_24(t_413)
t_415 = F.relu(t_414)
t_415_padded = F.pad(t_415, [1, 1, 1, 1], value=0)
t_416 = self.n_Conv_25(t_415_padded)
t_417 = F.relu(t_416)
t_418 = self.n_Conv_26(t_417)
t_419 = torch.add(t_418, t_413)
t_420 = F.relu(t_419)
t_421 = self.n_Conv_27(t_420)
t_422 = F.relu(t_421)
t_422_padded = F.pad(t_422, [1, 1, 1, 1], value=0)
t_423 = self.n_Conv_28(t_422_padded)
t_424 = F.relu(t_423)
t_425 = self.n_Conv_29(t_424)
t_426 = torch.add(t_425, t_420)
t_427 = F.relu(t_426)
t_428 = self.n_Conv_30(t_427)
t_429 = F.relu(t_428)
t_429_padded = F.pad(t_429, [1, 1, 1, 1], value=0)
t_430 = self.n_Conv_31(t_429_padded)
t_431 = F.relu(t_430)
t_432 = self.n_Conv_32(t_431)
t_433 = torch.add(t_432, t_427)
t_434 = F.relu(t_433)
t_435 = self.n_Conv_33(t_434)
t_436 = F.relu(t_435)
t_436_padded = F.pad(t_436, [1, 1, 1, 1], value=0)
t_437 = self.n_Conv_34(t_436_padded)
t_438 = F.relu(t_437)
t_439 = self.n_Conv_35(t_438)
t_440 = torch.add(t_439, t_434)
t_441 = F.relu(t_440)
t_442 = self.n_Conv_36(t_441)
t_443 = self.n_Conv_37(t_441)
t_444 = F.relu(t_443)
t_444_padded = F.pad(t_444, [0, 1, 0, 1], value=0)
t_445 = self.n_Conv_38(t_444_padded)
t_446 = F.relu(t_445)
t_447 = self.n_Conv_39(t_446)
t_448 = torch.add(t_447, t_442)
t_449 = F.relu(t_448)
t_450 = self.n_Conv_40(t_449)
t_451 = F.relu(t_450)
t_451_padded = F.pad(t_451, [1, 1, 1, 1], value=0)
t_452 = self.n_Conv_41(t_451_padded)
t_453 = F.relu(t_452)
t_454 = self.n_Conv_42(t_453)
t_455 = torch.add(t_454, t_449)
t_456 = F.relu(t_455)
t_457 = self.n_Conv_43(t_456)
t_458 = F.relu(t_457)
t_458_padded = F.pad(t_458, [1, 1, 1, 1], value=0)
t_459 = self.n_Conv_44(t_458_padded)
t_460 = F.relu(t_459)
t_461 = self.n_Conv_45(t_460)
t_462 = torch.add(t_461, t_456)
t_463 = F.relu(t_462)
t_464 = self.n_Conv_46(t_463)
t_465 = F.relu(t_464)
t_465_padded = F.pad(t_465, [1, 1, 1, 1], value=0)
t_466 = self.n_Conv_47(t_465_padded)
t_467 = F.relu(t_466)
t_468 = self.n_Conv_48(t_467)
t_469 = torch.add(t_468, t_463)
t_470 = F.relu(t_469)
t_471 = self.n_Conv_49(t_470)
t_472 = F.relu(t_471)
t_472_padded = F.pad(t_472, [1, 1, 1, 1], value=0)
t_473 = self.n_Conv_50(t_472_padded)
t_474 = F.relu(t_473)
t_475 = self.n_Conv_51(t_474)
t_476 = torch.add(t_475, t_470)
t_477 = F.relu(t_476)
t_478 = self.n_Conv_52(t_477)
t_479 = F.relu(t_478)
t_479_padded = F.pad(t_479, [1, 1, 1, 1], value=0)
t_480 = self.n_Conv_53(t_479_padded)
t_481 = F.relu(t_480)
t_482 = self.n_Conv_54(t_481)
t_483 = torch.add(t_482, t_477)
t_484 = F.relu(t_483)
t_485 = self.n_Conv_55(t_484)
t_486 = F.relu(t_485)
t_486_padded = F.pad(t_486, [1, 1, 1, 1], value=0)
t_487 = self.n_Conv_56(t_486_padded)
t_488 = F.relu(t_487)
t_489 = self.n_Conv_57(t_488)
t_490 = torch.add(t_489, t_484)
t_491 = F.relu(t_490)
t_492 = self.n_Conv_58(t_491)
t_493 = F.relu(t_492)
t_493_padded = F.pad(t_493, [1, 1, 1, 1], value=0)
t_494 = self.n_Conv_59(t_493_padded)
t_495 = F.relu(t_494)
t_496 = self.n_Conv_60(t_495)
t_497 = torch.add(t_496, t_491)
t_498 = F.relu(t_497)
t_499 = self.n_Conv_61(t_498)
t_500 = F.relu(t_499)
t_500_padded = F.pad(t_500, [1, 1, 1, 1], value=0)
t_501 = self.n_Conv_62(t_500_padded)
t_502 = F.relu(t_501)
t_503 = self.n_Conv_63(t_502)
t_504 = torch.add(t_503, t_498)
t_505 = F.relu(t_504)
t_506 = self.n_Conv_64(t_505)
t_507 = F.relu(t_506)
t_507_padded = F.pad(t_507, [1, 1, 1, 1], value=0)
t_508 = self.n_Conv_65(t_507_padded)
t_509 = F.relu(t_508)
t_510 = self.n_Conv_66(t_509)
t_511 = torch.add(t_510, t_505)
t_512 = F.relu(t_511)
t_513 = self.n_Conv_67(t_512)
t_514 = F.relu(t_513)
t_514_padded = F.pad(t_514, [1, 1, 1, 1], value=0)
t_515 = self.n_Conv_68(t_514_padded)
t_516 = F.relu(t_515)
t_517 = self.n_Conv_69(t_516)
t_518 = torch.add(t_517, t_512)
t_519 = F.relu(t_518)
t_520 = self.n_Conv_70(t_519)
t_521 = F.relu(t_520)
t_521_padded = F.pad(t_521, [1, 1, 1, 1], value=0)
t_522 = self.n_Conv_71(t_521_padded)
t_523 = F.relu(t_522)
t_524 = self.n_Conv_72(t_523)
t_525 = torch.add(t_524, t_519)
t_526 = F.relu(t_525)
t_527 = self.n_Conv_73(t_526)
t_528 = F.relu(t_527)
t_528_padded = F.pad(t_528, [1, 1, 1, 1], value=0)
t_529 = self.n_Conv_74(t_528_padded)
t_530 = F.relu(t_529)
t_531 = self.n_Conv_75(t_530)
t_532 = torch.add(t_531, t_526)
t_533 = F.relu(t_532)
t_534 = self.n_Conv_76(t_533)
t_535 = F.relu(t_534)
t_535_padded = F.pad(t_535, [1, 1, 1, 1], value=0)
t_536 = self.n_Conv_77(t_535_padded)
t_537 = F.relu(t_536)
t_538 = self.n_Conv_78(t_537)
t_539 = torch.add(t_538, t_533)
t_540 = F.relu(t_539)
t_541 = self.n_Conv_79(t_540)
t_542 = F.relu(t_541)
t_542_padded = F.pad(t_542, [1, 1, 1, 1], value=0)
t_543 = self.n_Conv_80(t_542_padded)
t_544 = F.relu(t_543)
t_545 = self.n_Conv_81(t_544)
t_546 = torch.add(t_545, t_540)
t_547 = F.relu(t_546)
t_548 = self.n_Conv_82(t_547)
t_549 = F.relu(t_548)
t_549_padded = F.pad(t_549, [1, 1, 1, 1], value=0)
t_550 = self.n_Conv_83(t_549_padded)
t_551 = F.relu(t_550)
t_552 = self.n_Conv_84(t_551)
t_553 = torch.add(t_552, t_547)
t_554 = F.relu(t_553)
t_555 = self.n_Conv_85(t_554)
t_556 = F.relu(t_555)
t_556_padded = F.pad(t_556, [1, 1, 1, 1], value=0)
t_557 = self.n_Conv_86(t_556_padded)
t_558 = F.relu(t_557)
t_559 = self.n_Conv_87(t_558)
t_560 = torch.add(t_559, t_554)
t_561 = F.relu(t_560)
t_562 = self.n_Conv_88(t_561)
t_563 = F.relu(t_562)
t_563_padded = F.pad(t_563, [1, 1, 1, 1], value=0)
t_564 = self.n_Conv_89(t_563_padded)
t_565 = F.relu(t_564)
t_566 = self.n_Conv_90(t_565)
t_567 = torch.add(t_566, t_561)
t_568 = F.relu(t_567)
t_569 = self.n_Conv_91(t_568)
t_570 = F.relu(t_569)
t_570_padded = F.pad(t_570, [1, 1, 1, 1], value=0)
t_571 = self.n_Conv_92(t_570_padded)
t_572 = F.relu(t_571)
t_573 = self.n_Conv_93(t_572)
t_574 = torch.add(t_573, t_568)
t_575 = F.relu(t_574)
t_576 = self.n_Conv_94(t_575)
t_577 = F.relu(t_576)
t_577_padded = F.pad(t_577, [1, 1, 1, 1], value=0)
t_578 = self.n_Conv_95(t_577_padded)
t_579 = F.relu(t_578)
t_580 = self.n_Conv_96(t_579)
t_581 = torch.add(t_580, t_575)
t_582 = F.relu(t_581)
t_583 = self.n_Conv_97(t_582)
t_584 = F.relu(t_583)
t_584_padded = F.pad(t_584, [0, 1, 0, 1], value=0)
t_585 = self.n_Conv_98(t_584_padded)
t_586 = F.relu(t_585)
t_587 = self.n_Conv_99(t_586)
t_588 = self.n_Conv_100(t_582)
t_589 = torch.add(t_587, t_588)
t_590 = F.relu(t_589)
t_591 = self.n_Conv_101(t_590)
t_592 = F.relu(t_591)
t_592_padded = F.pad(t_592, [1, 1, 1, 1], value=0)
t_593 = self.n_Conv_102(t_592_padded)
t_594 = F.relu(t_593)
t_595 = self.n_Conv_103(t_594)
t_596 = torch.add(t_595, t_590)
t_597 = F.relu(t_596)
t_598 = self.n_Conv_104(t_597)
t_599 = F.relu(t_598)
t_599_padded = F.pad(t_599, [1, 1, 1, 1], value=0)
t_600 = self.n_Conv_105(t_599_padded)
t_601 = F.relu(t_600)
t_602 = self.n_Conv_106(t_601)
t_603 = torch.add(t_602, t_597)
t_604 = F.relu(t_603)
t_605 = self.n_Conv_107(t_604)
t_606 = F.relu(t_605)
t_606_padded = F.pad(t_606, [1, 1, 1, 1], value=0)
t_607 = self.n_Conv_108(t_606_padded)
t_608 = F.relu(t_607)
t_609 = self.n_Conv_109(t_608)
t_610 = torch.add(t_609, t_604)
t_611 = F.relu(t_610)
t_612 = self.n_Conv_110(t_611)
t_613 = F.relu(t_612)
t_613_padded = F.pad(t_613, [1, 1, 1, 1], value=0)
t_614 = self.n_Conv_111(t_613_padded)
t_615 = F.relu(t_614)
t_616 = self.n_Conv_112(t_615)
t_617 = torch.add(t_616, t_611)
t_618 = F.relu(t_617)
t_619 = self.n_Conv_113(t_618)
t_620 = F.relu(t_619)
t_620_padded = F.pad(t_620, [1, 1, 1, 1], value=0)
t_621 = self.n_Conv_114(t_620_padded)
t_622 = F.relu(t_621)
t_623 = self.n_Conv_115(t_622)
t_624 = torch.add(t_623, t_618)
t_625 = F.relu(t_624)
t_626 = self.n_Conv_116(t_625)
t_627 = F.relu(t_626)
t_627_padded = F.pad(t_627, [1, 1, 1, 1], value=0)
t_628 = self.n_Conv_117(t_627_padded)
t_629 = F.relu(t_628)
t_630 = self.n_Conv_118(t_629)
t_631 = torch.add(t_630, t_625)
t_632 = F.relu(t_631)
t_633 = self.n_Conv_119(t_632)
t_634 = F.relu(t_633)
t_634_padded = F.pad(t_634, [1, 1, 1, 1], value=0)
t_635 = self.n_Conv_120(t_634_padded)
t_636 = F.relu(t_635)
t_637 = self.n_Conv_121(t_636)
t_638 = torch.add(t_637, t_632)
t_639 = F.relu(t_638)
t_640 = self.n_Conv_122(t_639)
t_641 = F.relu(t_640)
t_641_padded = F.pad(t_641, [1, 1, 1, 1], value=0)
t_642 = self.n_Conv_123(t_641_padded)
t_643 = F.relu(t_642)
t_644 = self.n_Conv_124(t_643)
t_645 = torch.add(t_644, t_639)
t_646 = F.relu(t_645)
t_647 = self.n_Conv_125(t_646)
t_648 = F.relu(t_647)
t_648_padded = F.pad(t_648, [1, 1, 1, 1], value=0)
t_649 = self.n_Conv_126(t_648_padded)
t_650 = F.relu(t_649)
t_651 = self.n_Conv_127(t_650)
t_652 = torch.add(t_651, t_646)
t_653 = F.relu(t_652)
t_654 = self.n_Conv_128(t_653)
t_655 = F.relu(t_654)
t_655_padded = F.pad(t_655, [1, 1, 1, 1], value=0)
t_656 = self.n_Conv_129(t_655_padded)
t_657 = F.relu(t_656)
t_658 = self.n_Conv_130(t_657)
t_659 = torch.add(t_658, t_653)
t_660 = F.relu(t_659)
t_661 = self.n_Conv_131(t_660)
t_662 = F.relu(t_661)
t_662_padded = F.pad(t_662, [1, 1, 1, 1], value=0)
t_663 = self.n_Conv_132(t_662_padded)
t_664 = F.relu(t_663)
t_665 = self.n_Conv_133(t_664)
t_666 = torch.add(t_665, t_660)
t_667 = F.relu(t_666)
t_668 = self.n_Conv_134(t_667)
t_669 = F.relu(t_668)
t_669_padded = F.pad(t_669, [1, 1, 1, 1], value=0)
t_670 = self.n_Conv_135(t_669_padded)
t_671 = F.relu(t_670)
t_672 = self.n_Conv_136(t_671)
t_673 = torch.add(t_672, t_667)
t_674 = F.relu(t_673)
t_675 = self.n_Conv_137(t_674)
t_676 = F.relu(t_675)
t_676_padded = F.pad(t_676, [1, 1, 1, 1], value=0)
t_677 = self.n_Conv_138(t_676_padded)
t_678 = F.relu(t_677)
t_679 = self.n_Conv_139(t_678)
t_680 = torch.add(t_679, t_674)
t_681 = F.relu(t_680)
t_682 = self.n_Conv_140(t_681)
t_683 = F.relu(t_682)
t_683_padded = F.pad(t_683, [1, 1, 1, 1], value=0)
t_684 = self.n_Conv_141(t_683_padded)
t_685 = F.relu(t_684)
t_686 = self.n_Conv_142(t_685)
t_687 = torch.add(t_686, t_681)
t_688 = F.relu(t_687)
t_689 = self.n_Conv_143(t_688)
t_690 = F.relu(t_689)
t_690_padded = F.pad(t_690, [1, 1, 1, 1], value=0)
t_691 = self.n_Conv_144(t_690_padded)
t_692 = F.relu(t_691)
t_693 = self.n_Conv_145(t_692)
t_694 = torch.add(t_693, t_688)
t_695 = F.relu(t_694)
t_696 = self.n_Conv_146(t_695)
t_697 = F.relu(t_696)
t_697_padded = F.pad(t_697, [1, 1, 1, 1], value=0)
t_698 = self.n_Conv_147(t_697_padded)
t_699 = F.relu(t_698)
t_700 = self.n_Conv_148(t_699)
t_701 = torch.add(t_700, t_695)
t_702 = F.relu(t_701)
t_703 = self.n_Conv_149(t_702)
t_704 = F.relu(t_703)
t_704_padded = F.pad(t_704, [1, 1, 1, 1], value=0)
t_705 = self.n_Conv_150(t_704_padded)
t_706 = F.relu(t_705)
t_707 = self.n_Conv_151(t_706)
t_708 = torch.add(t_707, t_702)
t_709 = F.relu(t_708)
t_710 = self.n_Conv_152(t_709)
t_711 = F.relu(t_710)
t_711_padded = F.pad(t_711, [1, 1, 1, 1], value=0)
t_712 = self.n_Conv_153(t_711_padded)
t_713 = F.relu(t_712)
t_714 = self.n_Conv_154(t_713)
t_715 = torch.add(t_714, t_709)
t_716 = F.relu(t_715)
t_717 = self.n_Conv_155(t_716)
t_718 = F.relu(t_717)
t_718_padded = F.pad(t_718, [1, 1, 1, 1], value=0)
t_719 = self.n_Conv_156(t_718_padded)
t_720 = F.relu(t_719)
t_721 = self.n_Conv_157(t_720)
t_722 = torch.add(t_721, t_716)
t_723 = F.relu(t_722)
t_724 = self.n_Conv_158(t_723)
t_725 = self.n_Conv_159(t_723)
t_726 = F.relu(t_725)
t_726_padded = F.pad(t_726, [0, 1, 0, 1], value=0)
t_727 = self.n_Conv_160(t_726_padded)
t_728 = F.relu(t_727)
t_729 = self.n_Conv_161(t_728)
t_730 = torch.add(t_729, t_724)
t_731 = F.relu(t_730)
t_732 = self.n_Conv_162(t_731)
t_733 = F.relu(t_732)
t_733_padded = F.pad(t_733, [1, 1, 1, 1], value=0)
t_734 = self.n_Conv_163(t_733_padded)
t_735 = F.relu(t_734)
t_736 = self.n_Conv_164(t_735)
t_737 = torch.add(t_736, t_731)
t_738 = F.relu(t_737)
t_739 = self.n_Conv_165(t_738)
t_740 = F.relu(t_739)
t_740_padded = F.pad(t_740, [1, 1, 1, 1], value=0)
t_741 = self.n_Conv_166(t_740_padded)
t_742 = F.relu(t_741)
t_743 = self.n_Conv_167(t_742)
t_744 = torch.add(t_743, t_738)
t_745 = F.relu(t_744)
t_746 = self.n_Conv_168(t_745)
t_747 = self.n_Conv_169(t_745)
t_748 = F.relu(t_747)
t_748_padded = F.pad(t_748, [0, 1, 0, 1], value=0)
t_749 = self.n_Conv_170(t_748_padded)
t_750 = F.relu(t_749)
t_751 = self.n_Conv_171(t_750)
t_752 = torch.add(t_751, t_746)
t_753 = F.relu(t_752)
t_754 = self.n_Conv_172(t_753)
t_755 = F.relu(t_754)
t_755_padded = F.pad(t_755, [1, 1, 1, 1], value=0)
t_756 = self.n_Conv_173(t_755_padded)
t_757 = F.relu(t_756)
t_758 = self.n_Conv_174(t_757)
t_759 = torch.add(t_758, t_753)
t_760 = F.relu(t_759)
t_761 = self.n_Conv_175(t_760)
t_762 = F.relu(t_761)
t_762_padded = F.pad(t_762, [1, 1, 1, 1], value=0)
t_763 = self.n_Conv_176(t_762_padded)
t_764 = F.relu(t_763)
t_765 = self.n_Conv_177(t_764)
t_766 = torch.add(t_765, t_760)
t_767 = F.relu(t_766)
t_768 = self.n_Conv_178(t_767)
t_769 = F.avg_pool2d(t_768, kernel_size=t_768.shape[-2:])
t_770 = torch.squeeze(t_769, 3)
t_770 = torch.squeeze(t_770, 2)
t_771 = torch.sigmoid(t_770)
return t_771
def load_state_dict(self, state_dict, **kwargs):
self.tags = state_dict.get('tags', [])
super(DeepDanbooruModel, self).load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'})
...@@ -2,72 +2,95 @@ import sys, os, shlex ...@@ -2,72 +2,95 @@ import sys, os, shlex
import contextlib import contextlib
import torch import torch
from modules import errors from modules import errors
from packaging import version
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
has_mps = getattr(torch, 'has_mps', False)
cpu = torch.device("cpu") # has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
# check `getattr` and try it for compatibility
def has_mps() -> bool:
if not getattr(torch, 'has_mps', False):
return False
try:
torch.zeros(1).to(torch.device("mps"))
return True
except Exception:
return False
def extract_device_id(args, name): def extract_device_id(args, name):
for x in range(len(args)): for x in range(len(args)):
if name in args[x]: return args[x+1] if name in args[x]:
return args[x + 1]
return None return None
def get_optimal_device():
if torch.cuda.is_available(): def get_cuda_device_string():
from modules import shared from modules import shared
device_id = shared.cmd_opts.device_id if shared.cmd_opts.device_id is not None:
return f"cuda:{shared.cmd_opts.device_id}"
return "cuda"
if device_id is not None:
cuda_device = f"cuda:{device_id}"
return torch.device(cuda_device)
else:
return torch.device("cuda")
if has_mps: def get_optimal_device():
if torch.cuda.is_available():
return torch.device(get_cuda_device_string())
if has_mps():
return torch.device("mps") return torch.device("mps")
return cpu return cpu
def get_device_for(task):
from modules import shared
if task in shared.cmd_opts.use_cpu:
return cpu
return get_optimal_device()
def torch_gc(): def torch_gc():
if torch.cuda.is_available(): if torch.cuda.is_available():
with torch.cuda.device(get_cuda_device_string()):
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.ipc_collect() torch.cuda.ipc_collect()
def enable_tf32(): def enable_tf32():
if torch.cuda.is_available(): if torch.cuda.is_available():
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
if any([torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())]):
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
errors.run(enable_tf32, "Enabling TF32") errors.run(enable_tf32, "Enabling TF32")
device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None cpu = torch.device("cpu")
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
dtype = torch.float16 dtype = torch.float16
dtype_vae = torch.float16 dtype_vae = torch.float16
def randn(seed, shape):
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
if device.type == 'mps':
generator = torch.Generator(device=cpu)
generator.manual_seed(seed)
noise = torch.randn(shape, generator=generator, device=cpu).to(device)
return noise
def randn(seed, shape):
torch.manual_seed(seed) torch.manual_seed(seed)
if device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device) return torch.randn(shape, device=device)
def randn_without_seed(shape): def randn_without_seed(shape):
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
if device.type == 'mps': if device.type == 'mps':
generator = torch.Generator(device=cpu) return torch.randn(shape, device=cpu).to(device)
noise = torch.randn(shape, generator=generator, device=cpu).to(device)
return noise
return torch.randn(shape, device=device) return torch.randn(shape, device=device)
...@@ -82,6 +105,36 @@ def autocast(disable=False): ...@@ -82,6 +105,36 @@ def autocast(disable=False):
return torch.autocast("cuda") return torch.autocast("cuda")
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383 # MPS workaround for https://github.com/pytorch/pytorch/issues/79383
def mps_contiguous(input_tensor, device): return input_tensor.contiguous() if device.type == 'mps' else input_tensor orig_tensor_to = torch.Tensor.to
def mps_contiguous_to(input_tensor, device): return mps_contiguous(input_tensor, device).to(device) def tensor_to_fix(self, *args, **kwargs):
if self.device.type != 'mps' and \
((len(args) > 0 and isinstance(args[0], torch.device) and args[0].type == 'mps') or \
(isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')):
self = self.contiguous()
return orig_tensor_to(self, *args, **kwargs)
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
orig_layer_norm = torch.nn.functional.layer_norm
def layer_norm_fix(*args, **kwargs):
if len(args) > 0 and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps':
args = list(args)
args[0] = args[0].contiguous()
return orig_layer_norm(*args, **kwargs)
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
orig_tensor_numpy = torch.Tensor.numpy
def numpy_fix(self, *args, **kwargs):
if self.requires_grad:
self = self.detach()
return orig_tensor_numpy(self, *args, **kwargs)
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
if has_mps() and version.parse(torch.__version__) < version.parse("1.13"):
torch.Tensor.to = tensor_to_fix
torch.nn.functional.layer_norm = layer_norm_fix
torch.Tensor.numpy = numpy_fix
...@@ -2,9 +2,30 @@ import sys ...@@ -2,9 +2,30 @@ import sys
import traceback import traceback
def print_error_explanation(message):
lines = message.strip().split("\n")
max_len = max([len(x) for x in lines])
print('=' * max_len, file=sys.stderr)
for line in lines:
print(line, file=sys.stderr)
print('=' * max_len, file=sys.stderr)
def display(e: Exception, task):
print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
message = str(e)
if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
print_error_explanation("""
The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its connfig file.
See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
""")
def run(code, task): def run(code, task):
try: try:
code() code()
except Exception as e: except Exception as e:
print(f"{task}: {type(e).__name__}", file=sys.stderr) display(task, e)
print(traceback.format_exc(), file=sys.stderr)
...@@ -199,7 +199,7 @@ def upscale_without_tiling(model, img): ...@@ -199,7 +199,7 @@ def upscale_without_tiling(model, img):
img = img[:, :, ::-1] img = img[:, :, ::-1]
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
img = torch.from_numpy(img).float() img = torch.from_numpy(img).float()
img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_esrgan) img = img.unsqueeze(0).to(devices.device_esrgan)
with torch.no_grad(): with torch.no_grad():
output = model(img) output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy() output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
......
...@@ -6,9 +6,9 @@ import git ...@@ -6,9 +6,9 @@ import git
from modules import paths, shared from modules import paths, shared
extensions = [] extensions = []
extensions_dir = os.path.join(paths.script_path, "extensions") extensions_dir = os.path.join(paths.script_path, "extensions")
extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")
def active(): def active():
...@@ -16,12 +16,13 @@ def active(): ...@@ -16,12 +16,13 @@ def active():
class Extension: class Extension:
def __init__(self, name, path, enabled=True): def __init__(self, name, path, enabled=True, is_builtin=False):
self.name = name self.name = name
self.path = path self.path = path
self.enabled = enabled self.enabled = enabled
self.status = '' self.status = ''
self.can_update = False self.can_update = False
self.is_builtin = is_builtin
repo = None repo = None
try: try:
...@@ -66,9 +67,12 @@ class Extension: ...@@ -66,9 +67,12 @@ class Extension:
self.can_update = False self.can_update = False
self.status = "latest" self.status = "latest"
def pull(self): def fetch_and_reset_hard(self):
repo = git.Repo(self.path) repo = git.Repo(self.path)
repo.remotes.origin.pull() # Fix: `error: Your local changes to the following files would be overwritten by merge`,
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
repo.git.fetch('--all')
repo.git.reset('--hard', 'origin')
def list_extensions(): def list_extensions():
...@@ -77,10 +81,19 @@ def list_extensions(): ...@@ -77,10 +81,19 @@ def list_extensions():
if not os.path.isdir(extensions_dir): if not os.path.isdir(extensions_dir):
return return
for dirname in sorted(os.listdir(extensions_dir)): paths = []
path = os.path.join(extensions_dir, dirname) for dirname in [extensions_dir, extensions_builtin_dir]:
if not os.path.isdir(dirname):
return
for extension_dirname in sorted(os.listdir(dirname)):
path = os.path.join(dirname, extension_dirname)
if not os.path.isdir(path): if not os.path.isdir(path):
continue continue
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions) paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
for dirname, path, is_builtin in paths:
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
extensions.append(extension) extensions.append(extension)
from __future__ import annotations from __future__ import annotations
import math import math
import os import os
import sys
import traceback
import numpy as np import numpy as np
from PIL import Image from PIL import Image
...@@ -12,15 +14,13 @@ from typing import Callable, List, OrderedDict, Tuple ...@@ -12,15 +14,13 @@ from typing import Callable, List, OrderedDict, Tuple
from functools import partial from functools import partial
from dataclasses import dataclass from dataclasses import dataclass
from modules import processing, shared, images, devices, sd_models from modules import processing, shared, images, devices, sd_models, sd_samplers
from modules.shared import opts from modules.shared import opts
import modules.gfpgan_model import modules.gfpgan_model
from modules.ui import plaintext_to_html from modules.ui import plaintext_to_html
import modules.codeformer_model import modules.codeformer_model
import piexif
import piexif.helper
import gradio as gr import gradio as gr
import safetensors.torch
class LruCache(OrderedDict): class LruCache(OrderedDict):
@dataclass(frozen=True) @dataclass(frozen=True)
...@@ -53,9 +53,12 @@ class LruCache(OrderedDict): ...@@ -53,9 +53,12 @@ class LruCache(OrderedDict):
cached_images: LruCache = LruCache(max_size=5) cached_images: LruCache = LruCache(max_size=5)
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool): def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
devices.torch_gc() devices.torch_gc()
shared.state.begin()
shared.state.job = 'extras'
imageArr = [] imageArr = []
# Also keep track of original file names # Also keep track of original file names
imageNameArr = [] imageNameArr = []
...@@ -92,6 +95,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -92,6 +95,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
# Extra operation definitions # Extra operation definitions
def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]: def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
shared.state.job = 'extras-gfpgan'
restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
res = Image.fromarray(restored_img) res = Image.fromarray(restored_img)
...@@ -102,6 +106,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -102,6 +106,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
return (res, info) return (res, info)
def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]: def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
shared.state.job = 'extras-codeformer'
restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight) restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
res = Image.fromarray(restored_img) res = Image.fromarray(restored_img)
...@@ -112,6 +117,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -112,6 +117,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
return (res, info) return (res, info)
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
shared.state.job = 'extras-upscale'
upscaler = shared.sd_upscalers[scaler_index] upscaler = shared.sd_upscalers[scaler_index]
res = upscaler.scaler.upscale(image, resize, upscaler.data_path) res = upscaler.scaler.upscale(image, resize, upscaler.data_path)
if mode == 1 and crop: if mode == 1 and crop:
...@@ -178,6 +184,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -178,6 +184,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
for image, image_name in zip(imageArr, imageNameArr): for image, image_name in zip(imageArr, imageNameArr):
if image is None: if image is None:
return outputs, "Please select an input image.", '' return outputs, "Please select an input image.", ''
shared.state.textinfo = f'Processing image {image_name}'
existing_pnginfo = image.info or {} existing_pnginfo = image.info or {}
image = image.convert("RGB") image = image.convert("RGB")
...@@ -186,18 +195,25 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -186,18 +195,25 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
for op in extras_ops: for op in extras_ops:
image, info = op(image, info) image, info = op(image, info)
if opts.use_original_name_batch and image_name != None: if opts.use_original_name_batch and image_name is not None:
basename = os.path.splitext(os.path.basename(image_name))[0] basename = os.path.splitext(os.path.basename(image_name))[0]
else: else:
basename = '' basename = ''
images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, if opts.enable_pnginfo: # append info before save
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
if opts.enable_pnginfo:
image.info = existing_pnginfo image.info = existing_pnginfo
image.info["extras"] = info image.info["extras"] = info
if save_output:
# Add upscaler name as a suffix.
suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else ""
# Add second upscaler if applicable.
if suffix and extras_upscaler_2 and extras_upscaler_2_visibility:
suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}"
images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix)
if extras_mode != 2 or show_extras_results : if extras_mode != 2 or show_extras_results :
outputs.append(image) outputs.append(image)
...@@ -213,25 +229,8 @@ def run_pnginfo(image): ...@@ -213,25 +229,8 @@ def run_pnginfo(image):
if image is None: if image is None:
return '', '', '' return '', '', ''
items = image.info geninfo, items = images.read_info_from_image(image)
geninfo = '' items = {**{'parameters': geninfo}, **items}
if "exif" in image.info:
exif = piexif.load(image.info["exif"])
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
try:
exif_comment = piexif.helper.UserComment.load(exif_comment)
except ValueError:
exif_comment = exif_comment.decode('utf8', errors="ignore")
items['exif comment'] = exif_comment
geninfo = exif_comment
for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
'loop', 'background', 'timestamp', 'duration']:
items.pop(field, None)
geninfo = items.get('parameters', geninfo)
info = '' info = ''
for key, text in items.items(): for key, text in items.items():
...@@ -249,7 +248,10 @@ def run_pnginfo(image): ...@@ -249,7 +248,10 @@ def run_pnginfo(image):
return '', geninfo, info return '', geninfo, info
def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name): def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
shared.state.begin()
shared.state.job = 'model-merge'
def weighted_sum(theta0, theta1, alpha): def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1) return ((1 - alpha) * theta0) + (alpha * theta1)
...@@ -261,23 +263,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam ...@@ -261,23 +263,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
primary_model_info = sd_models.checkpoints_list[primary_model_name] primary_model_info = sd_models.checkpoints_list[primary_model_name]
secondary_model_info = sd_models.checkpoints_list[secondary_model_name] secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
teritary_model_info = sd_models.checkpoints_list.get(teritary_model_name, None) tertiary_model_info = sd_models.checkpoints_list.get(tertiary_model_name, None)
result_is_inpainting_model = False
print(f"Loading {primary_model_info.filename}...")
primary_model = torch.load(primary_model_info.filename, map_location='cpu')
theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
print(f"Loading {secondary_model_info.filename}...")
secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model)
if teritary_model_info is not None:
print(f"Loading {teritary_model_info.filename}...")
teritary_model = torch.load(teritary_model_info.filename, map_location='cpu')
theta_2 = sd_models.get_state_dict_from_checkpoint(teritary_model)
else:
teritary_model = None
theta_2 = None
theta_funcs = { theta_funcs = {
"Weighted sum": (None, weighted_sum), "Weighted sum": (None, weighted_sum),
...@@ -285,9 +272,19 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam ...@@ -285,9 +272,19 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
} }
theta_func1, theta_func2 = theta_funcs[interp_method] theta_func1, theta_func2 = theta_funcs[interp_method]
print(f"Merging...") if theta_func1 and not tertiary_model_info:
shared.state.textinfo = "Failed: Interpolation method requires a tertiary model."
shared.state.end()
return ["Failed: Interpolation method requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
shared.state.textinfo = f"Loading {secondary_model_info.filename}..."
print(f"Loading {secondary_model_info.filename}...")
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
if theta_func1: if theta_func1:
print(f"Loading {tertiary_model_info.filename}...")
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
for key in tqdm.tqdm(theta_1.keys()): for key in tqdm.tqdm(theta_1.keys()):
if 'model' in key: if 'model' in key:
if key in theta_2: if key in theta_2:
...@@ -295,12 +292,33 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam ...@@ -295,12 +292,33 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
theta_1[key] = theta_func1(theta_1[key], t2) theta_1[key] = theta_func1(theta_1[key], t2)
else: else:
theta_1[key] = torch.zeros_like(theta_1[key]) theta_1[key] = torch.zeros_like(theta_1[key])
del theta_2, teritary_model del theta_2
shared.state.textinfo = f"Loading {primary_model_info.filename}..."
print(f"Loading {primary_model_info.filename}...")
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
print("Merging...")
for key in tqdm.tqdm(theta_0.keys()): for key in tqdm.tqdm(theta_0.keys()):
if 'model' in key and key in theta_1: if 'model' in key and key in theta_1:
a = theta_0[key]
b = theta_1[key]
theta_0[key] = theta_func2(theta_0[key], theta_1[key], multiplier) shared.state.textinfo = f'Merging layer {key}'
# this enables merging an inpainting model (A) with another one (B);
# where normal model would have 4 channels, for latenst space, inpainting model would
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]:
if a.shape[1] == 4 and b.shape[1] == 9:
raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.")
assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
result_is_inpainting_model = True
else:
theta_0[key] = theta_func2(a, b, multiplier)
if save_as_half: if save_as_half:
theta_0[key] = theta_0[key].half() theta_0[key] = theta_0[key].half()
...@@ -311,17 +329,35 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam ...@@ -311,17 +329,35 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
theta_0[key] = theta_1[key] theta_0[key] = theta_1[key]
if save_as_half: if save_as_half:
theta_0[key] = theta_0[key].half() theta_0[key] = theta_0[key].half()
del theta_1
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
filename = primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt' filename = \
filename = filename if custom_name == '' else (custom_name + '.ckpt') primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + \
secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + \
interp_method.replace(" ", "_") + \
'-merged.' + \
("inpainting." if result_is_inpainting_model else "") + \
checkpoint_format
filename = filename if custom_name == '' else (custom_name + '.' + checkpoint_format)
output_modelname = os.path.join(ckpt_dir, filename) output_modelname = os.path.join(ckpt_dir, filename)
shared.state.textinfo = f"Saving to {output_modelname}..."
print(f"Saving to {output_modelname}...") print(f"Saving to {output_modelname}...")
torch.save(primary_model, output_modelname)
_, extension = os.path.splitext(output_modelname)
if extension.lower() == ".safetensors":
safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"})
else:
torch.save(theta_0, output_modelname)
sd_models.list_models() sd_models.list_models()
print(f"Checkpoint saved.") print("Checkpoint saved.")
shared.state.textinfo = "Checkpoint saved to " + output_modelname
shared.state.end()
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)] return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
import base64 import base64
import io import io
import math
import os import os
import re import re
from pathlib import Path
import gradio as gr import gradio as gr
from modules.shared import script_path from modules.shared import script_path
from modules import shared from modules import shared, ui_tempdir
import tempfile import tempfile
from PIL import Image from PIL import Image
...@@ -12,6 +15,7 @@ re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)' ...@@ -12,6 +15,7 @@ re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code) re_param = re.compile(re_param_code)
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$") re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
re_imagesize = re.compile(r"^(\d+)x(\d+)$") re_imagesize = re.compile(r"^(\d+)x(\d+)$")
re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
type_of_gr_update = type(gr.update()) type_of_gr_update = type(gr.update())
paste_fields = {} paste_fields = {}
bind_list = [] bind_list = []
...@@ -33,11 +37,13 @@ def quote(text): ...@@ -33,11 +37,13 @@ def quote(text):
def image_from_url_text(filedata): def image_from_url_text(filedata):
if type(filedata) == dict and filedata["is_file"]: if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False):
filedata = filedata[0]
if type(filedata) == dict and filedata.get("is_file", False):
filename = filedata["name"] filename = filedata["name"]
tempdir = os.path.normpath(tempfile.gettempdir()) is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
normfn = os.path.normpath(filename) assert is_in_right_dir, 'trying to open image file outside of allowed directories'
assert normfn.startswith(tempdir), 'trying to open image file not in temporary directory'
return Image.open(filename) return Image.open(filename)
...@@ -73,7 +79,10 @@ def integrate_settings_paste_fields(component_dict): ...@@ -73,7 +79,10 @@ def integrate_settings_paste_fields(component_dict):
'sd_hypernetwork': 'Hypernet', 'sd_hypernetwork': 'Hypernet',
'sd_hypernetwork_strength': 'Hypernet strength', 'sd_hypernetwork_strength': 'Hypernet strength',
'CLIP_stop_at_last_layers': 'Clip skip', 'CLIP_stop_at_last_layers': 'Clip skip',
'inpainting_mask_weight': 'Conditional mask weight',
'sd_model_checkpoint': 'Model hash', 'sd_model_checkpoint': 'Model hash',
'eta_noise_seed_delta': 'ENSD',
'initial_noise_multiplier': 'Noise multiplier',
} }
settings_paste_fields = [ settings_paste_fields = [
(component_dict[k], lambda d, k=k, v=v: ui.apply_setting(k, d.get(v, None))) (component_dict[k], lambda d, k=k, v=v: ui.apply_setting(k, d.get(v, None)))
...@@ -88,7 +97,7 @@ def integrate_settings_paste_fields(component_dict): ...@@ -88,7 +97,7 @@ def integrate_settings_paste_fields(component_dict):
def create_buttons(tabs_list): def create_buttons(tabs_list):
buttons = {} buttons = {}
for tab in tabs_list: for tab in tabs_list:
buttons[tab] = gr.Button(f"Send to {tab}") buttons[tab] = gr.Button(f"Send to {tab}", elem_id=f"{tab}_tab")
return buttons return buttons
...@@ -97,36 +106,57 @@ def bind_buttons(buttons, send_image, send_generate_info): ...@@ -97,36 +106,57 @@ def bind_buttons(buttons, send_image, send_generate_info):
bind_list.append([buttons, send_image, send_generate_info]) bind_list.append([buttons, send_image, send_generate_info])
def send_image_and_dimensions(x):
if isinstance(x, Image.Image):
img = x
else:
img = image_from_url_text(x)
if shared.opts.send_size and isinstance(img, Image.Image):
w = img.width
h = img.height
else:
w = gr.update()
h = gr.update()
return img, w, h
def run_bind(): def run_bind():
for buttons, send_image, send_generate_info in bind_list: for buttons, source_image_component, send_generate_info in bind_list:
for tab in buttons: for tab in buttons:
button = buttons[tab] button = buttons[tab]
if send_image and paste_fields[tab]["init_img"]: destination_image_component = paste_fields[tab]["init_img"]
if type(send_image) == gr.Gallery: fields = paste_fields[tab]["fields"]
button.click(
fn=lambda x: image_from_url_text(x), destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
_js="extract_image_from_gallery", destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
inputs=[send_image],
outputs=[paste_fields[tab]["init_img"]], if source_image_component and destination_image_component:
) if isinstance(source_image_component, gr.Gallery):
func = send_image_and_dimensions if destination_width_component else image_from_url_text
jsfunc = "extract_image_from_gallery"
else: else:
func = send_image_and_dimensions if destination_width_component else lambda x: x
jsfunc = None
button.click( button.click(
fn=lambda x: x, fn=func,
inputs=[send_image], _js=jsfunc,
outputs=[paste_fields[tab]["init_img"]], inputs=[source_image_component],
outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
) )
if send_generate_info and paste_fields[tab]["fields"] is not None: if send_generate_info and fields is not None:
if send_generate_info in paste_fields: if send_generate_info in paste_fields:
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Size-1', 'Size-2'] + (["Seed"] if shared.opts.send_seed else []) paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
button.click( button.click(
fn=lambda *x: x, fn=lambda *x: x,
inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names], inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names],
outputs=[field for field, name in paste_fields[tab]["fields"] if name in paste_field_names], outputs=[field for field, name in fields if name in paste_field_names],
) )
else: else:
connect_paste(button, paste_fields[tab]["fields"], send_generate_info) connect_paste(button, fields, send_generate_info)
button.click( button.click(
fn=None, fn=None,
...@@ -136,6 +166,59 @@ def run_bind(): ...@@ -136,6 +166,59 @@ def run_bind():
) )
def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
"""Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config
parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to.
If the infotext has no hash, then a hypernet with the same name will be selected instead.
"""
hypernet_name = hypernet_name.lower()
if hypernet_hash is not None:
# Try to match the hash in the name
for hypernet_key in shared.hypernetworks.keys():
result = re_hypernet_hash.search(hypernet_key)
if result is not None and result[1] == hypernet_hash:
return hypernet_key
else:
# Fall back to a hypernet with the same name
for hypernet_key in shared.hypernetworks.keys():
if hypernet_key.lower().startswith(hypernet_name):
return hypernet_key
return None
def restore_old_hires_fix_params(res):
"""for infotexts that specify old First pass size parameter, convert it into
width, height, and hr scale"""
firstpass_width = res.get('First pass size-1', None)
firstpass_height = res.get('First pass size-2', None)
if firstpass_width is None or firstpass_height is None:
return
firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height)
width = int(res.get("Size-1", 512))
height = int(res.get("Size-2", 512))
if firstpass_width == 0 or firstpass_height == 0:
# old algorithm for auto-calculating first pass size
desired_pixel_count = 512 * 512
actual_pixel_count = width * height
scale = math.sqrt(desired_pixel_count / actual_pixel_count)
firstpass_width = math.ceil(scale * width / 64) * 64
firstpass_height = math.ceil(scale * height / 64) * 64
hr_scale = width / firstpass_width if firstpass_width > 0 else height / firstpass_height
res['Size-1'] = firstpass_width
res['Size-2'] = firstpass_height
res['Hires upscale'] = hr_scale
def parse_generation_parameters(x: str): def parse_generation_parameters(x: str):
"""parses generation parameters string, the one you see in text field under the picture in UI: """parses generation parameters string, the one you see in text field under the picture in UI:
``` ```
...@@ -181,6 +264,20 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model ...@@ -181,6 +264,20 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
else: else:
res[k] = v res[k] = v
# Missing CLIP skip means it was set to 1 (the default)
if "Clip skip" not in res:
res["Clip skip"] = "1"
if "Hypernet strength" not in res:
res["Hypernet strength"] = "1"
if "Hypernet" in res:
hypernet_name = res["Hypernet"]
hypernet_hash = res.get("Hypernet hash", None)
res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash)
restore_old_hires_fix_params(res)
return res return res
......
...@@ -36,7 +36,9 @@ def gfpgann(): ...@@ -36,7 +36,9 @@ def gfpgann():
else: else:
print("Unable to load gfpgan model!") print("Unable to load gfpgan model!")
return None return None
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None) if hasattr(facexlib.detection.retinaface, 'device'):
facexlib.detection.retinaface.device = devices.device_gfpgan
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
loaded_gfpgan_model = model loaded_gfpgan_model = model
return model return model
......
...@@ -12,7 +12,7 @@ import torch ...@@ -12,7 +12,7 @@ import torch
import tqdm import tqdm
from einops import rearrange, repeat from einops import rearrange, repeat
from ldm.util import default from ldm.util import default
from modules import devices, processing, sd_models, shared from modules import devices, processing, sd_models, shared, sd_samplers
from modules.textual_inversion import textual_inversion from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum from torch import einsum
...@@ -38,7 +38,7 @@ class HypernetworkModule(torch.nn.Module): ...@@ -38,7 +38,7 @@ class HypernetworkModule(torch.nn.Module):
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'}) activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=True): add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=False):
super().__init__() super().__init__()
assert layer_structure is not None, "layer_structure must not be None" assert layer_structure is not None, "layer_structure must not be None"
...@@ -154,16 +154,28 @@ class Hypernetwork: ...@@ -154,16 +154,28 @@ class Hypernetwork:
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout), self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
) )
self.eval_mode()
def weights(self): def weights(self):
res = [] res = []
for k, layers in self.layers.items():
for layer in layers:
res += layer.parameters()
return res
def train_mode(self):
for k, layers in self.layers.items(): for k, layers in self.layers.items():
for layer in layers: for layer in layers:
layer.train() layer.train()
res += layer.trainables() for param in layer.parameters():
param.requires_grad = True
return res def eval_mode(self):
for k, layers in self.layers.items():
for layer in layers:
layer.eval()
for param in layer.parameters():
param.requires_grad = False
def save(self, filename): def save(self, filename):
state_dict = {} state_dict = {}
...@@ -265,7 +277,7 @@ def load_hypernetwork(filename): ...@@ -265,7 +277,7 @@ def load_hypernetwork(filename):
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
else: else:
if shared.loaded_hypernetwork is not None: if shared.loaded_hypernetwork is not None:
print(f"Unloading hypernetwork") print("Unloading hypernetwork")
shared.loaded_hypernetwork = None shared.loaded_hypernetwork = None
...@@ -366,19 +378,44 @@ def report_statistics(loss_info:dict): ...@@ -366,19 +378,44 @@ def report_statistics(loss_info:dict):
print(e) print(e)
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
# Remove illegal characters from name.
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
if not overwrite_old:
assert not os.path.exists(fn), f"file {fn} already exists"
if type(layer_structure) == str:
layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
name=name,
enable_sizes=[int(x) for x in enable_sizes],
layer_structure=layer_structure,
activation_func=activation_func,
weight_init=weight_init,
add_layer_norm=add_layer_norm,
use_dropout=use_dropout,
)
hypernet.save(fn)
shared.reload_hypernetworks()
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem. # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images from modules import images
save_hypernetwork_every = save_hypernetwork_every or 0 save_hypernetwork_every = save_hypernetwork_every or 0
create_image_every = create_image_every or 0 create_image_every = create_image_every or 0
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork") textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
path = shared.hypernetworks.get(hypernetwork_name, None) path = shared.hypernetworks.get(hypernetwork_name, None)
shared.loaded_hypernetwork = Hypernetwork() shared.loaded_hypernetwork = Hypernetwork()
shared.loaded_hypernetwork.load(path) shared.loaded_hypernetwork.load(path)
shared.state.job = "train-hypernetwork"
shared.state.textinfo = "Initializing hypernetwork training..." shared.state.textinfo = "Initializing hypernetwork training..."
shared.state.job_count = steps shared.state.job_count = steps
...@@ -403,38 +440,37 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log ...@@ -403,38 +440,37 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
hypernetwork = shared.loaded_hypernetwork hypernetwork = shared.loaded_hypernetwork
checkpoint = sd_models.select_checkpoint() checkpoint = sd_models.select_checkpoint()
ititial_step = hypernetwork.step or 0 initial_step = hypernetwork.step or 0
if ititial_step >= steps: if initial_step >= steps:
shared.state.textinfo = f"Model has already been trained beyond specified max steps" shared.state.textinfo = "Model has already been trained beyond specified max steps"
return hypernetwork, filename return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \ clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
None
if clip_grad: if clip_grad:
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False) clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
# dataset loading may take a while, so input validations and early returns should be done before this # dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size) pin_memory = shared.opts.pin_memory
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
latent_sampling_method = ds.latent_sampling_method
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
old_parallel_processing_allowed = shared.parallel_processing_allowed
if unload: if unload:
shared.parallel_processing_allowed = False
shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu)
size = len(ds.indexes)
loss_dict = defaultdict(lambda : deque(maxlen = 1024))
losses = torch.zeros((size,))
previous_mean_losses = [0]
previous_mean_loss = 0
print("Mean loss of {} elements".format(size))
weights = hypernetwork.weights() weights = hypernetwork.weights()
for weight in weights: hypernetwork.train_mode()
weight.requires_grad = True
# Here we use optimizer from saved HN, or we can specify as UI option. # Here we use optimizer from saved HN, or we can specify as UI option.
if hypernetwork.optimizer_name in optimizer_dict: if hypernetwork.optimizer_name in optimizer_dict:
...@@ -452,68 +488,84 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log ...@@ -452,68 +488,84 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
print("Cannot resume from saved optimizer!") print("Cannot resume from saved optimizer!")
print(e) print(e)
scaler = torch.cuda.amp.GradScaler()
batch_size = ds.batch_size
gradient_step = ds.gradient_step
# n steps = batch_size * gradient_step * n image processed
steps_per_epoch = len(ds) // batch_size // gradient_step
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
loss_step = 0
_loss_step = 0 #internal
# size = len(ds.indexes)
# loss_dict = defaultdict(lambda : deque(maxlen = 1024))
# losses = torch.zeros((size,))
# previous_mean_losses = [0]
# previous_mean_loss = 0
# print("Mean loss of {} elements".format(size))
steps_without_grad = 0 steps_without_grad = 0
last_saved_file = "<none>" last_saved_file = "<none>"
last_saved_image = "<none>" last_saved_image = "<none>"
forced_filename = "<none>" forced_filename = "<none>"
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) pbar = tqdm.tqdm(total=steps - initial_step)
for i, entries in pbar: try:
hypernetwork.step = i + ititial_step for i in range((steps-initial_step) * gradient_step):
if len(loss_dict) > 0: if scheduler.finished:
previous_mean_losses = [i[-1] for i in loss_dict.values()] break
previous_mean_loss = mean(previous_mean_losses) if shared.state.interrupted:
break
for j, batch in enumerate(dl):
# works as a drop_last=True for gradient accumulation
if j == max_steps_per_epoch:
break
scheduler.apply(optimizer, hypernetwork.step) scheduler.apply(optimizer, hypernetwork.step)
if scheduler.finished: if scheduler.finished:
break break
if shared.state.interrupted: if shared.state.interrupted:
break break
if clip_grad: if clip_grad:
clip_grad_sched.step(hypernetwork.step) clip_grad_sched.step(hypernetwork.step)
with torch.autocast("cuda"): with devices.autocast():
c = stack_conds([entry.cond for entry in entries]).to(devices.device) x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device) if tag_drop_out != 0 or shuffle_tags:
x = torch.stack([entry.latent for entry in entries]).to(devices.device) shared.sd_model.cond_stage_model.to(devices.device)
loss = shared.sd_model(x, c)[0] c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
shared.sd_model.cond_stage_model.to(devices.cpu)
else:
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
loss = shared.sd_model(x, c)[0] / gradient_step
del x del x
del c del c
losses[hypernetwork.step % losses.shape[0]] = loss.item() _loss_step += loss.item()
for entry in entries: scaler.scale(loss).backward()
loss_dict[entry.filename].append(loss.item())
optimizer.zero_grad()
weights[0].grad = None
loss.backward()
if weights[0].grad is None: # go back until we reach gradient accumulation steps
steps_without_grad += 1 if (j + 1) % gradient_step != 0:
else: continue
steps_without_grad = 0
assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
if clip_grad: if clip_grad:
clip_grad(weights, clip_grad_sched.learn_rate) clip_grad(weights, clip_grad_sched.learn_rate)
optimizer.step() scaler.step(optimizer)
scaler.update()
hypernetwork.step += 1
pbar.update()
optimizer.zero_grad(set_to_none=True)
loss_step = _loss_step
_loss_step = 0
steps_done = hypernetwork.step + 1 steps_done = hypernetwork.step + 1
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]): epoch_num = hypernetwork.step // steps_per_epoch
raise RuntimeError("Loss diverged.") epoch_step = hypernetwork.step % steps_per_epoch
if len(previous_mean_losses) > 1:
std = stdev(previous_mean_losses)
else:
std = 0
dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
pbar.set_description(dataset_loss_info)
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0: if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
# Before saving, change name to match current checkpoint. # Before saving, change name to match current checkpoint.
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}' hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
...@@ -524,16 +576,15 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log ...@@ -524,16 +576,15 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file) save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory. hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
"loss": f"{previous_mean_loss:.7f}", "loss": f"{loss_step:.7f}",
"learn_rate": scheduler.learn_rate "learn_rate": scheduler.learn_rate
}) })
if images_dir is not None and steps_done % create_image_every == 0: if images_dir is not None and steps_done % create_image_every == 0:
forced_filename = f'{hypernetwork_name}-{steps_done}' forced_filename = f'{hypernetwork_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename) last_saved_image = os.path.join(images_dir, forced_filename)
hypernetwork.eval_mode()
optimizer.zero_grad()
shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device)
...@@ -547,24 +598,26 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log ...@@ -547,24 +598,26 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
p.prompt = preview_prompt p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt p.negative_prompt = preview_negative_prompt
p.steps = preview_steps p.steps = preview_steps
p.sampler_index = preview_sampler_index p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
p.cfg_scale = preview_cfg_scale p.cfg_scale = preview_cfg_scale
p.seed = preview_seed p.seed = preview_seed
p.width = preview_width p.width = preview_width
p.height = preview_height p.height = preview_height
else: else:
p.prompt = entries[0].cond_text p.prompt = batch.cond_text[0]
p.steps = 20 p.steps = 20
p.width = training_width
p.height = training_height
preview_text = p.prompt preview_text = p.prompt
processed = processing.process_images(p) processed = processing.process_images(p)
image = processed.images[0] if len(processed.images)>0 else None image = processed.images[0] if len(processed.images) > 0 else None
if unload: if unload:
shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu)
hypernetwork.train_mode()
if image is not None: if image is not None:
shared.state.current_image = image shared.state.current_image = image
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
...@@ -574,23 +627,33 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log ...@@ -574,23 +627,33 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
shared.state.textinfo = f""" shared.state.textinfo = f"""
<p> <p>
Loss: {previous_mean_loss:.7f}<br/> Loss: {loss_step:.7f}<br/>
Step: {hypernetwork.step}<br/> Step: {steps_done}<br/>
Last prompt: {html.escape(entries[0].cond_text)}<br/> Last prompt: {html.escape(batch.cond_text[0])}<br/>
Last saved hypernetwork: {html.escape(last_saved_file)}<br/> Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/> Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>
""" """
except Exception:
report_statistics(loss_dict) print(traceback.format_exc(), file=sys.stderr)
finally:
pbar.leave = False
pbar.close()
hypernetwork.eval_mode()
#report_statistics(loss_dict)
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
hypernetwork.optimizer_name = optimizer_name hypernetwork.optimizer_name = optimizer_name
if shared.opts.save_optimizer_state: if shared.opts.save_optimizer_state:
hypernetwork.optimizer_state_dict = optimizer.state_dict() hypernetwork.optimizer_state_dict = optimizer.state_dict()
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename) save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
del optimizer del optimizer
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory. hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
shared.parallel_processing_allowed = old_parallel_processing_allowed
return hypernetwork, filename return hypernetwork, filename
def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename): def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
......
...@@ -3,39 +3,16 @@ import os ...@@ -3,39 +3,16 @@ import os
import re import re
import gradio as gr import gradio as gr
import modules.textual_inversion.preprocess import modules.hypernetworks.hypernetwork
import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack, shared from modules import devices, sd_hijack, shared
from modules.hypernetworks import hypernetwork
not_available = ["hardswish", "multiheadattention"] not_available = ["hardswish", "multiheadattention"]
keys = list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available) keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False): def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
# Remove illegal characters from name. filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout)
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""
if not overwrite_old:
assert not os.path.exists(fn), f"file {fn} already exists"
if type(layer_structure) == str:
layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
name=name,
enable_sizes=[int(x) for x in enable_sizes],
layer_structure=layer_structure,
activation_func=activation_func,
weight_init=weight_init,
add_layer_norm=add_layer_norm,
use_dropout=use_dropout,
)
hypernet.save(fn)
shared.reload_hypernetworks()
return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", ""
def train_hypernetwork(*args): def train_hypernetwork(*args):
......
...@@ -15,6 +15,7 @@ import piexif.helper ...@@ -15,6 +15,7 @@ import piexif.helper
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
from fonts.ttf import Roboto from fonts.ttf import Roboto
import string import string
import json
from modules import sd_samplers, shared, script_callbacks from modules import sd_samplers, shared, script_callbacks
from modules.shared import opts, cmd_opts from modules.shared import opts, cmd_opts
...@@ -38,11 +39,14 @@ def image_grid(imgs, batch_size=1, rows=None): ...@@ -38,11 +39,14 @@ def image_grid(imgs, batch_size=1, rows=None):
cols = math.ceil(len(imgs) / rows) cols = math.ceil(len(imgs) / rows)
params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)
script_callbacks.image_grid_callback(params)
w, h = imgs[0].size w, h = imgs[0].size
grid = Image.new('RGB', size=(cols * w, rows * h), color='black') grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color='black')
for i, img in enumerate(imgs): for i, img in enumerate(params.imgs):
grid.paste(img, box=(i % cols * w, i // cols * h)) grid.paste(img, box=(i % params.cols * w, i // params.cols * h))
return grid return grid
...@@ -135,8 +139,19 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts): ...@@ -135,8 +139,19 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
lines.append(word) lines.append(word)
return lines return lines
def draw_texts(drawing, draw_x, draw_y, lines): def get_font(fontsize):
try:
return ImageFont.truetype(opts.font or Roboto, fontsize)
except Exception:
return ImageFont.truetype(Roboto, fontsize)
def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
for i, line in enumerate(lines): for i, line in enumerate(lines):
fnt = initial_fnt
fontsize = initial_fontsize
while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
fontsize -= 1
fnt = get_font(fontsize)
drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center") drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
if not line.is_active: if not line.is_active:
...@@ -147,10 +162,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts): ...@@ -147,10 +162,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
fontsize = (width + height) // 25 fontsize = (width + height) // 25
line_spacing = fontsize // 2 line_spacing = fontsize // 2
try: fnt = get_font(fontsize)
fnt = ImageFont.truetype(opts.font or Roboto, fontsize)
except Exception:
fnt = ImageFont.truetype(Roboto, fontsize)
color_active = (0, 0, 0) color_active = (0, 0, 0)
color_inactive = (153, 153, 153) color_inactive = (153, 153, 153)
...@@ -177,6 +189,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts): ...@@ -177,6 +189,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
for line in texts: for line in texts:
bbox = calc_d.multiline_textbbox((0, 0), line.text, font=fnt) bbox = calc_d.multiline_textbbox((0, 0), line.text, font=fnt)
line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1]) line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
line.allowed_width = allowed_width
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts] hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in
...@@ -193,13 +206,13 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts): ...@@ -193,13 +206,13 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
x = pad_left + width * col + width / 2 x = pad_left + width * col + width / 2
y = pad_top / 2 - hor_text_heights[col] / 2 y = pad_top / 2 - hor_text_heights[col] / 2
draw_texts(d, x, y, hor_texts[col]) draw_texts(d, x, y, hor_texts[col], fnt, fontsize)
for row in range(rows): for row in range(rows):
x = pad_left / 2 x = pad_left / 2
y = pad_top + height * row + height / 2 - ver_text_heights[row] / 2 y = pad_top + height * row + height / 2 - ver_text_heights[row] / 2
draw_texts(d, x, y, ver_texts[row]) draw_texts(d, x, y, ver_texts[row], fnt, fontsize)
return result return result
...@@ -217,16 +230,32 @@ def draw_prompt_matrix(im, width, height, all_prompts): ...@@ -217,16 +230,32 @@ def draw_prompt_matrix(im, width, height, all_prompts):
return draw_grid_annotations(im, width, height, hor_texts, ver_texts) return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
def resize_image(resize_mode, im, width, height): def resize_image(resize_mode, im, width, height, upscaler_name=None):
"""
Resizes an image with the specified resize_mode, width, and height.
Args:
resize_mode: The mode to use when resizing the image.
0: Resize the image to the specified width and height.
1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
im: The image to resize.
width: The width to resize the image to.
height: The height to resize the image to.
upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img.
"""
upscaler_name = upscaler_name or opts.upscaler_for_img2img
def resize(im, w, h): def resize(im, w, h):
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L': if upscaler_name is None or upscaler_name == "None" or im.mode == 'L':
return im.resize((w, h), resample=LANCZOS) return im.resize((w, h), resample=LANCZOS)
scale = max(w / im.width, h / im.height) scale = max(w / im.width, h / im.height)
if scale > 1.0: if scale > 1.0:
upscalers = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img] upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]
assert len(upscalers) > 0, f"could not find upscaler named {opts.upscaler_for_img2img}" assert len(upscalers) > 0, f"could not find upscaler named {upscaler_name}"
upscaler = upscalers[0] upscaler = upscalers[0]
im = upscaler.scaler.upscale(im, scale, upscaler.data_path) im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
...@@ -303,8 +332,9 @@ class FilenameGenerator: ...@@ -303,8 +332,9 @@ class FilenameGenerator:
'width': lambda self: self.image.width, 'width': lambda self: self.image.width,
'height': lambda self: self.image.height, 'height': lambda self: self.image.height,
'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False), 'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
'sampler': lambda self: self.p and sanitize_filename_part(sd_samplers.samplers[self.p.sampler_index].name, replace_spaces=False), 'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash), 'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
'model_name': lambda self: sanitize_filename_part(shared.sd_model.sd_checkpoint_info.model_name, replace_spaces=False),
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'), 'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>] 'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp), 'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
...@@ -499,14 +529,11 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i ...@@ -499,14 +529,11 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
image = params.image image = params.image
fullfn = params.filename fullfn = params.filename
info = params.pnginfo.get(pnginfo_section_name, None) info = params.pnginfo.get(pnginfo_section_name, None)
fullfn_without_extension, extension = os.path.splitext(params.filename)
def exif_bytes(): def _atomically_save_image(image_to_save, filename_without_extension, extension):
return piexif.dump({ # save image with .tmp extension to avoid race condition when another process detects new image in the directory
"Exif": { temp_file_path = filename_without_extension + ".tmp"
piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(info or "", encoding="unicode") image_format = Image.registered_extensions()[extension]
},
})
if extension.lower() == '.png': if extension.lower() == '.png':
pnginfo_data = PngImagePlugin.PngInfo() pnginfo_data = PngImagePlugin.PngInfo()
...@@ -514,15 +541,32 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i ...@@ -514,15 +541,32 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
for k, v in params.pnginfo.items(): for k, v in params.pnginfo.items():
pnginfo_data.add_text(k, str(v)) pnginfo_data.add_text(k, str(v))
image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo_data) image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
elif extension.lower() in (".jpg", ".jpeg", ".webp"): elif extension.lower() in (".jpg", ".jpeg", ".webp"):
image.save(fullfn, quality=opts.jpeg_quality) if image_to_save.mode == 'RGBA':
image_to_save = image_to_save.convert("RGB")
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
if opts.enable_pnginfo and info is not None: if opts.enable_pnginfo and info is not None:
piexif.insert(exif_bytes(), fullfn) exif_bytes = piexif.dump({
"Exif": {
piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(info or "", encoding="unicode")
},
})
piexif.insert(exif_bytes, temp_file_path)
else: else:
image.save(fullfn, quality=opts.jpeg_quality) image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
# atomically rename the file with correct extension
os.replace(temp_file_path, filename_without_extension + extension)
fullfn_without_extension, extension = os.path.splitext(params.filename)
_atomically_save_image(image, fullfn_without_extension, extension)
image.already_saved_as = fullfn
target_side_length = 4000 target_side_length = 4000
oversize = image.width > target_side_length or image.height > target_side_length oversize = image.width > target_side_length or image.height > target_side_length
...@@ -534,9 +578,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i ...@@ -534,9 +578,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
elif oversize: elif oversize:
image = image.resize((image.width * target_side_length // image.height, target_side_length), LANCZOS) image = image.resize((image.width * target_side_length // image.height, target_side_length), LANCZOS)
image.save(fullfn_without_extension + ".jpg", quality=opts.jpeg_quality) _atomically_save_image(image, fullfn_without_extension, ".jpg")
if opts.enable_pnginfo and info is not None:
piexif.insert(exif_bytes(), fullfn_without_extension + ".jpg")
if opts.save_txt and info is not None: if opts.save_txt and info is not None:
txt_fullfn = f"{fullfn_without_extension}.txt" txt_fullfn = f"{fullfn_without_extension}.txt"
...@@ -550,10 +592,45 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i ...@@ -550,10 +592,45 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
return fullfn, txt_fullfn return fullfn, txt_fullfn
def read_info_from_image(image):
items = image.info or {}
geninfo = items.pop('parameters', None)
if "exif" in items:
exif = piexif.load(items["exif"])
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
try:
exif_comment = piexif.helper.UserComment.load(exif_comment)
except ValueError:
exif_comment = exif_comment.decode('utf8', errors="ignore")
items['exif comment'] = exif_comment
geninfo = exif_comment
for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
'loop', 'background', 'timestamp', 'duration']:
items.pop(field, None)
if items.get("Software", None) == "NovelAI":
try:
json_info = json.loads(items["Comment"])
sampler = sd_samplers.samplers_map.get(json_info["sampler"], "Euler a")
geninfo = f"""{items["Description"]}
Negative prompt: {json_info["uc"]}
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
except Exception:
print("Error parsing NovelAI image generation parameters:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return geninfo, items
def image_data(data): def image_data(data):
try: try:
image = Image.open(io.BytesIO(data)) image = Image.open(io.BytesIO(data))
textinfo = image.text["parameters"] textinfo, _ = read_info_from_image(image)
return textinfo, None return textinfo, None
except Exception: except Exception:
pass pass
...@@ -567,3 +644,14 @@ def image_data(data): ...@@ -567,3 +644,14 @@ def image_data(data):
pass pass
return '', None return '', None
def flatten(img, bgcolor):
"""replaces transparency with bgcolor (example: "#ffffff"), returning an RGB mode image with no transparency"""
if img.mode == "RGBA":
background = Image.new('RGBA', img.size, bgcolor)
background.paste(img, mask=img)
img = background
return img.convert('RGB')
...@@ -4,9 +4,9 @@ import sys ...@@ -4,9 +4,9 @@ import sys
import traceback import traceback
import numpy as np import numpy as np
from PIL import Image, ImageOps, ImageChops from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops
from modules import devices from modules import devices, sd_samplers
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state from modules.shared import opts, state
import modules.shared as shared import modules.shared as shared
...@@ -59,18 +59,31 @@ def process_batch(p, input_dir, output_dir, args): ...@@ -59,18 +59,31 @@ def process_batch(p, input_dir, output_dir, args):
processed_image.save(os.path.join(output_dir, filename)) processed_image.save(os.path.join(output_dir, filename))
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_with_mask_orig, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
is_inpaint = mode == 1 is_inpaint = mode == 1
is_batch = mode == 2 is_batch = mode == 2
if is_inpaint: if is_inpaint:
# Drawn mask # Drawn mask
if mask_mode == 0: if mask_mode == 0:
image = init_img_with_mask['image'] is_mask_sketch = isinstance(init_img_with_mask, dict)
mask = init_img_with_mask['mask'] is_mask_paint = not is_mask_sketch
if is_mask_sketch:
# Sketch: mask iff. not transparent
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1') alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L') mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
image = image.convert('RGB') else:
# Color-sketch: mask iff. painted over
image = init_img_with_mask
orig = init_img_with_mask_orig or init_img_with_mask
pred = np.any(np.array(image) != np.array(orig), axis=-1)
mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
blur = ImageFilter.GaussianBlur(mask_blur)
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
image = image.convert("RGB")
# Uploaded mask # Uploaded mask
else: else:
image = init_img_inpaint image = init_img_inpaint
...@@ -99,7 +112,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro ...@@ -99,7 +112,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
seed_resize_from_h=seed_resize_from_h, seed_resize_from_h=seed_resize_from_h,
seed_resize_from_w=seed_resize_from_w, seed_resize_from_w=seed_resize_from_w,
seed_enable_extras=seed_enable_extras, seed_enable_extras=seed_enable_extras,
sampler_index=sampler_index, sampler_name=sd_samplers.samplers_for_img2img[sampler_index].name,
batch_size=batch_size, batch_size=batch_size,
n_iter=n_iter, n_iter=n_iter,
steps=steps, steps=steps,
...@@ -149,4 +162,4 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro ...@@ -149,4 +162,4 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
if opts.do_not_show_images: if opts.do_not_show_images:
processed.images = [] processed.images = []
return processed.images, generation_info_js, plaintext_to_html(processed.info) return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
import sys
# this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it
if "--xformers" not in "".join(sys.argv):
sys.modules["xformers"] = None
import contextlib
import os import os
import sys import sys
import traceback import traceback
...@@ -11,10 +10,9 @@ from torchvision import transforms ...@@ -11,10 +10,9 @@ from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
import modules.shared as shared import modules.shared as shared
from modules import devices, paths, lowvram from modules import devices, paths, lowvram, modelloader
blip_image_eval_size = 384 blip_image_eval_size = 384
blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
clip_model_name = 'ViT-L/14' clip_model_name = 'ViT-L/14'
Category = namedtuple("Category", ["name", "topn", "items"]) Category = namedtuple("Category", ["name", "topn", "items"])
...@@ -47,7 +45,14 @@ class InterrogateModels: ...@@ -47,7 +45,14 @@ class InterrogateModels:
def load_blip_model(self): def load_blip_model(self):
import models.blip import models.blip
blip_model = models.blip.blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json")) files = modelloader.load_models(
model_path=os.path.join(paths.models_path, "BLIP"),
model_url='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
ext_filter=[".pth"],
download_name='model_base_caption_capfilt_large.pth',
)
blip_model = models.blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
blip_model.eval() blip_model.eval()
return blip_model return blip_model
...@@ -130,8 +135,9 @@ class InterrogateModels: ...@@ -130,8 +135,9 @@ class InterrogateModels:
return caption[0] return caption[0]
def interrogate(self, pil_image): def interrogate(self, pil_image):
res = None res = ""
shared.state.begin()
shared.state.job = 'interrogate'
try: try:
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
...@@ -148,8 +154,7 @@ class InterrogateModels: ...@@ -148,8 +154,7 @@ class InterrogateModels:
clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate) clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext with torch.no_grad(), devices.autocast():
with torch.no_grad(), precision_scope("cuda"):
image_features = self.clip_model.encode_image(clip_image).type(self.dtype) image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
image_features /= image_features.norm(dim=-1, keepdim=True) image_features /= image_features.norm(dim=-1, keepdim=True)
...@@ -168,10 +173,11 @@ class InterrogateModels: ...@@ -168,10 +173,11 @@ class InterrogateModels:
res += ", " + match res += ", " + match
except Exception: except Exception:
print(f"Error interrogating", file=sys.stderr) print("Error interrogating", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
res += "<error>" res += "<error>"
self.unload() self.unload()
shared.state.end()
return res return res
...@@ -51,20 +51,30 @@ def setup_for_low_vram(sd_model, use_medvram): ...@@ -51,20 +51,30 @@ def setup_for_low_vram(sd_model, use_medvram):
send_me_to_gpu(first_stage_model, None) send_me_to_gpu(first_stage_model, None)
return first_stage_model_decode(z) return first_stage_model_decode(z)
# remove three big modules, cond, first_stage, and unet from the model and then # for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field
if hasattr(sd_model.cond_stage_model, 'model'):
sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
# remove four big modules, cond, first_stage, depth (if applicable), and unet from the model and then
# send the model to GPU. Then put modules back. the modules will be in CPU. # send the model to GPU. Then put modules back. the modules will be in CPU.
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), sd_model.model
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = None, None, None, None
sd_model.to(devices.device) sd_model.to(devices.device)
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = stored
# register hooks for those the first two models # register hooks for those the first three models
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu) sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu) sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.encode = first_stage_model_encode_wrap sd_model.first_stage_model.encode = first_stage_model_encode_wrap
sd_model.first_stage_model.decode = first_stage_model_decode_wrap sd_model.first_stage_model.decode = first_stage_model_decode_wrap
if sd_model.depth_model:
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
if hasattr(sd_model.cond_stage_model, 'model'):
sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer
del sd_model.cond_stage_model.transformer
if use_medvram: if use_medvram:
sd_model.model.register_forward_pre_hook(send_me_to_gpu) sd_model.model.register_forward_pre_hook(send_me_to_gpu)
else: else:
......
...@@ -71,10 +71,13 @@ class MemUsageMonitor(threading.Thread): ...@@ -71,10 +71,13 @@ class MemUsageMonitor(threading.Thread):
def read(self): def read(self):
if not self.disabled: if not self.disabled:
free, total = torch.cuda.mem_get_info() free, total = torch.cuda.mem_get_info()
self.data["free"] = free
self.data["total"] = total self.data["total"] = total
torch_stats = torch.cuda.memory_stats(self.device) torch_stats = torch.cuda.memory_stats(self.device)
self.data["active"] = torch_stats["active.all.current"]
self.data["active_peak"] = torch_stats["active_bytes.all.peak"] self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
self.data["reserved"] = torch_stats["reserved_bytes.all.current"]
self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"] self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
self.data["system_peak"] = total - self.data["min_free"] self.data["system_peak"] = total - self.data["min_free"]
......
...@@ -82,6 +82,7 @@ def cleanup_models(): ...@@ -82,6 +82,7 @@ def cleanup_models():
src_path = models_path src_path = models_path
dest_path = os.path.join(models_path, "Stable-diffusion") dest_path = os.path.join(models_path, "Stable-diffusion")
move_files(src_path, dest_path, ".ckpt") move_files(src_path, dest_path, ".ckpt")
move_files(src_path, dest_path, ".safetensors")
src_path = os.path.join(root_path, "ESRGAN") src_path = os.path.join(root_path, "ESRGAN")
dest_path = os.path.join(models_path, "ESRGAN") dest_path = os.path.join(models_path, "ESRGAN")
move_files(src_path, dest_path) move_files(src_path, dest_path)
...@@ -122,11 +123,27 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None): ...@@ -122,11 +123,27 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
pass pass
builtin_upscaler_classes = []
forbidden_upscaler_classes = set()
def list_builtin_upscalers():
load_upscalers()
builtin_upscaler_classes.clear()
builtin_upscaler_classes.extend(Upscaler.__subclasses__())
def forbid_loaded_nonbuiltin_upscalers():
for cls in Upscaler.__subclasses__():
if cls not in builtin_upscaler_classes:
forbidden_upscaler_classes.add(cls)
def load_upscalers(): def load_upscalers():
sd = shared.script_path
# We can only do this 'magic' method to dynamically load upscalers if they are referenced, # We can only do this 'magic' method to dynamically load upscalers if they are referenced,
# so we'll try to import any _model.py files before looking in __subclasses__ # so we'll try to import any _model.py files before looking in __subclasses__
modules_dir = os.path.join(sd, "modules") modules_dir = os.path.join(shared.script_path, "modules")
for file in os.listdir(modules_dir): for file in os.listdir(modules_dir):
if "_model.py" in file: if "_model.py" in file:
model_name = file.replace("_model.py", "") model_name = file.replace("_model.py", "")
...@@ -135,22 +152,16 @@ def load_upscalers(): ...@@ -135,22 +152,16 @@ def load_upscalers():
importlib.import_module(full_model) importlib.import_module(full_model)
except: except:
pass pass
datas = [] datas = []
c_o = vars(shared.cmd_opts) commandline_options = vars(shared.cmd_opts)
for cls in Upscaler.__subclasses__(): for cls in Upscaler.__subclasses__():
if cls in forbidden_upscaler_classes:
continue
name = cls.__name__ name = cls.__name__
module_name = cls.__module__
module = importlib.import_module(module_name)
class_ = getattr(module, name)
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path" cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
opt_string = None scaler = cls(commandline_options.get(cmd_name, None))
try: datas += scaler.scalers
if cmd_name in c_o:
opt_string = c_o[cmd_name]
except:
pass
scaler = class_(opt_string)
for child in scaler.scalers:
datas.append(child)
shared.sd_upscalers = datas shared.sd_upscalers = datas
from pyngrok import ngrok, conf, exception from pyngrok import ngrok, conf, exception
def connect(token, port, region): def connect(token, port, region):
if token == None: account = None
if token is None:
token = 'None' token = 'None'
else:
if ':' in token:
# token = authtoken:username:password
account = token.split(':')[1] + ':' + token.split(':')[-1]
token = token.split(':')[0]
config = conf.PyngrokConfig( config = conf.PyngrokConfig(
auth_token=token, region=region auth_token=token, region=region
) )
try: try:
public_url = ngrok.connect(port, pyngrok_config=config).public_url if account is None:
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url
else:
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True, auth=account).public_url
except exception.PyngrokNgrokError: except exception.PyngrokNgrokError:
print(f'Invalid ngrok authtoken, ngrok connection aborted.\n' print(f'Invalid ngrok authtoken, ngrok connection aborted.\n'
f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken') f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')
......
...@@ -9,7 +9,7 @@ sys.path.insert(0, script_path) ...@@ -9,7 +9,7 @@ sys.path.insert(0, script_path)
# search for directory of stable diffusion in following places # search for directory of stable diffusion in following places
sd_path = None sd_path = None
possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), '.', os.path.dirname(script_path)] possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion-stability-ai'), '.', os.path.dirname(script_path)]
for possible_sd_path in possible_sd_paths: for possible_sd_path in possible_sd_paths:
if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')): if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
sd_path = os.path.abspath(possible_sd_path) sd_path = os.path.abspath(possible_sd_path)
......
...@@ -2,6 +2,7 @@ import json ...@@ -2,6 +2,7 @@ import json
import math import math
import os import os
import sys import sys
import warnings
import torch import torch
import numpy as np import numpy as np
...@@ -12,15 +13,21 @@ from skimage import exposure ...@@ -12,15 +13,21 @@ from skimage import exposure
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import modules.sd_hijack import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks
from modules.sd_hijack import model_hijack from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
import modules.face_restoration import modules.face_restoration
import modules.images as images import modules.images as images
import modules.styles import modules.styles
import modules.sd_models as sd_models
import modules.sd_vae as sd_vae
import logging import logging
from ldm.data.util import AddMiDaS
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
from einops import repeat, rearrange
from blendmodes.blend import blendLayers, BlendType
# some of those options should not be changed at all because they would break the model, so I removed them from options. # some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4 opt_C = 4
...@@ -33,17 +40,19 @@ def setup_color_correction(image): ...@@ -33,17 +40,19 @@ def setup_color_correction(image):
return correction_target return correction_target
def apply_color_correction(correction, image): def apply_color_correction(correction, original_image):
logging.info("Applying color correction.") logging.info("Applying color correction.")
image = Image.fromarray(cv2.cvtColor(exposure.match_histograms( image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
cv2.cvtColor( cv2.cvtColor(
np.asarray(image), np.asarray(original_image),
cv2.COLOR_RGB2LAB cv2.COLOR_RGB2LAB
), ),
correction, correction,
channel_axis=2 channel_axis=2
), cv2.COLOR_LAB2RGB).astype("uint8")) ), cv2.COLOR_LAB2RGB).astype("uint8"))
image = blendLayers(image, original_image, BlendType.LUMINOSITY)
return image return image
...@@ -66,19 +75,33 @@ def apply_overlay(image, paste_loc, index, overlays): ...@@ -66,19 +75,33 @@ def apply_overlay(image, paste_loc, index, overlays):
return image return image
def get_correct_sampler(p):
if isinstance(p, modules.processing.StableDiffusionProcessingTxt2Img): def txt2img_image_conditioning(sd_model, x, width, height):
return sd_samplers.samplers if sd_model.model.conditioning_key not in {'hybrid', 'concat'}:
elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img): # Dummy zero conditioning if we're not using inpainting model.
return sd_samplers.samplers_for_img2img # Still takes up a bit of memory, but no encoder call.
elif isinstance(p, modules.api.processing.StableDiffusionProcessingAPI): # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
return sd_samplers.samplers return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
# The "masked-image" in this case will just be all zeros since the entire image is masked.
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
# Add the fake full 1s mask to the first dimension.
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
image_conditioning = image_conditioning.to(x.dtype)
return image_conditioning
class StableDiffusionProcessing(): class StableDiffusionProcessing():
""" """
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
""" """
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_index: int = 0, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None): def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None):
if sampler_index is not None:
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
self.sd_model = sd_model self.sd_model = sd_model
self.outpath_samples: str = outpath_samples self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids self.outpath_grids: str = outpath_grids
...@@ -91,7 +114,7 @@ class StableDiffusionProcessing(): ...@@ -91,7 +114,7 @@ class StableDiffusionProcessing():
self.subseed_strength: float = subseed_strength self.subseed_strength: float = subseed_strength
self.seed_resize_from_h: int = seed_resize_from_h self.seed_resize_from_h: int = seed_resize_from_h
self.seed_resize_from_w: int = seed_resize_from_w self.seed_resize_from_w: int = seed_resize_from_w
self.sampler_index: int = sampler_index self.sampler_name: str = sampler_name
self.batch_size: int = batch_size self.batch_size: int = batch_size
self.n_iter: int = n_iter self.n_iter: int = n_iter
self.steps: int = steps self.steps: int = steps
...@@ -116,6 +139,8 @@ class StableDiffusionProcessing(): ...@@ -116,6 +139,8 @@ class StableDiffusionProcessing():
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
self.s_noise = s_noise or opts.s_noise self.s_noise = s_noise or opts.s_noise
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts} self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
self.override_settings_restore_afterwards = override_settings_restore_afterwards
self.is_using_inpainting_conditioning = False
if not seed_enable_extras: if not seed_enable_extras:
self.subseed = -1 self.subseed = -1
...@@ -126,33 +151,37 @@ class StableDiffusionProcessing(): ...@@ -126,33 +151,37 @@ class StableDiffusionProcessing():
self.scripts = None self.scripts = None
self.script_args = None self.script_args = None
self.all_prompts = None self.all_prompts = None
self.all_negative_prompts = None
self.all_seeds = None self.all_seeds = None
self.all_subseeds = None self.all_subseeds = None
self.iteration = 0
def txt2img_image_conditioning(self, x, width=None, height=None): def txt2img_image_conditioning(self, x, width=None, height=None):
if self.sampler.conditioning_key not in {'hybrid', 'concat'}: self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
# Dummy zero conditioning if we're not using inpainting model.
# Still takes up a bit of memory, but no encoder call. return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
return x.new_zeros(x.shape[0], 5, 1, 1) def depth2img_image_conditioning(self, source_image):
# Use the AddMiDaS helper to Format our source image to suit the MiDaS model
height = height or self.height transformer = AddMiDaS(model_type="dpt_hybrid")
width = width or self.width transformed = transformer({"jpg": rearrange(source_image[0], "c h w -> h w c")})
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
# The "masked-image" in this case will just be all zeros since the entire image is masked. midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning)) conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
conditioning = torch.nn.functional.interpolate(
# Add the fake full 1s mask to the first dimension. self.sd_model.depth_model(midas_in),
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) size=conditioning_image.shape[2:],
image_conditioning = image_conditioning.to(x.dtype) mode="bicubic",
align_corners=False,
)
return image_conditioning (depth_min, depth_max) = torch.aminmax(conditioning)
conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
return conditioning
def img2img_image_conditioning(self, source_image, latent_image, image_mask = None): def inpainting_image_conditioning(self, source_image, latent_image, image_mask = None):
if self.sampler.conditioning_key not in {'hybrid', 'concat'}: self.is_using_inpainting_conditioning = True
# Dummy zero conditioning if we're not using inpainting model.
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
# Handle the different mask inputs # Handle the different mask inputs
if image_mask is not None: if image_mask is not None:
...@@ -188,6 +217,18 @@ class StableDiffusionProcessing(): ...@@ -188,6 +217,18 @@ class StableDiffusionProcessing():
return image_conditioning return image_conditioning
def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
# identify itself with a field common to all models. The conditioning_key is also hybrid.
if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
return self.depth2img_image_conditioning(source_image)
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
# Dummy zero conditioning if we're not using inpainting or depth model.
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
def init(self, all_prompts, all_seeds, all_subseeds): def init(self, all_prompts, all_seeds, all_subseeds):
pass pass
...@@ -200,7 +241,7 @@ class StableDiffusionProcessing(): ...@@ -200,7 +241,7 @@ class StableDiffusionProcessing():
class Processed: class Processed:
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None): def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
self.images = images_list self.images = images_list
self.prompt = p.prompt self.prompt = p.prompt
self.negative_prompt = p.negative_prompt self.negative_prompt = p.negative_prompt
...@@ -208,10 +249,10 @@ class Processed: ...@@ -208,10 +249,10 @@ class Processed:
self.subseed = subseed self.subseed = subseed
self.subseed_strength = p.subseed_strength self.subseed_strength = p.subseed_strength
self.info = info self.info = info
self.comments = comments
self.width = p.width self.width = p.width
self.height = p.height self.height = p.height
self.sampler_index = p.sampler_index self.sampler_name = p.sampler_name
self.sampler = sd_samplers.samplers[p.sampler_index].name
self.cfg_scale = p.cfg_scale self.cfg_scale = p.cfg_scale
self.steps = p.steps self.steps = p.steps
self.batch_size = p.batch_size self.batch_size = p.batch_size
...@@ -238,17 +279,20 @@ class Processed: ...@@ -238,17 +279,20 @@ class Processed:
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0] self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1 self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1 self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
self.all_prompts = all_prompts or [self.prompt] self.all_prompts = all_prompts or p.all_prompts or [self.prompt]
self.all_seeds = all_seeds or [self.seed] self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt]
self.all_subseeds = all_subseeds or [self.subseed] self.all_seeds = all_seeds or p.all_seeds or [self.seed]
self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
self.infotexts = infotexts or [info] self.infotexts = infotexts or [info]
def js(self): def js(self):
obj = { obj = {
"prompt": self.prompt, "prompt": self.all_prompts[0],
"all_prompts": self.all_prompts, "all_prompts": self.all_prompts,
"negative_prompt": self.negative_prompt, "negative_prompt": self.all_negative_prompts[0],
"all_negative_prompts": self.all_negative_prompts,
"seed": self.seed, "seed": self.seed,
"all_seeds": self.all_seeds, "all_seeds": self.all_seeds,
"subseed": self.subseed, "subseed": self.subseed,
...@@ -256,8 +300,7 @@ class Processed: ...@@ -256,8 +300,7 @@ class Processed:
"subseed_strength": self.subseed_strength, "subseed_strength": self.subseed_strength,
"width": self.width, "width": self.width,
"height": self.height, "height": self.height,
"sampler_index": self.sampler_index, "sampler_name": self.sampler_name,
"sampler": self.sampler,
"cfg_scale": self.cfg_scale, "cfg_scale": self.cfg_scale,
"steps": self.steps, "steps": self.steps,
"batch_size": self.batch_size, "batch_size": self.batch_size,
...@@ -273,6 +316,7 @@ class Processed: ...@@ -273,6 +316,7 @@ class Processed:
"styles": self.styles, "styles": self.styles,
"job_timestamp": self.job_timestamp, "job_timestamp": self.job_timestamp,
"clip_skip": self.clip_skip, "clip_skip": self.clip_skip,
"is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
} }
return json.dumps(obj) return json.dumps(obj)
...@@ -297,13 +341,14 @@ def slerp(val, low, high): ...@@ -297,13 +341,14 @@ def slerp(val, low, high):
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None): def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
eta_noise_seed_delta = opts.eta_noise_seed_delta or 0
xs = [] xs = []
# if we have multiple seeds, this means we are working with batch size>1; this then # if we have multiple seeds, this means we are working with batch size>1; this then
# enables the generation of additional tensors with noise that the sampler will use during its processing. # enables the generation of additional tensors with noise that the sampler will use during its processing.
# Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
# produce the same images as with two batches [100], [101]. # produce the same images as with two batches [100], [101].
if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0): if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or eta_noise_seed_delta > 0):
sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))] sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
else: else:
sampler_noises = None sampler_noises = None
...@@ -343,8 +388,8 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see ...@@ -343,8 +388,8 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
if sampler_noises is not None: if sampler_noises is not None:
cnt = p.sampler.number_of_needed_noises(p) cnt = p.sampler.number_of_needed_noises(p)
if opts.eta_noise_seed_delta > 0: if eta_noise_seed_delta > 0:
torch.manual_seed(seed + opts.eta_noise_seed_delta) torch.manual_seed(seed + eta_noise_seed_delta)
for j in range(cnt): for j in range(cnt):
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape))) sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
...@@ -377,14 +422,14 @@ def fix_seed(p): ...@@ -377,14 +422,14 @@ def fix_seed(p):
p.subseed = get_fixed_seed(p.subseed) p.subseed = get_fixed_seed(p.subseed)
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0): def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
index = position_in_batch + iteration * p.batch_size index = position_in_batch + iteration * p.batch_size
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers) clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
generation_params = { generation_params = {
"Steps": p.steps, "Steps": p.steps,
"Sampler": get_correct_sampler(p)[p.sampler_index].name, "Sampler": p.sampler_name,
"CFG scale": p.cfg_scale, "CFG scale": p.cfg_scale,
"Seed": all_seeds[index], "Seed": all_seeds[index],
"Face restoration": (opts.face_restoration_model if p.restore_faces else None), "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
...@@ -392,6 +437,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration ...@@ -392,6 +437,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name), "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
"Hypernet hash": (None if shared.loaded_hypernetwork is None else sd_models.model_hash(shared.loaded_hypernetwork.filename)),
"Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength), "Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
"Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch), "Batch pos": (None if p.batch_size < 2 else position_in_batch),
...@@ -399,6 +445,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration ...@@ -399,6 +445,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None), "Denoising strength": getattr(p, 'denoising_strength', None),
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
"Clip skip": None if clip_skip <= 1 else clip_skip, "Clip skip": None if clip_skip <= 1 else clip_skip,
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta, "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
...@@ -408,7 +455,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration ...@@ -408,7 +455,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None]) generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else "" negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[index] if p.all_negative_prompts[index] else ""
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
...@@ -418,13 +465,21 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -418,13 +465,21 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
try: try:
for k, v in p.override_settings.items(): for k, v in p.override_settings.items():
setattr(opts, k, v) # we don't call onchange for simplicity which makes changing model, hypernet impossible setattr(opts, k, v)
if k == 'sd_hypernetwork': shared.reload_hypernetworks() # make onchange call for changing hypernet
if k == 'sd_model_checkpoint': sd_models.reload_model_weights() # make onchange call for changing SD model
if k == 'sd_vae': sd_vae.reload_vae_weights() # make onchange call for changing VAE
res = process_images_inner(p) res = process_images_inner(p)
finally: finally:
# restore opts to original state
if p.override_settings_restore_afterwards:
for k, v in stored_opts.items(): for k, v in stored_opts.items():
setattr(opts, k, v) setattr(opts, k, v)
if k == 'sd_hypernetwork': shared.reload_hypernetworks()
if k == 'sd_model_checkpoint': sd_models.reload_model_weights()
if k == 'sd_vae': sd_vae.reload_vae_weights()
return res return res
...@@ -437,10 +492,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -437,10 +492,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
else: else:
assert p.prompt is not None assert p.prompt is not None
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
processed = Processed(p, [], p.seed, "")
file.write(processed.infotext(p, 0))
devices.torch_gc() devices.torch_gc()
seed = get_fixed_seed(p.seed) seed = get_fixed_seed(p.seed)
...@@ -451,12 +502,15 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -451,12 +502,15 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
comments = {} comments = {}
shared.prompt_styles.apply_styles(p)
if type(p.prompt) == list: if type(p.prompt) == list:
p.all_prompts = p.prompt p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.prompt]
else: else:
p.all_prompts = p.batch_size * p.n_iter * [p.prompt] p.all_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)]
if type(p.negative_prompt) == list:
p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in p.negative_prompt]
else:
p.all_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)]
if type(seed) == list: if type(seed) == list:
p.all_seeds = seed p.all_seeds = seed
...@@ -471,6 +525,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -471,6 +525,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
def infotext(iteration=0, position_in_batch=0): def infotext(iteration=0, position_in_batch=0):
return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch) return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
processed = Processed(p, [], p.seed, "")
file.write(processed.infotext(p, 0))
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings: if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings() model_hijack.embedding_db.load_textual_inversion_embeddings()
...@@ -488,6 +546,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -488,6 +546,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
state.job_count = p.n_iter state.job_count = p.n_iter
for n in range(p.n_iter): for n in range(p.n_iter):
p.iteration = n
if state.skipped: if state.skipped:
state.skipped = False state.skipped = False
...@@ -495,6 +555,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -495,6 +555,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
break break
prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size] prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size] seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
...@@ -505,7 +566,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -505,7 +566,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds) p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
with devices.autocast(): with devices.autocast():
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps) uc = prompt_parser.get_learned_conditioning(shared.sd_model, negative_prompts, p.steps)
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps) c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
if len(model_hijack.comments) > 0: if len(model_hijack.comments) > 0:
...@@ -518,8 +579,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -518,8 +579,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
with devices.autocast(): with devices.autocast():
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts) samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
samples_ddim = samples_ddim.to(devices.dtype_vae) x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim) x_samples_ddim = torch.stack(x_samples_ddim).float()
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
del samples_ddim del samples_ddim
...@@ -529,9 +590,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -529,9 +590,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
devices.torch_gc() devices.torch_gc()
if opts.filter_nsfw: if p.scripts is not None:
import modules.safety as safety p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
for i, x_sample in enumerate(x_samples_ddim): for i, x_sample in enumerate(x_samples_ddim):
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
...@@ -591,7 +651,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -591,7 +651,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
devices.torch_gc() devices.torch_gc()
res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts) res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
if p.scripts is not None: if p.scripts is not None:
p.scripts.postprocess(p, res) p.scripts.postprocess(p, res)
...@@ -602,14 +662,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -602,14 +662,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sampler = None sampler = None
def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs): def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.enable_hr = enable_hr self.enable_hr = enable_hr
self.denoising_strength = denoising_strength self.denoising_strength = denoising_strength
self.firstphase_width = firstphase_width self.hr_scale = hr_scale
self.firstphase_height = firstphase_height self.hr_upscaler = hr_upscaler
self.truncate_x = 0
self.truncate_y = 0 if firstphase_width != 0 or firstphase_height != 0:
print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr)
self.hr_scale = self.width / firstphase_width
self.width = firstphase_width
self.height = firstphase_height
def init(self, all_prompts, all_seeds, all_subseeds): def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr: if self.enable_hr:
...@@ -618,60 +682,43 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -618,60 +682,43 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
else: else:
state.job_count = state.job_count * 2 state.job_count = state.job_count * 2
self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}" self.extra_generation_params["Hires upscale"] = self.hr_scale
if self.hr_upscaler is not None:
if self.firstphase_width == 0 or self.firstphase_height == 0: self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
desired_pixel_count = 512 * 512
actual_pixel_count = self.width * self.height
scale = math.sqrt(desired_pixel_count / actual_pixel_count)
self.firstphase_width = math.ceil(scale * self.width / 64) * 64
self.firstphase_height = math.ceil(scale * self.height / 64) * 64
firstphase_width_truncated = int(scale * self.width)
firstphase_height_truncated = int(scale * self.height)
else:
width_ratio = self.width / self.firstphase_width
height_ratio = self.height / self.firstphase_height
if width_ratio > height_ratio:
firstphase_width_truncated = self.firstphase_width
firstphase_height_truncated = self.firstphase_width * self.height / self.width
else:
firstphase_width_truncated = self.firstphase_height * self.width / self.height
firstphase_height_truncated = self.firstphase_height
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
if self.enable_hr and latent_scale_mode is None:
assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}"
if not self.enable_hr:
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
return samples
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) if not self.enable_hr:
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x, self.firstphase_width, self.firstphase_height)) return samples
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2] target_width = int(self.width * self.hr_scale)
target_height = int(self.height * self.hr_scale)
"""saves image before applying hires fix, if enabled in options; takes as an arguyment either an image or batch with latent space images"""
def save_intermediate(image, index): def save_intermediate(image, index):
"""saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix: if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
return return
if not isinstance(image, Image.Image): if not isinstance(image, Image.Image):
image = sd_samplers.sample_to_image(image, index) image = sd_samplers.sample_to_image(image, index, approximation=0)
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix") info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, suffix="-before-highres-fix")
if opts.use_scale_latent_for_hires_fix: if latent_scale_mode is not None:
for i in range(samples.shape[0]): for i in range(samples.shape[0]):
save_intermediate(samples, i) save_intermediate(samples, i)
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"])
# Avoid making the inpainting conditioning unless necessary as # Avoid making the inpainting conditioning unless necessary as
# this does need some extra compute to decode / encode the image again. # this does need some extra compute to decode / encode the image again.
...@@ -691,7 +738,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -691,7 +738,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
save_intermediate(image, i) save_intermediate(image, i)
image = images.resize_image(0, image, self.width, self.height) image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler)
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0) image = np.moveaxis(image, 2, 0)
batch_images.append(image) batch_images.append(image)
...@@ -706,9 +753,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -706,9 +753,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
shared.state.nextjob() shared.state.nextjob()
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
# GC now before running the next img2img to prevent running out of memory # GC now before running the next img2img to prevent running out of memory
x = None x = None
...@@ -722,7 +769,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -722,7 +769,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sampler = None sampler = None
def __init__(self, init_images: list=None, resize_mode: int=0, denoising_strength: float=0.75, mask: Any=None, mask_blur: int=4, inpainting_fill: int=0, inpaint_full_res: bool=True, inpaint_full_res_padding: int=0, inpainting_mask_invert: int=0, **kwargs): def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.init_images = init_images self.init_images = init_images
...@@ -730,7 +777,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): ...@@ -730,7 +777,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.denoising_strength: float = denoising_strength self.denoising_strength: float = denoising_strength
self.init_latent = None self.init_latent = None
self.image_mask = mask self.image_mask = mask
#self.image_unblurred_mask = None
self.latent_mask = None self.latent_mask = None
self.mask_for_overlay = None self.mask_for_overlay = None
self.mask_blur = mask_blur self.mask_blur = mask_blur
...@@ -738,66 +784,68 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): ...@@ -738,66 +784,68 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.inpaint_full_res = inpaint_full_res self.inpaint_full_res = inpaint_full_res
self.inpaint_full_res_padding = inpaint_full_res_padding self.inpaint_full_res_padding = inpaint_full_res_padding
self.inpainting_mask_invert = inpainting_mask_invert self.inpainting_mask_invert = inpainting_mask_invert
self.initial_noise_multiplier = opts.initial_noise_multiplier if initial_noise_multiplier is None else initial_noise_multiplier
self.mask = None self.mask = None
self.nmask = None self.nmask = None
self.image_conditioning = None self.image_conditioning = None
def init(self, all_prompts, all_seeds, all_subseeds): def init(self, all_prompts, all_seeds, all_subseeds):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model) self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
crop_region = None crop_region = None
if self.image_mask is not None: image_mask = self.image_mask
self.image_mask = self.image_mask.convert('L')
if self.inpainting_mask_invert: if image_mask is not None:
self.image_mask = ImageOps.invert(self.image_mask) image_mask = image_mask.convert('L')
#self.image_unblurred_mask = self.image_mask if self.inpainting_mask_invert:
image_mask = ImageOps.invert(image_mask)
if self.mask_blur > 0: if self.mask_blur > 0:
self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur)) image_mask = image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
if self.inpaint_full_res: if self.inpaint_full_res:
self.mask_for_overlay = self.image_mask self.mask_for_overlay = image_mask
mask = self.image_mask.convert('L') mask = image_mask.convert('L')
crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding) crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
x1, y1, x2, y2 = crop_region x1, y1, x2, y2 = crop_region
mask = mask.crop(crop_region) mask = mask.crop(crop_region)
self.image_mask = images.resize_image(2, mask, self.width, self.height) image_mask = images.resize_image(2, mask, self.width, self.height)
self.paste_to = (x1, y1, x2-x1, y2-y1) self.paste_to = (x1, y1, x2-x1, y2-y1)
else: else:
self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height) image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
np_mask = np.array(self.image_mask) np_mask = np.array(image_mask)
np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8) np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
self.mask_for_overlay = Image.fromarray(np_mask) self.mask_for_overlay = Image.fromarray(np_mask)
self.overlay_images = [] self.overlay_images = []
latent_mask = self.latent_mask if self.latent_mask is not None else self.image_mask latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
add_color_corrections = opts.img2img_color_correction and self.color_corrections is None add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
if add_color_corrections: if add_color_corrections:
self.color_corrections = [] self.color_corrections = []
imgs = [] imgs = []
for img in self.init_images: for img in self.init_images:
image = img.convert("RGB") image = images.flatten(img, opts.img2img_background_color)
if crop_region is None: if crop_region is None and self.resize_mode != 3:
image = images.resize_image(self.resize_mode, image, self.width, self.height) image = images.resize_image(self.resize_mode, image, self.width, self.height)
if self.image_mask is not None: if image_mask is not None:
image_masked = Image.new('RGBa', (image.width, image.height)) image_masked = Image.new('RGBa', (image.width, image.height))
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
self.overlay_images.append(image_masked.convert('RGBA')) self.overlay_images.append(image_masked.convert('RGBA'))
# crop_region is not None if we are doing inpaint full res
if crop_region is not None: if crop_region is not None:
image = image.crop(crop_region) image = image.crop(crop_region)
image = images.resize_image(2, image, self.width, self.height) image = images.resize_image(2, image, self.width, self.height)
if self.image_mask is not None: if image_mask is not None:
if self.inpainting_fill != 1: if self.inpainting_fill != 1:
image = masking.fill(image, latent_mask) image = masking.fill(image, latent_mask)
...@@ -829,7 +877,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): ...@@ -829,7 +877,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image)) self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
if self.image_mask is not None: if self.resize_mode == 3:
self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
if image_mask is not None:
init_mask = latent_mask init_mask = latent_mask
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2])) latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255 latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
...@@ -846,11 +897,15 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): ...@@ -846,11 +897,15 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
elif self.inpainting_fill == 3: elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask self.init_latent = self.init_latent * self.mask
self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask) self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, image_mask)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
if self.initial_noise_multiplier != 1.0:
self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
x *= self.initial_noise_multiplier
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
if self.mask is not None: if self.mask is not None:
......
...@@ -37,16 +37,16 @@ class RestrictedUnpickler(pickle.Unpickler): ...@@ -37,16 +37,16 @@ class RestrictedUnpickler(pickle.Unpickler):
if module == 'collections' and name == 'OrderedDict': if module == 'collections' and name == 'OrderedDict':
return getattr(collections, name) return getattr(collections, name)
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']: if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
return getattr(torch._utils, name) return getattr(torch._utils, name)
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage']: if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']:
return getattr(torch, name) return getattr(torch, name)
if module == 'torch.nn.modules.container' and name in ['ParameterDict']: if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
return getattr(torch.nn.modules.container, name) return getattr(torch.nn.modules.container, name)
if module == 'numpy.core.multiarray' and name == 'scalar': if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
return numpy.core.multiarray.scalar return getattr(numpy.core.multiarray, name)
if module == 'numpy' and name == 'dtype': if module == 'numpy' and name in ['dtype', 'ndarray']:
return numpy.dtype return getattr(numpy, name)
if module == '_codecs' and name == 'encode': if module == '_codecs' and name == 'encode':
return encode return encode
if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint': if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
...@@ -62,14 +62,12 @@ class RestrictedUnpickler(pickle.Unpickler): ...@@ -62,14 +62,12 @@ class RestrictedUnpickler(pickle.Unpickler):
raise Exception(f"global '{module}/{name}' is forbidden") raise Exception(f"global '{module}/{name}' is forbidden")
allowed_zip_names = ["archive/data.pkl", "archive/version"] # Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
allowed_zip_names_re = re.compile(r"^archive/data/\d+$") allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
def check_zip_filenames(filename, names): def check_zip_filenames(filename, names):
for name in names: for name in names:
if name in allowed_zip_names:
continue
if allowed_zip_names_re.match(name): if allowed_zip_names_re.match(name):
continue continue
...@@ -83,7 +81,13 @@ def check_pt(filename, extra_handler): ...@@ -83,7 +81,13 @@ def check_pt(filename, extra_handler):
with zipfile.ZipFile(filename) as z: with zipfile.ZipFile(filename) as z:
check_zip_filenames(filename, z.namelist()) check_zip_filenames(filename, z.namelist())
with z.open('archive/data.pkl') as file: # find filename of data.pkl in zip file: '<directory name>/data.pkl'
data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
if len(data_pkl_filenames) == 0:
raise Exception(f"data.pkl not found in {filename}")
if len(data_pkl_filenames) > 1:
raise Exception(f"Multiple data.pkl found in {filename}")
with z.open(data_pkl_filenames[0]) as file:
unpickler = RestrictedUnpickler(file) unpickler = RestrictedUnpickler(file)
unpickler.extra_handler = extra_handler unpickler.extra_handler = extra_handler
unpickler.load() unpickler.load()
...@@ -99,12 +103,12 @@ def check_pt(filename, extra_handler): ...@@ -99,12 +103,12 @@ def check_pt(filename, extra_handler):
def load(filename, *args, **kwargs): def load(filename, *args, **kwargs):
return load_with_extra(filename, *args, **kwargs) return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
def load_with_extra(filename, extra_handler=None, *args, **kwargs): def load_with_extra(filename, extra_handler=None, *args, **kwargs):
""" """
this functon is intended to be used by extensions that want to load models with this function is intended to be used by extensions that want to load models with
some extra classes in them that the usual unpickler would find suspicious. some extra classes in them that the usual unpickler would find suspicious.
Use the extra_handler argument to specify a function that takes module and field name as text, Use the extra_handler argument to specify a function that takes module and field name as text,
...@@ -133,19 +137,56 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs): ...@@ -133,19 +137,56 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs):
except pickle.UnpicklingError: except pickle.UnpicklingError:
print(f"Error verifying pickled file from {filename}:", file=sys.stderr) print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
print(f"-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr) print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
print(f"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr) print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
return None return None
except Exception: except Exception:
print(f"Error verifying pickled file from {filename}:", file=sys.stderr) print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr) print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
print(f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr) print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
return None return None
return unsafe_torch_load(filename, *args, **kwargs) return unsafe_torch_load(filename, *args, **kwargs)
class Extra:
"""
A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
(because it's not your code making the torch.load call). The intended use is like this:
```
import torch
from modules import safe
def handler(module, name):
if module == 'torch' and name in ['float64', 'float16']:
return getattr(torch, name)
return None
with safe.Extra(handler):
x = torch.load('model.pt')
```
"""
def __init__(self, handler):
self.handler = handler
def __enter__(self):
global global_extra_handler
assert global_extra_handler is None, 'already inside an Extra() block'
global_extra_handler = self.handler
def __exit__(self, exc_type, exc_val, exc_tb):
global global_extra_handler
global_extra_handler = None
unsafe_torch_load = torch.load unsafe_torch_load = torch.load
torch.load = load torch.load = load
global_extra_handler = None
import torch
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
from PIL import Image
import modules.shared as shared
safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_feature_extractor = None
safety_checker = None
def numpy_to_pil(images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images
# check and replace nsfw content
def check_safety(x_image):
global safety_feature_extractor, safety_checker
if safety_feature_extractor is None:
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
return x_checked_image, has_nsfw_concept
def censor_batch(x):
x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy()
x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy)
x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
return x
...@@ -51,6 +51,13 @@ class UiTrainTabParams: ...@@ -51,6 +51,13 @@ class UiTrainTabParams:
self.txt2img_preview_params = txt2img_preview_params self.txt2img_preview_params = txt2img_preview_params
class ImageGridLoopParams:
def __init__(self, imgs, cols, rows):
self.imgs = imgs
self.cols = cols
self.rows = rows
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
callback_map = dict( callback_map = dict(
callbacks_app_started=[], callbacks_app_started=[],
...@@ -61,6 +68,9 @@ callback_map = dict( ...@@ -61,6 +68,9 @@ callback_map = dict(
callbacks_before_image_saved=[], callbacks_before_image_saved=[],
callbacks_image_saved=[], callbacks_image_saved=[],
callbacks_cfg_denoiser=[], callbacks_cfg_denoiser=[],
callbacks_before_component=[],
callbacks_after_component=[],
callbacks_image_grid=[],
) )
...@@ -137,6 +147,30 @@ def cfg_denoiser_callback(params: CFGDenoiserParams): ...@@ -137,6 +147,30 @@ def cfg_denoiser_callback(params: CFGDenoiserParams):
report_exception(c, 'cfg_denoiser_callback') report_exception(c, 'cfg_denoiser_callback')
def before_component_callback(component, **kwargs):
for c in callback_map['callbacks_before_component']:
try:
c.callback(component, **kwargs)
except Exception:
report_exception(c, 'before_component_callback')
def after_component_callback(component, **kwargs):
for c in callback_map['callbacks_after_component']:
try:
c.callback(component, **kwargs)
except Exception:
report_exception(c, 'after_component_callback')
def image_grid_callback(params: ImageGridLoopParams):
for c in callback_map['callbacks_image_grid']:
try:
c.callback(params)
except Exception:
report_exception(c, 'image_grid')
def add_callback(callbacks, fun): def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__] stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file' filename = stack[0].filename if len(stack) > 0 else 'unknown file'
...@@ -220,3 +254,28 @@ def on_cfg_denoiser(callback): ...@@ -220,3 +254,28 @@ def on_cfg_denoiser(callback):
- params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details. - params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
""" """
add_callback(callback_map['callbacks_cfg_denoiser'], callback) add_callback(callback_map['callbacks_cfg_denoiser'], callback)
def on_before_component(callback):
"""register a function to be called before a component is created.
The callback is called with arguments:
- component - gradio component that is about to be created.
- **kwargs - args to gradio.components.IOComponent.__init__ function
Use elem_id/label fields of kwargs to figure out which component it is.
This can be useful to inject your own components somewhere in the middle of vanilla UI.
"""
add_callback(callback_map['callbacks_before_component'], callback)
def on_after_component(callback):
"""register a function to be called after a component is created. See on_before_component for more."""
add_callback(callback_map['callbacks_after_component'], callback)
def on_image_grid(callback):
"""register a function to be called before making an image grid.
The callback is called with one argument:
- params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
"""
add_callback(callback_map['callbacks_image_grid'], callback)
import os
import sys
import traceback
from types import ModuleType
def load_module(path):
with open(path, "r", encoding="utf8") as file:
text = file.read()
compiled = compile(text, path, 'exec')
module = ModuleType(os.path.basename(path))
exec(compiled, module.__dict__)
return module
def preload_extensions(extensions_dir, parser):
if not os.path.isdir(extensions_dir):
return
for dirname in sorted(os.listdir(extensions_dir)):
preload_script = os.path.join(extensions_dir, dirname, "preload.py")
if not os.path.isfile(preload_script):
continue
try:
module = load_module(preload_script)
if hasattr(module, 'preload'):
module.preload(parser)
except Exception:
print(f"Error running preload() for {preload_script}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
...@@ -6,7 +6,7 @@ from collections import namedtuple ...@@ -6,7 +6,7 @@ from collections import namedtuple
import gradio as gr import gradio as gr
from modules.processing import StableDiffusionProcessing from modules.processing import StableDiffusionProcessing
from modules import shared, paths, script_callbacks, extensions from modules import shared, paths, script_callbacks, extensions, script_loading
AlwaysVisible = object() AlwaysVisible = object()
...@@ -17,6 +17,9 @@ class Script: ...@@ -17,6 +17,9 @@ class Script:
args_to = None args_to = None
alwayson = False alwayson = False
is_txt2img = False
is_img2img = False
"""A gr.Group component that has all script's UI inside it""" """A gr.Group component that has all script's UI inside it"""
group = None group = None
...@@ -33,7 +36,7 @@ class Script: ...@@ -33,7 +36,7 @@ class Script:
def ui(self, is_img2img): def ui(self, is_img2img):
"""this function should create gradio UI elements. See https://gradio.app/docs/#components """this function should create gradio UI elements. See https://gradio.app/docs/#components
The return value should be an array of all components that are used in processing. The return value should be an array of all components that are used in processing.
Values of those returned componenbts will be passed to run() and process() functions. Values of those returned components will be passed to run() and process() functions.
""" """
pass pass
...@@ -44,7 +47,7 @@ class Script: ...@@ -44,7 +47,7 @@ class Script:
This function should return: This function should return:
- False if the script should not be shown in UI at all - False if the script should not be shown in UI at all
- True if the script should be shown in UI if it's scelected in the scripts drowpdown - True if the script should be shown in UI if it's selected in the scripts dropdown
- script.AlwaysVisible if the script should be shown in UI at all times - script.AlwaysVisible if the script should be shown in UI at all times
""" """
...@@ -85,6 +88,17 @@ class Script: ...@@ -85,6 +88,17 @@ class Script:
pass pass
def postprocess_batch(self, p, *args, **kwargs):
"""
Same as process_batch(), but called for every batch after it has been generated.
**kwargs will have same items as process_batch, and also:
- batch_number - index of current batch, from 0 to number of batches-1
- images - torch tensor with all generated images, with values ranging from 0 to 1;
"""
pass
def postprocess(self, p, processed, *args): def postprocess(self, p, processed, *args):
""" """
This function is called after processing ends for AlwaysVisible scripts. This function is called after processing ends for AlwaysVisible scripts.
...@@ -93,6 +107,23 @@ class Script: ...@@ -93,6 +107,23 @@ class Script:
pass pass
def before_component(self, component, **kwargs):
"""
Called before a component is created.
Use elem_id/label fields of kwargs to figure out which component it is.
This can be useful to inject your own components somewhere in the middle of vanilla UI.
You can return created components in the ui() function to add them to the list of arguments for your processing functions
"""
pass
def after_component(self, component, **kwargs):
"""
Called after a component is created. Same as above.
"""
pass
def describe(self): def describe(self):
"""unused""" """unused"""
return "" return ""
...@@ -140,7 +171,7 @@ def list_files_with_name(filename): ...@@ -140,7 +171,7 @@ def list_files_with_name(filename):
continue continue
path = os.path.join(dirpath, filename) path = os.path.join(dirpath, filename)
if os.path.isfile(filename): if os.path.isfile(path):
res.append(path) res.append(path)
return res return res
...@@ -161,13 +192,7 @@ def load_scripts(): ...@@ -161,13 +192,7 @@ def load_scripts():
sys.path = [scriptfile.basedir] + sys.path sys.path = [scriptfile.basedir] + sys.path
current_basedir = scriptfile.basedir current_basedir = scriptfile.basedir
with open(scriptfile.path, "r", encoding="utf8") as file: module = script_loading.load_module(scriptfile.path)
text = file.read()
from types import ModuleType
compiled = compile(text, scriptfile.path, 'exec')
module = ModuleType(scriptfile.filename)
exec(compiled, module.__dict__)
for key, script_class in module.__dict__.items(): for key, script_class in module.__dict__.items():
if type(script_class) == type and issubclass(script_class, Script): if type(script_class) == type and issubclass(script_class, Script):
...@@ -201,12 +226,18 @@ class ScriptRunner: ...@@ -201,12 +226,18 @@ class ScriptRunner:
self.titles = [] self.titles = []
self.infotext_fields = [] self.infotext_fields = []
def setup_ui(self, is_img2img): def initialize_scripts(self, is_img2img):
self.scripts.clear()
self.alwayson_scripts.clear()
self.selectable_scripts.clear()
for script_class, path, basedir in scripts_data: for script_class, path, basedir in scripts_data:
script = script_class() script = script_class()
script.filename = path script.filename = path
script.is_txt2img = not is_img2img
script.is_img2img = is_img2img
visibility = script.show(is_img2img) visibility = script.show(script.is_img2img)
if visibility == AlwaysVisible: if visibility == AlwaysVisible:
self.scripts.append(script) self.scripts.append(script)
...@@ -217,6 +248,7 @@ class ScriptRunner: ...@@ -217,6 +248,7 @@ class ScriptRunner:
self.scripts.append(script) self.scripts.append(script)
self.selectable_scripts.append(script) self.selectable_scripts.append(script)
def setup_ui(self):
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts] self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
inputs = [None] inputs = [None]
...@@ -226,7 +258,7 @@ class ScriptRunner: ...@@ -226,7 +258,7 @@ class ScriptRunner:
script.args_from = len(inputs) script.args_from = len(inputs)
script.args_to = len(inputs) script.args_to = len(inputs)
controls = wrap_call(script.ui, script.filename, "ui", is_img2img) controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
if controls is None: if controls is None:
return return
...@@ -326,21 +358,40 @@ class ScriptRunner: ...@@ -326,21 +358,40 @@ class ScriptRunner:
print(f"Error running postprocess: {script.filename}", file=sys.stderr) print(f"Error running postprocess: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
def postprocess_batch(self, p, images, **kwargs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_batch(p, *script_args, images=images, **kwargs)
except Exception:
print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def before_component(self, component, **kwargs):
for script in self.scripts:
try:
script.before_component(component, **kwargs)
except Exception:
print(f"Error running before_component: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def after_component(self, component, **kwargs):
for script in self.scripts:
try:
script.after_component(component, **kwargs)
except Exception:
print(f"Error running after_component: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def reload_sources(self, cache): def reload_sources(self, cache):
for si, script in list(enumerate(self.scripts)): for si, script in list(enumerate(self.scripts)):
with open(script.filename, "r", encoding="utf8") as file:
args_from = script.args_from args_from = script.args_from
args_to = script.args_to args_to = script.args_to
filename = script.filename filename = script.filename
text = file.read()
from types import ModuleType
module = cache.get(filename, None) module = cache.get(filename, None)
if module is None: if module is None:
compiled = compile(text, filename, 'exec') module = script_loading.load_module(script.filename)
module = ModuleType(script.filename)
exec(compiled, module.__dict__)
cache[filename] = module cache[filename] = module
for key, script_class in module.__dict__.items(): for key, script_class in module.__dict__.items():
...@@ -353,6 +404,7 @@ class ScriptRunner: ...@@ -353,6 +404,7 @@ class ScriptRunner:
scripts_txt2img = ScriptRunner() scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner() scripts_img2img = ScriptRunner()
scripts_current: ScriptRunner = None
def reload_script_body_only(): def reload_script_body_only():
...@@ -369,3 +421,22 @@ def reload_scripts(): ...@@ -369,3 +421,22 @@ def reload_scripts():
scripts_txt2img = ScriptRunner() scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner() scripts_img2img = ScriptRunner()
def IOComponent_init(self, *args, **kwargs):
if scripts_current is not None:
scripts_current.before_component(self, **kwargs)
script_callbacks.before_component_callback(self, **kwargs)
res = original_IOComponent_init(self, *args, **kwargs)
script_callbacks.after_component_callback(self, **kwargs)
if scripts_current is not None:
scripts_current.after_component(self, **kwargs)
return res
original_IOComponent_init = gr.components.IOComponent.__init__
gr.components.IOComponent.__init__ = IOComponent_init
import math
import os
import sys
import traceback
import torch import torch
import numpy as np
from torch import einsum
from torch.nn.functional import silu from torch.nn.functional import silu
import modules.textual_inversion.textual_inversion import modules.textual_inversion.textual_inversion
from modules import prompt_parser, devices, sd_hijack_optimizations, shared from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
from modules.shared import opts, device, cmd_opts from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
from modules.sd_hijack_optimizations import invokeAI_mps_available from modules.sd_hijack_optimizations import invokeAI_mps_available
import ldm.modules.attention import ldm.modules.attention
import ldm.modules.diffusionmodules.model import ldm.modules.diffusionmodules.model
import ldm.modules.diffusionmodules.openaimodel
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
import ldm.modules.encoders.modules
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
# new memory efficient cross attention blocks do not support hypernets and we already
# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
# silence new console spam from SD2
ldm.modules.attention.print = lambda *args: None
ldm.modules.diffusionmodules.model.print = lambda *args: None
def apply_optimizations(): def apply_optimizations():
undo_optimizations() undo_optimizations()
ldm.modules.diffusionmodules.model.nonlinearity = silu ldm.modules.diffusionmodules.model.nonlinearity = silu
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
optimization_method = None
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
print("Applying xformers cross attention optimization.") print("Applying xformers cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
optimization_method = 'xformers'
elif cmd_opts.opt_split_attention_v1: elif cmd_opts.opt_split_attention_v1:
print("Applying v1 cross attention optimization.") print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
optimization_method = 'V1'
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()): elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
if not invokeAI_mps_available and shared.device.type == 'mps': if not invokeAI_mps_available and shared.device.type == 'mps':
print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.") print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
print("Applying v1 cross attention optimization.") print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
optimization_method = 'V1'
else: else:
print("Applying cross attention optimization (InvokeAI).") print("Applying cross attention optimization (InvokeAI).")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
optimization_method = 'InvokeAI'
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
print("Applying cross attention optimization (Doggettx).") print("Applying cross attention optimization (Doggettx).")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
optimization_method = 'Doggettx'
return optimization_method
def undo_optimizations():
from modules.hypernetworks import hypernetwork
def undo_optimizations():
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
def get_target_prompt_token_count(token_count): def fix_checkpoint():
return math.ceil(max(token_count, 1) / 75) * 75 ldm.modules.attention.BasicTransformerBlock.forward = sd_hijack_checkpoint.BasicTransformerBlock_forward
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
class StableDiffusionModelHijack: class StableDiffusionModelHijack:
...@@ -64,18 +84,31 @@ class StableDiffusionModelHijack: ...@@ -64,18 +84,31 @@ class StableDiffusionModelHijack:
layers = None layers = None
circular_enabled = False circular_enabled = False
clip = None clip = None
optimization_method = None
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir) embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
def hijack(self, m): def hijack(self, m):
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
model_embeddings = m.cond_stage_model.roberta.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
self.optimization_method = apply_optimizations()
self.clip = m.cond_stage_model self.clip = m.cond_stage_model
apply_optimizations() fix_checkpoint()
def flatten(el): def flatten(el):
flattened = [flatten(children) for children in el.children()] flattened = [flatten(children) for children in el.children()]
...@@ -87,15 +120,22 @@ class StableDiffusionModelHijack: ...@@ -87,15 +120,22 @@ class StableDiffusionModelHijack:
self.layers = flatten(m) self.layers = flatten(m)
def undo_hijack(self, m): def undo_hijack(self, m):
if type(m.cond_stage_model) == FrozenCLIPEmbedderWithCustomWords:
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
m.cond_stage_model = m.cond_stage_model.wrapped
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped m.cond_stage_model = m.cond_stage_model.wrapped
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes: if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:
m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
m.cond_stage_model = m.cond_stage_model.wrapped
self.apply_circular(False)
self.layers = None self.layers = None
self.circular_enabled = False
self.clip = None self.clip = None
def apply_circular(self, enable): def apply_circular(self, enable):
...@@ -112,261 +152,8 @@ class StableDiffusionModelHijack: ...@@ -112,261 +152,8 @@ class StableDiffusionModelHijack:
def tokenize(self, text): def tokenize(self, text):
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count)
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack):
super().__init__()
self.wrapped = wrapped
self.hijack: StableDiffusionModelHijack = hijack
self.tokenizer = wrapped.tokenizer
self.token_mults = {}
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens:
mult = 1.0
for c in text:
if c == '[':
mult /= 1.1
if c == ']':
mult *= 1.1
if c == '(':
mult *= 1.1
if c == ')':
mult /= 1.1
if mult != 1.0:
self.token_mults[ident] = mult
def tokenize_line(self, line, used_custom_terms, hijack_comments):
id_end = self.wrapped.tokenizer.eos_token_id
if opts.enable_emphasis:
parsed = prompt_parser.parse_prompt_attention(line)
else:
parsed = [[line, 1.0]]
tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"]
fixes = []
remade_tokens = []
multipliers = []
last_comma = -1
for tokens, (text, weight) in zip(tokenized, parsed):
i = 0
while i < len(tokens):
token = tokens[i]
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
if token == self.comma_token: return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
last_comma = len(remade_tokens)
elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
last_comma += 1
reloc_tokens = remade_tokens[last_comma:]
reloc_mults = multipliers[last_comma:]
remade_tokens = remade_tokens[:last_comma]
length = len(remade_tokens)
rem = int(math.ceil(length / 75)) * 75 - length
remade_tokens += [id_end] * rem + reloc_tokens
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
if embedding is None:
remade_tokens.append(token)
multipliers.append(weight)
i += 1
else:
emb_len = int(embedding.vec.shape[0])
iteration = len(remade_tokens) // 75
if (len(remade_tokens) + emb_len) // 75 != iteration:
rem = (75 * (iteration + 1) - len(remade_tokens))
remade_tokens += [id_end] * rem
multipliers += [1.0] * rem
iteration += 1
fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
remade_tokens += [0] * emb_len
multipliers += [weight] * emb_len
used_custom_terms.append((embedding.name, embedding.checksum()))
i += embedding_length_in_tokens
token_count = len(remade_tokens)
prompt_target_length = get_target_prompt_token_count(token_count)
tokens_to_add = prompt_target_length - len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * tokens_to_add
multipliers = multipliers + [1.0] * tokens_to_add
return remade_tokens, fixes, multipliers, token_count
def process_text(self, texts):
used_custom_terms = []
remade_batch_tokens = []
hijack_comments = []
hijack_fixes = []
token_count = 0
cache = {}
batch_multipliers = []
for line in texts:
if line in cache:
remade_tokens, fixes, multipliers = cache[line]
else:
remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
token_count = max(current_token_count, token_count)
cache[line] = (remade_tokens, fixes, multipliers)
remade_batch_tokens.append(remade_tokens)
hijack_fixes.append(fixes)
batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def process_text_old(self, text):
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
maxlen = self.wrapped.max_length # you get to stay at 77
used_custom_terms = []
remade_batch_tokens = []
overflowing_words = []
hijack_comments = []
hijack_fixes = []
token_count = 0
cache = {}
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
batch_multipliers = []
for tokens in batch_tokens:
tuple_tokens = tuple(tokens)
if tuple_tokens in cache:
remade_tokens, fixes, multipliers = cache[tuple_tokens]
else:
fixes = []
remade_tokens = []
multipliers = []
mult = 1.0
i = 0
while i < len(tokens):
token = tokens[i]
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
if mult_change is not None:
mult *= mult_change
i += 1
elif embedding is None:
remade_tokens.append(token)
multipliers.append(mult)
i += 1
else:
emb_len = int(embedding.vec.shape[0])
fixes.append((len(remade_tokens), embedding))
remade_tokens += [0] * emb_len
multipliers += [mult] * emb_len
used_custom_terms.append((embedding.name, embedding.checksum()))
i += embedding_length_in_tokens
if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
ovf = remade_tokens[maxlen - 2:]
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
remade_batch_tokens.append(remade_tokens)
hijack_fixes.append(fixes)
batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def forward(self, text):
use_old = opts.use_old_emphasis_implementation
if use_old:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
else:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
self.hijack.comments += hijack_comments
if len(used_custom_terms) > 0:
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
if use_old:
self.hijack.fixes = hijack_fixes
return self.process_tokens(remade_batch_tokens, batch_multipliers)
z = None
i = 0
while max(map(len, remade_batch_tokens)) != 0:
rem_tokens = [x[75:] for x in remade_batch_tokens]
rem_multipliers = [x[75:] for x in batch_multipliers]
self.hijack.fixes = []
for unfiltered in hijack_fixes:
fixes = []
for fix in unfiltered:
if fix[0] == i:
fixes.append(fix[1])
self.hijack.fixes.append(fixes)
tokens = []
multipliers = []
for j in range(len(remade_batch_tokens)):
if len(remade_batch_tokens[j]) > 0:
tokens.append(remade_batch_tokens[j][:75])
multipliers.append(batch_multipliers[j][:75])
else:
tokens.append([self.wrapped.tokenizer.eos_token_id] * 75)
multipliers.append([1.0] * 75)
z1 = self.process_tokens(tokens, multipliers)
z = z1 if z is None else torch.cat((z, z1), axis=-2)
remade_batch_tokens = rem_tokens
batch_multipliers = rem_multipliers
i += 1
return z
def process_tokens(self, remade_batch_tokens, batch_multipliers):
if not opts.use_old_emphasis_implementation:
remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens]
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
tokens = torch.asarray(remade_batch_tokens).to(device)
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
if opts.CLIP_stop_at_last_layers > 1:
z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
z = self.wrapped.transformer.text_model.final_layer_norm(z)
else:
z = outputs.last_hidden_state
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device)
original_mean = z.mean()
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()
z *= original_mean / new_mean
return z
class EmbeddingsWithFixes(torch.nn.Module): class EmbeddingsWithFixes(torch.nn.Module):
...@@ -406,3 +193,19 @@ def add_circular_option_to_conv_2d(): ...@@ -406,3 +193,19 @@ def add_circular_option_to_conv_2d():
model_hijack = StableDiffusionModelHijack() model_hijack = StableDiffusionModelHijack()
def register_buffer(self, name, attr):
"""
Fix register buffer bug for Mac OS.
"""
if type(attr) == torch.Tensor:
if attr.device != devices.device:
attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
setattr(self, name, attr)
ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
from torch.utils.checkpoint import checkpoint
def BasicTransformerBlock_forward(self, x, context=None):
return checkpoint(self._forward, x, context)
def AttentionBlock_forward(self, x):
return checkpoint(self._forward, x)
def ResBlock_forward(self, x, emb):
return checkpoint(self._forward, x, emb)
\ No newline at end of file
import math
import torch
from modules import prompt_parser, devices
from modules.shared import opts
def get_target_prompt_token_count(token_count):
return math.ceil(max(token_count, 1) / 75) * 75
class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
def __init__(self, wrapped, hijack):
super().__init__()
self.wrapped = wrapped
self.hijack = hijack
def tokenize(self, texts):
raise NotImplementedError
def encode_with_transformers(self, tokens):
raise NotImplementedError
def encode_embedding_init_text(self, init_text, nvpt):
raise NotImplementedError
def tokenize_line(self, line, used_custom_terms, hijack_comments):
if opts.enable_emphasis:
parsed = prompt_parser.parse_prompt_attention(line)
else:
parsed = [[line, 1.0]]
tokenized = self.tokenize([text for text, _ in parsed])
fixes = []
remade_tokens = []
multipliers = []
last_comma = -1
for tokens, (text, weight) in zip(tokenized, parsed):
i = 0
while i < len(tokens):
token = tokens[i]
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
if token == self.comma_token:
last_comma = len(remade_tokens)
elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
last_comma += 1
reloc_tokens = remade_tokens[last_comma:]
reloc_mults = multipliers[last_comma:]
remade_tokens = remade_tokens[:last_comma]
length = len(remade_tokens)
rem = int(math.ceil(length / 75)) * 75 - length
remade_tokens += [self.id_end] * rem + reloc_tokens
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
if embedding is None:
remade_tokens.append(token)
multipliers.append(weight)
i += 1
else:
emb_len = int(embedding.vec.shape[0])
iteration = len(remade_tokens) // 75
if (len(remade_tokens) + emb_len) // 75 != iteration:
rem = (75 * (iteration + 1) - len(remade_tokens))
remade_tokens += [self.id_end] * rem
multipliers += [1.0] * rem
iteration += 1
fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
remade_tokens += [0] * emb_len
multipliers += [weight] * emb_len
used_custom_terms.append((embedding.name, embedding.checksum()))
i += embedding_length_in_tokens
token_count = len(remade_tokens)
prompt_target_length = get_target_prompt_token_count(token_count)
tokens_to_add = prompt_target_length - len(remade_tokens)
remade_tokens = remade_tokens + [self.id_end] * tokens_to_add
multipliers = multipliers + [1.0] * tokens_to_add
return remade_tokens, fixes, multipliers, token_count
def process_text(self, texts):
used_custom_terms = []
remade_batch_tokens = []
hijack_comments = []
hijack_fixes = []
token_count = 0
cache = {}
batch_multipliers = []
for line in texts:
if line in cache:
remade_tokens, fixes, multipliers = cache[line]
else:
remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
token_count = max(current_token_count, token_count)
cache[line] = (remade_tokens, fixes, multipliers)
remade_batch_tokens.append(remade_tokens)
hijack_fixes.append(fixes)
batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def process_text_old(self, texts):
id_start = self.id_start
id_end = self.id_end
maxlen = self.wrapped.max_length # you get to stay at 77
used_custom_terms = []
remade_batch_tokens = []
hijack_comments = []
hijack_fixes = []
token_count = 0
cache = {}
batch_tokens = self.tokenize(texts)
batch_multipliers = []
for tokens in batch_tokens:
tuple_tokens = tuple(tokens)
if tuple_tokens in cache:
remade_tokens, fixes, multipliers = cache[tuple_tokens]
else:
fixes = []
remade_tokens = []
multipliers = []
mult = 1.0
i = 0
while i < len(tokens):
token = tokens[i]
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
if mult_change is not None:
mult *= mult_change
i += 1
elif embedding is None:
remade_tokens.append(token)
multipliers.append(mult)
i += 1
else:
emb_len = int(embedding.vec.shape[0])
fixes.append((len(remade_tokens), embedding))
remade_tokens += [0] * emb_len
multipliers += [mult] * emb_len
used_custom_terms.append((embedding.name, embedding.checksum()))
i += embedding_length_in_tokens
if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
ovf = remade_tokens[maxlen - 2:]
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
remade_batch_tokens.append(remade_tokens)
hijack_fixes.append(fixes)
batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def forward(self, text):
use_old = opts.use_old_emphasis_implementation
if use_old:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
else:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
self.hijack.comments += hijack_comments
if len(used_custom_terms) > 0:
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
if use_old:
self.hijack.fixes = hijack_fixes
return self.process_tokens(remade_batch_tokens, batch_multipliers)
z = None
i = 0
while max(map(len, remade_batch_tokens)) != 0:
rem_tokens = [x[75:] for x in remade_batch_tokens]
rem_multipliers = [x[75:] for x in batch_multipliers]
self.hijack.fixes = []
for unfiltered in hijack_fixes:
fixes = []
for fix in unfiltered:
if fix[0] == i:
fixes.append(fix[1])
self.hijack.fixes.append(fixes)
tokens = []
multipliers = []
for j in range(len(remade_batch_tokens)):
if len(remade_batch_tokens[j]) > 0:
tokens.append(remade_batch_tokens[j][:75])
multipliers.append(batch_multipliers[j][:75])
else:
tokens.append([self.id_end] * 75)
multipliers.append([1.0] * 75)
z1 = self.process_tokens(tokens, multipliers)
z = z1 if z is None else torch.cat((z, z1), axis=-2)
remade_batch_tokens = rem_tokens
batch_multipliers = rem_multipliers
i += 1
return z
def process_tokens(self, remade_batch_tokens, batch_multipliers):
if not opts.use_old_emphasis_implementation:
remade_batch_tokens = [[self.id_start] + x[:75] + [self.id_end] for x in remade_batch_tokens]
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
tokens = torch.asarray(remade_batch_tokens).to(devices.device)
if self.id_end != self.id_pad:
for batch_pos in range(len(remade_batch_tokens)):
index = remade_batch_tokens[batch_pos].index(self.id_end)
tokens[batch_pos, index+1:tokens.shape[1]] = self.id_pad
z = self.encode_with_transformers(tokens)
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(devices.device)
original_mean = z.mean()
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()
z *= original_mean / new_mean
return z
class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
def __init__(self, wrapped, hijack):
super().__init__(wrapped, hijack)
self.tokenizer = wrapped.tokenizer
vocab = self.tokenizer.get_vocab()
self.comma_token = vocab.get(',</w>', None)
self.token_mults = {}
tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens:
mult = 1.0
for c in text:
if c == '[':
mult /= 1.1
if c == ']':
mult *= 1.1
if c == '(':
mult *= 1.1
if c == ')':
mult /= 1.1
if mult != 1.0:
self.token_mults[ident] = mult
self.id_start = self.wrapped.tokenizer.bos_token_id
self.id_end = self.wrapped.tokenizer.eos_token_id
self.id_pad = self.id_end
def tokenize(self, texts):
tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
return tokenized
def encode_with_transformers(self, tokens):
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
if opts.CLIP_stop_at_last_layers > 1:
z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
z = self.wrapped.transformer.text_model.final_layer_norm(z)
else:
z = outputs.last_hidden_state
return z
def encode_embedding_init_text(self, init_text, nvpt):
embedding_layer = self.wrapped.transformer.text_model.embeddings
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
return embedded
import os
import torch import torch
from einops import repeat from einops import repeat
...@@ -11,196 +12,11 @@ from ldm.models.diffusion.ddpm import LatentDiffusion ...@@ -11,196 +12,11 @@ from ldm.models.diffusion.ddpm import LatentDiffusion
from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ddim import DDIMSampler, noise_like from ldm.models.diffusion.ddim import DDIMSampler, noise_like
# =================================================================================================
# Monkey patch DDIMSampler methods from RunwayML repo directly.
# Adapted from:
# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py
# =================================================================================================
@torch.no_grad()
def sample_ddim(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list):
ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
return samples, intermediates
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None):
b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
if isinstance(c, dict):
assert isinstance(unconditional_conditioning, dict)
c_in = dict()
for k in c:
if isinstance(c[k], list):
c_in[k] = [
torch.cat([unconditional_conditioning[k][i], c[k][i]])
for i in range(len(c[k]))
]
else:
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
else:
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
# =================================================================================================
# Monkey patch PLMSSampler methods.
# This one was not actually patched correctly in the RunwayML repo, but we can replicate the changes.
# Adapted from:
# https://github.com/CompVis/stable-diffusion/blob/main/ldm/models/diffusion/plms.py
# =================================================================================================
@torch.no_grad()
def sample_plms(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list):
ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for PLMS sampling is {size}')
samples, intermediates = self.plms_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
return samples, intermediates
@torch.no_grad() @torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
def get_model_output(x, t): def get_model_output(x, t):
...@@ -249,6 +65,8 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F ...@@ -249,6 +65,8 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised: if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
# direction pointing to x_t # direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
...@@ -277,55 +95,17 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F ...@@ -277,55 +95,17 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
return x_prev, pred_x0, e_t return x_prev, pred_x0, e_t
# =================================================================================================
# Monkey patch LatentInpaintDiffusion to load the checkpoint with a proper config.
# Adapted from:
# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddpm.py
# =================================================================================================
@torch.no_grad() def should_hijack_inpainting(checkpoint_info):
def get_unconditional_conditioning(self, batch_size, null_label=None): from modules import sd_models
if null_label is not None:
xc = null_label
if isinstance(xc, ListConfig):
xc = list(xc)
if isinstance(xc, dict) or isinstance(xc, list):
c = self.get_learned_conditioning(xc)
else:
if hasattr(xc, "to"):
xc = xc.to(self.device)
c = self.get_learned_conditioning(xc)
else:
# todo: get null label from cond_stage_model
raise NotImplementedError()
c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
return c
class LatentInpaintDiffusion(LatentDiffusion):
def __init__(
self,
concat_keys=("mask", "masked_image"),
masked_image_key="masked_image",
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.masked_image_key = masked_image_key
assert self.masked_image_key in concat_keys
self.concat_keys = concat_keys
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower()
def should_hijack_inpainting(checkpoint_info): return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename
return str(checkpoint_info.filename).endswith("inpainting.ckpt") and not checkpoint_info.config.endswith("inpainting.yaml")
def do_inpainting_hijack(): def do_inpainting_hijack():
ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning # p_sample_plms is needed because PLMS can't work with dicts as conditionings
ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
\ No newline at end of file
import open_clip.tokenizer
import torch
from modules import sd_hijack_clip, devices
from modules.shared import opts
tokenizer = open_clip.tokenizer._tokenizer
class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
def __init__(self, wrapped, hijack):
super().__init__(wrapped, hijack)
self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0]
self.id_start = tokenizer.encoder["<start_of_text>"]
self.id_end = tokenizer.encoder["<end_of_text>"]
self.id_pad = 0
def tokenize(self, texts):
assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
tokenized = [tokenizer.encode(text) for text in texts]
return tokenized
def encode_with_transformers(self, tokens):
# set self.wrapped.layer_idx here according to opts.CLIP_stop_at_last_layers
z = self.wrapped.encode_with_transformer(tokens)
return z
def encode_embedding_init_text(self, init_text, nvpt):
ids = tokenizer.encode(init_text)
ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
return embedded
...@@ -127,7 +127,7 @@ def check_for_psutil(): ...@@ -127,7 +127,7 @@ def check_for_psutil():
invokeAI_mps_available = check_for_psutil() invokeAI_mps_available = check_for_psutil()
# -- Taken from https://github.com/invoke-ai/InvokeAI -- # -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
if invokeAI_mps_available: if invokeAI_mps_available:
import psutil import psutil
mem_total_gb = psutil.virtual_memory().total // (1 << 30) mem_total_gb = psutil.virtual_memory().total // (1 << 30)
...@@ -152,14 +152,16 @@ def einsum_op_slice_1(q, k, v, slice_size): ...@@ -152,14 +152,16 @@ def einsum_op_slice_1(q, k, v, slice_size):
return r return r
def einsum_op_mps_v1(q, k, v): def einsum_op_mps_v1(q, k, v):
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
return einsum_op_compvis(q, k, v) return einsum_op_compvis(q, k, v)
else: else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
if slice_size % 4096 == 0:
slice_size -= 1
return einsum_op_slice_1(q, k, v, slice_size) return einsum_op_slice_1(q, k, v, slice_size)
def einsum_op_mps_v2(q, k, v): def einsum_op_mps_v2(q, k, v):
if mem_total_gb > 8 and q.shape[1] <= 4096: if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
return einsum_op_compvis(q, k, v) return einsum_op_compvis(q, k, v)
else: else:
return einsum_op_slice_0(q, k, v, 1) return einsum_op_slice_0(q, k, v, 1)
...@@ -188,7 +190,7 @@ def einsum_op(q, k, v): ...@@ -188,7 +190,7 @@ def einsum_op(q, k, v):
return einsum_op_cuda(q, k, v) return einsum_op_cuda(q, k, v)
if q.device.type == 'mps': if q.device.type == 'mps':
if mem_total_gb >= 32: if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:
return einsum_op_mps_v1(q, k, v) return einsum_op_mps_v1(q, k, v)
return einsum_op_mps_v2(q, k, v) return einsum_op_mps_v2(q, k, v)
......
import torch
class TorchHijackForUnet:
"""
This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
"""
def __getattr__(self, item):
if item == 'cat':
return self.cat
if hasattr(torch, item):
return getattr(torch, item)
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
def cat(self, tensors, *args, **kwargs):
if len(tensors) == 2:
a, b = tensors
if a.shape[-2:] != b.shape[-2:]:
a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
tensors = (a, b)
return torch.cat(tensors, *args, **kwargs)
th = TorchHijackForUnet()
import open_clip.tokenizer
import torch
from modules import sd_hijack_clip, devices
from modules.shared import opts
class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
def __init__(self, wrapped, hijack):
super().__init__(wrapped, hijack)
self.id_start = wrapped.config.bos_token_id
self.id_end = wrapped.config.eos_token_id
self.id_pad = wrapped.config.pad_token_id
self.comma_token = self.tokenizer.get_vocab().get(',', None) # alt diffusion doesn't have </w> bits for comma
def encode_with_transformers(self, tokens):
# there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a
# trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer
# layer to work with - you have to use the last
attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64)
features = self.wrapped(input_ids=tokens, attention_mask=attention_mask)
z = features['projection_state']
return z
def encode_embedding_init_text(self, init_text, nvpt):
embedding_layer = self.wrapped.roberta.embeddings
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
return embedded
...@@ -5,7 +5,11 @@ import gc ...@@ -5,7 +5,11 @@ import gc
from collections import namedtuple from collections import namedtuple
import torch import torch
import re import re
import safetensors.torch
from omegaconf import OmegaConf from omegaconf import OmegaConf
from os import mkdir
from urllib import request
import ldm.modules.midas as midas
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
...@@ -16,7 +20,7 @@ from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inp ...@@ -16,7 +20,7 @@ from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inp
model_dir = "Stable-diffusion" model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir)) model_path = os.path.abspath(os.path.join(models_path, model_dir))
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config']) CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
checkpoints_list = {} checkpoints_list = {}
checkpoints_loaded = collections.OrderedDict() checkpoints_loaded = collections.OrderedDict()
...@@ -35,6 +39,7 @@ def setup_model(): ...@@ -35,6 +39,7 @@ def setup_model():
os.makedirs(model_path) os.makedirs(model_path)
list_models() list_models()
enable_midas_autodownload()
def checkpoint_tiles(): def checkpoint_tiles():
...@@ -43,9 +48,17 @@ def checkpoint_tiles(): ...@@ -43,9 +48,17 @@ def checkpoint_tiles():
return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key) return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)
def find_checkpoint_config(info):
config = os.path.splitext(info.filename)[0] + ".yaml"
if os.path.exists(config):
return config
return shared.cmd_opts.config
def list_models(): def list_models():
checkpoints_list.clear() checkpoints_list.clear()
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt"]) model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"])
def modeltitle(path, shorthash): def modeltitle(path, shorthash):
abspath = os.path.abspath(path) abspath = os.path.abspath(path)
...@@ -68,7 +81,7 @@ def list_models(): ...@@ -68,7 +81,7 @@ def list_models():
if os.path.exists(cmd_ckpt): if os.path.exists(cmd_ckpt):
h = model_hash(cmd_ckpt) h = model_hash(cmd_ckpt)
title, short_model_name = modeltitle(cmd_ckpt, h) title, short_model_name = modeltitle(cmd_ckpt, h)
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config) checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
shared.opts.data['sd_model_checkpoint'] = title shared.opts.data['sd_model_checkpoint'] = title
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
...@@ -76,12 +89,7 @@ def list_models(): ...@@ -76,12 +89,7 @@ def list_models():
h = model_hash(filename) h = model_hash(filename)
title, short_model_name = modeltitle(filename, h) title, short_model_name = modeltitle(filename, h)
basename, _ = os.path.splitext(filename) checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
config = basename + ".yaml"
if not os.path.exists(config):
config = shared.cmd_opts.config
checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config)
def get_closet_checkpoint_match(searchString): def get_closet_checkpoint_match(searchString):
...@@ -106,18 +114,19 @@ def model_hash(filename): ...@@ -106,18 +114,19 @@ def model_hash(filename):
def select_checkpoint(): def select_checkpoint():
model_checkpoint = shared.opts.sd_model_checkpoint model_checkpoint = shared.opts.sd_model_checkpoint
checkpoint_info = checkpoints_list.get(model_checkpoint, None) checkpoint_info = checkpoints_list.get(model_checkpoint, None)
if checkpoint_info is not None: if checkpoint_info is not None:
return checkpoint_info return checkpoint_info
if len(checkpoints_list) == 0: if len(checkpoints_list) == 0:
print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr) print("No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
if shared.cmd_opts.ckpt is not None: if shared.cmd_opts.ckpt is not None:
print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr) print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
print(f" - directory {model_path}", file=sys.stderr) print(f" - directory {model_path}", file=sys.stderr)
if shared.cmd_opts.ckpt_dir is not None: if shared.cmd_opts.ckpt_dir is not None:
print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr) print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr) print("Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
exit(1) exit(1)
checkpoint_info = next(iter(checkpoints_list.values())) checkpoint_info = next(iter(checkpoints_list.values()))
...@@ -143,8 +152,8 @@ def transform_checkpoint_dict_key(k): ...@@ -143,8 +152,8 @@ def transform_checkpoint_dict_key(k):
def get_state_dict_from_checkpoint(pl_sd): def get_state_dict_from_checkpoint(pl_sd):
if "state_dict" in pl_sd: pl_sd = pl_sd.pop("state_dict", pl_sd)
pl_sd = pl_sd["state_dict"] pl_sd.pop("state_dict", None)
sd = {} sd = {}
for k, v in pl_sd.items(): for k, v in pl_sd.items():
...@@ -159,28 +168,45 @@ def get_state_dict_from_checkpoint(pl_sd): ...@@ -159,28 +168,45 @@ def get_state_dict_from_checkpoint(pl_sd):
return pl_sd return pl_sd
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
_, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors":
device = map_location or shared.weight_load_location
if device is None:
device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu"
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
else:
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
if print_global_state and "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = get_state_dict_from_checkpoint(pl_sd)
return sd
def load_model_weights(model, checkpoint_info, vae_file="auto"): def load_model_weights(model, checkpoint_info, vae_file="auto"):
checkpoint_file = checkpoint_info.filename checkpoint_file = checkpoint_info.filename
sd_model_hash = checkpoint_info.hash sd_model_hash = checkpoint_info.hash
if shared.opts.sd_checkpoint_cache > 0 and hasattr(model, "sd_checkpoint_info"): cache_enabled = shared.opts.sd_checkpoint_cache > 0
sd_vae.restore_base_vae(model)
checkpoints_loaded[model.sd_checkpoint_info] = model.state_dict().copy()
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) if cache_enabled and checkpoint_info in checkpoints_loaded:
# use checkpoint cache
if checkpoint_info not in checkpoints_loaded: print(f"Loading weights [{sd_model_hash}] from cache")
model.load_state_dict(checkpoints_loaded[checkpoint_info])
else:
# load from file
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) sd = read_state_dict(checkpoint_file)
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = get_state_dict_from_checkpoint(pl_sd)
del pl_sd
model.load_state_dict(sd, strict=False) model.load_state_dict(sd, strict=False)
del sd del sd
if cache_enabled:
# cache newly loaded model
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
if shared.cmd_opts.opt_channelslast: if shared.cmd_opts.opt_channelslast:
model.to(memory_format=torch.channels_last) model.to(memory_format=torch.channels_last)
...@@ -199,29 +225,73 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): ...@@ -199,29 +225,73 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
model.first_stage_model.to(devices.dtype_vae) model.first_stage_model.to(devices.dtype_vae)
else: # clean up cache if limit is reached
vae_name = sd_vae.get_filename(vae_file) if vae_file else None if cache_enabled:
vae_message = f" with {vae_name} VAE" if vae_name else "" while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: # we need to count the current model
print(f"Loading weights [{sd_model_hash}]{vae_message} from cache")
model.load_state_dict(checkpoints_loaded[checkpoint_info])
if shared.opts.sd_checkpoint_cache > 0:
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
checkpoints_loaded.popitem(last=False) # LRU checkpoints_loaded.popitem(last=False) # LRU
model.sd_model_hash = sd_model_hash model.sd_model_hash = sd_model_hash
model.sd_model_checkpoint = checkpoint_file model.sd_model_checkpoint = checkpoint_file
model.sd_checkpoint_info = checkpoint_info model.sd_checkpoint_info = checkpoint_info
model.logvar = model.logvar.to(devices.device) # fix for training
sd_vae.delete_base_vae()
sd_vae.clear_loaded_vae()
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
sd_vae.load_vae(model, vae_file) sd_vae.load_vae(model, vae_file)
def enable_midas_autodownload():
"""
Gives the ldm.modules.midas.api.load_model function automatic downloading.
When the 512-depth-ema model, and other future models like it, is loaded,
it calls midas.api.load_model to load the associated midas depth model.
This function applies a wrapper to download the model to the correct
location automatically.
"""
midas_path = os.path.join(models_path, 'midas')
# stable-diffusion-stability-ai hard-codes the midas model path to
# a location that differs from where other scripts using this model look.
# HACK: Overriding the path here.
for k, v in midas.api.ISL_PATHS.items():
file_name = os.path.basename(v)
midas.api.ISL_PATHS[k] = os.path.join(midas_path, file_name)
midas_urls = {
"dpt_large": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
"dpt_hybrid": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
"midas_v21": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt",
"midas_v21_small": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt",
}
midas.api.load_model_inner = midas.api.load_model
def load_model_wrapper(model_type):
path = midas.api.ISL_PATHS[model_type]
if not os.path.exists(path):
if not os.path.exists(midas_path):
mkdir(midas_path)
print(f"Downloading midas model weights for {model_type} to {path}")
request.urlretrieve(midas_urls[model_type], path)
print(f"{model_type} downloaded")
return midas.api.load_model_inner(model_type)
midas.api.load_model = load_model_wrapper
def load_model(checkpoint_info=None): def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint() checkpoint_info = checkpoint_info or select_checkpoint()
checkpoint_config = find_checkpoint_config(checkpoint_info)
if checkpoint_info.config != shared.cmd_opts.config: if checkpoint_config != shared.cmd_opts.config:
print(f"Loading config from: {checkpoint_info.config}") print(f"Loading config from: {checkpoint_config}")
if shared.sd_model: if shared.sd_model:
sd_hijack.model_hijack.undo_hijack(shared.sd_model) sd_hijack.model_hijack.undo_hijack(shared.sd_model)
...@@ -229,21 +299,25 @@ def load_model(checkpoint_info=None): ...@@ -229,21 +299,25 @@ def load_model(checkpoint_info=None):
gc.collect() gc.collect()
devices.torch_gc() devices.torch_gc()
sd_config = OmegaConf.load(checkpoint_info.config) sd_config = OmegaConf.load(checkpoint_config)
if should_hijack_inpainting(checkpoint_info): if should_hijack_inpainting(checkpoint_info):
# Hardcoded config for now... # Hardcoded config for now...
sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion" sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
sd_config.model.params.use_ema = False
sd_config.model.params.conditioning_key = "hybrid" sd_config.model.params.conditioning_key = "hybrid"
sd_config.model.params.unet_config.params.in_channels = 9 sd_config.model.params.unet_config.params.in_channels = 9
sd_config.model.params.finetune_keys = None
# Create a "fake" config with a different name so that we know to unload it when switching models. if not hasattr(sd_config.model.params, "use_ema"):
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml")) sd_config.model.params.use_ema = False
do_inpainting_hijack() do_inpainting_hijack()
if shared.cmd_opts.no_half:
sd_config.model.params.unet_config.params.use_fp16 = False
sd_model = instantiate_from_config(sd_config.model) sd_model = instantiate_from_config(sd_config.model)
load_model_weights(sd_model, checkpoint_info) load_model_weights(sd_model, checkpoint_info)
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
...@@ -256,9 +330,12 @@ def load_model(checkpoint_info=None): ...@@ -256,9 +330,12 @@ def load_model(checkpoint_info=None):
sd_model.eval() sd_model.eval()
shared.sd_model = sd_model shared.sd_model = sd_model
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
script_callbacks.model_loaded_callback(sd_model) script_callbacks.model_loaded_callback(sd_model)
print(f"Model loaded.") print("Model loaded.")
return sd_model return sd_model
...@@ -269,10 +346,13 @@ def reload_model_weights(sd_model=None, info=None): ...@@ -269,10 +346,13 @@ def reload_model_weights(sd_model=None, info=None):
if not sd_model: if not sd_model:
sd_model = shared.sd_model sd_model = shared.sd_model
current_checkpoint_info = sd_model.sd_checkpoint_info
checkpoint_config = find_checkpoint_config(current_checkpoint_info)
if sd_model.sd_model_checkpoint == checkpoint_info.filename: if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return return
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): if checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
del sd_model del sd_model
checkpoints_loaded.clear() checkpoints_loaded.clear()
load_model(checkpoint_info) load_model(checkpoint_info)
...@@ -285,13 +365,19 @@ def reload_model_weights(sd_model=None, info=None): ...@@ -285,13 +365,19 @@ def reload_model_weights(sd_model=None, info=None):
sd_hijack.model_hijack.undo_hijack(sd_model) sd_hijack.model_hijack.undo_hijack(sd_model)
try:
load_model_weights(sd_model, checkpoint_info) load_model_weights(sd_model, checkpoint_info)
except Exception as e:
print("Failed to load checkpoint, restoring previous")
load_model_weights(sd_model, current_checkpoint_info)
raise
finally:
sd_hijack.model_hijack.hijack(sd_model) sd_hijack.model_hijack.hijack(sd_model)
script_callbacks.model_loaded_callback(sd_model) script_callbacks.model_loaded_callback(sd_model)
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device) sd_model.to(devices.device)
print(f"Weights loaded.") print("Weights loaded.")
return sd_model return sd_model
from collections import namedtuple from collections import namedtuple, deque
import numpy as np import numpy as np
from math import floor from math import floor
import torch import torch
...@@ -6,9 +6,10 @@ import tqdm ...@@ -6,9 +6,10 @@ import tqdm
from PIL import Image from PIL import Image
import inspect import inspect
import k_diffusion.sampling import k_diffusion.sampling
import torchsde._brownian.brownian_interval
import ldm.models.diffusion.ddim import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms import ldm.models.diffusion.plms
from modules import prompt_parser, devices, processing, images from modules import prompt_parser, devices, processing, images, sd_vae_approx
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
...@@ -18,21 +19,23 @@ from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback ...@@ -18,21 +19,23 @@ from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
samplers_k_diffusion = [ samplers_k_diffusion = [
('Euler a', 'sample_euler_ancestral', ['k_euler_a'], {}), ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
('Euler', 'sample_euler', ['k_euler'], {}), ('Euler', 'sample_euler', ['k_euler'], {}),
('LMS', 'sample_lms', ['k_lms'], {}), ('LMS', 'sample_lms', ['k_lms'], {}),
('Heun', 'sample_heun', ['k_heun'], {}), ('Heun', 'sample_heun', ['k_heun'], {}),
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {}), ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}), ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}),
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}), ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}), ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}), ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}), ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras'}), ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}), ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}), ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
] ]
samplers_data_k_diffusion = [ samplers_data_k_diffusion = [
...@@ -46,13 +49,21 @@ all_samplers = [ ...@@ -46,13 +49,21 @@ all_samplers = [
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}), SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
] ]
all_samplers_map = {x.name: x for x in all_samplers}
samplers = [] samplers = []
samplers_for_img2img = [] samplers_for_img2img = []
samplers_map = {}
def create_sampler_with_index(list_of_configs, index, model): def create_sampler(name, model):
config = list_of_configs[index] if name is not None:
config = all_samplers_map.get(name, None)
else:
config = all_samplers[0]
assert config is not None, f'bad sampler name: {name}'
sampler = config.constructor(model) sampler = config.constructor(model)
sampler.config = config sampler.config = config
...@@ -68,6 +79,12 @@ def set_samplers(): ...@@ -68,6 +79,12 @@ def set_samplers():
samplers = [x for x in all_samplers if x.name not in hidden] samplers = [x for x in all_samplers if x.name not in hidden]
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img] samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
samplers_map.clear()
for sampler in all_samplers:
samplers_map[sampler.name.lower()] = sampler.name
for alias in sampler.aliases:
samplers_map[alias.lower()] = sampler.name
set_samplers() set_samplers()
...@@ -89,20 +106,32 @@ def setup_img2img_steps(p, steps=None): ...@@ -89,20 +106,32 @@ def setup_img2img_steps(p, steps=None):
return steps, t_enc return steps, t_enc
def single_sample_to_image(sample): approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
def single_sample_to_image(sample, approximation=None):
if approximation is None:
approximation = approximation_indexes.get(opts.show_progress_type, 0)
if approximation == 2:
x_sample = sd_vae_approx.cheap_approximation(sample)
elif approximation == 1:
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
else:
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8) x_sample = x_sample.astype(np.uint8)
return Image.fromarray(x_sample) return Image.fromarray(x_sample)
def sample_to_image(samples, index=0): def sample_to_image(samples, index=0, approximation=None):
return single_sample_to_image(samples[index]) return single_sample_to_image(samples[index], approximation)
def samples_to_image_grid(samples): def samples_to_image_grid(samples, approximation=None):
return images.image_grid([single_sample_to_image(sample) for sample in samples]) return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
def store_latent(decoded): def store_latent(decoded):
...@@ -120,7 +149,8 @@ class InterruptedException(BaseException): ...@@ -120,7 +149,8 @@ class InterruptedException(BaseException):
class VanillaStableDiffusionSampler: class VanillaStableDiffusionSampler:
def __init__(self, constructor, sd_model): def __init__(self, constructor, sd_model):
self.sampler = constructor(sd_model) self.sampler = constructor(sd_model)
self.orig_p_sample_ddim = self.sampler.p_sample_ddim if hasattr(self.sampler, 'p_sample_ddim') else self.sampler.p_sample_plms self.is_plms = hasattr(self.sampler, 'p_sample_plms')
self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim
self.mask = None self.mask = None
self.nmask = None self.nmask = None
self.init_latent = None self.init_latent = None
...@@ -211,7 +241,6 @@ class VanillaStableDiffusionSampler: ...@@ -211,7 +241,6 @@ class VanillaStableDiffusionSampler:
self.mask = p.mask if hasattr(p, 'mask') else None self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None self.nmask = p.nmask if hasattr(p, 'nmask') else None
def adjust_steps_if_invalid(self, p, num_steps): def adjust_steps_if_invalid(self, p, num_steps):
if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'): if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
valid_step = 999 / (1000 // num_steps) valid_step = 999 / (1000 // num_steps)
...@@ -220,7 +249,6 @@ class VanillaStableDiffusionSampler: ...@@ -220,7 +249,6 @@ class VanillaStableDiffusionSampler:
return num_steps return num_steps
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = setup_img2img_steps(p, steps) steps, t_enc = setup_img2img_steps(p, steps)
steps = self.adjust_steps_if_invalid(p, steps) steps = self.adjust_steps_if_invalid(p, steps)
...@@ -253,9 +281,10 @@ class VanillaStableDiffusionSampler: ...@@ -253,9 +281,10 @@ class VanillaStableDiffusionSampler:
steps = self.adjust_steps_if_invalid(p, steps or p.steps) steps = self.adjust_steps_if_invalid(p, steps or p.steps)
# Wrap the conditioning models with additional image conditioning for inpainting model # Wrap the conditioning models with additional image conditioning for inpainting model
# dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
if image_conditioning is not None: if image_conditioning is not None:
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]} conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0]) samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
...@@ -271,6 +300,16 @@ class CFGDenoiser(torch.nn.Module): ...@@ -271,6 +300,16 @@ class CFGDenoiser(torch.nn.Module):
self.init_latent = None self.init_latent = None
self.step = 0 self.step = 0
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond)
for i, conds in enumerate(conds_list):
for cond_index, weight in conds:
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
return denoised
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond): def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
if state.interrupted or state.skipped: if state.interrupted or state.skipped:
raise InterruptedException raise InterruptedException
...@@ -312,12 +351,7 @@ class CFGDenoiser(torch.nn.Module): ...@@ -312,12 +351,7 @@ class CFGDenoiser(torch.nn.Module):
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]}) x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
denoised_uncond = x_out[-uncond.shape[0]:] denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
denoised = torch.clone(denoised_uncond)
for i, conds in enumerate(conds_list):
for cond_index, weight in conds:
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
if self.mask is not None: if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised denoised = self.init_latent * self.mask + self.nmask * denoised
...@@ -328,28 +362,55 @@ class CFGDenoiser(torch.nn.Module): ...@@ -328,28 +362,55 @@ class CFGDenoiser(torch.nn.Module):
class TorchHijack: class TorchHijack:
def __init__(self, kdiff_sampler): def __init__(self, sampler_noises):
self.kdiff_sampler = kdiff_sampler # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
# implementation.
self.sampler_noises = deque(sampler_noises)
def __getattr__(self, item): def __getattr__(self, item):
if item == 'randn_like': if item == 'randn_like':
return self.kdiff_sampler.randn_like return self.randn_like
if hasattr(torch, item): if hasattr(torch, item):
return getattr(torch, item) return getattr(torch, item)
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
def randn_like(self, x):
if self.sampler_noises:
noise = self.sampler_noises.popleft()
if noise.shape == x.shape:
return noise
if x.device.type == 'mps':
return torch.randn_like(x, device=devices.cpu).to(x.device)
else:
return torch.randn_like(x)
# MPS fix for randn in torchsde
def torchsde_randn(size, dtype, device, seed):
if device.type == 'mps':
generator = torch.Generator(devices.cpu).manual_seed(int(seed))
return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
else:
generator = torch.Generator(device).manual_seed(int(seed))
return torch.randn(size, dtype=dtype, device=device, generator=generator)
torchsde._brownian.brownian_interval._randn = torchsde_randn
class KDiffusionSampler: class KDiffusionSampler:
def __init__(self, funcname, sd_model): def __init__(self, funcname, sd_model):
self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization) denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
self.funcname = funcname self.funcname = funcname
self.func = getattr(k_diffusion.sampling, self.funcname) self.func = getattr(k_diffusion.sampling, self.funcname)
self.extra_params = sampler_extra_params.get(funcname, []) self.extra_params = sampler_extra_params.get(funcname, [])
self.model_wrap_cfg = CFGDenoiser(self.model_wrap) self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
self.sampler_noises = None self.sampler_noises = None
self.sampler_noise_index = 0
self.stop_at = None self.stop_at = None
self.eta = None self.eta = None
self.default_eta = 1.0 self.default_eta = 1.0
...@@ -382,26 +443,13 @@ class KDiffusionSampler: ...@@ -382,26 +443,13 @@ class KDiffusionSampler:
def number_of_needed_noises(self, p): def number_of_needed_noises(self, p):
return p.steps return p.steps
def randn_like(self, x):
noise = self.sampler_noises[self.sampler_noise_index] if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises) else None
if noise is not None and x.shape == noise.shape:
res = noise
else:
res = torch.randn_like(x)
self.sampler_noise_index += 1
return res
def initialize(self, p): def initialize(self, p):
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
self.model_wrap.step = 0 self.model_wrap.step = 0
self.sampler_noise_index = 0
self.eta = p.eta or opts.eta_ancestral self.eta = p.eta or opts.eta_ancestral
if self.sampler_noises is not None: k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
k_diffusion.sampling.torch = TorchHijack(self)
extra_params_kwargs = {} extra_params_kwargs = {}
for param_name in self.extra_params: for param_name in self.extra_params:
...@@ -413,16 +461,26 @@ class KDiffusionSampler: ...@@ -413,16 +461,26 @@ class KDiffusionSampler:
return extra_params_kwargs return extra_params_kwargs
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): def get_sigmas(self, p, steps):
steps, t_enc = setup_img2img_steps(p, steps)
if p.sampler_noise_scheduler_override: if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps) sigmas = p.sampler_noise_scheduler_override(steps)
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras': elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device) sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
else: else:
sigmas = self.model_wrap.get_sigmas(steps) sigmas = self.model_wrap.get_sigmas(steps)
if self.config is not None and self.config.options.get('discard_next_to_last_sigma', False):
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
return sigmas
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = setup_img2img_steps(p, steps)
sigmas = self.get_sigmas(p, steps)
sigma_sched = sigmas[steps - t_enc - 1:] sigma_sched = sigmas[steps - t_enc - 1:]
xi = x + noise * sigma_sched[0] xi = x + noise * sigma_sched[0]
...@@ -454,12 +512,7 @@ class KDiffusionSampler: ...@@ -454,12 +512,7 @@ class KDiffusionSampler:
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None): def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
steps = steps or p.steps steps = steps or p.steps
if p.sampler_noise_scheduler_override: sigmas = self.get_sigmas(p, steps)
sigmas = p.sampler_noise_scheduler_override(steps)
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
else:
sigmas = self.model_wrap.get_sigmas(steps)
x = x * sigmas[0] x = x * sigmas[0]
......
import torch import torch
import os import os
import collections
from collections import namedtuple from collections import namedtuple
from modules import shared, devices, script_callbacks from modules import shared, devices, script_callbacks
from modules.paths import models_path from modules.paths import models_path
import glob import glob
from copy import deepcopy
model_dir = "Stable-diffusion" model_dir = "Stable-diffusion"
...@@ -15,7 +17,7 @@ vae_path = os.path.abspath(os.path.join(models_path, vae_dir)) ...@@ -15,7 +17,7 @@ vae_path = os.path.abspath(os.path.join(models_path, vae_dir))
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
default_vae_dict = {"auto": "auto", "None": "None"} default_vae_dict = {"auto": "auto", "None": None, None: None}
default_vae_list = ["auto", "None"] default_vae_list = ["auto", "None"]
...@@ -29,6 +31,7 @@ base_vae = None ...@@ -29,6 +31,7 @@ base_vae = None
loaded_vae_file = None loaded_vae_file = None
checkpoint_info = None checkpoint_info = None
checkpoints_loaded = collections.OrderedDict()
def get_base_vae(model): def get_base_vae(model):
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model: if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
...@@ -39,7 +42,8 @@ def get_base_vae(model): ...@@ -39,7 +42,8 @@ def get_base_vae(model):
def store_base_vae(model): def store_base_vae(model):
global base_vae, checkpoint_info global base_vae, checkpoint_info
if checkpoint_info != model.sd_checkpoint_info: if checkpoint_info != model.sd_checkpoint_info:
base_vae = model.first_stage_model.state_dict().copy() assert not loaded_vae_file, "Trying to store non-base VAE!"
base_vae = deepcopy(model.first_stage_model.state_dict())
checkpoint_info = model.sd_checkpoint_info checkpoint_info = model.sd_checkpoint_info
...@@ -50,9 +54,11 @@ def delete_base_vae(): ...@@ -50,9 +54,11 @@ def delete_base_vae():
def restore_base_vae(model): def restore_base_vae(model):
global base_vae, checkpoint_info global loaded_vae_file
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info: if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
load_vae_dict(model, base_vae) print("Restoring base VAE")
_load_vae_dict(model, base_vae)
loaded_vae_file = None
delete_base_vae() delete_base_vae()
...@@ -83,47 +89,54 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path): ...@@ -83,47 +89,54 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path):
return vae_list return vae_list
def resolve_vae(checkpoint_file, vae_file="auto"): def get_vae_from_settings(vae_file="auto"):
# else, we load from settings, if not set to be default
if vae_file == "auto" and shared.opts.sd_vae is not None:
# if saved VAE settings isn't recognized, fallback to auto
vae_file = vae_dict.get(shared.opts.sd_vae, "auto")
# if VAE selected but not found, fallback to auto
if vae_file not in default_vae_values and not os.path.isfile(vae_file):
vae_file = "auto"
print(f"Selected VAE doesn't exist: {vae_file}")
return vae_file
def resolve_vae(checkpoint_file=None, vae_file="auto"):
global first_load, vae_dict, vae_list global first_load, vae_dict, vae_list
# if vae_file argument is provided, it takes priority, but not saved # if vae_file argument is provided, it takes priority, but not saved
if vae_file and vae_file not in default_vae_list: if vae_file and vae_file not in default_vae_list:
if not os.path.isfile(vae_file): if not os.path.isfile(vae_file):
print(f"VAE provided as function argument doesn't exist: {vae_file}")
vae_file = "auto" vae_file = "auto"
print("VAE provided as function argument doesn't exist")
# for the first load, if vae-path is provided, it takes priority, saved, and failure is reported # for the first load, if vae-path is provided, it takes priority, saved, and failure is reported
if first_load and shared.cmd_opts.vae_path is not None: if first_load and shared.cmd_opts.vae_path is not None:
if os.path.isfile(shared.cmd_opts.vae_path): if os.path.isfile(shared.cmd_opts.vae_path):
vae_file = shared.cmd_opts.vae_path vae_file = shared.cmd_opts.vae_path
shared.opts.data['sd_vae'] = get_filename(vae_file) shared.opts.data['sd_vae'] = get_filename(vae_file)
else: else:
print("VAE provided as command line argument doesn't exist") print(f"VAE provided as command line argument doesn't exist: {vae_file}")
# else, we load from settings # fallback to selector in settings, if vae selector not set to act as default fallback
if vae_file == "auto" and shared.opts.sd_vae is not None: if not shared.opts.sd_vae_as_default:
# if saved VAE settings isn't recognized, fallback to auto vae_file = get_vae_from_settings(vae_file)
vae_file = vae_dict.get(shared.opts.sd_vae, "auto")
# if VAE selected but not found, fallback to auto
if vae_file not in default_vae_values and not os.path.isfile(vae_file):
vae_file = "auto"
print("Selected VAE doesn't exist")
# vae-path cmd arg takes priority for auto # vae-path cmd arg takes priority for auto
if vae_file == "auto" and shared.cmd_opts.vae_path is not None: if vae_file == "auto" and shared.cmd_opts.vae_path is not None:
if os.path.isfile(shared.cmd_opts.vae_path): if os.path.isfile(shared.cmd_opts.vae_path):
vae_file = shared.cmd_opts.vae_path vae_file = shared.cmd_opts.vae_path
print("Using VAE provided as command line argument") print(f"Using VAE provided as command line argument: {vae_file}")
# if still not found, try look for ".vae.pt" beside model # if still not found, try look for ".vae.pt" beside model
model_path = os.path.splitext(checkpoint_file)[0] model_path = os.path.splitext(checkpoint_file)[0]
if vae_file == "auto": if vae_file == "auto":
vae_file_try = model_path + ".vae.pt" vae_file_try = model_path + ".vae.pt"
if os.path.isfile(vae_file_try): if os.path.isfile(vae_file_try):
vae_file = vae_file_try vae_file = vae_file_try
print("Using VAE found beside selected model") print(f"Using VAE found similar to selected model: {vae_file}")
# if still not found, try look for ".vae.ckpt" beside model # if still not found, try look for ".vae.ckpt" beside model
if vae_file == "auto": if vae_file == "auto":
vae_file_try = model_path + ".vae.ckpt" vae_file_try = model_path + ".vae.ckpt"
if os.path.isfile(vae_file_try): if os.path.isfile(vae_file_try):
vae_file = vae_file_try vae_file = vae_file_try
print("Using VAE found beside selected model") print(f"Using VAE found similar to selected model: {vae_file}")
# No more fallbacks for auto # No more fallbacks for auto
if vae_file == "auto": if vae_file == "auto":
vae_file = None vae_file = None
...@@ -138,11 +151,30 @@ def load_vae(model, vae_file=None): ...@@ -138,11 +151,30 @@ def load_vae(model, vae_file=None):
global first_load, vae_dict, vae_list, loaded_vae_file global first_load, vae_dict, vae_list, loaded_vae_file
# save_settings = False # save_settings = False
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
if vae_file: if vae_file:
if cache_enabled and vae_file in checkpoints_loaded:
# use vae checkpoint cache
print(f"Loading VAE weights [{get_filename(vae_file)}] from cache")
store_base_vae(model)
_load_vae_dict(model, checkpoints_loaded[vae_file])
else:
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
print(f"Loading VAE weights from: {vae_file}") print(f"Loading VAE weights from: {vae_file}")
store_base_vae(model)
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
load_vae_dict(model, vae_dict_1) _load_vae_dict(model, vae_dict_1)
if cache_enabled:
# cache newly loaded vae
checkpoints_loaded[vae_file] = vae_dict_1.copy()
# clean up cache if limit is reached
if cache_enabled:
while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model
checkpoints_loaded.popitem(last=False) # LRU
# If vae used is not in dict, update it # If vae used is not in dict, update it
# It will be removed on refresh though # It will be removed on refresh though
...@@ -150,30 +182,22 @@ def load_vae(model, vae_file=None): ...@@ -150,30 +182,22 @@ def load_vae(model, vae_file=None):
if vae_opt not in vae_dict: if vae_opt not in vae_dict:
vae_dict[vae_opt] = vae_file vae_dict[vae_opt] = vae_file
vae_list.append(vae_opt) vae_list.append(vae_opt)
elif loaded_vae_file:
restore_base_vae(model)
loaded_vae_file = vae_file loaded_vae_file = vae_file
"""
# Save current VAE to VAE settings, maybe? will it work?
if save_settings:
if vae_file is None:
vae_opt = "None"
# shared.opts.sd_vae = vae_opt
"""
first_load = False first_load = False
# don't call this from outside # don't call this from outside
def load_vae_dict(model, vae_dict_1=None): def _load_vae_dict(model, vae_dict_1):
if vae_dict_1:
store_base_vae(model)
model.first_stage_model.load_state_dict(vae_dict_1) model.first_stage_model.load_state_dict(vae_dict_1)
else:
restore_base_vae()
model.first_stage_model.to(devices.dtype_vae) model.first_stage_model.to(devices.dtype_vae)
def clear_loaded_vae():
global loaded_vae_file
loaded_vae_file = None
def reload_vae_weights(sd_model=None, vae_file="auto"): def reload_vae_weights(sd_model=None, vae_file="auto"):
from modules import lowvram, devices, sd_hijack from modules import lowvram, devices, sd_hijack
...@@ -203,5 +227,5 @@ def reload_vae_weights(sd_model=None, vae_file="auto"): ...@@ -203,5 +227,5 @@ def reload_vae_weights(sd_model=None, vae_file="auto"):
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device) sd_model.to(devices.device)
print(f"VAE Weights loaded.") print("VAE Weights loaded.")
return sd_model return sd_model
import os
import torch
from torch import nn
from modules import devices, paths
sd_vae_approx_model = None
class VAEApprox(nn.Module):
def __init__(self):
super(VAEApprox, self).__init__()
self.conv1 = nn.Conv2d(4, 8, (7, 7))
self.conv2 = nn.Conv2d(8, 16, (5, 5))
self.conv3 = nn.Conv2d(16, 32, (3, 3))
self.conv4 = nn.Conv2d(32, 64, (3, 3))
self.conv5 = nn.Conv2d(64, 32, (3, 3))
self.conv6 = nn.Conv2d(32, 16, (3, 3))
self.conv7 = nn.Conv2d(16, 8, (3, 3))
self.conv8 = nn.Conv2d(8, 3, (3, 3))
def forward(self, x):
extra = 11
x = nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2))
x = nn.functional.pad(x, (extra, extra, extra, extra))
for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8, ]:
x = layer(x)
x = nn.functional.leaky_relu(x, 0.1)
return x
def model():
global sd_vae_approx_model
if sd_vae_approx_model is None:
sd_vae_approx_model = VAEApprox()
sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt")))
sd_vae_approx_model.eval()
sd_vae_approx_model.to(devices.device, devices.dtype)
return sd_vae_approx_model
def cheap_approximation(sample):
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
coefs = torch.tensor([
[0.298, 0.207, 0.208],
[0.187, 0.286, 0.173],
[-0.158, 0.189, 0.264],
[-0.184, -0.271, -0.473],
]).to(sample.device)
x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
return x_sample
...@@ -3,26 +3,27 @@ import datetime ...@@ -3,26 +3,27 @@ import datetime
import json import json
import os import os
import sys import sys
from collections import OrderedDict
import time import time
from PIL import Image
import gradio as gr import gradio as gr
import tqdm import tqdm
import modules.artists import modules.artists
import modules.interrogate import modules.interrogate
import modules.memmon import modules.memmon
import modules.sd_models
import modules.styles import modules.styles
import modules.devices as devices import modules.devices as devices
from modules import sd_samplers, sd_models, localization, sd_vae from modules import localization, sd_vae, extensions, script_loading, errors
from modules.hypernetworks import hypernetwork
from modules.paths import models_path, script_path, sd_path from modules.paths import models_path, script_path, sd_path
demo = None
sd_model_file = os.path.join(script_path, 'model.ckpt') sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file default_sd_model_file = sd_model_file
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",) parser.add_argument("--config", type=str, default=os.path.join(script_path, "configs/v1-inference.yaml"), help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints") parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
...@@ -50,18 +51,15 @@ parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory wi ...@@ -50,18 +51,15 @@ parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory wi
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN')) parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN')) parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN'))
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN')) parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN'))
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET'))
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR'))
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR'))
parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None) parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers") parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work") parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator") parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.") parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
parser.add_argument("--use-cpu", nargs='+',choices=['all', 'sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'], help="use CPU as torch device for specified modules", default=[], type=str.lower) parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
...@@ -72,6 +70,7 @@ parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui ...@@ -72,6 +70,7 @@ parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option") parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="editor") parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="editor")
parser.add_argument("--gradio-inpaint-tool", type=str, choices=["sketch", "color-sketch"], default="sketch", help="gradio inpainting editor: can be either sketch to only blur/noise the input, or color-sketch to paint over it")
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv')) parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
...@@ -81,17 +80,24 @@ parser.add_argument("--disable-console-progressbars", action='store_true', help= ...@@ -81,17 +80,24 @@ parser.add_argument("--disable-console-progressbars", action='store_true', help=
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui") parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui") parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
parser.add_argument("--api-log", action='store_true', help="use api-log=True to enable logging of all API requests")
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui")
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False) parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origins", default=None) parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(s) in the form of a comma-separated list (no spaces)", default=None)
parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None)
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None) parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None) parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
script_loading.preload_extensions(extensions.extensions_dir, parser)
script_loading.preload_extensions(extensions.extensions_builtin_dir, parser)
cmd_opts = parser.parse_args() cmd_opts = parser.parse_args()
restricted_opts = { restricted_opts = {
"samples_filename_pattern", "samples_filename_pattern",
"directories_filename_pattern", "directories_filename_pattern",
...@@ -104,10 +110,21 @@ restricted_opts = { ...@@ -104,10 +110,21 @@ restricted_opts = {
"outdir_save", "outdir_save",
} }
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen) and not cmd_opts.enable_insecure_extension_access ui_reorder_categories = [
"sampler",
"dimensions",
"cfg",
"seed",
"checkboxes",
"hires_fix",
"batch",
"scripts",
]
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \ cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'])
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
device = devices.device device = devices.device
weight_load_location = None if cmd_opts.lowram else "cpu" weight_load_location = None if cmd_opts.lowram else "cpu"
...@@ -118,10 +135,12 @@ xformers_available = False ...@@ -118,10 +135,12 @@ xformers_available = False
config_filename = cmd_opts.ui_settings_file config_filename = cmd_opts.ui_settings_file
os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) hypernetworks = {}
loaded_hypernetwork = None loaded_hypernetwork = None
def reload_hypernetworks(): def reload_hypernetworks():
from modules.hypernetworks import hypernetwork
global hypernetworks global hypernetworks
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
...@@ -161,9 +180,10 @@ class State: ...@@ -161,9 +180,10 @@ class State:
def dict(self): def dict(self):
obj = { obj = {
"skipped": self.skipped, "skipped": self.skipped,
"interrupted": self.skipped, "interrupted": self.interrupted,
"job": self.job, "job": self.job,
"job_count": self.job_count, "job_count": self.job_count,
"job_timestamp": self.job_timestamp,
"job_no": self.job_no, "job_no": self.job_no,
"sampling_step": self.sampling_step, "sampling_step": self.sampling_step,
"sampling_steps": self.sampling_steps, "sampling_steps": self.sampling_steps,
...@@ -194,22 +214,25 @@ class State: ...@@ -194,22 +214,25 @@ class State:
"""sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this""" """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
def set_current_image(self): def set_current_image(self):
if not parallel_processing_allowed:
return
if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0: if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0:
self.do_set_current_image() self.do_set_current_image()
def do_set_current_image(self): def do_set_current_image(self):
if not parallel_processing_allowed:
return
if self.current_latent is None: if self.current_latent is None:
return return
import modules.sd_samplers
if opts.show_progress_grid: if opts.show_progress_grid:
self.current_image = sd_samplers.samples_to_image_grid(self.current_latent) self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent)
else: else:
self.current_image = sd_samplers.sample_to_image(self.current_latent) self.current_image = modules.sd_samplers.sample_to_image(self.current_latent)
self.current_image_sampling_step = self.sampling_step self.current_image_sampling_step = self.sampling_step
state = State() state = State()
artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv')) artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
...@@ -245,6 +268,21 @@ def options_section(section_identifier, options_dict): ...@@ -245,6 +268,21 @@ def options_section(section_identifier, options_dict):
return options_dict return options_dict
def list_checkpoint_tiles():
import modules.sd_models
return modules.sd_models.checkpoint_tiles()
def refresh_checkpoints():
import modules.sd_models
return modules.sd_models.list_models()
def list_samplers():
import modules.sd_samplers
return modules.sd_samplers.all_samplers
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config} hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
options_templates = {} options_templates = {}
...@@ -271,8 +309,13 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" ...@@ -271,8 +309,13 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"), "export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
"use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"), "use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"), "save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
"do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"), "do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"),
"temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"),
"clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
})) }))
options_templates.update(options_section(('saving-paths', "Paths for saving"), { options_templates.update(options_section(('saving-paths', "Paths for saving"), {
...@@ -297,12 +340,8 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo ...@@ -297,12 +340,8 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo
options_templates.update(options_section(('upscaling', "Upscaling"), { options_templates.update(options_section(('upscaling', "Upscaling"), {
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN x4+", "R-ESRGAN x4+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}), "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
"SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
"use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space image when doing hires. fix"),
})) }))
options_templates.update(options_section(('face-restoration', "Face restoration"), { options_templates.update(options_section(('face-restoration', "Face restoration"), {
...@@ -319,7 +358,8 @@ options_templates.update(options_section(('system', "System"), { ...@@ -319,7 +358,8 @@ options_templates.update(options_section(('system', "System"), {
options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."), "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
...@@ -328,24 +368,31 @@ options_templates.update(options_section(('training', "Training"), { ...@@ -328,24 +368,31 @@ options_templates.update(options_section(('training', "Training"), {
})) }))
options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": list(sd_vae.vae_list)}, refresh=sd_vae.refresh_vae_list), "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
"sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01 }),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", gr.ColorPicker, {}),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."), "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
"filter_nsfw": OptionInfo(False, "Filter NSFW content"), 'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
})) }))
options_templates.update(options_section(('compatibility', "Compatibility"), {
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
}))
options_templates.update(options_section(('interrogate', "Interrogate Options"), { options_templates.update(options_section(('interrogate', "Interrogate Options"), {
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"), "interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
"interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"), "interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
...@@ -358,11 +405,13 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), ...@@ -358,11 +405,13 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
"deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"), "deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
"deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"), "deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),
"deepbooru_escape": OptionInfo(True, "escape (\\) brackets in deepbooru (so they are used as literal brackets and not for emphasis)"), "deepbooru_escape": OptionInfo(True, "escape (\\) brackets in deepbooru (so they are used as literal brackets and not for emphasis)"),
"deepbooru_filter_tags": OptionInfo("", "filter out those tags from deepbooru output (separated by comma)"),
})) }))
options_templates.update(options_section(('ui', "User interface"), { options_templates.update(options_section(('ui', "User interface"), {
"show_progressbar": OptionInfo(True, "Show progressbar"), "show_progressbar": OptionInfo(True, "Show progressbar"),
"show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}), "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
"show_progress_type": OptionInfo("Full", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"), "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"return_grid": OptionInfo(True, "Show grid in results for web"), "return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
...@@ -370,16 +419,20 @@ options_templates.update(options_section(('ui', "User interface"), { ...@@ -370,16 +419,20 @@ options_templates.update(options_section(('ui', "User interface"), {
"add_model_name_to_info": OptionInfo(False, "Add model name to generation information"), "add_model_name_to_info": OptionInfo(False, "Add model name to generation information"),
"disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."), "disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"), "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
"font": OptionInfo("", "Font for image grids that have text"), "font": OptionInfo("", "Font for image grids that have text"),
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"),
"dimensions_and_batch_together": OptionInfo(True, "Show Witdth/Height and Batch sliders in same row"),
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/ing2img UI item order"),
'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
})) }))
options_templates.update(options_section(('sampler-params', "Sampler parameters"), { options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
"hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in sd_samplers.all_samplers]}), "hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}),
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}), "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
...@@ -432,6 +485,28 @@ class Options: ...@@ -432,6 +485,28 @@ class Options:
return super(Options, self).__getattribute__(item) return super(Options, self).__getattribute__(item)
def set(self, key, value):
"""sets an option and calls its onchange callback, returning True if the option changed and False otherwise"""
oldval = self.data.get(key, None)
if oldval == value:
return False
try:
setattr(self, key, value)
except RuntimeError:
return False
if self.data_labels[key].onchange is not None:
try:
self.data_labels[key].onchange()
except Exception as e:
errors.display(e, f"changing setting {key} to {value}")
setattr(self, key, oldval)
return False
return True
def save(self, filename): def save(self, filename):
assert not cmd_opts.freeze_settings, "saving settings is disabled" assert not cmd_opts.freeze_settings, "saving settings is disabled"
...@@ -491,6 +566,15 @@ opts = Options() ...@@ -491,6 +566,15 @@ opts = Options()
if os.path.exists(config_filename): if os.path.exists(config_filename):
opts.load(config_filename) opts.load(config_filename)
latent_upscale_default_mode = "Latent"
latent_upscale_modes = {
"Latent": {"mode": "bilinear", "antialias": False},
"Latent (antialiased)": {"mode": "bilinear", "antialias": True},
"Latent (bicubic)": {"mode": "bicubic", "antialias": False},
"Latent (bicubic antialiased)": {"mode": "bicubic", "antialias": True},
"Latent (nearest)": {"mode": "nearest", "antialias": False},
}
sd_upscalers = [] sd_upscalers = []
sd_model = None sd_model = None
......
...@@ -65,17 +65,6 @@ class StyleDatabase: ...@@ -65,17 +65,6 @@ class StyleDatabase:
def apply_negative_styles_to_prompt(self, prompt, styles): def apply_negative_styles_to_prompt(self, prompt, styles):
return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles]) return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])
def apply_styles(self, p: StableDiffusionProcessing) -> None:
if isinstance(p.prompt, list):
p.prompt = [self.apply_styles_to_prompt(prompt, p.styles) for prompt in p.prompt]
else:
p.prompt = self.apply_styles_to_prompt(p.prompt, p.styles)
if isinstance(p.negative_prompt, list):
p.negative_prompt = [self.apply_negative_styles_to_prompt(prompt, p.styles) for prompt in p.negative_prompt]
else:
p.negative_prompt = self.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)
def save_styles(self, path: str) -> None: def save_styles(self, path: str) -> None:
# Write to temporary file first, so we don't nuke the file if something goes wrong # Write to temporary file first, so we don't nuke the file if something goes wrong
fd, temp_path = tempfile.mkstemp(".csv") fd, temp_path = tempfile.mkstemp(".csv")
......
...@@ -276,8 +276,8 @@ def poi_average(pois, settings): ...@@ -276,8 +276,8 @@ def poi_average(pois, settings):
weight += poi.weight weight += poi.weight
x += poi.x * poi.weight x += poi.x * poi.weight
y += poi.y * poi.weight y += poi.y * poi.weight
avg_x = round(x / weight) avg_x = round(weight and x / weight)
avg_y = round(y / weight) avg_y = round(weight and y / weight)
return PointOfInterest(avg_x, avg_y) return PointOfInterest(avg_x, avg_y)
......
...@@ -3,7 +3,7 @@ import numpy as np ...@@ -3,7 +3,7 @@ import numpy as np
import PIL import PIL
import torch import torch
from PIL import Image from PIL import Image
from torch.utils.data import Dataset from torch.utils.data import Dataset, DataLoader
from torchvision import transforms from torchvision import transforms
import random import random
...@@ -11,25 +11,28 @@ import tqdm ...@@ -11,25 +11,28 @@ import tqdm
from modules import devices, shared from modules import devices, shared
import re import re
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
re_numbers_at_start = re.compile(r"^[-\d]+\s*") re_numbers_at_start = re.compile(r"^[-\d]+\s*")
class DatasetEntry: class DatasetEntry:
def __init__(self, filename=None, latent=None, filename_text=None): def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None):
self.filename = filename self.filename = filename
self.latent = latent
self.filename_text = filename_text self.filename_text = filename_text
self.cond = None self.latent_dist = latent_dist
self.cond_text = None self.latent_sample = latent_sample
self.cond = cond
self.cond_text = cond_text
self.pixel_values = pixel_values
class PersonalizedBase(Dataset): class PersonalizedBase(Dataset):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1): def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
self.placeholder_token = placeholder_token self.placeholder_token = placeholder_token
self.batch_size = batch_size
self.width = width self.width = width
self.height = height self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p) self.flip = transforms.RandomHorizontalFlip(p=flip_p)
...@@ -45,11 +48,16 @@ class PersonalizedBase(Dataset): ...@@ -45,11 +48,16 @@ class PersonalizedBase(Dataset):
assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty" assert os.listdir(data_root), "Dataset directory is empty"
cond_model = shared.sd_model.cond_stage_model
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
self.shuffle_tags = shuffle_tags
self.tag_drop_out = tag_drop_out
print("Preparing dataset...") print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths): for path in tqdm.tqdm(self.image_paths):
if shared.state.interrupted:
raise Exception("interrupted")
try: try:
image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC) image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
except Exception: except Exception:
...@@ -71,53 +79,94 @@ class PersonalizedBase(Dataset): ...@@ -71,53 +79,94 @@ class PersonalizedBase(Dataset):
npimage = np.array(image).astype(np.uint8) npimage = np.array(image).astype(np.uint8)
npimage = (npimage / 127.5 - 1.0).astype(np.float32) npimage = (npimage / 127.5 - 1.0).astype(np.float32)
torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32) torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
torchdata = torch.moveaxis(torchdata, 2, 0) latent_sample = None
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze() with devices.autocast():
init_latent = init_latent.to(devices.cpu) latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent) if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)):
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
if include_cond: latent_sampling_method = "once"
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
elif latent_sampling_method == "deterministic":
# Works only for DiagonalGaussianDistribution
latent_dist.std = 0
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
elif latent_sampling_method == "random":
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist)
if not (self.tag_drop_out != 0 or self.shuffle_tags):
entry.cond_text = self.create_text(filename_text) entry.cond_text = self.create_text(filename_text)
if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
with devices.autocast():
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
self.dataset.append(entry) self.dataset.append(entry)
del torchdata
del latent_dist
del latent_sample
assert len(self.dataset) > 0, "No images have been found in the dataset." self.length = len(self.dataset)
self.length = len(self.dataset) * repeats // batch_size assert self.length > 0, "No images have been found in the dataset."
self.batch_size = min(batch_size, self.length)
self.dataset_length = len(self.dataset) self.gradient_step = min(gradient_step, self.length // self.batch_size)
self.indexes = None self.latent_sampling_method = latent_sampling_method
self.shuffle()
def shuffle(self):
self.indexes = np.random.permutation(self.dataset_length)
def create_text(self, filename_text): def create_text(self, filename_text):
text = random.choice(self.lines) text = random.choice(self.lines)
tags = filename_text.split(',')
if self.tag_drop_out != 0:
tags = [t for t in tags if random.random() > self.tag_drop_out]
if self.shuffle_tags:
random.shuffle(tags)
text = text.replace("[filewords]", ','.join(tags))
text = text.replace("[name]", self.placeholder_token) text = text.replace("[name]", self.placeholder_token)
text = text.replace("[filewords]", filename_text)
return text return text
def __len__(self): def __len__(self):
return self.length return self.length
def __getitem__(self, i): def __getitem__(self, i):
res = [] entry = self.dataset[i]
if self.tag_drop_out != 0 or self.shuffle_tags:
entry.cond_text = self.create_text(entry.filename_text)
if self.latent_sampling_method == "random":
entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
return entry
class PersonalizedDataLoader(DataLoader):
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
super(PersonalizedDataLoader, self).__init__(dataset, shuffle=True, drop_last=True, batch_size=batch_size, pin_memory=pin_memory)
if latent_sampling_method == "random":
self.collate_fn = collate_wrapper_random
else:
self.collate_fn = collate_wrapper
for j in range(self.batch_size):
position = i * self.batch_size + j
if position % len(self.indexes) == 0:
self.shuffle()
index = self.indexes[position % len(self.indexes)] class BatchLoader:
entry = self.dataset[index] def __init__(self, data):
self.cond_text = [entry.cond_text for entry in data]
self.cond = [entry.cond for entry in data]
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
#self.emb_index = [entry.emb_index for entry in data]
#print(self.latent_sample.device)
if entry.cond is None: def pin_memory(self):
entry.cond_text = self.create_text(entry.filename_text) self.latent_sample = self.latent_sample.pin_memory()
return self
def collate_wrapper(batch):
return BatchLoader(batch)
class BatchLoaderRandom(BatchLoader):
def __init__(self, data):
super().__init__(data)
res.append(entry) def pin_memory(self):
return self
return res def collate_wrapper_random(batch):
return BatchLoaderRandom(batch)
\ No newline at end of file
...@@ -6,12 +6,10 @@ import sys ...@@ -6,12 +6,10 @@ import sys
import tqdm import tqdm
import time import time
from modules import shared, images from modules import shared, images, deepbooru
from modules.paths import models_path from modules.paths import models_path
from modules.shared import opts, cmd_opts from modules.shared import opts, cmd_opts
from modules.textual_inversion import autocrop from modules.textual_inversion import autocrop
if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru
def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False): def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
...@@ -20,9 +18,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce ...@@ -20,9 +18,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce
shared.interrogator.load() shared.interrogator.load()
if process_caption_deepbooru: if process_caption_deepbooru:
db_opts = deepbooru.create_deepbooru_opts() deepbooru.model.start()
db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug) preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug)
...@@ -32,7 +28,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce ...@@ -32,7 +28,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce
shared.interrogator.send_blip_to_ram() shared.interrogator.send_blip_to_ram()
if process_caption_deepbooru: if process_caption_deepbooru:
deepbooru.release_process() deepbooru.model.stop()
def listfiles(dirname): def listfiles(dirname):
...@@ -58,7 +54,7 @@ def save_pic_with_caption(image, index, params: PreprocessParams, existing_capti ...@@ -58,7 +54,7 @@ def save_pic_with_caption(image, index, params: PreprocessParams, existing_capti
if params.process_caption_deepbooru: if params.process_caption_deepbooru:
if len(caption) > 0: if len(caption) > 0:
caption += ", " caption += ", "
caption += deepbooru.get_tags_from_process(image) caption += deepbooru.model.tag_multi(image)
filename_part = params.src filename_part = params.src
filename_part = os.path.splitext(filename_part)[0] filename_part = os.path.splitext(filename_part)[0]
...@@ -128,6 +124,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre ...@@ -128,6 +124,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
files = listfiles(src) files = listfiles(src)
shared.state.job = "preprocess"
shared.state.textinfo = "Preprocessing..." shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files) shared.state.job_count = len(files)
......
...@@ -10,7 +10,7 @@ import csv ...@@ -10,7 +10,7 @@ import csv
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
from modules import shared, devices, sd_hijack, processing, sd_models, images from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers
import modules.textual_inversion.dataset import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.learn_schedule import LearnRateScheduler
...@@ -23,9 +23,12 @@ class Embedding: ...@@ -23,9 +23,12 @@ class Embedding:
self.vec = vec self.vec = vec
self.name = name self.name = name
self.step = step self.step = step
self.shape = None
self.vectors = 0
self.cached_checksum = None self.cached_checksum = None
self.sd_checkpoint = None self.sd_checkpoint = None
self.sd_checkpoint_name = None self.sd_checkpoint_name = None
self.optimizer_state_dict = None
def save(self, filename): def save(self, filename):
embedding_data = { embedding_data = {
...@@ -39,6 +42,13 @@ class Embedding: ...@@ -39,6 +42,13 @@ class Embedding:
torch.save(embedding_data, filename) torch.save(embedding_data, filename)
if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None:
optimizer_saved_dict = {
'hash': self.checksum(),
'optimizer_state_dict': self.optimizer_state_dict,
}
torch.save(optimizer_saved_dict, filename + '.optim')
def checksum(self): def checksum(self):
if self.cached_checksum is not None: if self.cached_checksum is not None:
return self.cached_checksum return self.cached_checksum
...@@ -57,14 +67,17 @@ class EmbeddingDatabase: ...@@ -57,14 +67,17 @@ class EmbeddingDatabase:
def __init__(self, embeddings_dir): def __init__(self, embeddings_dir):
self.ids_lookup = {} self.ids_lookup = {}
self.word_embeddings = {} self.word_embeddings = {}
self.skipped_embeddings = {}
self.dir_mtime = None self.dir_mtime = None
self.embeddings_dir = embeddings_dir self.embeddings_dir = embeddings_dir
self.expected_shape = -1
def register_embedding(self, embedding, model): def register_embedding(self, embedding, model):
self.word_embeddings[embedding.name] = embedding self.word_embeddings[embedding.name] = embedding
ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0] # TODO changing between clip and open clip changes tokenization, which will cause embeddings to stop working
ids = model.cond_stage_model.tokenize([embedding.name])[0]
first_id = ids[0] first_id = ids[0]
if first_id not in self.ids_lookup: if first_id not in self.ids_lookup:
...@@ -74,21 +87,26 @@ class EmbeddingDatabase: ...@@ -74,21 +87,26 @@ class EmbeddingDatabase:
return embedding return embedding
def load_textual_inversion_embeddings(self): def get_expected_shape(self):
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
return vec.shape[1]
def load_textual_inversion_embeddings(self, force_reload = False):
mt = os.path.getmtime(self.embeddings_dir) mt = os.path.getmtime(self.embeddings_dir)
if self.dir_mtime is not None and mt <= self.dir_mtime: if not force_reload and self.dir_mtime is not None and mt <= self.dir_mtime:
return return
self.dir_mtime = mt self.dir_mtime = mt
self.ids_lookup.clear() self.ids_lookup.clear()
self.word_embeddings.clear() self.word_embeddings.clear()
self.skipped_embeddings.clear()
self.expected_shape = self.get_expected_shape()
def process_file(path, filename): def process_file(path, filename):
name = os.path.splitext(filename)[0] name, ext = os.path.splitext(filename)
ext = ext.upper()
data = [] if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
embed_image = Image.open(path) embed_image = Image.open(path)
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text: if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
data = embedding_from_b64(embed_image.text['sd-ti-embedding']) data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
...@@ -96,8 +114,10 @@ class EmbeddingDatabase: ...@@ -96,8 +114,10 @@ class EmbeddingDatabase:
else: else:
data = extract_image_data_embed(embed_image) data = extract_image_data_embed(embed_image)
name = data.get('name', name) name = data.get('name', name)
else: elif ext in ['.BIN', '.PT']:
data = torch.load(path, map_location="cpu") data = torch.load(path, map_location="cpu")
else:
return
# textual inversion embeddings # textual inversion embeddings
if 'string_to_param' in data: if 'string_to_param' in data:
...@@ -121,7 +141,13 @@ class EmbeddingDatabase: ...@@ -121,7 +141,13 @@ class EmbeddingDatabase:
embedding.step = data.get('step', None) embedding.step = data.get('step', None)
embedding.sd_checkpoint = data.get('sd_checkpoint', None) embedding.sd_checkpoint = data.get('sd_checkpoint', None)
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
embedding.vectors = vec.shape[0]
embedding.shape = vec.shape[-1]
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
self.register_embedding(embedding, shared.sd_model) self.register_embedding(embedding, shared.sd_model)
else:
self.skipped_embeddings[name] = embedding
for fn in os.listdir(self.embeddings_dir): for fn in os.listdir(self.embeddings_dir):
try: try:
...@@ -132,12 +158,13 @@ class EmbeddingDatabase: ...@@ -132,12 +158,13 @@ class EmbeddingDatabase:
process_file(fullfn, fn) process_file(fullfn, fn)
except Exception: except Exception:
print(f"Error loading emedding {fn}:", file=sys.stderr) print(f"Error loading embedding {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
continue continue
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.") print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
print("Embeddings:", ', '.join(self.word_embeddings.keys())) if len(self.skipped_embeddings) > 0:
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
def find_embedding_at_position(self, tokens, offset): def find_embedding_at_position(self, tokens, offset):
token = tokens[offset] token = tokens[offset]
...@@ -155,13 +182,11 @@ class EmbeddingDatabase: ...@@ -155,13 +182,11 @@ class EmbeddingDatabase:
def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
cond_model = shared.sd_model.cond_stage_model cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
with devices.autocast(): with devices.autocast():
cond_model([""]) # will send cond model to GPU if lowvram/medvram is active cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"] embedded = cond_model.encode_embedding_init_text(init_text, num_vectors_per_token)
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device) vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
for i in range(num_vectors_per_token): for i in range(num_vectors_per_token):
...@@ -184,7 +209,7 @@ def write_loss(log_directory, filename, step, epoch_len, values): ...@@ -184,7 +209,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
if shared.opts.training_write_csv_every == 0: if shared.opts.training_write_csv_every == 0:
return return
if (step + 1) % shared.opts.training_write_csv_every != 0: if step % shared.opts.training_write_csv_every != 0:
return return
write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
...@@ -194,21 +219,23 @@ def write_loss(log_directory, filename, step, epoch_len, values): ...@@ -194,21 +219,23 @@ def write_loss(log_directory, filename, step, epoch_len, values):
if write_csv_header: if write_csv_header:
csv_writer.writeheader() csv_writer.writeheader()
epoch = step // epoch_len epoch = (step - 1) // epoch_len
epoch_step = step % epoch_len epoch_step = (step - 1) % epoch_len
csv_writer.writerow({ csv_writer.writerow({
"step": step + 1, "step": step,
"epoch": epoch, "epoch": epoch,
"epoch_step": epoch_step + 1, "epoch_step": epoch_step,
**values, **values,
}) })
def validate_train_inputs(model_name, learn_rate, batch_size, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"): def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
assert model_name, f"{name} not selected" assert model_name, f"{name} not selected"
assert learn_rate, "Learning rate is empty or 0" assert learn_rate, "Learning rate is empty or 0"
assert isinstance(batch_size, int), "Batch size must be integer" assert isinstance(batch_size, int), "Batch size must be integer"
assert batch_size > 0, "Batch size must be positive" assert batch_size > 0, "Batch size must be positive"
assert isinstance(gradient_step, int), "Gradient accumulation step must be integer"
assert gradient_step > 0, "Gradient accumulation step must be positive"
assert data_root, "Dataset directory is empty" assert data_root, "Dataset directory is empty"
assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty" assert os.listdir(data_root), "Dataset directory is empty"
...@@ -224,11 +251,12 @@ def validate_train_inputs(model_name, learn_rate, batch_size, data_root, templat ...@@ -224,11 +251,12 @@ def validate_train_inputs(model_name, learn_rate, batch_size, data_root, templat
if save_model_every or create_image_every: if save_model_every or create_image_every:
assert log_directory, "Log directory is empty" assert log_directory, "Log directory is empty"
def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0 save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0 create_image_every = create_image_every or 0
validate_train_inputs(embedding_name, learn_rate, batch_size, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding") validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
shared.state.job = "train-embedding"
shared.state.textinfo = "Initializing textual inversion training..." shared.state.textinfo = "Initializing textual inversion training..."
shared.state.job_count = steps shared.state.job_count = steps
...@@ -255,19 +283,16 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -255,19 +283,16 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
else: else:
images_embeds_dir = None images_embeds_dir = None
cond_model = shared.sd_model.cond_stage_model
hijack = sd_hijack.model_hijack hijack = sd_hijack.model_hijack
embedding = hijack.embedding_db.word_embeddings[embedding_name] embedding = hijack.embedding_db.word_embeddings[embedding_name]
checkpoint = sd_models.select_checkpoint() checkpoint = sd_models.select_checkpoint()
ititial_step = embedding.step or 0 initial_step = embedding.step or 0
if ititial_step >= steps: if initial_step >= steps:
shared.state.textinfo = f"Model has already been trained beyond specified max steps" shared.state.textinfo = "Model has already been trained beyond specified max steps"
return embedding, filename return embedding, filename
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \ clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \ torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
...@@ -276,67 +301,121 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -276,67 +301,121 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False) clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
# dataset loading may take a while, so input validations and early returns should be done before this # dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"): old_parallel_processing_allowed = shared.parallel_processing_allowed
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
pin_memory = shared.opts.pin_memory
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
latent_sampling_method = ds.latent_sampling_method
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
if unload: if unload:
shared.parallel_processing_allowed = False
shared.sd_model.first_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu)
embedding.vec.requires_grad = True embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
if shared.opts.save_optimizer_state:
optimizer_state_dict = None
if os.path.exists(filename + '.optim'):
optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu')
if embedding.checksum() == optimizer_saved_dict.get('hash', None):
optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
if optimizer_state_dict is not None:
optimizer.load_state_dict(optimizer_state_dict)
print("Loaded existing optimizer from checkpoint")
else:
print("No saved optimizer exists in checkpoint")
scaler = torch.cuda.amp.GradScaler()
losses = torch.zeros((32,)) batch_size = ds.batch_size
gradient_step = ds.gradient_step
# n steps = batch_size * gradient_step * n image processed
steps_per_epoch = len(ds) // batch_size // gradient_step
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
loss_step = 0
_loss_step = 0 #internal
last_saved_file = "<none>" last_saved_file = "<none>"
last_saved_image = "<none>" last_saved_image = "<none>"
forced_filename = "<none>" forced_filename = "<none>"
embedding_yet_to_be_embedded = False embedding_yet_to_be_embedded = False
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) is_training_inpainting_model = shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}
for i, entries in pbar: img_c = None
embedding.step = i + ititial_step
pbar = tqdm.tqdm(total=steps - initial_step)
try:
for i in range((steps-initial_step) * gradient_step):
if scheduler.finished:
break
if shared.state.interrupted:
break
for j, batch in enumerate(dl):
# works as a drop_last=True for gradient accumulation
if j == max_steps_per_epoch:
break
scheduler.apply(optimizer, embedding.step) scheduler.apply(optimizer, embedding.step)
if scheduler.finished: if scheduler.finished:
break break
if shared.state.interrupted: if shared.state.interrupted:
break break
if clip_grad: if clip_grad:
clip_grad_sched.step(embedding.step) clip_grad_sched.step(embedding.step)
with torch.autocast("cuda"): with devices.autocast():
c = cond_model([entry.cond_text for entry in entries]) x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
x = torch.stack([entry.latent for entry in entries]).to(devices.device) c = shared.sd_model.cond_stage_model(batch.cond_text)
loss = shared.sd_model(x, c)[0]
if is_training_inpainting_model:
if img_c is None:
img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)
cond = {"c_concat": [img_c], "c_crossattn": [c]}
else:
cond = c
loss = shared.sd_model(x, cond)[0] / gradient_step
del x del x
losses[embedding.step % losses.shape[0]] = loss.item() _loss_step += loss.item()
scaler.scale(loss).backward()
optimizer.zero_grad() # go back until we reach gradient accumulation steps
loss.backward() if (j + 1) % gradient_step != 0:
continue
if clip_grad: if clip_grad:
clip_grad(embedding.vec, clip_grad_sched.learn_rate) clip_grad(embedding.vec, clip_grad_sched.learn_rate)
optimizer.step() scaler.step(optimizer)
scaler.update()
embedding.step += 1
pbar.update()
optimizer.zero_grad(set_to_none=True)
loss_step = _loss_step
_loss_step = 0
steps_done = embedding.step + 1 steps_done = embedding.step + 1
epoch_num = embedding.step // len(ds) epoch_num = embedding.step // steps_per_epoch
epoch_step = embedding.step % len(ds) epoch_step = embedding.step % steps_per_epoch
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}")
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
if embedding_dir is not None and steps_done % save_embedding_every == 0: if embedding_dir is not None and steps_done % save_embedding_every == 0:
# Before saving, change name to match current checkpoint. # Before saving, change name to match current checkpoint.
embedding_name_every = f'{embedding_name}-{steps_done}' embedding_name_every = f'{embedding_name}-{steps_done}'
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt') last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True) save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
embedding_yet_to_be_embedded = True embedding_yet_to_be_embedded = True
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
"loss": f"{losses.mean():.7f}", "loss": f"{loss_step:.7f}",
"learn_rate": scheduler.learn_rate "learn_rate": scheduler.learn_rate
}) })
...@@ -357,13 +436,13 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -357,13 +436,13 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
p.prompt = preview_prompt p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt p.negative_prompt = preview_negative_prompt
p.steps = preview_steps p.steps = preview_steps
p.sampler_index = preview_sampler_index p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
p.cfg_scale = preview_cfg_scale p.cfg_scale = preview_cfg_scale
p.seed = preview_seed p.seed = preview_seed
p.width = preview_width p.width = preview_width
p.height = preview_height p.height = preview_height
else: else:
p.prompt = entries[0].cond_text p.prompt = batch.cond_text[0]
p.steps = 20 p.steps = 20
p.width = training_width p.width = training_width
p.height = training_height p.height = training_height
...@@ -371,12 +450,15 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -371,12 +450,15 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
preview_text = p.prompt preview_text = p.prompt
processed = processing.process_images(p) processed = processing.process_images(p)
image = processed.images[0] image = processed.images[0] if len(processed.images) > 0 else None
if unload: if unload:
shared.sd_model.first_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu)
if image is not None:
shared.state.current_image = image shared.state.current_image = image
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
...@@ -411,21 +493,27 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -411,21 +493,27 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
shared.state.textinfo = f""" shared.state.textinfo = f"""
<p> <p>
Loss: {losses.mean():.7f}<br/> Loss: {loss_step:.7f}<br/>
Step: {embedding.step}<br/> Step: {steps_done}<br/>
Last prompt: {html.escape(entries[0].cond_text)}<br/> Last prompt: {html.escape(batch.cond_text[0])}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/> Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/> Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>
""" """
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True) save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
except Exception:
print(traceback.format_exc(), file=sys.stderr)
pass
finally:
pbar.leave = False
pbar.close()
shared.sd_model.first_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device)
shared.parallel_processing_allowed = old_parallel_processing_allowed
return embedding, filename return embedding, filename
def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True): def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
old_embedding_name = embedding.name old_embedding_name = embedding.name
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
...@@ -436,6 +524,7 @@ def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cache ...@@ -436,6 +524,7 @@ def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cache
if remove_cached_checksum: if remove_cached_checksum:
embedding.cached_checksum = None embedding.cached_checksum = None
embedding.name = embedding_name embedding.name = embedding_name
embedding.optimizer_state_dict = optimizer.state_dict()
embedding.save(filename) embedding.save(filename)
except: except:
embedding.sd_checkpoint = old_sd_checkpoint embedding.sd_checkpoint = old_sd_checkpoint
......
...@@ -18,7 +18,7 @@ def create_embedding(name, initialization_text, nvpt, overwrite_old): ...@@ -18,7 +18,7 @@ def create_embedding(name, initialization_text, nvpt, overwrite_old):
def preprocess(*args): def preprocess(*args):
modules.textual_inversion.preprocess.preprocess(*args) modules.textual_inversion.preprocess.preprocess(*args)
return "Preprocessing finished.", "" return f"Preprocessing {'interrupted' if shared.state.interrupted else 'finished'}.", ""
def train_embedding(*args): def train_embedding(*args):
......
import modules.scripts import modules.scripts
from modules import sd_samplers
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \ from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
StableDiffusionProcessingImg2Img, process_images StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, cmd_opts from modules.shared import opts, cmd_opts
...@@ -7,7 +8,7 @@ import modules.processing as processing ...@@ -7,7 +8,7 @@ import modules.processing as processing
from modules.ui import plaintext_to_html from modules.ui import plaintext_to_html
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args): def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, *args):
p = StableDiffusionProcessingTxt2Img( p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model, sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
...@@ -21,7 +22,7 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: ...@@ -21,7 +22,7 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
seed_resize_from_h=seed_resize_from_h, seed_resize_from_h=seed_resize_from_h,
seed_resize_from_w=seed_resize_from_w, seed_resize_from_w=seed_resize_from_w,
seed_enable_extras=seed_enable_extras, seed_enable_extras=seed_enable_extras,
sampler_index=sampler_index, sampler_name=sd_samplers.samplers[sampler_index].name,
batch_size=batch_size, batch_size=batch_size,
n_iter=n_iter, n_iter=n_iter,
steps=steps, steps=steps,
...@@ -32,8 +33,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: ...@@ -32,8 +33,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
tiling=tiling, tiling=tiling,
enable_hr=enable_hr, enable_hr=enable_hr,
denoising_strength=denoising_strength if enable_hr else None, denoising_strength=denoising_strength if enable_hr else None,
firstphase_width=firstphase_width if enable_hr else None, hr_scale=hr_scale,
firstphase_height=firstphase_height if enable_hr else None, hr_upscaler=hr_upscaler,
) )
p.scripts = modules.scripts.scripts_txt2img p.scripts = modules.scripts.scripts_txt2img
...@@ -58,4 +59,4 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: ...@@ -58,4 +59,4 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
if opts.do_not_show_images: if opts.do_not_show_images:
processed.images = [] processed.images = []
return processed.images, generation_info_js, plaintext_to_html(processed.info) return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
...@@ -17,21 +17,18 @@ import gradio.routes ...@@ -17,21 +17,18 @@ import gradio.routes
import gradio.utils import gradio.utils
import numpy as np import numpy as np
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions from modules.ui_components import FormRow, FormGroup, ToolButton
from modules.paths import script_path from modules.paths import script_path
from modules.shared import opts, cmd_opts, restricted_opts from modules.shared import opts, cmd_opts, restricted_opts
if cmd_opts.deepdanbooru:
from modules.deepbooru import get_deepbooru_tags
import modules.codeformer_model import modules.codeformer_model
import modules.generation_parameters_copypaste as parameters_copypaste import modules.generation_parameters_copypaste as parameters_copypaste
import modules.gfpgan_model import modules.gfpgan_model
import modules.hypernetworks.ui import modules.hypernetworks.ui
import modules.ldsr_model
import modules.scripts import modules.scripts
import modules.shared as shared import modules.shared as shared
import modules.styles import modules.styles
...@@ -53,10 +50,14 @@ if not cmd_opts.share and not cmd_opts.listen: ...@@ -53,10 +50,14 @@ if not cmd_opts.share and not cmd_opts.listen:
gradio.utils.version_check = lambda: None gradio.utils.version_check = lambda: None
gradio.utils.get_local_ip_address = lambda: '127.0.0.1' gradio.utils.get_local_ip_address = lambda: '127.0.0.1'
if cmd_opts.ngrok != None: if cmd_opts.ngrok is not None:
import modules.ngrok as ngrok import modules.ngrok as ngrok
print('ngrok authtoken detected, trying to connect...') print('ngrok authtoken detected, trying to connect...')
ngrok.connect(cmd_opts.ngrok, cmd_opts.port if cmd_opts.port != None else 7860, cmd_opts.ngrok_region) ngrok.connect(
cmd_opts.ngrok,
cmd_opts.port if cmd_opts.port is not None else 7860,
cmd_opts.ngrok_region
)
def gr_show(visible=True): def gr_show(visible=True):
...@@ -69,20 +70,23 @@ sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None ...@@ -69,20 +70,23 @@ sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None
css_hide_progressbar = """ css_hide_progressbar = """
.wrap .m-12 svg { display:none!important; } .wrap .m-12 svg { display:none!important; }
.wrap .m-12::before { content:"Loading..." } .wrap .m-12::before { content:"Loading..." }
.wrap .z-20 svg { display:none!important; }
.wrap .z-20::before { content:"Loading..." }
.progress-bar { display:none!important; } .progress-bar { display:none!important; }
.meta-text { display:none!important; } .meta-text { display:none!important; }
.meta-text-center { display:none!important; }
""" """
# Using constants for these since the variation selector isn't visible. # Using constants for these since the variation selector isn't visible.
# Important that they exactly match script.js for tooltip to work. # Important that they exactly match script.js for tooltip to work.
random_symbol = '\U0001f3b2\ufe0f' # 🎲️ random_symbol = '\U0001f3b2\ufe0f' # 🎲️
reuse_symbol = '\u267b\ufe0f' # ♻️ reuse_symbol = '\u267b\ufe0f' # ♻️
art_symbol = '\U0001f3a8' # 🎨
paste_symbol = '\u2199\ufe0f' # ↙ paste_symbol = '\u2199\ufe0f' # ↙
folder_symbol = '\U0001f4c2' # 📂 folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄 refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾 save_style_symbol = '\U0001f4be' # 💾
apply_style_symbol = '\U0001f4cb' # 📋 apply_style_symbol = '\U0001f4cb' # 📋
clear_prompt_symbol = '\U0001F5D1' # 🗑️
def plaintext_to_html(text): def plaintext_to_html(text):
...@@ -142,7 +146,7 @@ def save_files(js_data, images, do_make_zip, index): ...@@ -142,7 +146,7 @@ def save_files(js_data, images, do_make_zip, index):
filenames.append(os.path.basename(txt_fullfn)) filenames.append(os.path.basename(txt_fullfn))
fullfns.append(txt_fullfn) fullfns.append(txt_fullfn)
writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
# Make Zip # Make Zip
if do_make_zip: if do_make_zip:
...@@ -155,96 +159,17 @@ def save_files(js_data, images, do_make_zip, index): ...@@ -155,96 +159,17 @@ def save_files(js_data, images, do_make_zip, index):
zip_file.writestr(filenames[i], f.read()) zip_file.writestr(filenames[i], f.read())
fullfns.insert(0, zip_filepath) fullfns.insert(0, zip_filepath)
return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}") return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
def save_pil_to_file(pil_image, dir=None):
use_metadata = False
metadata = PngImagePlugin.PngInfo()
for key, value in pil_image.info.items():
if isinstance(key, str) and isinstance(value, str):
metadata.add_text(key, value)
use_metadata = True
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
return file_obj
# override save to file function so that it also writes PNG info
gr.processing_utils.save_pil_to_file = save_pil_to_file
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
if run_memmon:
shared.mem_mon.monitor()
t = time.perf_counter()
try:
res = list(func(*args, **kwargs))
except Exception as e:
# When printing out our debug argument list, do not print out more than a MB of text
max_debug_str_len = 131072 # (1024*1024)/8
print("Error completing request", file=sys.stderr)
argStr = f"Arguments: {str(args)} {str(kwargs)}"
print(argStr[:max_debug_str_len], file=sys.stderr)
if len(argStr) > max_debug_str_len:
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
shared.state.job = ""
shared.state.job_count = 0
if extra_outputs_array is None:
extra_outputs_array = [None, '']
res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
shared.state.skipped = False
shared.state.interrupted = False
shared.state.job_count = 0
if not add_stats:
return tuple(res)
elapsed = time.perf_counter() - t
elapsed_m = int(elapsed // 60)
elapsed_s = elapsed % 60
elapsed_text = f"{elapsed_s:.2f}s"
if elapsed_m > 0:
elapsed_text = f"{elapsed_m}m "+elapsed_text
if run_memmon:
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
active_peak = mem_stats['active_peak']
reserved_peak = mem_stats['reserved_peak']
sys_peak = mem_stats['system_peak']
sys_total = mem_stats['total']
sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
else:
vram_html = ''
# last item is always HTML def calc_time_left(progress, threshold, label, force_display, show_eta):
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
return tuple(res)
return f
def calc_time_left(progress, threshold, label, force_display):
if progress == 0: if progress == 0:
return "" return ""
else: else:
time_since_start = time.time() - shared.state.time_start time_since_start = time.time() - shared.state.time_start
eta = (time_since_start/progress) eta = (time_since_start/progress)
eta_relative = eta-time_since_start eta_relative = eta-time_since_start
if (eta_relative > threshold and progress > 0.02) or force_display: if (eta_relative > threshold and show_eta) or force_display:
if eta_relative > 3600: if eta_relative > 3600:
return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
elif eta_relative > 60: elif eta_relative > 60:
...@@ -266,7 +191,10 @@ def check_progress_call(id_part): ...@@ -266,7 +191,10 @@ def check_progress_call(id_part):
if shared.state.sampling_steps > 0: if shared.state.sampling_steps > 0:
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
time_left = calc_time_left( progress, 1, " ETA: ", shared.state.time_left_force_display ) # Show progress percentage and time left at the same moment, and base it also on steps done
show_eta = progress >= 0.01 or shared.state.sampling_step >= 10
time_left = calc_time_left(progress, 1, " ETA: ", shared.state.time_left_force_display, show_eta)
if time_left != "": if time_left != "":
shared.state.time_left_force_display = True shared.state.time_left_force_display = True
...@@ -274,7 +202,7 @@ def check_progress_call(id_part): ...@@ -274,7 +202,7 @@ def check_progress_call(id_part):
progressbar = "" progressbar = ""
if opts.show_progressbar: if opts.show_progressbar:
progressbar = f"""<div class='progressDiv'><div class='progress' style="overflow:visible;width:{progress * 100}%;white-space:nowrap;">{"&nbsp;" * 2 + str(int(progress*100))+"%" + time_left if progress > 0.01 else ""}</div></div>""" progressbar = f"""<div class='progressDiv'><div class='progress' style="overflow:visible;width:{progress * 100}%;white-space:nowrap;">{"&nbsp;" * 2 + str(int(progress*100))+"%" + time_left if show_eta else ""}</div></div>"""
image = gr_show(False) image = gr_show(False)
preview_visibility = gr_show(False) preview_visibility = gr_show(False)
...@@ -307,13 +235,6 @@ def check_progress_call_initial(id_part): ...@@ -307,13 +235,6 @@ def check_progress_call_initial(id_part):
return check_progress_call(id_part) return check_progress_call(id_part)
def roll_artist(prompt):
allowed_cats = set([x for x in shared.artist_db.categories() if len(opts.random_artist_categories)==0 or x in opts.random_artist_categories])
artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats])
return prompt + ", " + artist.name if prompt != '' else artist.name
def visit(x, func, path=""): def visit(x, func, path=""):
if hasattr(x, 'children'): if hasattr(x, 'children'):
for c in x.children: for c in x.children:
...@@ -343,45 +264,41 @@ def apply_styles(prompt, prompt_neg, style1_name, style2_name): ...@@ -343,45 +264,41 @@ def apply_styles(prompt, prompt_neg, style1_name, style2_name):
def interrogate(image): def interrogate(image):
prompt = shared.interrogator.interrogate(image) prompt = shared.interrogator.interrogate(image.convert("RGB"))
return gr_show(True) if prompt is None else prompt return gr_show(True) if prompt is None else prompt
def interrogate_deepbooru(image): def interrogate_deepbooru(image):
prompt = get_deepbooru_tags(image) prompt = deepbooru.model.tag(image)
return gr_show(True) if prompt is None else prompt return gr_show(True) if prompt is None else prompt
def create_seed_inputs(): def create_seed_inputs(target_interface):
with gr.Row(): with FormRow(elem_id=target_interface + '_seed_row'):
with gr.Box(): seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed')
with gr.Row(elem_id='seed_row'):
seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1)
seed.style(container=False) seed.style(container=False)
random_seed = gr.Button(random_symbol, elem_id='random_seed') random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed')
reuse_seed = gr.Button(reuse_symbol, elem_id='reuse_seed') reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed')
with gr.Box(elem_id='subseed_show_box'): with gr.Group(elem_id=target_interface + '_subseed_show_box'):
seed_checkbox = gr.Checkbox(label='Extra', elem_id='subseed_show', value=False) seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
# Components to show/hide based on the 'Extra' checkbox # Components to show/hide based on the 'Extra' checkbox
seed_extras = [] seed_extras = []
with gr.Row(visible=False) as seed_extra_row_1: with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1:
seed_extras.append(seed_extra_row_1) seed_extras.append(seed_extra_row_1)
with gr.Box(): subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed')
with gr.Row(elem_id='subseed_row'):
subseed = gr.Number(label='Variation seed', value=-1)
subseed.style(container=False) subseed.style(container=False)
random_subseed = gr.Button(random_symbol, elem_id='random_subseed') random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed')
reuse_subseed = gr.Button(reuse_symbol, elem_id='reuse_subseed') reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed')
subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01) subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength')
with gr.Row(visible=False) as seed_extra_row_2: with FormRow(visible=False) as seed_extra_row_2:
seed_extras.append(seed_extra_row_2) seed_extras.append(seed_extra_row_2)
seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from width", value=0) seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w')
seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from height", value=0) seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h')
random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed]) random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed]) random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed])
...@@ -394,6 +311,17 @@ def create_seed_inputs(): ...@@ -394,6 +311,17 @@ def create_seed_inputs():
return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox
def connect_clear_prompt(button):
"""Given clear button, prompt, and token_counter objects, setup clear prompt button click event"""
button.click(
_js="clear_prompt",
fn=None,
inputs=[],
outputs=[],
)
def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed): def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed):
""" Connects a 'reuse (sub)seed' button's click event so that it copies last used """ Connects a 'reuse (sub)seed' button's click event so that it copies last used
(sub)seed value from generation info the to the seed field. If copying subseed and subseed strength (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
...@@ -465,21 +393,25 @@ def create_toprow(is_img2img): ...@@ -465,21 +393,25 @@ def create_toprow(is_img2img):
) )
with gr.Column(scale=1, elem_id="roll_col"): with gr.Column(scale=1, elem_id="roll_col"):
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
paste = gr.Button(value=paste_symbol, elem_id="paste") paste = gr.Button(value=paste_symbol, elem_id="paste")
save_style = gr.Button(value=save_style_symbol, elem_id="style_create") save_style = gr.Button(value=save_style_symbol, elem_id="style_create")
prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply")
clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter") token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
clear_prompt_button.click(
fn=lambda *x: x,
_js="confirm_clear_prompt",
inputs=[prompt, negative_prompt],
outputs=[prompt, negative_prompt],
)
button_interrogate = None button_interrogate = None
button_deepbooru = None button_deepbooru = None
if is_img2img: if is_img2img:
with gr.Column(scale=1, elem_id="interrogate_col"): with gr.Column(scale=1, elem_id="interrogate_col"):
button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
if cmd_opts.deepdanbooru:
button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
with gr.Column(scale=1): with gr.Column(scale=1):
...@@ -509,7 +441,7 @@ def create_toprow(is_img2img): ...@@ -509,7 +441,7 @@ def create_toprow(is_img2img):
prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
prompt_style2.save_to_config = True prompt_style2.save_to_config = True
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
def setup_progressbar(progressbar, preview, id_part, textinfo=None): def setup_progressbar(progressbar, preview, id_part, textinfo=None):
...@@ -557,7 +489,7 @@ def apply_setting(key, value): ...@@ -557,7 +489,7 @@ def apply_setting(key, value):
return return
valtype = type(opts.data_labels[key].default) valtype = type(opts.data_labels[key].default)
oldval = opts.data[key] oldval = opts.data.get(key, None)
opts.data[key] = valtype(value) if valtype != type(None) else value opts.data[key] = valtype(value) if valtype != type(None) else value
if oldval != value and opts.data_labels[key].onchange is not None: if oldval != value and opts.data_labels[key].onchange is not None:
opts.data_labels[key].onchange() opts.data_labels[key].onchange()
...@@ -566,6 +498,19 @@ def apply_setting(key, value): ...@@ -566,6 +498,19 @@ def apply_setting(key, value):
return value return value
def update_generation_info(args):
generation_info, html_info, img_index = args
try:
generation_info = json.loads(generation_info)
if img_index < 0 or img_index >= len(generation_info["infotexts"]):
return html_info
return plaintext_to_html(generation_info["infotexts"][img_index])
except Exception:
pass
# if the json parse or anything else fails, just return the old html_info
return html_info
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
def refresh(): def refresh():
refresh_method() refresh_method()
...@@ -576,7 +521,7 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args, ele ...@@ -576,7 +521,7 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args, ele
return gr.update(**(args or {})) return gr.update(**(args or {}))
refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id) refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
refresh_button.click( refresh_button.click(
fn=refresh, fn=refresh,
inputs=[], inputs=[],
...@@ -614,13 +559,14 @@ Requested path was: {f} ...@@ -614,13 +559,14 @@ Requested path was: {f}
generation_info = None generation_info = None
with gr.Column(): with gr.Column():
with gr.Row(): with gr.Row(elem_id=f"image_buttons_{tabname}"):
open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder')
if tabname != "extras": if tabname != "extras":
save = gr.Button('Save', elem_id=f'save_{tabname}') save = gr.Button('Save', elem_id=f'save_{tabname}')
save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}')
buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
open_folder_button = gr.Button(folder_symbol, elem_id=button_id)
open_folder_button.click( open_folder_button.click(
fn=lambda: open_folder(opts.outdir_samples or outdir), fn=lambda: open_folder(opts.outdir_samples or outdir),
...@@ -629,40 +575,85 @@ Requested path was: {f} ...@@ -629,40 +575,85 @@ Requested path was: {f}
) )
if tabname != "extras": if tabname != "extras":
with gr.Row():
do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
with gr.Row(): with gr.Row():
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False) download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
with gr.Group(): with gr.Group():
html_info = gr.HTML() html_info = gr.HTML()
html_log = gr.HTML()
generation_info = gr.Textbox(visible=False) generation_info = gr.Textbox(visible=False)
if tabname == 'txt2img' or tabname == 'img2img':
generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
generation_info_button.click(
fn=update_generation_info,
_js="(x, y) => [x, y, selected_gallery_index()]",
inputs=[generation_info, html_info],
outputs=[html_info],
preprocess=False
)
save.click( save.click(
fn=wrap_gradio_call(save_files), fn=wrap_gradio_call(save_files),
_js="(x, y, z, w) => [x, y, z, selected_gallery_index()]", _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
inputs=[ inputs=[
generation_info, generation_info,
result_gallery, result_gallery,
do_make_zip, html_info,
html_info, html_info,
], ],
outputs=[ outputs=[
download_files, download_files,
html_log,
]
)
save_zip.click(
fn=wrap_gradio_call(save_files),
_js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
inputs=[
generation_info,
result_gallery,
html_info, html_info,
html_info, html_info,
html_info, ],
outputs=[
download_files,
html_log,
] ]
) )
else: else:
html_info_x = gr.HTML() html_info_x = gr.HTML()
html_info = gr.HTML() html_info = gr.HTML()
html_log = gr.HTML()
parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None)
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
def create_sampler_and_steps_selection(choices, tabname):
if opts.samplers_in_dropdown:
with FormRow(elem_id=f"sampler_selection_{tabname}"):
sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
sampler_index.save_to_config = True
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20)
else:
with FormGroup(elem_id=f"sampler_selection_{tabname}"):
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20)
sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
return steps, sampler_index
def create_ui(wrap_gradio_gpu_call): def ordered_ui_categories():
user_order = {x.strip(): i for i, x in enumerate(shared.opts.ui_reorder.split(","))}
for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] + 1000)):
yield category
def create_ui():
import modules.img2img import modules.img2img
import modules.txt2img import modules.txt2img
...@@ -670,8 +661,12 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -670,8 +661,12 @@ def create_ui(wrap_gradio_gpu_call):
parameters_copypaste.reset() parameters_copypaste.reset()
modules.scripts.scripts_current = modules.scripts.scripts_txt2img
modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
with gr.Blocks(analytics_enabled=False) as txt2img_interface: with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False) dummy_component = gr.Label(visible=False)
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
...@@ -685,43 +680,58 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -685,43 +680,58 @@ def create_ui(wrap_gradio_gpu_call):
setup_progressbar(progressbar, txt2img_preview, 'txt2img') setup_progressbar(progressbar, txt2img_preview, 'txt2img')
with gr.Row().style(equal_height=False): with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'): with gr.Column(variant='panel', elem_id="txt2img_settings"):
steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20) for category in ordered_ui_categories():
sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index") if category == "sampler":
steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
with gr.Group():
width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) elif category == "dimensions":
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) with FormRow():
with gr.Column(elem_id="txt2img_column_size", scale=4):
with gr.Row(): width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1) height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
tiling = gr.Checkbox(label='Tiling', value=False)
enable_hr = gr.Checkbox(label='Highres. fix', value=False) if opts.dimensions_and_batch_together:
with gr.Column(elem_id="txt2img_column_batch"):
with gr.Row(visible=False) as hr_options: batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
firstphase_width = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass width", value=0) batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
firstphase_height = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass height", value=0)
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7) elif category == "cfg":
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale")
with gr.Row(equal_height=True):
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1) elif category == "seed":
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1) seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0) elif category == "checkboxes":
with FormRow(elem_id="txt2img_checkboxes"):
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs() restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
with gr.Group(): enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False)
elif category == "hires_fix":
txt2img_gallery, generation_info, html_info = create_output_panel("txt2img", opts.outdir_txt2img_samples) with FormRow(visible=False, elem_id="txt2img_hires_fix") as hr_options:
hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
elif category == "batch":
if not opts.dimensions_and_batch_together:
with FormRow(elem_id="txt2img_column_batch"):
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
elif category == "scripts":
with FormGroup(elem_id="txt2img_script_container"):
custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
txt2img_args = dict( txt2img_args = dict(
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img), fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
_js="submit", _js="submit",
inputs=[ inputs=[
txt2img_prompt, txt2img_prompt,
...@@ -741,14 +751,15 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -741,14 +751,15 @@ def create_ui(wrap_gradio_gpu_call):
width, width,
enable_hr, enable_hr,
denoising_strength, denoising_strength,
firstphase_width, hr_scale,
firstphase_height, hr_upscaler,
] + custom_inputs, ] + custom_inputs,
outputs=[ outputs=[
txt2img_gallery, txt2img_gallery,
generation_info, generation_info,
html_info html_info,
html_log,
], ],
show_progress=False, show_progress=False,
) )
...@@ -773,17 +784,6 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -773,17 +784,6 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[hr_options], outputs=[hr_options],
) )
roll.click(
fn=roll_artist,
_js="update_txt2img_tokens",
inputs=[
txt2img_prompt,
],
outputs=[
txt2img_prompt,
]
)
txt2img_paste_fields = [ txt2img_paste_fields = [
(txt2img_prompt, "Prompt"), (txt2img_prompt, "Prompt"),
(txt2img_negative_prompt, "Negative prompt"), (txt2img_negative_prompt, "Negative prompt"),
...@@ -802,8 +802,8 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -802,8 +802,8 @@ def create_ui(wrap_gradio_gpu_call):
(denoising_strength, "Denoising strength"), (denoising_strength, "Denoising strength"),
(enable_hr, lambda d: "Denoising strength" in d), (enable_hr, lambda d: "Denoising strength" in d),
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
(firstphase_width, "First pass size-1"), (hr_scale, "Hires upscale"),
(firstphase_height, "First pass size-2"), (hr_upscaler, "Hires upscaler"),
*modules.scripts.scripts_txt2img.infotext_fields *modules.scripts.scripts_txt2img.infotext_fields
] ]
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields) parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields)
...@@ -819,10 +819,13 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -819,10 +819,13 @@ def create_ui(wrap_gradio_gpu_call):
height, height,
] ]
token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter]) token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
modules.scripts.scripts_current = modules.scripts.scripts_img2img
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
with gr.Blocks(analytics_enabled=False) as img2img_interface: with gr.Blocks(analytics_enabled=False) as img2img_interface:
img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button = create_toprow(is_img2img=True) img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
with gr.Row(elem_id='img2img_progress_row'): with gr.Row(elem_id='img2img_progress_row'):
img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
...@@ -835,65 +838,97 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -835,65 +838,97 @@ def create_ui(wrap_gradio_gpu_call):
img2img_preview = gr.Image(elem_id='img2img_preview', visible=False) img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
setup_progressbar(progressbar, img2img_preview, 'img2img') setup_progressbar(progressbar, img2img_preview, 'img2img')
with gr.Row().style(equal_height=False): with FormRow().style(equal_height=False):
with gr.Column(variant='panel'): with gr.Column(variant='panel', elem_id="img2img_settings"):
with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode: with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
with gr.TabItem('img2img', id='img2img'): with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"):
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool).style(height=480) init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480)
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"):
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480)
init_img_with_mask_orig = gr.State(None)
with gr.TabItem('Inpaint', id='inpaint'): use_color_sketch = cmd_opts.gradio_inpaint_tool == "color-sketch"
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480) if use_color_sketch:
def update_orig(image, state):
if image is not None:
same_size = state is not None and state.size == image.size
has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
edited = same_size and has_exact_match
return image if not edited or state is None else state
init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig)
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base") init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask") init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4) with FormRow():
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha")
with gr.Row(): with FormRow():
mask_mode = gr.Radio(label="Mask mode", show_label=False, choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode") mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
inpainting_mask_invert = gr.Radio(label='Masking mode', show_label=False, choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index") inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index") with FormRow():
inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")
with gr.Row(): with FormRow():
inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False) with gr.Column():
inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=256, step=4, value=32) inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res")
with gr.Column(scale=4):
inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
with gr.TabItem('Batch img2img', id='batch'): with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"):
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
gr.HTML(f"<p class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.<br>Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}</p>") gr.HTML(f"<p class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.<br>Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}</p>")
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs) img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs) img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
with gr.Row(): with FormRow():
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", show_label=False, choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize") resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20) for category in ordered_ui_categories():
sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index") if category == "sampler":
steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img")
with gr.Group():
width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="img2img_width") elif category == "dimensions":
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="img2img_height") with FormRow():
with gr.Column(elem_id="img2img_column_size", scale=4):
with gr.Row(): width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1) height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
tiling = gr.Checkbox(label='Tiling', value=False)
if opts.dimensions_and_batch_together:
with gr.Row(): with gr.Column(elem_id="img2img_column_batch"):
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1) batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1) batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
with gr.Group(): elif category == "cfg":
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0) with FormGroup():
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75) cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
elif category == "seed":
with gr.Group(): seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img')
custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True)
elif category == "checkboxes":
img2img_gallery, generation_info, html_info = create_output_panel("img2img", opts.outdir_img2img_samples) with FormRow(elem_id="img2img_checkboxes"):
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces")
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling")
elif category == "batch":
if not opts.dimensions_and_batch_together:
with FormRow(elem_id="img2img_column_batch"):
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
elif category == "scripts":
with FormGroup(elem_id="img2img_script_container"):
custom_inputs = modules.scripts.scripts_img2img.setup_ui()
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt) parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
...@@ -925,7 +960,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -925,7 +960,7 @@ def create_ui(wrap_gradio_gpu_call):
) )
img2img_args = dict( img2img_args = dict(
fn=wrap_gradio_gpu_call(modules.img2img.img2img), fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
_js="submit_img2img", _js="submit_img2img",
inputs=[ inputs=[
dummy_component, dummy_component,
...@@ -935,12 +970,14 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -935,12 +970,14 @@ def create_ui(wrap_gradio_gpu_call):
img2img_prompt_style2, img2img_prompt_style2,
init_img, init_img,
init_img_with_mask, init_img_with_mask,
init_img_with_mask_orig,
init_img_inpaint, init_img_inpaint,
init_mask_inpaint, init_mask_inpaint,
mask_mode, mask_mode,
steps, steps,
sampler_index, sampler_index,
mask_blur, mask_blur,
mask_alpha,
inpainting_fill, inpainting_fill,
restore_faces, restore_faces,
tiling, tiling,
...@@ -962,7 +999,8 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -962,7 +999,8 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[ outputs=[
img2img_gallery, img2img_gallery,
generation_info, generation_info,
html_info html_info,
html_log,
], ],
show_progress=False, show_progress=False,
) )
...@@ -976,25 +1014,12 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -976,25 +1014,12 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[img2img_prompt], outputs=[img2img_prompt],
) )
if cmd_opts.deepdanbooru:
img2img_deepbooru.click( img2img_deepbooru.click(
fn=interrogate_deepbooru, fn=interrogate_deepbooru,
inputs=[init_img], inputs=[init_img],
outputs=[img2img_prompt], outputs=[img2img_prompt],
) )
roll.click(
fn=roll_artist,
_js="update_img2img_tokens",
inputs=[
img2img_prompt,
],
outputs=[
img2img_prompt,
]
)
prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)] style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)]
style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"] style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"]
...@@ -1035,59 +1060,62 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1035,59 +1060,62 @@ def create_ui(wrap_gradio_gpu_call):
(seed_resize_from_w, "Seed resize from-1"), (seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"), (seed_resize_from_h, "Seed resize from-2"),
(denoising_strength, "Denoising strength"), (denoising_strength, "Denoising strength"),
(mask_blur, "Mask blur"),
*modules.scripts.scripts_img2img.infotext_fields *modules.scripts.scripts_img2img.infotext_fields
] ]
parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields) parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields)
parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields) parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields)
modules.scripts.scripts_current = None
with gr.Blocks(analytics_enabled=False) as extras_interface: with gr.Blocks(analytics_enabled=False) as extras_interface:
with gr.Row().style(equal_height=False): with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'): with gr.Column(variant='panel'):
with gr.Tabs(elem_id="mode_extras"): with gr.Tabs(elem_id="mode_extras"):
with gr.TabItem('Single Image'): with gr.TabItem('Single Image', elem_id="extras_single_tab"):
extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil") extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
with gr.TabItem('Batch Process'): with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab"):
image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file") image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch")
with gr.TabItem('Batch from Directory'): with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab"):
extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.") extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.") extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
show_extras_results = gr.Checkbox(label='Show result images', value=True) show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
with gr.Tabs(elem_id="extras_resize_mode"): with gr.Tabs(elem_id="extras_resize_mode"):
with gr.TabItem('Scale by'): with gr.TabItem('Scale by', elem_id="extras_scale_by_tab"):
upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4) upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize")
with gr.TabItem('Scale to'): with gr.TabItem('Scale to', elem_id="extras_scale_to_tab"):
with gr.Group(): with gr.Group():
with gr.Row(): with gr.Row():
upscaling_resize_w = gr.Number(label="Width", value=512, precision=0) upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w")
upscaling_resize_h = gr.Number(label="Height", value=512, precision=0) upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h")
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True) upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
with gr.Group(): with gr.Group():
extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
with gr.Group(): with gr.Group():
extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1) extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1, elem_id="extras_upscaler_2_visibility")
with gr.Group(): with gr.Group():
gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan) gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan, elem_id="extras_gfpgan_visibility")
with gr.Group(): with gr.Group():
codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer) codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_visibility")
codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer) codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_weight")
with gr.Group(): with gr.Group():
upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False) upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False, elem_id="extras_upscale_before_face_fix")
result_images, html_info_x, html_info = create_output_panel("extras", opts.outdir_extras_samples) result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples)
submit.click( submit.click(
fn=wrap_gradio_gpu_call(modules.extras.run_extras), fn=wrap_gradio_gpu_call(modules.extras.run_extras, extra_outputs=[None, '']),
_js="get_extras_tab_index", _js="get_extras_tab_index",
inputs=[ inputs=[
dummy_component, dummy_component,
...@@ -1129,7 +1157,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1129,7 +1157,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Column(variant='panel'): with gr.Column(variant='panel'):
html = gr.HTML() html = gr.HTML()
generation_info = gr.Textbox(visible=False) generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info")
html2 = gr.HTML() html2 = gr.HTML()
with gr.Row(): with gr.Row():
buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"]) buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"])
...@@ -1148,19 +1176,27 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1148,19 +1176,27 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Row(): with gr.Row():
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B")
tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")
custom_name = gr.Textbox(label="Custom Name (Optional)") create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C")
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3)
interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method") custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
save_as_half = gr.Checkbox(value=False, label="Save as float16") interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
with gr.Row():
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
with gr.Column(variant='panel'): with gr.Column(variant='panel'):
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
with gr.Blocks(analytics_enabled=False) as train_interface: with gr.Blocks(analytics_enabled=False) as train_interface:
with gr.Row().style(equal_height=False): with gr.Row().style(equal_height=False):
gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>") gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
...@@ -1169,65 +1205,67 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1169,65 +1205,67 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Tabs(elem_id="train_tabs"): with gr.Tabs(elem_id="train_tabs"):
with gr.Tab(label="Create embedding"): with gr.Tab(label="Create embedding"):
new_embedding_name = gr.Textbox(label="Name") new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name")
initialization_text = gr.Textbox(label="Initialization text", value="*") initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text")
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1) nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt")
overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding") overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding")
with gr.Row(): with gr.Row():
with gr.Column(scale=3): with gr.Column(scale=3):
gr.HTML(value="") gr.HTML(value="")
with gr.Column(): with gr.Column():
create_embedding = gr.Button(value="Create embedding", variant='primary') create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding")
with gr.Tab(label="Create hypernetwork"): with gr.Tab(label="Create hypernetwork"):
new_hypernetwork_name = gr.Textbox(label="Name") new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes")
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'") new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure")
new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys) new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func")
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"]) new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option")
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm")
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout") new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout")
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork") overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork")
with gr.Row(): with gr.Row():
with gr.Column(scale=3): with gr.Column(scale=3):
gr.HTML(value="") gr.HTML(value="")
with gr.Column(): with gr.Column():
create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary') create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork")
with gr.Tab(label="Preprocess images"): with gr.Tab(label="Preprocess images"):
process_src = gr.Textbox(label='Source directory') process_src = gr.Textbox(label='Source directory', elem_id="train_process_src")
process_dst = gr.Textbox(label='Destination directory') process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst")
process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width")
process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height")
preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"]) preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action")
with gr.Row(): with gr.Row():
process_flip = gr.Checkbox(label='Create flipped copies') process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip")
process_split = gr.Checkbox(label='Split oversized images') process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split")
process_focal_crop = gr.Checkbox(label='Auto focal point crop') process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop")
process_caption = gr.Checkbox(label='Use BLIP for caption') process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption")
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False) process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru")
with gr.Row(visible=False) as process_split_extra_row: with gr.Row(visible=False) as process_split_extra_row:
process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05) process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold")
process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05) process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio")
with gr.Row(visible=False) as process_focal_crop_row: with gr.Row(visible=False) as process_focal_crop_row:
process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05) process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight")
process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05) process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight")
process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05) process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight")
process_focal_crop_debug = gr.Checkbox(label='Create debug image') process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
with gr.Row(): with gr.Row():
with gr.Column(scale=3): with gr.Column(scale=3):
gr.HTML(value="") gr.HTML(value="")
with gr.Column(): with gr.Column():
run_preprocess = gr.Button(value="Preprocess", variant='primary') with gr.Row():
interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing")
run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess")
process_split.change( process_split.change(
fn=lambda show: gr_show(show), fn=lambda show: gr_show(show),
...@@ -1250,27 +1288,35 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1250,27 +1288,35 @@ def create_ui(wrap_gradio_gpu_call):
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()]) train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name") create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
with gr.Row(): with gr.Row():
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005") embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001") hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate")
with gr.Row(): with gr.Row():
clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"]) clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False) clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False)
batch_size = gr.Number(label='Batch size', value=1, precision=0)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
steps = gr.Number(label='Max steps', value=100000, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size")
gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step")
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file")
training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps")
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every")
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every")
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding")
preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img")
with gr.Row(): with gr.Row():
interrupt_training = gr.Button(value="Interrupt") shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags")
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary') tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out")
train_embedding = gr.Button(value="Train Embedding", variant='primary') with gr.Row():
latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method")
with gr.Row():
interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training")
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork")
train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding")
params = script_callbacks.UiTrainTabParams(txt2img_preview_params) params = script_callbacks.UiTrainTabParams(txt2img_preview_params)
...@@ -1354,6 +1400,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1354,6 +1400,7 @@ def create_ui(wrap_gradio_gpu_call):
train_embedding_name, train_embedding_name,
embedding_learn_rate, embedding_learn_rate,
batch_size, batch_size,
gradient_step,
dataset_directory, dataset_directory,
log_directory, log_directory,
training_width, training_width,
...@@ -1361,6 +1408,9 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1361,6 +1408,9 @@ def create_ui(wrap_gradio_gpu_call):
steps, steps,
clip_grad_mode, clip_grad_mode,
clip_grad_value, clip_grad_value,
shuffle_tags,
tag_drop_out,
latent_sampling_method,
create_image_every, create_image_every,
save_embedding_every, save_embedding_every,
template_file, template_file,
...@@ -1381,6 +1431,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1381,6 +1431,7 @@ def create_ui(wrap_gradio_gpu_call):
train_hypernetwork_name, train_hypernetwork_name,
hypernetwork_learn_rate, hypernetwork_learn_rate,
batch_size, batch_size,
gradient_step,
dataset_directory, dataset_directory,
log_directory, log_directory,
training_width, training_width,
...@@ -1388,6 +1439,9 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1388,6 +1439,9 @@ def create_ui(wrap_gradio_gpu_call):
steps, steps,
clip_grad_mode, clip_grad_mode,
clip_grad_value, clip_grad_value,
shuffle_tags,
tag_drop_out,
latent_sampling_method,
create_image_every, create_image_every,
save_embedding_every, save_embedding_every,
template_file, template_file,
...@@ -1406,6 +1460,12 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1406,6 +1460,12 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[], outputs=[],
) )
interrupt_preprocessing.click(
fn=lambda: shared.state.interrupt(),
inputs=[],
outputs=[],
)
def create_setting_component(key, is_quicksettings=False): def create_setting_component(key, is_quicksettings=False):
def fun(): def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key].default return opts.data[key] if key in opts.data else opts.data_labels[key].default
...@@ -1433,7 +1493,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1433,7 +1493,7 @@ def create_ui(wrap_gradio_gpu_call):
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
else: else:
with gr.Row(variant="compact"): with FormRow():
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
else: else:
...@@ -1457,76 +1517,57 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1457,76 +1517,57 @@ def create_ui(wrap_gradio_gpu_call):
if comp == dummy_component: if comp == dummy_component:
continue continue
oldval = opts.data.get(key, None) if opts.set(key, value):
try:
setattr(opts, key, value)
except RuntimeError:
continue
if oldval != value:
if opts.data_labels[key].onchange is not None:
opts.data_labels[key].onchange()
changed.append(key) changed.append(key)
try: try:
opts.save(shared.config_filename) opts.save(shared.config_filename)
except RuntimeError: except RuntimeError:
return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.'
return opts.dumpjson(), f'{len(changed)} settings changed: {", ".join(changed)}.' return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.'
def run_settings_single(value, key): def run_settings_single(value, key):
if not opts.same_type(value, opts.data_labels[key].default): if not opts.same_type(value, opts.data_labels[key].default):
return gr.update(visible=True), opts.dumpjson() return gr.update(visible=True), opts.dumpjson()
oldval = opts.data.get(key, None) if not opts.set(key, value):
try: return gr.update(value=getattr(opts, key)), opts.dumpjson()
setattr(opts, key, value)
except Exception:
return gr.update(value=oldval), opts.dumpjson()
if oldval != value:
if opts.data_labels[key].onchange is not None:
opts.data_labels[key].onchange()
opts.save(shared.config_filename) opts.save(shared.config_filename)
return gr.update(value=value), opts.dumpjson() return gr.update(value=value), opts.dumpjson()
with gr.Blocks(analytics_enabled=False) as settings_interface: with gr.Blocks(analytics_enabled=False) as settings_interface:
settings_submit = gr.Button(value="Apply settings", variant='primary') with gr.Row():
result = gr.HTML() with gr.Column(scale=6):
settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
with gr.Column():
restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio")
settings_cols = 3 result = gr.HTML(elem_id="settings_result")
items_per_col = int(len(opts.data_labels) * 0.9 / settings_cols)
quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")] quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")]
quicksettings_names = set(x for x in quicksettings_names if x != 'quicksettings') quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'}
quicksettings_list = [] quicksettings_list = []
cols_displayed = 0
items_displayed = 0
previous_section = None previous_section = None
column = None current_tab = None
with gr.Row(elem_id="settings").style(equal_height=False): with gr.Tabs(elem_id="settings"):
for i, (k, item) in enumerate(opts.data_labels.items()): for i, (k, item) in enumerate(opts.data_labels.items()):
section_must_be_skipped = item.section[0] is None section_must_be_skipped = item.section[0] is None
if previous_section != item.section and not section_must_be_skipped: if previous_section != item.section and not section_must_be_skipped:
if cols_displayed < settings_cols and (items_displayed >= items_per_col or previous_section is None): elem_id, text = item.section
if column is not None:
column.__exit__()
column = gr.Column(variant='panel') if current_tab is not None:
column.__enter__() current_tab.__exit__()
items_displayed = 0 current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text)
cols_displayed += 1 current_tab.__enter__()
previous_section = item.section previous_section = item.section
elem_id, text = item.section
gr.HTML(elem_id="settings_header_text_{}".format(elem_id), value='<h1 class="gr-button-lg">{}</h1>'.format(text))
if k in quicksettings_names and not shared.cmd_opts.freeze_settings: if k in quicksettings_names and not shared.cmd_opts.freeze_settings:
quicksettings_list.append((i, k, item)) quicksettings_list.append((i, k, item))
components.append(dummy_component) components.append(dummy_component)
...@@ -1536,15 +1577,21 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1536,15 +1577,21 @@ def create_ui(wrap_gradio_gpu_call):
component = create_setting_component(k) component = create_setting_component(k)
component_dict[k] = component component_dict[k] = component
components.append(component) components.append(component)
items_displayed += 1
with gr.Row(): if current_tab is not None:
current_tab.__exit__()
with gr.TabItem("Actions"):
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
download_localization = gr.Button(value='Download localization template', elem_id="download_localization") download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
with gr.Row(): if os.path.exists("html/licenses.html"):
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary') with open("html/licenses.html", encoding="utf8") as file:
restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary') with gr.TabItem("Licenses"):
gr.HTML(file.read(), elem_id="licenses")
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
request_notifications.click( request_notifications.click(
fn=lambda: None, fn=lambda: None,
...@@ -1581,9 +1628,6 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1581,9 +1628,6 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[], outputs=[],
) )
if column is not None:
column.__exit__()
interfaces = [ interfaces = [
(txt2img_interface, "txt2img", "txt2img"), (txt2img_interface, "txt2img", "txt2img"),
(img2img_interface, "img2img", "img2img"), (img2img_interface, "img2img", "img2img"),
...@@ -1617,7 +1661,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1617,7 +1661,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
with gr.Row(elem_id="quicksettings"): with gr.Row(elem_id="quicksettings"):
for i, k, item in quicksettings_list: for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
component = create_setting_component(k, is_quicksettings=True) component = create_setting_component(k, is_quicksettings=True)
component_dict[k] = component component_dict[k] = component
...@@ -1632,6 +1676,10 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1632,6 +1676,10 @@ def create_ui(wrap_gradio_gpu_call):
if os.path.exists(os.path.join(script_path, "notification.mp3")): if os.path.exists(os.path.join(script_path, "notification.mp3")):
audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False) audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
if os.path.exists("html/footer.html"):
with open("html/footer.html", encoding="utf8") as file:
gr.HTML(file.read(), elem_id="footer")
text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
settings_submit.click( settings_submit.click(
fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]), fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]),
...@@ -1666,7 +1714,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1666,7 +1714,7 @@ def create_ui(wrap_gradio_gpu_call):
print("Error loading/saving model file:", file=sys.stderr) print("Error loading/saving model file:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
modules.sd_models.list_models() # to remove the potentially missing models from the list modules.sd_models.list_models() # to remove the potentially missing models from the list
return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)] return [f"Error merging checkpoints: {e}"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)]
return results return results
modelmerger_merge.click( modelmerger_merge.click(
...@@ -1679,6 +1727,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1679,6 +1727,7 @@ def create_ui(wrap_gradio_gpu_call):
interp_amount, interp_amount,
save_as_half, save_as_half,
custom_name, custom_name,
checkpoint_format,
], ],
outputs=[ outputs=[
submit_result, submit_result,
......
import gradio as gr
class ToolButton(gr.Button, gr.components.FormComponent):
"""Small button with single emoji as text, fits inside gradio forms"""
def __init__(self, **kwargs):
super().__init__(variant="tool", **kwargs)
def get_block_name(self):
return "button"
class FormRow(gr.Row, gr.components.FormComponent):
"""Same as gr.Row but fits inside gradio forms"""
def get_block_name(self):
return "row"
class FormGroup(gr.Group, gr.components.FormComponent):
"""Same as gr.Row but fits inside gradio forms"""
def get_block_name(self):
return "group"
...@@ -9,6 +9,8 @@ import git ...@@ -9,6 +9,8 @@ import git
import gradio as gr import gradio as gr
import html import html
import shutil
import errno
from modules import extensions, shared, paths from modules import extensions, shared, paths
...@@ -17,7 +19,7 @@ available_extensions = {"extensions": []} ...@@ -17,7 +19,7 @@ available_extensions = {"extensions": []}
def check_access(): def check_access():
assert not shared.cmd_opts.disable_extension_access, "extension access disabed because of commandline flags" assert not shared.cmd_opts.disable_extension_access, "extension access disabled because of command line flags"
def apply_and_restart(disable_list, update_list): def apply_and_restart(disable_list, update_list):
...@@ -36,9 +38,9 @@ def apply_and_restart(disable_list, update_list): ...@@ -36,9 +38,9 @@ def apply_and_restart(disable_list, update_list):
continue continue
try: try:
ext.pull() ext.fetch_and_reset_hard()
except Exception: except Exception:
print(f"Error pulling updates for {ext.name}:", file=sys.stderr) print(f"Error getting updates for {ext.name}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
shared.opts.disabled_extensions = disabled shared.opts.disabled_extensions = disabled
...@@ -78,6 +80,12 @@ def extension_table(): ...@@ -78,6 +80,12 @@ def extension_table():
""" """
for ext in extensions.extensions: for ext in extensions.extensions:
remote = ""
if ext.is_builtin:
remote = "built-in"
elif ext.remote:
remote = f"""<a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape("built-in" if ext.is_builtin else ext.remote or '')}</a>"""
if ext.can_update: if ext.can_update:
ext_status = f"""<label><input class="gr-check-radio gr-checkbox" name="update_{html.escape(ext.name)}" checked="checked" type="checkbox">{html.escape(ext.status)}</label>""" ext_status = f"""<label><input class="gr-check-radio gr-checkbox" name="update_{html.escape(ext.name)}" checked="checked" type="checkbox">{html.escape(ext.status)}</label>"""
else: else:
...@@ -86,7 +94,7 @@ def extension_table(): ...@@ -86,7 +94,7 @@ def extension_table():
code += f""" code += f"""
<tr> <tr>
<td><label><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td> <td><label><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
<td><a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape(ext.remote or '')}</a></td> <td>{remote}</td>
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td> <td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
</tr> </tr>
""" """
...@@ -132,7 +140,21 @@ def install_extension_from_url(dirname, url): ...@@ -132,7 +140,21 @@ def install_extension_from_url(dirname, url):
repo = git.Repo.clone_from(url, tmpdir) repo = git.Repo.clone_from(url, tmpdir)
repo.remote().fetch() repo.remote().fetch()
try:
os.rename(tmpdir, target_dir) os.rename(tmpdir, target_dir)
except OSError as err:
# TODO what does this do on windows? I think it'll be a different error code but I don't have a system to check it
# Shouldn't cause any new issues at least but we probably want to handle it there too.
if err.errno == errno.EXDEV:
# Cross device link, typical in docker or when tmp/ and extensions/ are on different file systems
# Since we can't use a rename, do the slower but more versitile shutil.move()
shutil.move(tmpdir, target_dir)
else:
# Something else, not enough free space, permissions, etc. rethrow it so that it gets handled.
raise(err)
import launch
launch.run_extension_installer(target_dir)
extensions.list_extensions() extensions.list_extensions()
return [extension_table(), html.escape(f"Installed into {target_dir}. Use Installed tab to restart.")] return [extension_table(), html.escape(f"Installed into {target_dir}. Use Installed tab to restart.")]
...@@ -197,12 +219,13 @@ def refresh_available_extensions_from_data(hide_tags): ...@@ -197,12 +219,13 @@ def refresh_available_extensions_from_data(hide_tags):
if url is None: if url is None:
continue continue
existing = installed_extension_urls.get(normalize_git_url(url), None)
extension_tags = extension_tags + ["installed"] if existing else extension_tags
if len([x for x in extension_tags if x in tags_to_hide]) > 0: if len([x for x in extension_tags if x in tags_to_hide]) > 0:
hidden += 1 hidden += 1
continue continue
existing = installed_extension_urls.get(normalize_git_url(url), None)
install_code = f"""<input onclick="install_extension_from_index(this, '{html.escape(url)}')" type="button" value="{"Install" if not existing else "Installed"}" {"disabled=disabled" if existing else ""} class="gr-button gr-button-lg gr-button-secondary">""" install_code = f"""<input onclick="install_extension_from_index(this, '{html.escape(url)}')" type="button" value="{"Install" if not existing else "Installed"}" {"disabled=disabled" if existing else ""} class="gr-button gr-button-lg gr-button-secondary">"""
tags_text = ", ".join([f"<span class='extension-tag' title='{tags.get(x, '')}'>{x}</span>" for x in extension_tags]) tags_text = ", ".join([f"<span class='extension-tag' title='{tags.get(x, '')}'>{x}</span>" for x in extension_tags])
...@@ -213,8 +236,12 @@ def refresh_available_extensions_from_data(hide_tags): ...@@ -213,8 +236,12 @@ def refresh_available_extensions_from_data(hide_tags):
<td>{html.escape(description)}</td> <td>{html.escape(description)}</td>
<td>{install_code}</td> <td>{install_code}</td>
</tr> </tr>
""" """
for tag in [x for x in extension_tags if x not in tags]:
tags[tag] = tag
code += """ code += """
</tbody> </tbody>
</table> </table>
...@@ -263,7 +290,7 @@ def create_ui(): ...@@ -263,7 +290,7 @@ def create_ui():
install_extension_button = gr.Button(elem_id="install_extension_button", visible=False) install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
with gr.Row(): with gr.Row():
hide_tags = gr.CheckboxGroup(value=["ads", "localization"], label="Hide extensions with tags", choices=["script", "ads", "localization"]) hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
install_result = gr.HTML() install_result = gr.HTML()
available_extensions_table = gr.HTML() available_extensions_table = gr.HTML()
......
import os
import tempfile
from collections import namedtuple
from pathlib import Path
import gradio as gr
from PIL import PngImagePlugin
from modules import shared
Savedfile = namedtuple("Savedfile", ["name"])
def register_tmp_file(gradio, filename):
if hasattr(gradio, 'temp_file_sets'): # gradio 3.15
gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)}
if hasattr(gradio, 'temp_dirs'): # gradio 3.9
gradio.temp_dirs = gradio.temp_dirs | {os.path.abspath(os.path.dirname(filename))}
def check_tmp_file(gradio, filename):
if hasattr(gradio, 'temp_file_sets'):
return any([filename in fileset for fileset in gradio.temp_file_sets])
if hasattr(gradio, 'temp_dirs'):
return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs)
return False
def save_pil_to_file(pil_image, dir=None):
already_saved_as = getattr(pil_image, 'already_saved_as', None)
if already_saved_as and os.path.isfile(already_saved_as):
register_tmp_file(shared.demo, already_saved_as)
file_obj = Savedfile(already_saved_as)
return file_obj
if shared.opts.temp_dir != "":
dir = shared.opts.temp_dir
use_metadata = False
metadata = PngImagePlugin.PngInfo()
for key, value in pil_image.info.items():
if isinstance(key, str) and isinstance(value, str):
metadata.add_text(key, value)
use_metadata = True
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
return file_obj
# override save to file function so that it also writes PNG info
gr.processing_utils.save_pil_to_file = save_pil_to_file
def on_tmpdir_changed():
if shared.opts.temp_dir == "" or shared.demo is None:
return
os.makedirs(shared.opts.temp_dir, exist_ok=True)
register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x"))
def cleanup_tmpdr():
temp_dir = shared.opts.temp_dir
if temp_dir == "" or not os.path.isdir(temp_dir):
return
for root, dirs, files in os.walk(temp_dir, topdown=False):
for name in files:
_, extension = os.path.splitext(name)
if extension != ".png":
continue
filename = os.path.join(root, name)
os.remove(filename)
...@@ -53,10 +53,10 @@ class Upscaler: ...@@ -53,10 +53,10 @@ class Upscaler:
def do_upscale(self, img: PIL.Image, selected_model: str): def do_upscale(self, img: PIL.Image, selected_model: str):
return img return img
def upscale(self, img: PIL.Image, scale: int, selected_model: str = None): def upscale(self, img: PIL.Image, scale, selected_model: str = None):
self.scale = scale self.scale = scale
dest_w = img.width * scale dest_w = int(img.width * scale)
dest_h = img.height * scale dest_h = int(img.height * scale)
for i in range(3): for i in range(3):
shape = (img.width, img.height) shape = (img.width, img.height)
......
from transformers import BertPreTrainedModel,BertModel,BertConfig
import torch.nn as nn
import torch
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
from transformers import XLMRobertaModel,XLMRobertaTokenizer
from typing import Optional
class BertSeriesConfig(BertConfig):
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
self.project_dim = project_dim
self.pooler_fn = pooler_fn
self.learn_encoder = learn_encoder
class RobertaSeriesConfig(XLMRobertaConfig):
def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.project_dim = project_dim
self.pooler_fn = pooler_fn
self.learn_encoder = learn_encoder
class BertSeriesModelWithTransformation(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
config_class = BertSeriesConfig
def __init__(self, config=None, **kargs):
# modify initialization for autoloading
if config is None:
config = XLMRobertaConfig()
config.attention_probs_dropout_prob= 0.1
config.bos_token_id=0
config.eos_token_id=2
config.hidden_act='gelu'
config.hidden_dropout_prob=0.1
config.hidden_size=1024
config.initializer_range=0.02
config.intermediate_size=4096
config.layer_norm_eps=1e-05
config.max_position_embeddings=514
config.num_attention_heads=16
config.num_hidden_layers=24
config.output_past=True
config.pad_token_id=1
config.position_embedding_type= "absolute"
config.type_vocab_size= 1
config.use_cache=True
config.vocab_size= 250002
config.project_dim = 768
config.learn_encoder = False
super().__init__(config)
self.roberta = XLMRobertaModel(config)
self.transformation = nn.Linear(config.hidden_size,config.project_dim)
self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
self.pooler = lambda x: x[:,0]
self.post_init()
def encode(self,c):
device = next(self.parameters()).device
text = self.tokenizer(c,
truncation=True,
max_length=77,
return_length=False,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt")
text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
text["attention_mask"] = torch.tensor(
text['attention_mask']).to(device)
features = self(**text)
return features['projection_state']
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) :
r"""
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.roberta(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=True,
return_dict=return_dict,
)
# last module outputs
sequence_output = outputs[0]
# project every module
sequence_output_ln = self.pre_LN(sequence_output)
# pooler
pooler_output = self.pooler(sequence_output_ln)
pooler_output = self.transformation(pooler_output)
projection_state = self.transformation(outputs.last_hidden_state)
return {
'pooler_output':pooler_output,
'last_hidden_state':outputs.last_hidden_state,
'hidden_states':outputs.hidden_states,
'attentions':outputs.attentions,
'projection_state':projection_state,
'sequence_out': sequence_output
}
class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
base_model_prefix = 'roberta'
config_class= RobertaSeriesConfig
\ No newline at end of file
blendmodes==2022
transformers==4.19.2 transformers==4.19.2
diffusers==0.3.0 accelerate==0.12.0
basicsr==1.4.2 basicsr==1.4.2
gfpgan==1.3.8 gfpgan==1.3.8
gradio==3.9 gradio==3.15.0
numpy==1.23.3 numpy==1.23.3
Pillow==9.2.0 Pillow==9.4.0
realesrgan==0.3.0 realesrgan==0.3.0
torch torch
omegaconf==2.2.3 omegaconf==2.2.3
...@@ -24,3 +25,6 @@ kornia==0.6.7 ...@@ -24,3 +25,6 @@ kornia==0.6.7
lark==1.1.2 lark==1.1.2
inflection==0.5.1 inflection==0.5.1
GitPython==3.1.27 GitPython==3.1.27
torchsde==0.2.5
safetensors==0.2.7
httpcore<=0.15
function gradioApp(){ function gradioApp() {
return document.getElementsByTagName('gradio-app')[0].shadowRoot; const gradioShadowRoot = document.getElementsByTagName('gradio-app')[0].shadowRoot
return !!gradioShadowRoot ? gradioShadowRoot : document;
} }
function get_uiCurrentTab() { function get_uiCurrentTab() {
return gradioApp().querySelector('.tabs button:not(.border-transparent)') return gradioApp().querySelector('#tabs button:not(.border-transparent)')
} }
function get_uiCurrentTabContent() { function get_uiCurrentTabContent() {
......
...@@ -157,7 +157,7 @@ class Script(scripts.Script): ...@@ -157,7 +157,7 @@ class Script(scripts.Script):
def run(self, p, _, override_sampler, override_prompt, original_prompt, original_negative_prompt, override_steps, st, override_strength, cfg, randomness, sigma_adjustment): def run(self, p, _, override_sampler, override_prompt, original_prompt, original_negative_prompt, override_steps, st, override_strength, cfg, randomness, sigma_adjustment):
# Override # Override
if override_sampler: if override_sampler:
p.sampler_index = [sampler.name for sampler in sd_samplers.samplers].index("Euler") p.sampler_name = "Euler"
if override_prompt: if override_prompt:
p.prompt = original_prompt p.prompt = original_prompt
p.negative_prompt = original_negative_prompt p.negative_prompt = original_negative_prompt
...@@ -191,7 +191,7 @@ class Script(scripts.Script): ...@@ -191,7 +191,7 @@ class Script(scripts.Script):
combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5) combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5)
sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, p.sampler_index, p.sd_model) sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model)
sigmas = sampler.model_wrap.get_sigmas(p.steps) sigmas = sampler.model_wrap.get_sigmas(p.steps)
......
...@@ -18,7 +18,7 @@ def draw_xy_grid(xs, ys, x_label, y_label, cell): ...@@ -18,7 +18,7 @@ def draw_xy_grid(xs, ys, x_label, y_label, cell):
ver_texts = [[images.GridAnnotation(y_label(y))] for y in ys] ver_texts = [[images.GridAnnotation(y_label(y))] for y in ys]
hor_texts = [[images.GridAnnotation(x_label(x))] for x in xs] hor_texts = [[images.GridAnnotation(x_label(x))] for x in xs]
first_pocessed = None first_processed = None
state.job_count = len(xs) * len(ys) state.job_count = len(xs) * len(ys)
...@@ -27,17 +27,17 @@ def draw_xy_grid(xs, ys, x_label, y_label, cell): ...@@ -27,17 +27,17 @@ def draw_xy_grid(xs, ys, x_label, y_label, cell):
state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}" state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
processed = cell(x, y) processed = cell(x, y)
if first_pocessed is None: if first_processed is None:
first_pocessed = processed first_processed = processed
res.append(processed.images[0]) res.append(processed.images[0])
grid = images.image_grid(res, rows=len(ys)) grid = images.image_grid(res, rows=len(ys))
grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts) grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts)
first_pocessed.images = [grid] first_processed.images = [grid]
return first_pocessed return first_processed
class Script(scripts.Script): class Script(scripts.Script):
...@@ -46,10 +46,11 @@ class Script(scripts.Script): ...@@ -46,10 +46,11 @@ class Script(scripts.Script):
def ui(self, is_img2img): def ui(self, is_img2img):
put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False) put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False)
different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False)
return [put_at_start] return [put_at_start, different_seeds]
def run(self, p, put_at_start): def run(self, p, put_at_start, different_seeds):
modules.processing.fix_seed(p) modules.processing.fix_seed(p)
original_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt original_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt
...@@ -73,15 +74,17 @@ class Script(scripts.Script): ...@@ -73,15 +74,17 @@ class Script(scripts.Script):
print(f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.") print(f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.")
p.prompt = all_prompts p.prompt = all_prompts
p.seed = [p.seed for _ in all_prompts] p.seed = [p.seed + (i if different_seeds else 0) for i in range(len(all_prompts))]
p.prompt_for_display = original_prompt p.prompt_for_display = original_prompt
processed = process_images(p) processed = process_images(p)
grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2)) grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2))
grid = images.draw_prompt_matrix(grid, p.width, p.height, prompt_matrix_parts) grid = images.draw_prompt_matrix(grid, p.width, p.height, prompt_matrix_parts)
processed.images.insert(0, grid) processed.images.insert(0, grid)
processed.index_of_first_image = 1
processed.infotexts.insert(0, processed.infotexts[0])
if opts.grid_save: if opts.grid_save:
images.save_image(processed.images[0], p.outpath_grids, "prompt_matrix", prompt=original_prompt, seed=processed.seed, grid=True, p=p) images.save_image(processed.images[0], p.outpath_grids, "prompt_matrix", extension=opts.grid_format, prompt=original_prompt, seed=processed.seed, grid=True, p=p)
return processed return processed
...@@ -9,6 +9,7 @@ import shlex ...@@ -9,6 +9,7 @@ import shlex
import modules.scripts as scripts import modules.scripts as scripts
import gradio as gr import gradio as gr
from modules import sd_samplers
from modules.processing import Processed, process_images from modules.processing import Processed, process_images
from PIL import Image from PIL import Image
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
...@@ -44,6 +45,7 @@ prompt_tags = { ...@@ -44,6 +45,7 @@ prompt_tags = {
"seed_resize_from_h": process_int_tag, "seed_resize_from_h": process_int_tag,
"seed_resize_from_w": process_int_tag, "seed_resize_from_w": process_int_tag,
"sampler_index": process_int_tag, "sampler_index": process_int_tag,
"sampler_name": process_string_tag,
"batch_size": process_int_tag, "batch_size": process_int_tag,
"n_iter": process_int_tag, "n_iter": process_int_tag,
"steps": process_int_tag, "steps": process_int_tag,
...@@ -66,14 +68,28 @@ def cmdargs(line): ...@@ -66,14 +68,28 @@ def cmdargs(line):
arg = args[pos] arg = args[pos]
assert arg.startswith("--"), f'must start with "--": {arg}' assert arg.startswith("--"), f'must start with "--": {arg}'
assert pos+1 < len(args), f'missing argument for command line option {arg}'
tag = arg[2:] tag = arg[2:]
if tag == "prompt" or tag == "negative_prompt":
pos += 1
prompt = args[pos]
pos += 1
while pos < len(args) and not args[pos].startswith("--"):
prompt += " "
prompt += args[pos]
pos += 1
res[tag] = prompt
continue
func = prompt_tags.get(tag, None) func = prompt_tags.get(tag, None)
assert func, f'unknown commandline option: {arg}' assert func, f'unknown commandline option: {arg}'
assert pos+1 < len(args), f'missing argument for command line option {arg}'
val = args[pos+1] val = args[pos+1]
if tag == "sampler_name":
val = sd_samplers.samplers_map.get(val.lower(), None)
res[tag] = func(val) res[tag] = func(val)
...@@ -124,7 +140,7 @@ class Script(scripts.Script): ...@@ -124,7 +140,7 @@ class Script(scripts.Script):
try: try:
args = cmdargs(line) args = cmdargs(line)
except Exception: except Exception:
print(f"Error parsing line [line] as commandline:", file=sys.stderr) print(f"Error parsing line {line} as commandline:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
args = {"prompt": line} args = {"prompt": line}
else: else:
...@@ -145,6 +161,8 @@ class Script(scripts.Script): ...@@ -145,6 +161,8 @@ class Script(scripts.Script):
state.job_count = job_count state.job_count = job_count
images = [] images = []
all_prompts = []
infotexts = []
for n, args in enumerate(jobs): for n, args in enumerate(jobs):
state.job = f"{state.job_no + 1} out of {state.job_count}" state.job = f"{state.job_no + 1} out of {state.job_count}"
...@@ -157,5 +175,7 @@ class Script(scripts.Script): ...@@ -157,5 +175,7 @@ class Script(scripts.Script):
if checkbox_iterate: if checkbox_iterate:
p.seed = p.seed + (p.batch_size * p.n_iter) p.seed = p.seed + (p.batch_size * p.n_iter)
all_prompts += proc.all_prompts
infotexts += proc.infotexts
return Processed(p, images, p.seed, "") return Processed(p, images, p.seed, "", all_prompts=all_prompts, infotexts=infotexts)
...@@ -17,13 +17,14 @@ class Script(scripts.Script): ...@@ -17,13 +17,14 @@ class Script(scripts.Script):
return is_img2img return is_img2img
def ui(self, is_img2img): def ui(self, is_img2img):
info = gr.HTML("<p style=\"margin-bottom:0.75em\">Will upscale the image to twice the dimensions; use width and height sliders to set tile size</p>") info = gr.HTML("<p style=\"margin-bottom:0.75em\">Will upscale the image by the selected scale factor; use width and height sliders to set tile size</p>")
overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64) overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64)
scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0)
upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
return [info, overlap, upscaler_index] return [info, overlap, upscaler_index, scale_factor]
def run(self, p, _, overlap, upscaler_index): def run(self, p, _, overlap, upscaler_index, scale_factor):
processing.fix_seed(p) processing.fix_seed(p)
upscaler = shared.sd_upscalers[upscaler_index] upscaler = shared.sd_upscalers[upscaler_index]
...@@ -34,9 +35,10 @@ class Script(scripts.Script): ...@@ -34,9 +35,10 @@ class Script(scripts.Script):
seed = p.seed seed = p.seed
init_img = p.init_images[0] init_img = p.init_images[0]
init_img = images.flatten(init_img, opts.img2img_background_color)
if(upscaler.name != "None"): if upscaler.name != "None":
img = upscaler.scaler.upscale(init_img, 2, upscaler.data_path) img = upscaler.scaler.upscale(init_img, scale_factor, upscaler.data_path)
else: else:
img = init_img img = init_img
...@@ -69,7 +71,7 @@ class Script(scripts.Script): ...@@ -69,7 +71,7 @@ class Script(scripts.Script):
work_results = [] work_results = []
for i in range(batch_count): for i in range(batch_count):
p.batch_size = batch_size p.batch_size = batch_size
p.init_images = work[i*batch_size:(i+1)*batch_size] p.init_images = work[i * batch_size:(i + 1) * batch_size]
state.job = f"Batch {i + 1 + n * batch_count} out of {state.job_count}" state.job = f"Batch {i + 1 + n * batch_count} out of {state.job_count}"
processed = processing.process_images(p) processed = processing.process_images(p)
......
...@@ -10,13 +10,16 @@ import numpy as np ...@@ -10,13 +10,16 @@ import numpy as np
import modules.scripts as scripts import modules.scripts as scripts
import gradio as gr import gradio as gr
from modules import images from modules import images, paths, sd_samplers, processing
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
from modules.processing import process_images, Processed, get_correct_sampler, StableDiffusionProcessingTxt2Img from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
import modules.sd_samplers import modules.sd_samplers
import modules.sd_models import modules.sd_models
import modules.sd_vae
import glob
import os
import re import re
...@@ -60,27 +63,17 @@ def apply_order(p, x, xs): ...@@ -60,27 +63,17 @@ def apply_order(p, x, xs):
p.prompt = prompt_tmp + p.prompt p.prompt = prompt_tmp + p.prompt
def build_samplers_dict(p):
samplers_dict = {}
for i, sampler in enumerate(get_correct_sampler(p)):
samplers_dict[sampler.name.lower()] = i
for alias in sampler.aliases:
samplers_dict[alias.lower()] = i
return samplers_dict
def apply_sampler(p, x, xs): def apply_sampler(p, x, xs):
sampler_index = build_samplers_dict(p).get(x.lower(), None) sampler_name = sd_samplers.samplers_map.get(x.lower(), None)
if sampler_index is None: if sampler_name is None:
raise RuntimeError(f"Unknown sampler: {x}") raise RuntimeError(f"Unknown sampler: {x}")
p.sampler_index = sampler_index p.sampler_name = sampler_name
def confirm_samplers(p, xs): def confirm_samplers(p, xs):
samplers_dict = build_samplers_dict(p)
for x in xs: for x in xs:
if x.lower() not in samplers_dict.keys(): if x.lower() not in sd_samplers.samplers_map:
raise RuntimeError(f"Unknown sampler: {x}") raise RuntimeError(f"Unknown sampler: {x}")
...@@ -124,6 +117,38 @@ def apply_clip_skip(p, x, xs): ...@@ -124,6 +117,38 @@ def apply_clip_skip(p, x, xs):
opts.data["CLIP_stop_at_last_layers"] = x opts.data["CLIP_stop_at_last_layers"] = x
def apply_upscale_latent_space(p, x, xs):
if x.lower().strip() != '0':
opts.data["use_scale_latent_for_hires_fix"] = True
else:
opts.data["use_scale_latent_for_hires_fix"] = False
def find_vae(name: str):
if name.lower() in ['auto', 'none']:
return name
else:
vae_path = os.path.abspath(os.path.join(paths.models_path, 'VAE'))
found = glob.glob(os.path.join(vae_path, f'**/{name}.*pt'), recursive=True)
if found:
return found[0]
else:
return 'auto'
def apply_vae(p, x, xs):
if x.lower().strip() == 'none':
modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file='None')
else:
found = find_vae(x)
if found:
v = modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=found)
def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _):
p.styles = x.split(',')
def format_value_add_label(p, opt, x): def format_value_add_label(p, opt, x):
if type(x) == float: if type(x) == float:
x = round(x, 8) x = round(x, 8)
...@@ -177,7 +202,10 @@ axis_options = [ ...@@ -177,7 +202,10 @@ axis_options = [
AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None), AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None),
AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None), AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None),
AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None), AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None),
AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), format_value_add_label, None),
AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None), AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None),
AxisOption("VAE", str, apply_vae, format_value_add_label, None),
AxisOption("Styles", str, apply_styles, format_value_add_label, None),
] ]
...@@ -239,9 +267,11 @@ class SharedSettingsStackHelper(object): ...@@ -239,9 +267,11 @@ class SharedSettingsStackHelper(object):
self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
self.hypernetwork = opts.sd_hypernetwork self.hypernetwork = opts.sd_hypernetwork
self.model = shared.sd_model self.model = shared.sd_model
self.vae = opts.sd_vae
def __exit__(self, exc_type, exc_value, tb): def __exit__(self, exc_type, exc_value, tb):
modules.sd_models.reload_model_weights(self.model) modules.sd_models.reload_model_weights(self.model)
modules.sd_vae.reload_vae_weights(self.model, vae_file=find_vae(self.vae))
hypernetwork.load_hypernetwork(self.hypernetwork) hypernetwork.load_hypernetwork(self.hypernetwork)
hypernetwork.apply_strength() hypernetwork.apply_strength()
...@@ -255,6 +285,7 @@ re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d ...@@ -255,6 +285,7 @@ re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d
re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*") re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*")
re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*") re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*")
class Script(scripts.Script): class Script(scripts.Script):
def title(self): def title(self):
return "X/Y plot" return "X/Y plot"
...@@ -351,7 +382,7 @@ class Script(scripts.Script): ...@@ -351,7 +382,7 @@ class Script(scripts.Script):
ys = process_axis(y_opt, y_values) ys = process_axis(y_opt, y_values)
def fix_axis_seeds(axis_opt, axis_list): def fix_axis_seeds(axis_opt, axis_list):
if axis_opt.label in ['Seed','Var. seed']: if axis_opt.label in ['Seed', 'Var. seed']:
return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list] return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list]
else: else:
return axis_list return axis_list
...@@ -373,12 +404,33 @@ class Script(scripts.Script): ...@@ -373,12 +404,33 @@ class Script(scripts.Script):
print(f"X/Y plot will create {len(xs) * len(ys) * p.n_iter} images on a {len(xs)}x{len(ys)} grid. (Total steps to process: {total_steps * p.n_iter})") print(f"X/Y plot will create {len(xs) * len(ys) * p.n_iter} images on a {len(xs)}x{len(ys)} grid. (Total steps to process: {total_steps * p.n_iter})")
shared.total_tqdm.updateTotal(total_steps * p.n_iter) shared.total_tqdm.updateTotal(total_steps * p.n_iter)
grid_infotext = [None]
def cell(x, y): def cell(x, y):
pc = copy(p) pc = copy(p)
x_opt.apply(pc, x, xs) x_opt.apply(pc, x, xs)
y_opt.apply(pc, y, ys) y_opt.apply(pc, y, ys)
return process_images(pc) res = process_images(pc)
if grid_infotext[0] is None:
pc.extra_generation_params = copy(pc.extra_generation_params)
if x_opt.label != 'Nothing':
pc.extra_generation_params["X Type"] = x_opt.label
pc.extra_generation_params["X Values"] = x_values
if x_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
pc.extra_generation_params["Fixed X Values"] = ", ".join([str(x) for x in xs])
if y_opt.label != 'Nothing':
pc.extra_generation_params["Y Type"] = y_opt.label
pc.extra_generation_params["Y Values"] = y_values
if y_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
pc.extra_generation_params["Fixed Y Values"] = ", ".join([str(y) for y in ys])
grid_infotext[0] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds)
return res
with SharedSettingsStackHelper(): with SharedSettingsStackHelper():
processed = draw_xy_grid( processed = draw_xy_grid(
...@@ -393,6 +445,6 @@ class Script(scripts.Script): ...@@ -393,6 +445,6 @@ class Script(scripts.Script):
) )
if opts.grid_save: if opts.grid_save:
images.save_image(processed.images[0], p.outpath_grids, "xy_grid", prompt=p.prompt, seed=processed.seed, grid=True, p=p) images.save_image(processed.images[0], p.outpath_grids, "xy_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
return processed return processed
...@@ -73,8 +73,9 @@ ...@@ -73,8 +73,9 @@
margin-right: auto; margin-right: auto;
} }
#random_seed, #random_subseed, #reuse_seed, #reuse_subseed, #open_folder{ [id$=_random_seed], [id$=_random_subseed], [id$=_reuse_seed], [id$=_reuse_subseed], #open_folder{
min-width: auto; min-width: 2.3em;
height: 2.5em;
flex-grow: 0; flex-grow: 0;
padding-left: 0.25em; padding-left: 0.25em;
padding-right: 0.25em; padding-right: 0.25em;
...@@ -84,27 +85,28 @@ ...@@ -84,27 +85,28 @@
display: none; display: none;
} }
#seed_row, #subseed_row{ [id$=_seed_row], [id$=_subseed_row]{
gap: 0.5rem; gap: 0.5rem;
padding: 0.6em;
} }
#subseed_show_box{ [id$=_subseed_show_box]{
min-width: auto; min-width: auto;
flex-grow: 0; flex-grow: 0;
} }
#subseed_show_box > div{ [id$=_subseed_show_box] > div{
border: 0; border: 0;
height: 100%; height: 100%;
} }
#subseed_show{ [id$=_subseed_show]{
min-width: auto; min-width: auto;
flex-grow: 0; flex-grow: 0;
padding: 0; padding: 0;
} }
#subseed_show label{ [id$=_subseed_show] label{
height: 100%; height: 100%;
} }
...@@ -114,7 +116,7 @@ ...@@ -114,7 +116,7 @@
padding: 0.4em 0; padding: 0.4em 0;
} }
#roll, #paste, #style_create, #style_apply{ #roll_col > button {
min-width: 2em; min-width: 2em;
min-height: 2em; min-height: 2em;
max-width: 2em; max-width: 2em;
...@@ -206,24 +208,24 @@ button{ ...@@ -206,24 +208,24 @@ button{
fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block span{ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block span{
position: absolute; position: absolute;
top: -0.6em; top: -0.7em;
line-height: 1.2em; line-height: 1.2em;
padding: 0 0.5em; padding: 0;
margin: 0; margin: 0 0.5em;
background-color: white; background-color: white;
border-top: 1px solid #eee; box-shadow: 6px 0 6px 0px white, -6px 0 6px 0px white;
border-left: 1px solid #eee;
border-right: 1px solid #eee;
z-index: 300; z-index: 300;
} }
.dark fieldset span.text-gray-500, .dark .gr-block.gr-box span.text-gray-500, .dark label.block span{ .dark fieldset span.text-gray-500, .dark .gr-block.gr-box span.text-gray-500, .dark label.block span{
background-color: rgb(31, 41, 55); background-color: rgb(31, 41, 55);
border-top: 1px solid rgb(55 65 81); box-shadow: 6px 0 6px 0px rgb(31, 41, 55), -6px 0 6px 0px rgb(31, 41, 55);
border-left: 1px solid rgb(55 65 81); }
border-right: 1px solid rgb(55 65 81);
#txt2img_column_batch, #img2img_column_batch{
min-width: min(13.5em, 100%) !important;
} }
#settings fieldset span.text-gray-500, #settings .gr-block.gr-box span.text-gray-500, #settings label.block span{ #settings fieldset span.text-gray-500, #settings .gr-block.gr-box span.text-gray-500, #settings label.block span{
...@@ -232,22 +234,40 @@ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block s ...@@ -232,22 +234,40 @@ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block s
margin-right: 8em; margin-right: 8em;
} }
.gr-panel div.flex-col div.justify-between label span{
margin: 0;
}
#settings .gr-panel div.flex-col div.justify-between div{ #settings .gr-panel div.flex-col div.justify-between div{
position: relative; position: relative;
z-index: 200; z-index: 200;
} }
input[type="range"]{ #settings{
margin: 0.5em 0 -0.3em 0; display: block;
}
#settings > div{
border: none;
margin-left: 10em;
}
#settings > div.flex-wrap{
float: left;
display: block;
margin-left: 0;
width: 10em;
}
#settings > div.flex-wrap button{
display: block;
border: none;
text-align: left;
} }
#txt2img_sampling label{ #settings_result{
padding-left: 0.6em; height: 1.4em;
padding-right: 0.6em; margin: 0 1.2em;
}
input[type="range"]{
margin: 0.5em 0 -0.3em 0;
} }
#mask_bug_info { #mask_bug_info {
...@@ -501,13 +521,6 @@ input[type="range"]{ ...@@ -501,13 +521,6 @@ input[type="range"]{
padding: 0; padding: 0;
} }
#refresh_sd_model_checkpoint, #refresh_sd_vae, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization{
max-width: 2.5em;
min-width: 2.5em;
height: 2.4em;
}
canvas[key="mask"] { canvas[key="mask"] {
z-index: 12 !important; z-index: 12 !important;
filter: invert(); filter: invert();
...@@ -521,7 +534,7 @@ canvas[key="mask"] { ...@@ -521,7 +534,7 @@ canvas[key="mask"] {
position: absolute; position: absolute;
right: 0.5em; right: 0.5em;
top: -0.6em; top: -0.6em;
z-index: 200; z-index: 400;
width: 8em; width: 8em;
} }
#quicksettings .gr-box > div > div > input.gr-text-input { #quicksettings .gr-box > div > div > input.gr-text-input {
...@@ -568,6 +581,53 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h ...@@ -568,6 +581,53 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h
font-size: 95%; font-size: 95%;
} }
#image_buttons_txt2img button, #image_buttons_img2img button, #image_buttons_extras button{
min-width: auto;
padding-left: 0.5em;
padding-right: 0.5em;
}
.gr-form{
background-color: white;
}
.dark .gr-form{
background-color: rgb(31 41 55 / var(--tw-bg-opacity));
}
.gr-button-tool{
max-width: 2.5em;
min-width: 2.5em !important;
height: 2.4em;
margin: 0.55em 0;
}
#quicksettings .gr-button-tool{
margin: 0;
}
#img2img_settings > div.gr-form, #txt2img_settings > div.gr-form {
padding-top: 0.9em;
}
#img2img_settings div.gr-form .gr-form, #txt2img_settings div.gr-form .gr-form{
border: none;
padding-bottom: 0.5em;
}
footer {
display: none !important;
}
#footer{
text-align: center;
}
#footer div{
display: inline-block;
}
/* The following handles localization for right-to-left (RTL) languages like Arabic. /* The following handles localization for right-to-left (RTL) languages like Arabic.
The rtl media type will only be activated by the logic in javascript/localization.js. The rtl media type will only be activated by the logic in javascript/localization.js.
If you change anything above, you need to make sure it is RTL compliant by just running If you change anything above, you need to make sure it is RTL compliant by just running
......
...@@ -11,8 +11,8 @@ class TestExtrasWorking(unittest.TestCase): ...@@ -11,8 +11,8 @@ class TestExtrasWorking(unittest.TestCase):
"codeformer_visibility": 0, "codeformer_visibility": 0,
"codeformer_weight": 0, "codeformer_weight": 0,
"upscaling_resize": 2, "upscaling_resize": 2,
"upscaling_resize_w": 512, "upscaling_resize_w": 128,
"upscaling_resize_h": 512, "upscaling_resize_h": 128,
"upscaling_crop": True, "upscaling_crop": True,
"upscaler_1": "None", "upscaler_1": "None",
"upscaler_2": "None", "upscaler_2": "None",
......
import unittest
import requests
class TestTxt2ImgWorking(unittest.TestCase):
def setUp(self):
self.url_txt2img = "http://localhost:7860/sdapi/v1/txt2img"
self.simple_txt2img = {
"enable_hr": False,
"denoising_strength": 0,
"firstphase_width": 0,
"firstphase_height": 0,
"prompt": "example prompt",
"styles": [],
"seed": -1,
"subseed": -1,
"subseed_strength": 0,
"seed_resize_from_h": -1,
"seed_resize_from_w": -1,
"batch_size": 1,
"n_iter": 1,
"steps": 3,
"cfg_scale": 7,
"width": 64,
"height": 64,
"restore_faces": False,
"tiling": False,
"negative_prompt": "",
"eta": 0,
"s_churn": 0,
"s_tmax": 0,
"s_tmin": 0,
"s_noise": 1,
"sampler_index": "Euler a"
}
def test_txt2img_with_restore_faces_performed(self):
self.simple_txt2img["restore_faces"] = True
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
class TestTxt2ImgCorrectness(unittest.TestCase):
pass
if __name__ == "__main__":
unittest.main()
...@@ -51,9 +51,5 @@ class TestImg2ImgWorking(unittest.TestCase): ...@@ -51,9 +51,5 @@ class TestImg2ImgWorking(unittest.TestCase):
self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200) self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
class TestImg2ImgCorrectness(unittest.TestCase):
pass
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -49,26 +49,20 @@ class TestTxt2ImgWorking(unittest.TestCase): ...@@ -49,26 +49,20 @@ class TestTxt2ImgWorking(unittest.TestCase):
self.simple_txt2img["enable_hr"] = True self.simple_txt2img["enable_hr"] = True
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
def test_txt2img_with_restore_faces_performed(self): def test_txt2img_with_tiling_performed(self):
self.simple_txt2img["restore_faces"] = True
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
def test_txt2img_with_tiling_faces_performed(self):
self.simple_txt2img["tiling"] = True self.simple_txt2img["tiling"] = True
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
def test_txt2img_with_vanilla_sampler_performed(self): def test_txt2img_with_vanilla_sampler_performed(self):
self.simple_txt2img["sampler_index"] = "PLMS" self.simple_txt2img["sampler_index"] = "PLMS"
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
self.simple_txt2img["sampler_index"] = "DDIM"
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
def test_txt2img_multiple_batches_performed(self): def test_txt2img_multiple_batches_performed(self):
self.simple_txt2img["n_iter"] = 2 self.simple_txt2img["n_iter"] = 2
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
class TestTxt2ImgCorrectness(unittest.TestCase):
pass
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -18,20 +18,6 @@ class UtilsTests(unittest.TestCase): ...@@ -18,20 +18,6 @@ class UtilsTests(unittest.TestCase):
def test_options_get(self): def test_options_get(self):
self.assertEqual(requests.get(self.url_options).status_code, 200) self.assertEqual(requests.get(self.url_options).status_code, 200)
def test_options_write(self):
response = requests.get(self.url_options)
self.assertEqual(response.status_code, 200)
pre_value = response.json()["send_seed"]
self.assertEqual(requests.post(self.url_options, json={"send_seed":not pre_value}).status_code, 200)
response = requests.get(self.url_options)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json()["send_seed"], not pre_value)
requests.post(self.url_options, json={"send_seed": pre_value})
def test_cmd_flags(self): def test_cmd_flags(self):
self.assertEqual(requests.get(self.url_cmd_flags).status_code, 200) self.assertEqual(requests.get(self.url_cmd_flags).status_code, 200)
...@@ -61,3 +47,7 @@ class UtilsTests(unittest.TestCase): ...@@ -61,3 +47,7 @@ class UtilsTests(unittest.TestCase):
def test_artists(self): def test_artists(self):
self.assertEqual(requests.get(self.url_artists).status_code, 200) self.assertEqual(requests.get(self.url_artists).status_code, 200)
if __name__ == "__main__":
unittest.main()
...@@ -3,7 +3,7 @@ import requests ...@@ -3,7 +3,7 @@ import requests
import time import time
def run_tests(): def run_tests(proc, test_dir):
timeout_threshold = 240 timeout_threshold = 240
start_time = time.time() start_time = time.time()
while time.time()-start_time < timeout_threshold: while time.time()-start_time < timeout_threshold:
...@@ -11,9 +11,14 @@ def run_tests(): ...@@ -11,9 +11,14 @@ def run_tests():
requests.head("http://localhost:7860/") requests.head("http://localhost:7860/")
break break
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
pass if proc.poll() is not None:
if time.time()-start_time < timeout_threshold: break
suite = unittest.TestLoader().discover('', pattern='*_test.py') if proc.poll() is None:
if test_dir is None:
test_dir = ""
suite = unittest.TestLoader().discover(test_dir, pattern="*_test.py", top_level_dir="test")
result = unittest.TextTestRunner(verbosity=2).run(suite) result = unittest.TextTestRunner(verbosity=2).run(suite)
return len(result.failures) + len(result.errors)
else: else:
print("Launch unsuccessful") print("Launch unsuccessful")
return 1
model:
base_learning_rate: 1.0e-4
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
parameterization: "v"
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False # we set this to false because this is an inference only config
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
use_fp16: True
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
#attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"
\ No newline at end of file
#!/bin/bash
####################################################################
# macOS defaults #
# Please modify webui-user.sh to change these instead of this file #
####################################################################
if [[ -x "$(command -v python3.10)" ]]
then
python_cmd="python3.10"
fi
export install_dir="$HOME"
export COMMANDLINE_ARGS="--skip-torch-cuda-test --no-half --use-cpu interrogate"
export TORCH_COMMAND="pip install torch==1.12.1 torchvision==0.13.1"
export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git"
export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71"
export PYTORCH_ENABLE_MPS_FALLBACK=1
####################################################################
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#clone_dir="stable-diffusion-webui" #clone_dir="stable-diffusion-webui"
# Commandline arguments for webui.py, for example: export COMMANDLINE_ARGS="--medvram --opt-split-attention" # Commandline arguments for webui.py, for example: export COMMANDLINE_ARGS="--medvram --opt-split-attention"
export COMMANDLINE_ARGS="" #export COMMANDLINE_ARGS=""
# python3 executable # python3 executable
#python_cmd="python3" #python_cmd="python3"
...@@ -40,4 +40,7 @@ export COMMANDLINE_ARGS="" ...@@ -40,4 +40,7 @@ export COMMANDLINE_ARGS=""
#export CODEFORMER_COMMIT_HASH="" #export CODEFORMER_COMMIT_HASH=""
#export BLIP_COMMIT_HASH="" #export BLIP_COMMIT_HASH=""
# Uncomment to enable accelerated launch
#export ACCELERATE="True"
########################################### ###########################################
...@@ -28,15 +28,27 @@ goto :show_stdout_stderr ...@@ -28,15 +28,27 @@ goto :show_stdout_stderr
:activate_venv :activate_venv
set PYTHON="%~dp0%VENV_DIR%\Scripts\Python.exe" set PYTHON="%~dp0%VENV_DIR%\Scripts\Python.exe"
echo venv %PYTHON% echo venv %PYTHON%
if [%ACCELERATE%] == ["True"] goto :accelerate
goto :launch goto :launch
:skip_venv :skip_venv
:accelerate
echo "Checking for accelerate"
set ACCELERATE="%~dp0%VENV_DIR%\Scripts\accelerate.exe"
if EXIST %ACCELERATE% goto :accelerate_launch
:launch :launch
%PYTHON% launch.py %* %PYTHON% launch.py %*
pause pause
exit /b exit /b
:accelerate_launch
echo "Accelerating"
%ACCELERATE% launch --num_cpu_threads_per_process=6 launch.py
pause
exit /b
:show_stdout_stderr :show_stdout_stderr
echo. echo.
......
import os import os
import sys
import threading import threading
import time import time
import importlib import importlib
...@@ -8,9 +9,11 @@ from fastapi import FastAPI ...@@ -8,9 +9,11 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware from fastapi.middleware.gzip import GZipMiddleware
from modules import import_hook, errors
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
from modules.paths import script_path from modules.paths import script_path
from modules import devices, sd_samplers, upscaler, extensions, localization from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir
import modules.codeformer_model as codeformer import modules.codeformer_model as codeformer
import modules.extras import modules.extras
import modules.face_restoration import modules.face_restoration
...@@ -23,7 +26,6 @@ import modules.scripts ...@@ -23,7 +26,6 @@ import modules.scripts
import modules.sd_hijack import modules.sd_hijack
import modules.sd_models import modules.sd_models
import modules.sd_vae import modules.sd_vae
import modules.shared as shared
import modules.txt2img import modules.txt2img
import modules.script_callbacks import modules.script_callbacks
...@@ -32,32 +34,11 @@ from modules import modelloader ...@@ -32,32 +34,11 @@ from modules import modelloader
from modules.shared import cmd_opts from modules.shared import cmd_opts
import modules.hypernetworks.hypernetwork import modules.hypernetworks.hypernetwork
queue_lock = threading.Lock()
server_name = "0.0.0.0" if cmd_opts.listen else cmd_opts.server_name
def wrap_queued_call(func): if cmd_opts.server_name:
def f(*args, **kwargs): server_name = cmd_opts.server_name
with queue_lock: else:
res = func(*args, **kwargs) server_name = "0.0.0.0" if cmd_opts.listen else None
return res
return f
def wrap_gradio_gpu_call(func, extra_outputs=None):
def f(*args, **kwargs):
shared.state.begin()
with queue_lock:
res = func(*args, **kwargs)
shared.state.end()
return res
return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
def initialize(): def initialize():
...@@ -74,16 +55,27 @@ def initialize(): ...@@ -74,16 +55,27 @@ def initialize():
codeformer.setup_model(cmd_opts.codeformer_models_path) codeformer.setup_model(cmd_opts.codeformer_models_path)
gfpgan.setup_model(cmd_opts.gfpgan_models_path) gfpgan.setup_model(cmd_opts.gfpgan_models_path)
shared.face_restorers.append(modules.face_restoration.FaceRestoration()) shared.face_restorers.append(modules.face_restoration.FaceRestoration())
modelloader.load_upscalers()
modelloader.list_builtin_upscalers()
modules.scripts.load_scripts() modules.scripts.load_scripts()
modelloader.load_upscalers()
modules.sd_vae.refresh_vae_list() modules.sd_vae.refresh_vae_list()
try:
modules.sd_models.load_model() modules.sd_models.load_model()
except Exception as e:
errors.display(e, "loading stable diffusion model")
print("", file=sys.stderr)
print("Stable diffusion model failed to load, exiting", file=sys.stderr)
exit(1)
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: shared.reload_hypernetworks()))
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None: if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
...@@ -107,8 +99,12 @@ def initialize(): ...@@ -107,8 +99,12 @@ def initialize():
def setup_cors(app): def setup_cors(app):
if cmd_opts.cors_allow_origins: if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex:
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*']) app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
elif cmd_opts.cors_allow_origins:
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
elif cmd_opts.cors_allow_origins_regex:
app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
def create_api(app): def create_api(app):
...@@ -146,9 +142,12 @@ def webui(): ...@@ -146,9 +142,12 @@ def webui():
initialize() initialize()
while 1: while 1:
demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) if shared.opts.clean_temp_dir_at_start:
ui_tempdir.cleanup_tmpdr()
app, local_url, share_url = demo.launch( shared.demo = modules.ui.create_ui()
app, local_url, share_url = shared.demo.queue(default_enabled=False).launch(
share=cmd_opts.share, share=cmd_opts.share,
server_name=server_name, server_name=server_name,
server_port=cmd_opts.port, server_port=cmd_opts.port,
...@@ -164,8 +163,8 @@ def webui(): ...@@ -164,8 +163,8 @@ def webui():
# gradio uses a very open CORS policy via app.user_middleware, which makes it possible for # gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
# an attacker to trick the user into opening a malicious HTML page, which makes a request to the # an attacker to trick the user into opening a malicious HTML page, which makes a request to the
# running web ui and do whatever the attcker wants, including installing an extension and # running web ui and do whatever the attacker wants, including installing an extension and
# runnnig its code. We disable this here. Suggested by RyotaK. # running its code. We disable this here. Suggested by RyotaK.
app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware'] app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
setup_cors(app) setup_cors(app)
...@@ -175,24 +174,26 @@ def webui(): ...@@ -175,24 +174,26 @@ def webui():
if launch_api: if launch_api:
create_api(app) create_api(app)
modules.script_callbacks.app_started_callback(demo, app) modules.script_callbacks.app_started_callback(shared.demo, app)
modules.script_callbacks.app_started_callback(shared.demo, app)
wait_on_server(demo) wait_on_server(shared.demo)
print('Restarting UI...')
sd_samplers.set_samplers() sd_samplers.set_samplers()
print('Reloading extensions')
extensions.list_extensions() extensions.list_extensions()
localization.list_localizations(cmd_opts.localizations_dir) localization.list_localizations(cmd_opts.localizations_dir)
print('Reloading custom scripts') modelloader.forbid_loaded_nonbuiltin_upscalers()
modules.scripts.reload_scripts() modules.scripts.reload_scripts()
print('Reloading modules: modules.ui') modelloader.load_upscalers()
importlib.reload(modules.ui)
print('Refreshing Model List') for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
importlib.reload(module)
modules.sd_models.list_models() modules.sd_models.list_models()
print('Restarting Gradio')
if __name__ == "__main__": if __name__ == "__main__":
......
#!/bin/bash #!/usr/bin/env bash
################################################# #################################################
# Please do not make any changes to this file, # # Please do not make any changes to this file, #
# change the variables in webui-user.sh instead # # change the variables in webui-user.sh instead #
################################################# #################################################
# If run from macOS, load defaults from webui-macos-env.sh
if [[ "$OSTYPE" == "darwin"* ]]; then
if [[ -f webui-macos-env.sh ]]
then
source ./webui-macos-env.sh
fi
fi
# Read variables from webui-user.sh # Read variables from webui-user.sh
# shellcheck source=/dev/null # shellcheck source=/dev/null
if [[ -f webui-user.sh ]] if [[ -f webui-user.sh ]]
...@@ -46,6 +55,18 @@ then ...@@ -46,6 +55,18 @@ then
LAUNCH_SCRIPT="launch.py" LAUNCH_SCRIPT="launch.py"
fi fi
# this script cannot be run as root by default
can_run_as_root=0
# read any command line flags to the webui.sh script
while getopts "f" flag > /dev/null 2>&1
do
case ${flag} in
f) can_run_as_root=1;;
*) break;;
esac
done
# Disable sentry logging # Disable sentry logging
export ERROR_REPORTING=FALSE export ERROR_REPORTING=FALSE
...@@ -61,7 +82,7 @@ printf "\e[1m\e[34mTested on Debian 11 (Bullseye)\e[0m" ...@@ -61,7 +82,7 @@ printf "\e[1m\e[34mTested on Debian 11 (Bullseye)\e[0m"
printf "\n%s\n" "${delimiter}" printf "\n%s\n" "${delimiter}"
# Do not run as root # Do not run as root
if [[ $(id -u) -eq 0 ]] if [[ $(id -u) -eq 0 && can_run_as_root -eq 0 ]]
then then
printf "\n%s\n" "${delimiter}" printf "\n%s\n" "${delimiter}"
printf "\e[1m\e[31mERROR: This script must not be launched as root, aborting...\e[0m" printf "\e[1m\e[31mERROR: This script must not be launched as root, aborting...\e[0m"
...@@ -134,7 +155,15 @@ else ...@@ -134,7 +155,15 @@ else
exit 1 exit 1
fi fi
printf "\n%s\n" "${delimiter}" if [[ ! -z "${ACCELERATE}" ]] && [ ${ACCELERATE}="True" ] && [ -x "$(command -v accelerate)" ]
printf "Launching launch.py..." then
printf "\n%s\n" "${delimiter}" printf "\n%s\n" "${delimiter}"
"${python_cmd}" "${LAUNCH_SCRIPT}" "$@" printf "Accelerating launch.py..."
printf "\n%s\n" "${delimiter}"
accelerate launch --num_cpu_threads_per_process=6 "${LAUNCH_SCRIPT}" "$@"
else
printf "\n%s\n" "${delimiter}"
printf "Launching launch.py..."
printf "\n%s\n" "${delimiter}"
"${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
fi
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment