Unverified Commit 263b323d authored by xucj98's avatar xucj98 Committed by GitHub

Merge branch 'AUTOMATIC1111:master' into draft

parents d20dbe47 828438b4
...@@ -70,7 +70,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web ...@@ -70,7 +70,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- separate prompts using uppercase `AND` - separate prompts using uppercase `AND`
- also supports weights for prompts: `a cat :1.2 AND a dog AND a penguin :2.2` - also supports weights for prompts: `a cat :1.2 AND a dog AND a penguin :2.2`
- No token limit for prompts (original stable diffusion lets you use up to 75 tokens) - No token limit for prompts (original stable diffusion lets you use up to 75 tokens)
- DeepDanbooru integration, creates danbooru style tags for anime prompts (add --deepdanbooru to commandline args) - DeepDanbooru integration, creates danbooru style tags for anime prompts
- [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args) - [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args)
- via extension: [History tab](https://github.com/yfszzx/stable-diffusion-webui-images-browser): view, direct and delete images conveniently within the UI - via extension: [History tab](https://github.com/yfszzx/stable-diffusion-webui-images-browser): view, direct and delete images conveniently within the UI
- Generate forever option - Generate forever option
......
...@@ -134,14 +134,13 @@ def prepare_enviroment(): ...@@ -134,14 +134,13 @@ def prepare_enviroment():
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379") gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1") clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
deepdanbooru_package = os.environ.get('DEEPDANBOORU_PACKAGE', "git+https://github.com/KichangKim/DeepDanbooru.git@d91a2963bf87c6a770d74894667e9ffa9f6de7ff")
xformers_windows_package = os.environ.get('XFORMERS_WINDOWS_PACKAGE', 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl') xformers_windows_package = os.environ.get('XFORMERS_WINDOWS_PACKAGE', 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl')
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/CompVis/stable-diffusion.git") stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/CompVis/stable-diffusion.git")
taming_transformers_repo = os.environ.get('TAMING_REANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git") taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
codeformer_repo = os.environ.get('CODEFORMET_REPO', 'https://github.com/sczhou/CodeFormer.git') codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc") stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
...@@ -158,7 +157,6 @@ def prepare_enviroment(): ...@@ -158,7 +157,6 @@ def prepare_enviroment():
sys.argv, update_check = extract_arg(sys.argv, '--update-check') sys.argv, update_check = extract_arg(sys.argv, '--update-check')
sys.argv, run_tests = extract_arg(sys.argv, '--tests') sys.argv, run_tests = extract_arg(sys.argv, '--tests')
xformers = '--xformers' in sys.argv xformers = '--xformers' in sys.argv
deepdanbooru = '--deepdanbooru' in sys.argv
ngrok = '--ngrok' in sys.argv ngrok = '--ngrok' in sys.argv
try: try:
...@@ -193,9 +191,6 @@ def prepare_enviroment(): ...@@ -193,9 +191,6 @@ def prepare_enviroment():
elif platform.system() == "Linux": elif platform.system() == "Linux":
run_pip("install xformers", "xformers") run_pip("install xformers", "xformers")
if not is_installed("deepdanbooru") and deepdanbooru:
run_pip(f"install {deepdanbooru_package}#egg=deepdanbooru[tensorflow] tensorflow==2.10.0 tensorflow-io==0.27.0", "deepdanbooru")
if not is_installed("pyngrok") and ngrok: if not is_installed("pyngrok") and ngrok:
run_pip("install pyngrok", "ngrok") run_pip("install pyngrok", "ngrok")
......
This diff is collapsed.
...@@ -176,9 +176,9 @@ class InterrogateResponse(BaseModel): ...@@ -176,9 +176,9 @@ class InterrogateResponse(BaseModel):
caption: str = Field(default=None, title="Caption", description="The generated caption for the image.") caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
fields = {} fields = {}
for key, value in opts.data.items(): for key, metadata in opts.data_labels.items():
metadata = opts.data_labels.get(key) value = opts.data.get(key)
optType = opts.typemap.get(type(value), type(value)) optType = opts.typemap.get(type(metadata.default), type(value))
if (metadata is not None): if (metadata is not None):
fields.update({key: (Optional[optType], Field( fields.update({key: (Optional[optType], Field(
......
import os.path import os
from concurrent.futures import ProcessPoolExecutor
import multiprocessing
import time
import re import re
import torch
from PIL import Image
import numpy as np
from modules import modelloader, paths, deepbooru_model, devices, images, shared
re_special = re.compile(r'([\\()])') re_special = re.compile(r'([\\()])')
def get_deepbooru_tags(pil_image):
""" class DeepDanbooru:
This method is for running only one image at a time for simple use. Used to the img2img interrogate. def __init__(self):
""" self.model = None
from modules import shared # prevents circular reference
def load(self):
try: if self.model is not None:
create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, create_deepbooru_opts()) return
return get_tags_from_process(pil_image)
finally: files = modelloader.load_models(
release_process() model_path=os.path.join(paths.models_path, "torch_deepdanbooru"),
model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt',
ext_filter=".pt",
OPT_INCLUDE_RANKS = "include_ranks" download_name='model-resnet_custom_v3.pt',
def create_deepbooru_opts():
from modules import shared
return {
"use_spaces": shared.opts.deepbooru_use_spaces,
"use_escape": shared.opts.deepbooru_escape,
"alpha_sort": shared.opts.deepbooru_sort_alpha,
OPT_INCLUDE_RANKS: shared.opts.interrogate_return_ranks,
}
def deepbooru_process(queue, deepbooru_process_return, threshold, deepbooru_opts):
model, tags = get_deepbooru_tags_model()
while True: # while process is running, keep monitoring queue for new image
pil_image = queue.get()
if pil_image == "QUIT":
break
else:
deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts)
def create_deepbooru_process(threshold, deepbooru_opts):
"""
Creates deepbooru process. A queue is created to send images into the process. This enables multiple images
to be processed in a row without reloading the model or creating a new process. To return the data, a shared
dictionary is created to hold the tags created. To wait for tags to be returned, a value of -1 is assigned
to the dictionary and the method adding the image to the queue should wait for this value to be updated with
the tags.
"""
from modules import shared # prevents circular reference
context = multiprocessing.get_context("spawn")
shared.deepbooru_process_manager = context.Manager()
shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue()
shared.deepbooru_process_return = shared.deepbooru_process_manager.dict()
shared.deepbooru_process_return["value"] = -1
shared.deepbooru_process = context.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts))
shared.deepbooru_process.start()
def get_tags_from_process(image):
from modules import shared
shared.deepbooru_process_return["value"] = -1
shared.deepbooru_process_queue.put(image)
while shared.deepbooru_process_return["value"] == -1:
time.sleep(0.2)
caption = shared.deepbooru_process_return["value"]
shared.deepbooru_process_return["value"] = -1
return caption
def release_process():
"""
Stops the deepbooru process to return used memory
"""
from modules import shared # prevents circular reference
shared.deepbooru_process_queue.put("QUIT")
shared.deepbooru_process.join()
shared.deepbooru_process_queue = None
shared.deepbooru_process = None
shared.deepbooru_process_return = None
shared.deepbooru_process_manager = None
def get_deepbooru_tags_model():
import deepdanbooru as dd
import tensorflow as tf
import numpy as np
this_folder = os.path.dirname(__file__)
model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru'))
if not os.path.exists(os.path.join(model_path, 'project.json')):
# there is no point importing these every time
import zipfile
from basicsr.utils.download_util import load_file_from_url
load_file_from_url(
r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip",
model_path)
with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref:
zip_ref.extractall(model_path)
os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"))
tags = dd.project.load_tags_from_project(model_path)
model = dd.project.load_model_from_project(
model_path, compile_model=False
)
return model, tags
def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts):
import deepdanbooru as dd
import tensorflow as tf
import numpy as np
alpha_sort = deepbooru_opts['alpha_sort']
use_spaces = deepbooru_opts['use_spaces']
use_escape = deepbooru_opts['use_escape']
include_ranks = deepbooru_opts['include_ranks']
width = model.input_shape[2]
height = model.input_shape[1]
image = np.array(pil_image)
image = tf.image.resize(
image,
size=(height, width),
method=tf.image.ResizeMethod.AREA,
preserve_aspect_ratio=True,
) )
image = image.numpy() # EagerTensor to np.array
image = dd.image.transform_and_pad_image(image, width, height)
image = image / 255.0
image_shape = image.shape
image = image.reshape((1, image_shape[0], image_shape[1], image_shape[2]))
y = model.predict(image)[0] self.model = deepbooru_model.DeepDanbooruModel()
self.model.load_state_dict(torch.load(files[0], map_location="cpu"))
result_dict = {} self.model.eval()
self.model.to(devices.cpu, devices.dtype)
for i, tag in enumerate(tags): def start(self):
result_dict[tag] = y[i] self.load()
self.model.to(devices.device)
def stop(self):
if not shared.opts.interrogate_keep_models_in_memory:
self.model.to(devices.cpu)
devices.torch_gc()
def tag(self, pil_image):
self.start()
res = self.tag_multi(pil_image)
self.stop()
return res
def tag_multi(self, pil_image, force_disable_ranks=False):
threshold = shared.opts.interrogate_deepbooru_score_threshold
use_spaces = shared.opts.deepbooru_use_spaces
use_escape = shared.opts.deepbooru_escape
alpha_sort = shared.opts.deepbooru_sort_alpha
include_ranks = shared.opts.interrogate_return_ranks and not force_disable_ranks
pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512)
a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
with torch.no_grad(), devices.autocast():
x = torch.from_numpy(a).cuda()
y = self.model(x)[0].detach().cpu().numpy()
probability_dict = {}
for tag, probability in zip(self.model.tags, y):
if probability < threshold:
continue
unsorted_tags_in_theshold = []
result_tags_print = []
for tag in tags:
if result_dict[tag] >= threshold:
if tag.startswith("rating:"): if tag.startswith("rating:"):
continue continue
unsorted_tags_in_theshold.append((result_dict[tag], tag))
result_tags_print.append(f'{result_dict[tag]} {tag}')
# sort tags probability_dict[tag] = probability
result_tags_out = []
sort_ndx = 0
if alpha_sort: if alpha_sort:
sort_ndx = 1 tags = sorted(probability_dict)
else:
tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]
res = []
# sort by reverse by likelihood and normal for alpha, and format tag text as requested for tag in tags:
unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort)) probability = probability_dict[tag]
for weight, tag in unsorted_tags_in_theshold:
tag_outformat = tag tag_outformat = tag
if use_spaces: if use_spaces:
tag_outformat = tag_outformat.replace('_', ' ') tag_outformat = tag_outformat.replace('_', ' ')
if use_escape: if use_escape:
tag_outformat = re.sub(re_special, r'\\\1', tag_outformat) tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
if include_ranks: if include_ranks:
tag_outformat = f"({tag_outformat}:{weight:.3f})" tag_outformat = f"({tag_outformat}:{probability:.3f})"
res.append(tag_outformat)
result_tags_out.append(tag_outformat) return ", ".join(res)
print('\n'.join(sorted(result_tags_print, reverse=True)))
return ', '.join(result_tags_out) model = DeepDanbooru()
This diff is collapsed.
...@@ -65,9 +65,12 @@ class Extension: ...@@ -65,9 +65,12 @@ class Extension:
self.can_update = False self.can_update = False
self.status = "latest" self.status = "latest"
def pull(self): def fetch_and_reset_hard(self):
repo = git.Repo(self.path) repo = git.Repo(self.path)
repo.remotes.origin.pull() # Fix: `error: Your local changes to the following files would be overwritten by merge`,
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
repo.git.fetch('--all')
repo.git.reset('--hard', 'origin')
def list_extensions(): def list_extensions():
......
...@@ -73,6 +73,7 @@ def integrate_settings_paste_fields(component_dict): ...@@ -73,6 +73,7 @@ def integrate_settings_paste_fields(component_dict):
'sd_hypernetwork': 'Hypernet', 'sd_hypernetwork': 'Hypernet',
'sd_hypernetwork_strength': 'Hypernet strength', 'sd_hypernetwork_strength': 'Hypernet strength',
'CLIP_stop_at_last_layers': 'Clip skip', 'CLIP_stop_at_last_layers': 'Clip skip',
'inpainting_mask_weight': 'Conditional mask weight',
'sd_model_checkpoint': 'Model hash', 'sd_model_checkpoint': 'Model hash',
} }
settings_paste_fields = [ settings_paste_fields = [
......
...@@ -12,7 +12,7 @@ import torch ...@@ -12,7 +12,7 @@ import torch
import tqdm import tqdm
from einops import rearrange, repeat from einops import rearrange, repeat
from ldm.util import default from ldm.util import default
from modules import devices, processing, sd_models, shared from modules import devices, processing, sd_models, shared, sd_samplers
from modules.textual_inversion import textual_inversion from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum from torch import einsum
...@@ -535,7 +535,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log ...@@ -535,7 +535,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
p.prompt = preview_prompt p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt p.negative_prompt = preview_negative_prompt
p.steps = preview_steps p.steps = preview_steps
p.sampler_index = preview_sampler_index p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
p.cfg_scale = preview_cfg_scale p.cfg_scale = preview_cfg_scale
p.seed = preview_seed p.seed = preview_seed
p.width = preview_width p.width = preview_width
......
...@@ -303,7 +303,7 @@ class FilenameGenerator: ...@@ -303,7 +303,7 @@ class FilenameGenerator:
'width': lambda self: self.image.width, 'width': lambda self: self.image.width,
'height': lambda self: self.image.height, 'height': lambda self: self.image.height,
'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False), 'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
'sampler': lambda self: self.p and sanitize_filename_part(sd_samplers.samplers[self.p.sampler_index].name, replace_spaces=False), 'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash), 'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'), 'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>] 'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
......
...@@ -6,7 +6,7 @@ import traceback ...@@ -6,7 +6,7 @@ import traceback
import numpy as np import numpy as np
from PIL import Image, ImageOps, ImageChops from PIL import Image, ImageOps, ImageChops
from modules import devices from modules import devices, sd_samplers
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state from modules.shared import opts, state
import modules.shared as shared import modules.shared as shared
...@@ -99,7 +99,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro ...@@ -99,7 +99,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
seed_resize_from_h=seed_resize_from_h, seed_resize_from_h=seed_resize_from_h,
seed_resize_from_w=seed_resize_from_w, seed_resize_from_w=seed_resize_from_w,
seed_enable_extras=seed_enable_extras, seed_enable_extras=seed_enable_extras,
sampler_index=sampler_index, sampler_index=sd_samplers.samplers_for_img2img[sampler_index].name,
batch_size=batch_size, batch_size=batch_size,
n_iter=n_iter, n_iter=n_iter,
steps=steps, steps=steps,
......
This diff is collapsed.
...@@ -61,6 +61,8 @@ callback_map = dict( ...@@ -61,6 +61,8 @@ callback_map = dict(
callbacks_before_image_saved=[], callbacks_before_image_saved=[],
callbacks_image_saved=[], callbacks_image_saved=[],
callbacks_cfg_denoiser=[], callbacks_cfg_denoiser=[],
callbacks_before_component=[],
callbacks_after_component=[],
) )
...@@ -137,6 +139,22 @@ def cfg_denoiser_callback(params: CFGDenoiserParams): ...@@ -137,6 +139,22 @@ def cfg_denoiser_callback(params: CFGDenoiserParams):
report_exception(c, 'cfg_denoiser_callback') report_exception(c, 'cfg_denoiser_callback')
def before_component_callback(component, **kwargs):
for c in callback_map['callbacks_before_component']:
try:
c.callback(component, **kwargs)
except Exception:
report_exception(c, 'before_component_callback')
def after_component_callback(component, **kwargs):
for c in callback_map['callbacks_after_component']:
try:
c.callback(component, **kwargs)
except Exception:
report_exception(c, 'after_component_callback')
def add_callback(callbacks, fun): def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__] stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file' filename = stack[0].filename if len(stack) > 0 else 'unknown file'
...@@ -220,3 +238,20 @@ def on_cfg_denoiser(callback): ...@@ -220,3 +238,20 @@ def on_cfg_denoiser(callback):
- params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details. - params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
""" """
add_callback(callback_map['callbacks_cfg_denoiser'], callback) add_callback(callback_map['callbacks_cfg_denoiser'], callback)
def on_before_component(callback):
"""register a function to be called before a component is created.
The callback is called with arguments:
- component - gradio component that is about to be created.
- **kwargs - args to gradio.components.IOComponent.__init__ function
Use elem_id/label fields of kwargs to figure out which component it is.
This can be useful to inject your own components somewhere in the middle of vanilla UI.
"""
add_callback(callback_map['callbacks_before_component'], callback)
def on_after_component(callback):
"""register a function to be called after a component is created. See on_before_component for more."""
add_callback(callback_map['callbacks_after_component'], callback)
...@@ -17,6 +17,9 @@ class Script: ...@@ -17,6 +17,9 @@ class Script:
args_to = None args_to = None
alwayson = False alwayson = False
is_txt2img = False
is_img2img = False
"""A gr.Group component that has all script's UI inside it""" """A gr.Group component that has all script's UI inside it"""
group = None group = None
...@@ -93,6 +96,23 @@ class Script: ...@@ -93,6 +96,23 @@ class Script:
pass pass
def before_component(self, component, **kwargs):
"""
Called before a component is created.
Use elem_id/label fields of kwargs to figure out which component it is.
This can be useful to inject your own components somewhere in the middle of vanilla UI.
You can return created components in the ui() function to add them to the list of arguments for your processing functions
"""
pass
def after_component(self, component, **kwargs):
"""
Called after a component is created. Same as above.
"""
pass
def describe(self): def describe(self):
"""unused""" """unused"""
return "" return ""
...@@ -195,12 +215,18 @@ class ScriptRunner: ...@@ -195,12 +215,18 @@ class ScriptRunner:
self.titles = [] self.titles = []
self.infotext_fields = [] self.infotext_fields = []
def setup_ui(self, is_img2img): def initialize_scripts(self, is_img2img):
self.scripts.clear()
self.alwayson_scripts.clear()
self.selectable_scripts.clear()
for script_class, path, basedir in scripts_data: for script_class, path, basedir in scripts_data:
script = script_class() script = script_class()
script.filename = path script.filename = path
script.is_txt2img = not is_img2img
script.is_img2img = is_img2img
visibility = script.show(is_img2img) visibility = script.show(script.is_img2img)
if visibility == AlwaysVisible: if visibility == AlwaysVisible:
self.scripts.append(script) self.scripts.append(script)
...@@ -211,6 +237,7 @@ class ScriptRunner: ...@@ -211,6 +237,7 @@ class ScriptRunner:
self.scripts.append(script) self.scripts.append(script)
self.selectable_scripts.append(script) self.selectable_scripts.append(script)
def setup_ui(self):
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts] self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
inputs = [None] inputs = [None]
...@@ -220,7 +247,7 @@ class ScriptRunner: ...@@ -220,7 +247,7 @@ class ScriptRunner:
script.args_from = len(inputs) script.args_from = len(inputs)
script.args_to = len(inputs) script.args_to = len(inputs)
controls = wrap_call(script.ui, script.filename, "ui", is_img2img) controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
if controls is None: if controls is None:
return return
...@@ -320,6 +347,22 @@ class ScriptRunner: ...@@ -320,6 +347,22 @@ class ScriptRunner:
print(f"Error running postprocess: {script.filename}", file=sys.stderr) print(f"Error running postprocess: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
def before_component(self, component, **kwargs):
for script in self.scripts:
try:
script.before_component(component, **kwargs)
except Exception:
print(f"Error running before_component: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def after_component(self, component, **kwargs):
for script in self.scripts:
try:
script.after_component(component, **kwargs)
except Exception:
print(f"Error running after_component: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def reload_sources(self, cache): def reload_sources(self, cache):
for si, script in list(enumerate(self.scripts)): for si, script in list(enumerate(self.scripts)):
args_from = script.args_from args_from = script.args_from
...@@ -341,6 +384,7 @@ class ScriptRunner: ...@@ -341,6 +384,7 @@ class ScriptRunner:
scripts_txt2img = ScriptRunner() scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner() scripts_img2img = ScriptRunner()
scripts_current: ScriptRunner = None
def reload_script_body_only(): def reload_script_body_only():
...@@ -357,3 +401,22 @@ def reload_scripts(): ...@@ -357,3 +401,22 @@ def reload_scripts():
scripts_txt2img = ScriptRunner() scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner() scripts_img2img = ScriptRunner()
def IOComponent_init(self, *args, **kwargs):
if scripts_current is not None:
scripts_current.before_component(self, **kwargs)
script_callbacks.before_component_callback(self, **kwargs)
res = original_IOComponent_init(self, *args, **kwargs)
script_callbacks.after_component_callback(self, **kwargs)
if scripts_current is not None:
scripts_current.after_component(self, **kwargs)
return res
original_IOComponent_init = gr.components.IOComponent.__init__
gr.components.IOComponent.__init__ = IOComponent_init
...@@ -96,8 +96,8 @@ class StableDiffusionModelHijack: ...@@ -96,8 +96,8 @@ class StableDiffusionModelHijack:
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes: if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
self.apply_circular(False)
self.layers = None self.layers = None
self.circular_enabled = False
self.clip = None self.clip = None
def apply_circular(self, enable): def apply_circular(self, enable):
......
...@@ -165,16 +165,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): ...@@ -165,16 +165,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
cache_enabled = shared.opts.sd_checkpoint_cache > 0 cache_enabled = shared.opts.sd_checkpoint_cache > 0
if cache_enabled:
sd_vae.restore_base_vae(model)
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
if cache_enabled and checkpoint_info in checkpoints_loaded: if cache_enabled and checkpoint_info in checkpoints_loaded:
# use checkpoint cache # use checkpoint cache
vae_name = sd_vae.get_filename(vae_file) if vae_file else None print(f"Loading weights [{sd_model_hash}] from cache")
vae_message = f" with {vae_name} VAE" if vae_name else ""
print(f"Loading weights [{sd_model_hash}]{vae_message} from cache")
model.load_state_dict(checkpoints_loaded[checkpoint_info]) model.load_state_dict(checkpoints_loaded[checkpoint_info])
else: else:
# load from file # load from file
...@@ -220,6 +213,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): ...@@ -220,6 +213,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
model.sd_model_checkpoint = checkpoint_file model.sd_model_checkpoint = checkpoint_file
model.sd_checkpoint_info = checkpoint_info model.sd_checkpoint_info = checkpoint_info
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
sd_vae.load_vae(model, vae_file) sd_vae.load_vae(model, vae_file)
......
...@@ -46,13 +46,20 @@ all_samplers = [ ...@@ -46,13 +46,20 @@ all_samplers = [
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}), SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
] ]
all_samplers_map = {x.name: x for x in all_samplers}
samplers = [] samplers = []
samplers_for_img2img = [] samplers_for_img2img = []
def create_sampler_with_index(list_of_configs, index, model): def create_sampler(name, model):
config = list_of_configs[index] if name is not None:
config = all_samplers_map.get(name, None)
else:
config = all_samplers[0]
assert config is not None, f'bad sampler name: {name}'
sampler = config.constructor(model) sampler = config.constructor(model)
sampler.config = config sampler.config = config
......
...@@ -83,47 +83,54 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path): ...@@ -83,47 +83,54 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path):
return vae_list return vae_list
def resolve_vae(checkpoint_file, vae_file="auto"): def get_vae_from_settings(vae_file="auto"):
# else, we load from settings, if not set to be default
if vae_file == "auto" and shared.opts.sd_vae is not None:
# if saved VAE settings isn't recognized, fallback to auto
vae_file = vae_dict.get(shared.opts.sd_vae, "auto")
# if VAE selected but not found, fallback to auto
if vae_file not in default_vae_values and not os.path.isfile(vae_file):
vae_file = "auto"
print(f"Selected VAE doesn't exist: {vae_file}")
return vae_file
def resolve_vae(checkpoint_file=None, vae_file="auto"):
global first_load, vae_dict, vae_list global first_load, vae_dict, vae_list
# if vae_file argument is provided, it takes priority, but not saved # if vae_file argument is provided, it takes priority, but not saved
if vae_file and vae_file not in default_vae_list: if vae_file and vae_file not in default_vae_list:
if not os.path.isfile(vae_file): if not os.path.isfile(vae_file):
print(f"VAE provided as function argument doesn't exist: {vae_file}")
vae_file = "auto" vae_file = "auto"
print("VAE provided as function argument doesn't exist")
# for the first load, if vae-path is provided, it takes priority, saved, and failure is reported # for the first load, if vae-path is provided, it takes priority, saved, and failure is reported
if first_load and shared.cmd_opts.vae_path is not None: if first_load and shared.cmd_opts.vae_path is not None:
if os.path.isfile(shared.cmd_opts.vae_path): if os.path.isfile(shared.cmd_opts.vae_path):
vae_file = shared.cmd_opts.vae_path vae_file = shared.cmd_opts.vae_path
shared.opts.data['sd_vae'] = get_filename(vae_file) shared.opts.data['sd_vae'] = get_filename(vae_file)
else: else:
print("VAE provided as command line argument doesn't exist") print(f"VAE provided as command line argument doesn't exist: {vae_file}")
# else, we load from settings # fallback to selector in settings, if vae selector not set to act as default fallback
if vae_file == "auto" and shared.opts.sd_vae is not None: if not shared.opts.sd_vae_as_default:
# if saved VAE settings isn't recognized, fallback to auto vae_file = get_vae_from_settings(vae_file)
vae_file = vae_dict.get(shared.opts.sd_vae, "auto")
# if VAE selected but not found, fallback to auto
if vae_file not in default_vae_values and not os.path.isfile(vae_file):
vae_file = "auto"
print("Selected VAE doesn't exist")
# vae-path cmd arg takes priority for auto # vae-path cmd arg takes priority for auto
if vae_file == "auto" and shared.cmd_opts.vae_path is not None: if vae_file == "auto" and shared.cmd_opts.vae_path is not None:
if os.path.isfile(shared.cmd_opts.vae_path): if os.path.isfile(shared.cmd_opts.vae_path):
vae_file = shared.cmd_opts.vae_path vae_file = shared.cmd_opts.vae_path
print("Using VAE provided as command line argument") print(f"Using VAE provided as command line argument: {vae_file}")
# if still not found, try look for ".vae.pt" beside model # if still not found, try look for ".vae.pt" beside model
model_path = os.path.splitext(checkpoint_file)[0] model_path = os.path.splitext(checkpoint_file)[0]
if vae_file == "auto": if vae_file == "auto":
vae_file_try = model_path + ".vae.pt" vae_file_try = model_path + ".vae.pt"
if os.path.isfile(vae_file_try): if os.path.isfile(vae_file_try):
vae_file = vae_file_try vae_file = vae_file_try
print("Using VAE found beside selected model") print(f"Using VAE found similar to selected model: {vae_file}")
# if still not found, try look for ".vae.ckpt" beside model # if still not found, try look for ".vae.ckpt" beside model
if vae_file == "auto": if vae_file == "auto":
vae_file_try = model_path + ".vae.ckpt" vae_file_try = model_path + ".vae.ckpt"
if os.path.isfile(vae_file_try): if os.path.isfile(vae_file_try):
vae_file = vae_file_try vae_file = vae_file_try
print("Using VAE found beside selected model") print(f"Using VAE found similar to selected model: {vae_file}")
# No more fallbacks for auto # No more fallbacks for auto
if vae_file == "auto": if vae_file == "auto":
vae_file = None vae_file = None
...@@ -139,6 +146,7 @@ def load_vae(model, vae_file=None): ...@@ -139,6 +146,7 @@ def load_vae(model, vae_file=None):
# save_settings = False # save_settings = False
if vae_file: if vae_file:
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
print(f"Loading VAE weights from: {vae_file}") print(f"Loading VAE weights from: {vae_file}")
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
......
...@@ -55,7 +55,7 @@ parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with ...@@ -55,7 +55,7 @@ parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with
parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None) parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers") parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work") parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator") parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.") parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
...@@ -81,6 +81,7 @@ parser.add_argument("--enable-console-prompts", action='store_true', help="print ...@@ -81,6 +81,7 @@ parser.add_argument("--enable-console-prompts", action='store_true', help="print
parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui") parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui")
parser.add_argument("--api-auth", type=str, help='Set authentication for api like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui") parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui")
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
...@@ -106,7 +107,7 @@ restricted_opts = { ...@@ -106,7 +107,7 @@ restricted_opts = {
"outdir_save", "outdir_save",
} }
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen) and not cmd_opts.enable_insecure_extension_access cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \ devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer']) (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'])
...@@ -334,7 +335,8 @@ options_templates.update(options_section(('training', "Training"), { ...@@ -334,7 +335,8 @@ options_templates.update(options_section(('training', "Training"), {
options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": list(sd_vae.vae_list)}, refresh=sd_vae.refresh_vae_list), "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
"sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
...@@ -436,6 +438,23 @@ class Options: ...@@ -436,6 +438,23 @@ class Options:
return super(Options, self).__getattribute__(item) return super(Options, self).__getattribute__(item)
def set(self, key, value):
"""sets an option and calls its onchange callback, returning True if the option changed and False otherwise"""
oldval = self.data.get(key, None)
if oldval == value:
return False
try:
setattr(self, key, value)
except RuntimeError:
return False
if self.data_labels[key].onchange is not None:
self.data_labels[key].onchange()
return True
def save(self, filename): def save(self, filename):
assert not cmd_opts.freeze_settings, "saving settings is disabled" assert not cmd_opts.freeze_settings, "saving settings is disabled"
......
...@@ -65,17 +65,6 @@ class StyleDatabase: ...@@ -65,17 +65,6 @@ class StyleDatabase:
def apply_negative_styles_to_prompt(self, prompt, styles): def apply_negative_styles_to_prompt(self, prompt, styles):
return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles]) return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])
def apply_styles(self, p: StableDiffusionProcessing) -> None:
if isinstance(p.prompt, list):
p.prompt = [self.apply_styles_to_prompt(prompt, p.styles) for prompt in p.prompt]
else:
p.prompt = self.apply_styles_to_prompt(p.prompt, p.styles)
if isinstance(p.negative_prompt, list):
p.negative_prompt = [self.apply_negative_styles_to_prompt(prompt, p.styles) for prompt in p.negative_prompt]
else:
p.negative_prompt = self.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)
def save_styles(self, path: str) -> None: def save_styles(self, path: str) -> None:
# Write to temporary file first, so we don't nuke the file if something goes wrong # Write to temporary file first, so we don't nuke the file if something goes wrong
fd, temp_path = tempfile.mkstemp(".csv") fd, temp_path = tempfile.mkstemp(".csv")
......
...@@ -6,12 +6,10 @@ import sys ...@@ -6,12 +6,10 @@ import sys
import tqdm import tqdm
import time import time
from modules import shared, images from modules import shared, images, deepbooru
from modules.paths import models_path from modules.paths import models_path
from modules.shared import opts, cmd_opts from modules.shared import opts, cmd_opts
from modules.textual_inversion import autocrop from modules.textual_inversion import autocrop
if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru
def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False): def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
...@@ -20,9 +18,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce ...@@ -20,9 +18,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce
shared.interrogator.load() shared.interrogator.load()
if process_caption_deepbooru: if process_caption_deepbooru:
db_opts = deepbooru.create_deepbooru_opts() deepbooru.model.start()
db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug) preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug)
...@@ -32,7 +28,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce ...@@ -32,7 +28,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce
shared.interrogator.send_blip_to_ram() shared.interrogator.send_blip_to_ram()
if process_caption_deepbooru: if process_caption_deepbooru:
deepbooru.release_process() deepbooru.model.stop()
def listfiles(dirname): def listfiles(dirname):
...@@ -58,7 +54,7 @@ def save_pic_with_caption(image, index, params: PreprocessParams, existing_capti ...@@ -58,7 +54,7 @@ def save_pic_with_caption(image, index, params: PreprocessParams, existing_capti
if params.process_caption_deepbooru: if params.process_caption_deepbooru:
if len(caption) > 0: if len(caption) > 0:
caption += ", " caption += ", "
caption += deepbooru.get_tags_from_process(image) caption += deepbooru.model.tag_multi(image)
filename_part = params.src filename_part = params.src
filename_part = os.path.splitext(filename_part)[0] filename_part = os.path.splitext(filename_part)[0]
......
...@@ -10,7 +10,7 @@ import csv ...@@ -10,7 +10,7 @@ import csv
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
from modules import shared, devices, sd_hijack, processing, sd_models, images from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers
import modules.textual_inversion.dataset import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.learn_schedule import LearnRateScheduler
...@@ -345,7 +345,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -345,7 +345,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
p.prompt = preview_prompt p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt p.negative_prompt = preview_negative_prompt
p.steps = preview_steps p.steps = preview_steps
p.sampler_index = preview_sampler_index p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
p.cfg_scale = preview_cfg_scale p.cfg_scale = preview_cfg_scale
p.seed = preview_seed p.seed = preview_seed
p.width = preview_width p.width = preview_width
......
...@@ -18,7 +18,7 @@ def create_embedding(name, initialization_text, nvpt, overwrite_old): ...@@ -18,7 +18,7 @@ def create_embedding(name, initialization_text, nvpt, overwrite_old):
def preprocess(*args): def preprocess(*args):
modules.textual_inversion.preprocess.preprocess(*args) modules.textual_inversion.preprocess.preprocess(*args)
return "Preprocessing finished.", "" return f"Preprocessing {'interrupted' if shared.state.interrupted else 'finished'}.", ""
def train_embedding(*args): def train_embedding(*args):
......
import modules.scripts import modules.scripts
from modules import sd_samplers
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \ from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
StableDiffusionProcessingImg2Img, process_images StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, cmd_opts from modules.shared import opts, cmd_opts
...@@ -21,7 +22,7 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: ...@@ -21,7 +22,7 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
seed_resize_from_h=seed_resize_from_h, seed_resize_from_h=seed_resize_from_h,
seed_resize_from_w=seed_resize_from_w, seed_resize_from_w=seed_resize_from_w,
seed_enable_extras=seed_enable_extras, seed_enable_extras=seed_enable_extras,
sampler_index=sampler_index, sampler_name=sd_samplers.samplers[sampler_index].name,
batch_size=batch_size, batch_size=batch_size,
n_iter=n_iter, n_iter=n_iter,
steps=steps, steps=steps,
......
...@@ -19,14 +19,11 @@ import numpy as np ...@@ -19,14 +19,11 @@ import numpy as np
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru
from modules.paths import script_path from modules.paths import script_path
from modules.shared import opts, cmd_opts, restricted_opts from modules.shared import opts, cmd_opts, restricted_opts
if cmd_opts.deepdanbooru:
from modules.deepbooru import get_deepbooru_tags
import modules.codeformer_model import modules.codeformer_model
import modules.generation_parameters_copypaste as parameters_copypaste import modules.generation_parameters_copypaste as parameters_copypaste
import modules.gfpgan_model import modules.gfpgan_model
...@@ -69,8 +66,11 @@ sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None ...@@ -69,8 +66,11 @@ sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None
css_hide_progressbar = """ css_hide_progressbar = """
.wrap .m-12 svg { display:none!important; } .wrap .m-12 svg { display:none!important; }
.wrap .m-12::before { content:"Loading..." } .wrap .m-12::before { content:"Loading..." }
.wrap .z-20 svg { display:none!important; }
.wrap .z-20::before { content:"Loading..." }
.progress-bar { display:none!important; } .progress-bar { display:none!important; }
.meta-text { display:none!important; } .meta-text { display:none!important; }
.meta-text-center { display:none!important; }
""" """
# Using constants for these since the variation selector isn't visible. # Using constants for these since the variation selector isn't visible.
...@@ -142,7 +142,7 @@ def save_files(js_data, images, do_make_zip, index): ...@@ -142,7 +142,7 @@ def save_files(js_data, images, do_make_zip, index):
filenames.append(os.path.basename(txt_fullfn)) filenames.append(os.path.basename(txt_fullfn))
fullfns.append(txt_fullfn) fullfns.append(txt_fullfn)
writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
# Make Zip # Make Zip
if do_make_zip: if do_make_zip:
...@@ -349,7 +349,7 @@ def interrogate(image): ...@@ -349,7 +349,7 @@ def interrogate(image):
def interrogate_deepbooru(image): def interrogate_deepbooru(image):
prompt = get_deepbooru_tags(image) prompt = deepbooru.model.tag(image)
return gr_show(True) if prompt is None else prompt return gr_show(True) if prompt is None else prompt
...@@ -692,6 +692,9 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -692,6 +692,9 @@ def create_ui(wrap_gradio_gpu_call):
parameters_copypaste.reset() parameters_copypaste.reset()
modules.scripts.scripts_current = modules.scripts.scripts_txt2img
modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
with gr.Blocks(analytics_enabled=False) as txt2img_interface: with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False) dummy_component = gr.Label(visible=False)
...@@ -734,7 +737,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -734,7 +737,7 @@ def create_ui(wrap_gradio_gpu_call):
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs() seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
with gr.Group(): with gr.Group():
custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False) custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
txt2img_gallery, generation_info, html_info = create_output_panel("txt2img", opts.outdir_txt2img_samples) txt2img_gallery, generation_info, html_info = create_output_panel("txt2img", opts.outdir_txt2img_samples)
parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
...@@ -843,6 +846,9 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -843,6 +846,9 @@ def create_ui(wrap_gradio_gpu_call):
token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter]) token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
modules.scripts.scripts_current = modules.scripts.scripts_img2img
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
with gr.Blocks(analytics_enabled=False) as img2img_interface: with gr.Blocks(analytics_enabled=False) as img2img_interface:
img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button = create_toprow(is_img2img=True) img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button = create_toprow(is_img2img=True)
...@@ -913,7 +919,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -913,7 +919,7 @@ def create_ui(wrap_gradio_gpu_call):
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs() seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
with gr.Group(): with gr.Group():
custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True) custom_inputs = modules.scripts.scripts_img2img.setup_ui()
img2img_gallery, generation_info, html_info = create_output_panel("img2img", opts.outdir_img2img_samples) img2img_gallery, generation_info, html_info = create_output_panel("img2img", opts.outdir_img2img_samples)
parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt) parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt)
...@@ -1062,6 +1068,8 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1062,6 +1068,8 @@ def create_ui(wrap_gradio_gpu_call):
parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields) parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields)
parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields) parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields)
modules.scripts.scripts_current = None
with gr.Blocks(analytics_enabled=False) as extras_interface: with gr.Blocks(analytics_enabled=False) as extras_interface:
with gr.Row().style(equal_height=False): with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'): with gr.Column(variant='panel'):
...@@ -1249,6 +1257,8 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1249,6 +1257,8 @@ def create_ui(wrap_gradio_gpu_call):
gr.HTML(value="") gr.HTML(value="")
with gr.Column(): with gr.Column():
with gr.Row():
interrupt_preprocessing = gr.Button("Interrupt")
run_preprocess = gr.Button(value="Preprocess", variant='primary') run_preprocess = gr.Button(value="Preprocess", variant='primary')
process_split.change( process_split.change(
...@@ -1422,6 +1432,12 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1422,6 +1432,12 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[], outputs=[],
) )
interrupt_preprocessing.click(
fn=lambda: shared.state.interrupt(),
inputs=[],
outputs=[],
)
def create_setting_component(key, is_quicksettings=False): def create_setting_component(key, is_quicksettings=False):
def fun(): def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key].default return opts.data[key] if key in opts.data else opts.data_labels[key].default
...@@ -1473,16 +1489,9 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1473,16 +1489,9 @@ def create_ui(wrap_gradio_gpu_call):
if comp == dummy_component: if comp == dummy_component:
continue continue
oldval = opts.data.get(key, None) if opts.set(key, value):
try:
setattr(opts, key, value)
except RuntimeError:
continue
if oldval != value:
if opts.data_labels[key].onchange is not None:
opts.data_labels[key].onchange()
changed.append(key) changed.append(key)
try: try:
opts.save(shared.config_filename) opts.save(shared.config_filename)
except RuntimeError: except RuntimeError:
...@@ -1493,15 +1502,8 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1493,15 +1502,8 @@ def create_ui(wrap_gradio_gpu_call):
if not opts.same_type(value, opts.data_labels[key].default): if not opts.same_type(value, opts.data_labels[key].default):
return gr.update(visible=True), opts.dumpjson() return gr.update(visible=True), opts.dumpjson()
oldval = opts.data.get(key, None) if not opts.set(key, value):
try: return gr.update(value=getattr(opts, key)), opts.dumpjson()
setattr(opts, key, value)
except Exception:
return gr.update(value=oldval), opts.dumpjson()
if oldval != value:
if opts.data_labels[key].onchange is not None:
opts.data_labels[key].onchange()
opts.save(shared.config_filename) opts.save(shared.config_filename)
......
...@@ -36,9 +36,9 @@ def apply_and_restart(disable_list, update_list): ...@@ -36,9 +36,9 @@ def apply_and_restart(disable_list, update_list):
continue continue
try: try:
ext.pull() ext.fetch_and_reset_hard()
except Exception: except Exception:
print(f"Error pulling updates for {ext.name}:", file=sys.stderr) print(f"Error getting updates for {ext.name}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
shared.opts.disabled_extensions = disabled shared.opts.disabled_extensions = disabled
......
transformers==4.19.2 transformers==4.19.2
diffusers==0.3.0 diffusers==0.3.0
accelerate==0.12.0
basicsr==1.4.2 basicsr==1.4.2
gfpgan==1.3.8 gfpgan==1.3.8
gradio==3.9 gradio==3.9
......
...@@ -157,7 +157,7 @@ class Script(scripts.Script): ...@@ -157,7 +157,7 @@ class Script(scripts.Script):
def run(self, p, _, override_sampler, override_prompt, original_prompt, original_negative_prompt, override_steps, st, override_strength, cfg, randomness, sigma_adjustment): def run(self, p, _, override_sampler, override_prompt, original_prompt, original_negative_prompt, override_steps, st, override_strength, cfg, randomness, sigma_adjustment):
# Override # Override
if override_sampler: if override_sampler:
p.sampler_index = [sampler.name for sampler in sd_samplers.samplers].index("Euler") p.sampler_name = "Euler"
if override_prompt: if override_prompt:
p.prompt = original_prompt p.prompt = original_prompt
p.negative_prompt = original_negative_prompt p.negative_prompt = original_negative_prompt
...@@ -191,7 +191,7 @@ class Script(scripts.Script): ...@@ -191,7 +191,7 @@ class Script(scripts.Script):
combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5) combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5)
sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, p.sampler_index, p.sd_model) sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model)
sigmas = sampler.model_wrap.get_sigmas(p.steps) sigmas = sampler.model_wrap.get_sigmas(p.steps)
......
...@@ -10,9 +10,9 @@ import numpy as np ...@@ -10,9 +10,9 @@ import numpy as np
import modules.scripts as scripts import modules.scripts as scripts
import gradio as gr import gradio as gr
from modules import images from modules import images, sd_samplers
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
from modules.processing import process_images, Processed, get_correct_sampler, StableDiffusionProcessingTxt2Img from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
import modules.sd_samplers import modules.sd_samplers
...@@ -60,9 +60,9 @@ def apply_order(p, x, xs): ...@@ -60,9 +60,9 @@ def apply_order(p, x, xs):
p.prompt = prompt_tmp + p.prompt p.prompt = prompt_tmp + p.prompt
def build_samplers_dict(p): def build_samplers_dict():
samplers_dict = {} samplers_dict = {}
for i, sampler in enumerate(get_correct_sampler(p)): for i, sampler in enumerate(sd_samplers.all_samplers):
samplers_dict[sampler.name.lower()] = i samplers_dict[sampler.name.lower()] = i
for alias in sampler.aliases: for alias in sampler.aliases:
samplers_dict[alias.lower()] = i samplers_dict[alias.lower()] = i
...@@ -70,7 +70,7 @@ def build_samplers_dict(p): ...@@ -70,7 +70,7 @@ def build_samplers_dict(p):
def apply_sampler(p, x, xs): def apply_sampler(p, x, xs):
sampler_index = build_samplers_dict(p).get(x.lower(), None) sampler_index = build_samplers_dict().get(x.lower(), None)
if sampler_index is None: if sampler_index is None:
raise RuntimeError(f"Unknown sampler: {x}") raise RuntimeError(f"Unknown sampler: {x}")
...@@ -78,7 +78,7 @@ def apply_sampler(p, x, xs): ...@@ -78,7 +78,7 @@ def apply_sampler(p, x, xs):
def confirm_samplers(p, xs): def confirm_samplers(p, xs):
samplers_dict = build_samplers_dict(p) samplers_dict = build_samplers_dict()
for x in xs: for x in xs:
if x.lower() not in samplers_dict.keys(): if x.lower() not in samplers_dict.keys():
raise RuntimeError(f"Unknown sampler: {x}") raise RuntimeError(f"Unknown sampler: {x}")
......
...@@ -40,4 +40,7 @@ export COMMANDLINE_ARGS="" ...@@ -40,4 +40,7 @@ export COMMANDLINE_ARGS=""
#export CODEFORMER_COMMIT_HASH="" #export CODEFORMER_COMMIT_HASH=""
#export BLIP_COMMIT_HASH="" #export BLIP_COMMIT_HASH=""
# Uncomment to enable accelerated launch
#export ACCELERATE="True"
########################################### ###########################################
...@@ -28,15 +28,27 @@ goto :show_stdout_stderr ...@@ -28,15 +28,27 @@ goto :show_stdout_stderr
:activate_venv :activate_venv
set PYTHON="%~dp0%VENV_DIR%\Scripts\Python.exe" set PYTHON="%~dp0%VENV_DIR%\Scripts\Python.exe"
echo venv %PYTHON% echo venv %PYTHON%
if [%ACCELERATE%] == ["True"] goto :accelerate
goto :launch goto :launch
:skip_venv :skip_venv
:accelerate
echo "Checking for accelerate"
set ACCELERATE="%~dp0%VENV_DIR%\Scripts\accelerate.exe"
if EXIST %ACCELERATE% goto :accelerate_launch
:launch :launch
%PYTHON% launch.py %* %PYTHON% launch.py %*
pause pause
exit /b exit /b
:accelerate_launch
echo "Accelerating"
%ACCELERATE% launch --num_cpu_threads_per_process=6 launch.py
pause
exit /b
:show_stdout_stderr :show_stdout_stderr
echo. echo.
......
...@@ -33,7 +33,10 @@ from modules.shared import cmd_opts ...@@ -33,7 +33,10 @@ from modules.shared import cmd_opts
import modules.hypernetworks.hypernetwork import modules.hypernetworks.hypernetwork
queue_lock = threading.Lock() queue_lock = threading.Lock()
server_name = "0.0.0.0" if cmd_opts.listen else cmd_opts.server_name if cmd_opts.server_name:
server_name = cmd_opts.server_name
else:
server_name = "0.0.0.0" if cmd_opts.listen else None
def wrap_queued_call(func): def wrap_queued_call(func):
def f(*args, **kwargs): def f(*args, **kwargs):
...@@ -82,6 +85,7 @@ def initialize(): ...@@ -82,6 +85,7 @@ def initialize():
modules.sd_models.load_model() modules.sd_models.load_model()
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
......
...@@ -134,7 +134,15 @@ else ...@@ -134,7 +134,15 @@ else
exit 1 exit 1
fi fi
printf "\n%s\n" "${delimiter}" if [[ ! -z "${ACCELERATE}" ]] && [ ${ACCELERATE}="True" ] && [ -x "$(command -v accelerate)" ]
printf "Launching launch.py..." then
printf "\n%s\n" "${delimiter}" printf "\n%s\n" "${delimiter}"
"${python_cmd}" "${LAUNCH_SCRIPT}" "$@" printf "Accelerating launch.py..."
printf "\n%s\n" "${delimiter}"
accelerate launch --num_cpu_threads_per_process=6 "${LAUNCH_SCRIPT}" "$@"
else
printf "\n%s\n" "${delimiter}"
printf "Launching launch.py..."
printf "\n%s\n" "${delimiter}"
"${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
fi
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment