Commit f6ab5a39 authored by brkirch's avatar brkirch

Merge branch 'AUTOMATIC1111:master' into sub-quad_attn_opt

parents d782a959 3e22e294
......@@ -127,6 +127,8 @@ Here's how to add code to this repo: [Contributing](https://github.com/AUTOMATIC
The documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki).
## Credits
Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
- Stable Diffusion - https://github.com/CompVis/stable-diffusion, https://github.com/CompVis/taming-transformers
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git
- GFPGAN - https://github.com/TencentARC/GFPGAN.git
......
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: modules.xlmr.BertSeriesModelWithTransformation
params:
name: "XLMR-Large"
\ No newline at end of file
import random
from modules import script_callbacks, shared
import gradio as gr
art_symbol = '\U0001f3a8' # 🎨
global_prompt = None
related_ids = {"txt2img_prompt", "txt2img_clear_prompt", "img2img_prompt", "img2img_clear_prompt" }
def roll_artist(prompt):
allowed_cats = set([x for x in shared.artist_db.categories() if len(shared.opts.random_artist_categories)==0 or x in shared.opts.random_artist_categories])
artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats])
return prompt + ", " + artist.name if prompt != '' else artist.name
def add_roll_button(prompt):
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
roll.click(
fn=roll_artist,
_js="update_txt2img_tokens",
inputs=[
prompt,
],
outputs=[
prompt,
]
)
def after_component(component, **kwargs):
global global_prompt
elem_id = kwargs.get('elem_id', None)
if elem_id not in related_ids:
return
if elem_id == "txt2img_prompt":
global_prompt = component
elif elem_id == "txt2img_clear_prompt":
add_roll_button(global_prompt)
elif elem_id == "img2img_prompt":
global_prompt = component
elif elem_id == "img2img_clear_prompt":
add_roll_button(global_prompt)
script_callbacks.on_after_component(after_component)
<div>
<a href="/docs">API</a>
 • 
<a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui">Github</a>
 • 
<a href="https://gradio.app">Gradio</a>
 • 
<a href="/" onclick="javascript:gradioApp().getElementById('settings_restart_gradio').click(); return false">Reload UI</a>
</div>
This diff is collapsed.
......@@ -19,7 +19,7 @@ function selected_gallery_index(){
function extract_image_from_gallery(gallery){
if(gallery.length == 1){
return gallery[0]
return [gallery[0]]
}
index = selected_gallery_index()
......@@ -28,7 +28,7 @@ function extract_image_from_gallery(gallery){
return [null]
}
return gallery[index];
return [gallery[index]];
}
function args_to_array(args){
......@@ -188,6 +188,17 @@ onUiUpdate(function(){
img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
}
show_all_pages = gradioApp().getElementById('settings_show_all_pages')
settings_tabs = gradioApp().querySelector('#settings div')
if(show_all_pages && settings_tabs){
settings_tabs.appendChild(show_all_pages)
show_all_pages.onclick = function(){
gradioApp().querySelectorAll('#settings > div').forEach(function(elem){
elem.style.display = "block";
})
}
}
})
let txt2img_textarea, img2img_textarea = undefined;
......
......@@ -100,6 +100,7 @@ class Api:
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
......@@ -121,7 +122,6 @@ class Api:
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
populate = txt2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model,
"sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
"do_not_save_samples": True,
"do_not_save_grid": True
......@@ -129,15 +129,14 @@ class Api:
)
if populate.sampler_name:
populate.sampler_index = None # prevent a warning later on
p = StableDiffusionProcessingTxt2Img(**vars(populate))
# Override object param
shared.state.begin()
with self.queue_lock:
p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **vars(populate))
shared.state.begin()
processed = process_images(p)
shared.state.end()
shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images))
......@@ -153,7 +152,6 @@ class Api:
mask = decode_base64_to_image(mask)
populate = img2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model,
"sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
"do_not_save_samples": True,
"do_not_save_grid": True,
......@@ -165,16 +163,14 @@ class Api:
args = vars(populate)
args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
p = StableDiffusionProcessingImg2Img(**args)
p.init_images = [decode_base64_to_image(x) for x in init_images]
shared.state.begin()
with self.queue_lock:
processed = process_images(p)
p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
p.init_images = [decode_base64_to_image(x) for x in init_images]
shared.state.end()
shared.state.begin()
processed = process_images(p)
shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images))
......@@ -332,6 +328,26 @@ class Api:
def get_artists(self):
return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
def get_embeddings(self):
db = sd_hijack.model_hijack.embedding_db
def convert_embedding(embedding):
return {
"step": embedding.step,
"sd_checkpoint": embedding.sd_checkpoint,
"sd_checkpoint_name": embedding.sd_checkpoint_name,
"shape": embedding.shape,
"vectors": embedding.vectors,
}
def convert_embeddings(embeddings):
return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
return {
"loaded": convert_embeddings(db.word_embeddings),
"skipped": convert_embeddings(db.skipped_embeddings),
}
def refresh_checkpoints(self):
shared.refresh_checkpoints()
......
......@@ -249,3 +249,13 @@ class ArtistItem(BaseModel):
score: float = Field(title="Score")
category: str = Field(title="Category")
class EmbeddingItem(BaseModel):
step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
shape: int = Field(title="Shape", description="The length of each individual vector in the embedding")
vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
class EmbeddingsResponse(BaseModel):
loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
\ No newline at end of file
......@@ -303,6 +303,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
result_is_inpainting_model = True
else:
assert a.shape == b.shape, f'Incompatible shapes for layer {key}: A is {a.shape}, and B is {b.shape}'
theta_0[key] = theta_func2(a, b, multiplier)
if save_as_half:
......
import base64
import io
import math
import os
import re
from pathlib import Path
import gradio as gr
from modules.shared import script_path
from modules import shared
from modules import shared, ui_tempdir
import tempfile
from PIL import Image
......@@ -36,9 +37,12 @@ def quote(text):
def image_from_url_text(filedata):
if type(filedata) == dict and filedata["is_file"]:
if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False):
filedata = filedata[0]
if type(filedata) == dict and filedata.get("is_file", False):
filename = filedata["name"]
is_in_right_dir = any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in shared.demo.temp_dirs)
is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
assert is_in_right_dir, 'trying to open image file outside of allowed directories'
return Image.open(filename)
......@@ -93,7 +97,7 @@ def integrate_settings_paste_fields(component_dict):
def create_buttons(tabs_list):
buttons = {}
for tab in tabs_list:
buttons[tab] = gr.Button(f"Send to {tab}")
buttons[tab] = gr.Button(f"Send to {tab}", elem_id=f"{tab}_tab")
return buttons
......@@ -102,35 +106,57 @@ def bind_buttons(buttons, send_image, send_generate_info):
bind_list.append([buttons, send_image, send_generate_info])
def send_image_and_dimensions(x):
if isinstance(x, Image.Image):
img = x
else:
img = image_from_url_text(x)
if shared.opts.send_size and isinstance(img, Image.Image):
w = img.width
h = img.height
else:
w = gr.update()
h = gr.update()
return img, w, h
def run_bind():
for buttons, send_image, send_generate_info in bind_list:
for buttons, source_image_component, send_generate_info in bind_list:
for tab in buttons:
button = buttons[tab]
if send_image and paste_fields[tab]["init_img"]:
if type(send_image) == gr.Gallery:
button.click(
fn=lambda x: image_from_url_text(x),
_js="extract_image_from_gallery",
inputs=[send_image],
outputs=[paste_fields[tab]["init_img"]],
)
destination_image_component = paste_fields[tab]["init_img"]
fields = paste_fields[tab]["fields"]
destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
if source_image_component and destination_image_component:
if isinstance(source_image_component, gr.Gallery):
func = send_image_and_dimensions if destination_width_component else image_from_url_text
jsfunc = "extract_image_from_gallery"
else:
button.click(
fn=lambda x: x,
inputs=[send_image],
outputs=[paste_fields[tab]["init_img"]],
)
func = send_image_and_dimensions if destination_width_component else lambda x: x
jsfunc = None
button.click(
fn=func,
_js=jsfunc,
inputs=[source_image_component],
outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
)
if send_generate_info and paste_fields[tab]["fields"] is not None:
if send_generate_info and fields is not None:
if send_generate_info in paste_fields:
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (['Size-1', 'Size-2'] if shared.opts.send_size else []) + (["Seed"] if shared.opts.send_seed else [])
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
button.click(
fn=lambda *x: x,
inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names],
outputs=[field for field, name in paste_fields[tab]["fields"] if name in paste_field_names],
outputs=[field for field, name in fields if name in paste_field_names],
)
else:
connect_paste(button, paste_fields[tab]["fields"], send_generate_info)
connect_paste(button, fields, send_generate_info)
button.click(
fn=None,
......@@ -164,6 +190,35 @@ def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
return None
def restore_old_hires_fix_params(res):
"""for infotexts that specify old First pass size parameter, convert it into
width, height, and hr scale"""
firstpass_width = res.get('First pass size-1', None)
firstpass_height = res.get('First pass size-2', None)
if firstpass_width is None or firstpass_height is None:
return
firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height)
width = int(res.get("Size-1", 512))
height = int(res.get("Size-2", 512))
if firstpass_width == 0 or firstpass_height == 0:
# old algorithm for auto-calculating first pass size
desired_pixel_count = 512 * 512
actual_pixel_count = width * height
scale = math.sqrt(desired_pixel_count / actual_pixel_count)
firstpass_width = math.ceil(scale * width / 64) * 64
firstpass_height = math.ceil(scale * height / 64) * 64
hr_scale = width / firstpass_width if firstpass_width > 0 else height / firstpass_height
res['Size-1'] = firstpass_width
res['Size-2'] = firstpass_height
res['Hires upscale'] = hr_scale
def parse_generation_parameters(x: str):
"""parses generation parameters string, the one you see in text field under the picture in UI:
```
......@@ -221,6 +276,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
hypernet_hash = res.get("Hypernet hash", None)
res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash)
restore_old_hires_fix_params(res)
return res
......
......@@ -39,11 +39,14 @@ def image_grid(imgs, batch_size=1, rows=None):
cols = math.ceil(len(imgs) / rows)
params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)
script_callbacks.image_grid_callback(params)
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols * w, rows * h), color='black')
grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color='black')
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
for i, img in enumerate(params.imgs):
grid.paste(img, box=(i % params.cols * w, i // params.cols * h))
return grid
......@@ -227,16 +230,32 @@ def draw_prompt_matrix(im, width, height, all_prompts):
return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
def resize_image(resize_mode, im, width, height):
def resize_image(resize_mode, im, width, height, upscaler_name=None):
"""
Resizes an image with the specified resize_mode, width, and height.
Args:
resize_mode: The mode to use when resizing the image.
0: Resize the image to the specified width and height.
1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
im: The image to resize.
width: The width to resize the image to.
height: The height to resize the image to.
upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img.
"""
upscaler_name = upscaler_name or opts.upscaler_for_img2img
def resize(im, w, h):
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L':
if upscaler_name is None or upscaler_name == "None" or im.mode == 'L':
return im.resize((w, h), resample=LANCZOS)
scale = max(w / im.width, h / im.height)
if scale > 1.0:
upscalers = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img]
assert len(upscalers) > 0, f"could not find upscaler named {opts.upscaler_for_img2img}"
upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]
assert len(upscalers) > 0, f"could not find upscaler named {upscaler_name}"
upscaler = upscalers[0]
im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
......@@ -525,6 +544,9 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
elif extension.lower() in (".jpg", ".jpeg", ".webp"):
if image_to_save.mode == 'RGBA':
image_to_save = image_to_save.convert("RGB")
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
if opts.enable_pnginfo and info is not None:
......
......@@ -162,4 +162,4 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
if opts.do_not_show_images:
processed.images = []
return processed.images, generation_info_js, plaintext_to_html(processed.info)
return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
......@@ -135,7 +135,7 @@ class InterrogateModels:
return caption[0]
def interrogate(self, pil_image):
res = None
res = ""
try:
......
......@@ -71,10 +71,13 @@ class MemUsageMonitor(threading.Thread):
def read(self):
if not self.disabled:
free, total = torch.cuda.mem_get_info()
self.data["free"] = free
self.data["total"] = total
torch_stats = torch.cuda.memory_stats(self.device)
self.data["active"] = torch_stats["active.all.current"]
self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
self.data["reserved"] = torch_stats["reserved_bytes.all.current"]
self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
self.data["system_peak"] = total - self.data["min_free"]
......
......@@ -123,6 +123,23 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
pass
builtin_upscaler_classes = []
forbidden_upscaler_classes = set()
def list_builtin_upscalers():
load_upscalers()
builtin_upscaler_classes.clear()
builtin_upscaler_classes.extend(Upscaler.__subclasses__())
def forbid_loaded_nonbuiltin_upscalers():
for cls in Upscaler.__subclasses__():
if cls not in builtin_upscaler_classes:
forbidden_upscaler_classes.add(cls)
def load_upscalers():
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
# so we'll try to import any _model.py files before looking in __subclasses__
......@@ -139,6 +156,9 @@ def load_upscalers():
datas = []
commandline_options = vars(shared.cmd_opts)
for cls in Upscaler.__subclasses__():
if cls in forbidden_upscaler_classes:
continue
name = cls.__name__
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
scaler = cls(commandline_options.get(cmd_name, None))
......
......@@ -239,7 +239,7 @@ class StableDiffusionProcessing():
class Processed:
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
self.images = images_list
self.prompt = p.prompt
self.negative_prompt = p.negative_prompt
......@@ -247,6 +247,7 @@ class Processed:
self.subseed = subseed
self.subseed_strength = p.subseed_strength
self.info = info
self.comments = comments
self.width = p.width
self.height = p.height
self.sampler_name = p.sampler_name
......@@ -646,7 +647,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
devices.torch_gc()
res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
if p.scripts is not None:
p.scripts.postprocess(p, res)
......@@ -657,14 +658,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sampler = None
def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs):
def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, **kwargs):
super().__init__(**kwargs)
self.enable_hr = enable_hr
self.denoising_strength = denoising_strength
self.firstphase_width = firstphase_width
self.firstphase_height = firstphase_height
self.truncate_x = 0
self.truncate_y = 0
self.hr_scale = hr_scale
self.hr_upscaler = hr_upscaler
if firstphase_width != 0 or firstphase_height != 0:
print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr)
self.hr_scale = self.width / firstphase_width
self.width = firstphase_width
self.height = firstphase_height
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
......@@ -673,47 +678,29 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
else:
state.job_count = state.job_count * 2
self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
if self.firstphase_width == 0 or self.firstphase_height == 0:
desired_pixel_count = 512 * 512
actual_pixel_count = self.width * self.height
scale = math.sqrt(desired_pixel_count / actual_pixel_count)
self.firstphase_width = math.ceil(scale * self.width / 64) * 64
self.firstphase_height = math.ceil(scale * self.height / 64) * 64
firstphase_width_truncated = int(scale * self.width)
firstphase_height_truncated = int(scale * self.height)
else:
width_ratio = self.width / self.firstphase_width
height_ratio = self.height / self.firstphase_height
if width_ratio > height_ratio:
firstphase_width_truncated = self.firstphase_width
firstphase_height_truncated = self.firstphase_width * self.height / self.width
else:
firstphase_width_truncated = self.firstphase_height * self.width / self.height
firstphase_height_truncated = self.firstphase_height
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
self.extra_generation_params["Hires upscale"] = self.hr_scale
if self.hr_upscaler is not None:
self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
if self.enable_hr and latent_scale_mode is None:
assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}"
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
if not self.enable_hr:
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
return samples
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x, self.firstphase_width, self.firstphase_height))
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
target_width = int(self.width * self.hr_scale)
target_height = int(self.height * self.hr_scale)
"""saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
def save_intermediate(image, index):
"""saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
return
......@@ -722,11 +709,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")
if opts.use_scale_latent_for_hires_fix:
if latent_scale_mode is not None:
for i in range(samples.shape[0]):
save_intermediate(samples, i)
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode)
# Avoid making the inpainting conditioning unless necessary as
# this does need some extra compute to decode / encode the image again.
......@@ -746,7 +733,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
save_intermediate(image, i)
image = images.resize_image(0, image, self.width, self.height)
image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler)
image = np.array(image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0)
batch_images.append(image)
......@@ -763,7 +750,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
# GC now before running the next img2img to prevent running out of memory
x = None
......
......@@ -51,6 +51,13 @@ class UiTrainTabParams:
self.txt2img_preview_params = txt2img_preview_params
class ImageGridLoopParams:
def __init__(self, imgs, cols, rows):
self.imgs = imgs
self.cols = cols
self.rows = rows
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
callback_map = dict(
callbacks_app_started=[],
......@@ -63,6 +70,7 @@ callback_map = dict(
callbacks_cfg_denoiser=[],
callbacks_before_component=[],
callbacks_after_component=[],
callbacks_image_grid=[],
)
......@@ -155,6 +163,14 @@ def after_component_callback(component, **kwargs):
report_exception(c, 'after_component_callback')
def image_grid_callback(params: ImageGridLoopParams):
for c in callback_map['callbacks_image_grid']:
try:
c.callback(params)
except Exception:
report_exception(c, 'image_grid')
def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
......@@ -255,3 +271,11 @@ def on_before_component(callback):
def on_after_component(callback):
"""register a function to be called after a component is created. See on_before_component for more."""
add_callback(callback_map['callbacks_after_component'], callback)
def on_image_grid(callback):
"""register a function to be called before making an image grid.
The callback is called with one argument:
- params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
"""
add_callback(callback_map['callbacks_image_grid'], callback)
......@@ -5,7 +5,7 @@ import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
......@@ -65,6 +65,7 @@ def fix_checkpoint():
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
class StableDiffusionModelHijack:
fixes = None
comments = []
......@@ -75,17 +76,25 @@ class StableDiffusionModelHijack:
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
def hijack(self, m):
if type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
model_embeddings = m.cond_stage_model.roberta.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
self.clip = m.cond_stage_model
apply_optimizations()
self.clip = m.cond_stage_model
fix_checkpoint()
def flatten(el):
......@@ -98,7 +107,11 @@ class StableDiffusionModelHijack:
self.layers = flatten(m)
def undo_hijack(self, m):
if type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
m.cond_stage_model = m.cond_stage_model.wrapped
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
......@@ -126,8 +139,8 @@ class StableDiffusionModelHijack:
def tokenize(self, text):
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
class EmbeddingsWithFixes(torch.nn.Module):
......
......@@ -5,7 +5,6 @@ import torch
from modules import prompt_parser, devices
from modules.shared import opts
def get_target_prompt_token_count(token_count):
return math.ceil(max(token_count, 1) / 75) * 75
......@@ -254,10 +253,13 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
def __init__(self, wrapped, hijack):
super().__init__(wrapped, hijack)
self.tokenizer = wrapped.tokenizer
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
vocab = self.tokenizer.get_vocab()
self.comma_token = vocab.get(',</w>', None)
self.token_mults = {}
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens:
mult = 1.0
for c in text:
......@@ -296,6 +298,6 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
def encode_embedding_init_text(self, init_text, nvpt):
embedding_layer = self.wrapped.transformer.text_model.embeddings
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
return embedded
This diff is collapsed.
import open_clip.tokenizer
import torch
from modules import sd_hijack_clip, devices
from modules.shared import opts
class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
def __init__(self, wrapped, hijack):
super().__init__(wrapped, hijack)
self.id_start = wrapped.config.bos_token_id
self.id_end = wrapped.config.eos_token_id
self.id_pad = wrapped.config.pad_token_id
self.comma_token = self.tokenizer.get_vocab().get(',', None) # alt diffusion doesn't have </w> bits for comma
def encode_with_transformers(self, tokens):
# there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a
# trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer
# layer to work with - you have to use the last
attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64)
features = self.wrapped(input_ids=tokens, attention_mask=attention_mask)
z = features['projection_state']
return z
def encode_embedding_init_text(self, init_text, nvpt):
embedding_layer = self.wrapped.roberta.embeddings
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
return embedded
......@@ -228,6 +228,8 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
model.sd_model_checkpoint = checkpoint_file
model.sd_checkpoint_info = checkpoint_info
model.logvar = model.logvar.to(devices.device) # fix for training
sd_vae.delete_base_vae()
sd_vae.clear_loaded_vae()
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
......@@ -322,9 +324,12 @@ def load_model(checkpoint_info=None):
sd_model.eval()
shared.sd_model = sd_model
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
script_callbacks.model_loaded_callback(sd_model)
print("Model loaded.")
return sd_model
......
......@@ -465,7 +465,9 @@ class KDiffusionSampler:
if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps)
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
else:
sigmas = self.model_wrap.get_sigmas(steps)
......
import torch
import os
import collections
from collections import namedtuple
from modules import shared, devices, script_callbacks
from modules.paths import models_path
......@@ -30,6 +31,7 @@ base_vae = None
loaded_vae_file = None
checkpoint_info = None
checkpoints_loaded = collections.OrderedDict()
def get_base_vae(model):
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
......@@ -149,13 +151,30 @@ def load_vae(model, vae_file=None):
global first_load, vae_dict, vae_list, loaded_vae_file
# save_settings = False
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
if vae_file:
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
print(f"Loading VAE weights from: {vae_file}")
store_base_vae(model)
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
_load_vae_dict(model, vae_dict_1)
if cache_enabled and vae_file in checkpoints_loaded:
# use vae checkpoint cache
print(f"Loading VAE weights [{get_filename(vae_file)}] from cache")
store_base_vae(model)
_load_vae_dict(model, checkpoints_loaded[vae_file])
else:
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
print(f"Loading VAE weights from: {vae_file}")
store_base_vae(model)
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
_load_vae_dict(model, vae_dict_1)
if cache_enabled:
# cache newly loaded vae
checkpoints_loaded[vae_file] = vae_dict_1.copy()
# clean up cache if limit is reached
if cache_enabled:
while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model
checkpoints_loaded.popitem(last=False) # LRU
# If vae used is not in dict, update it
# It will be removed on refresh though
......
......@@ -23,7 +23,7 @@ demo = None
sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=os.path.join(script_path, "v1-inference.yaml"), help="path to config which constructs model",)
parser.add_argument("--config", type=str, default=os.path.join(script_path, "configs/v1-inference.yaml"), help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
......@@ -113,6 +113,17 @@ restricted_opts = {
"outdir_save",
}
ui_reorder_categories = [
"sampler",
"dimensions",
"cfg",
"seed",
"checkboxes",
"hires_fix",
"batch",
"scripts",
]
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
......@@ -172,7 +183,7 @@ class State:
def dict(self):
obj = {
"skipped": self.skipped,
"interrupted": self.skipped,
"interrupted": self.interrupted,
"job": self.job,
"job_count": self.job_count,
"job_no": self.job_no,
......@@ -331,7 +342,6 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
"use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space image when doing hires. fix"),
}))
options_templates.update(options_section(('face-restoration', "Face restoration"), {
......@@ -360,6 +370,7 @@ options_templates.update(options_section(('training', "Training"), {
options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
"sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
......@@ -371,13 +382,17 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", gr.ColorPicker, {}),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
}))
options_templates.update(options_section(('compatibility', "Compatibility"), {
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
}))
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
"interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
......@@ -409,7 +424,10 @@ options_templates.update(options_section(('ui', "User interface"), {
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"),
"dimensions_and_batch_together": OptionInfo(True, "Show Witdth/Height and Batch sliders in same row"),
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/ing2img UI item order"),
'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
}))
......@@ -543,6 +561,12 @@ opts = Options()
if os.path.exists(config_filename):
opts.load(config_filename)
latent_upscale_default_mode = "Latent"
latent_upscale_modes = {
"Latent": "bilinear",
"Latent (nearest)": "nearest",
}
sd_upscalers = []
sd_model = None
......
......@@ -23,6 +23,8 @@ class Embedding:
self.vec = vec
self.name = name
self.step = step
self.shape = None
self.vectors = 0
self.cached_checksum = None
self.sd_checkpoint = None
self.sd_checkpoint_name = None
......@@ -57,8 +59,10 @@ class EmbeddingDatabase:
def __init__(self, embeddings_dir):
self.ids_lookup = {}
self.word_embeddings = {}
self.skipped_embeddings = {}
self.dir_mtime = None
self.embeddings_dir = embeddings_dir
self.expected_shape = -1
def register_embedding(self, embedding, model):
......@@ -75,20 +79,24 @@ class EmbeddingDatabase:
return embedding
def load_textual_inversion_embeddings(self):
def get_expected_shape(self):
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
return vec.shape[1]
def load_textual_inversion_embeddings(self, force_reload = False):
mt = os.path.getmtime(self.embeddings_dir)
if self.dir_mtime is not None and mt <= self.dir_mtime:
if not force_reload and self.dir_mtime is not None and mt <= self.dir_mtime:
return
self.dir_mtime = mt
self.ids_lookup.clear()
self.word_embeddings.clear()
self.skipped_embeddings.clear()
self.expected_shape = self.get_expected_shape()
def process_file(path, filename):
name = os.path.splitext(filename)[0]
data = []
if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
embed_image = Image.open(path)
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
......@@ -122,7 +130,13 @@ class EmbeddingDatabase:
embedding.step = data.get('step', None)
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
self.register_embedding(embedding, shared.sd_model)
embedding.vectors = vec.shape[0]
embedding.shape = vec.shape[-1]
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
self.register_embedding(embedding, shared.sd_model)
else:
self.skipped_embeddings[name] = embedding
for fn in os.listdir(self.embeddings_dir):
try:
......@@ -137,8 +151,9 @@ class EmbeddingDatabase:
print(traceback.format_exc(), file=sys.stderr)
continue
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
print("Embeddings:", ', '.join(self.word_embeddings.keys()))
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
if len(self.skipped_embeddings) > 0:
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
def find_embedding_at_position(self, tokens, offset):
token = tokens[offset]
......@@ -267,7 +282,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
return embedding, filename
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
# dataset loading may take a while, so input validations and early returns should be done before this
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
old_parallel_processing_allowed = shared.parallel_processing_allowed
......@@ -295,7 +310,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
loss_step = 0
_loss_step = 0 #internal
last_saved_file = "<none>"
last_saved_image = "<none>"
forced_filename = "<none>"
......
......@@ -8,7 +8,7 @@ import modules.processing as processing
from modules.ui import plaintext_to_html
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args):
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, *args):
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
......@@ -33,8 +33,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
tiling=tiling,
enable_hr=enable_hr,
denoising_strength=denoising_strength if enable_hr else None,
firstphase_width=firstphase_width if enable_hr else None,
firstphase_height=firstphase_height if enable_hr else None,
hr_scale=hr_scale,
hr_upscaler=hr_upscaler,
)
p.scripts = modules.scripts.scripts_txt2img
......@@ -59,4 +59,4 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
if opts.do_not_show_images:
processed.images = []
return processed.images, generation_info_js, plaintext_to_html(processed.info)
return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
This diff is collapsed.
import gradio as gr
class ToolButton(gr.Button, gr.components.FormComponent):
"""Small button with single emoji as text, fits inside gradio forms"""
def __init__(self, **kwargs):
super().__init__(variant="tool", **kwargs)
def get_block_name(self):
return "button"
class FormRow(gr.Row, gr.components.FormComponent):
"""Same as gr.Row but fits inside gradio forms"""
def get_block_name(self):
return "row"
class FormGroup(gr.Group, gr.components.FormComponent):
"""Same as gr.Row but fits inside gradio forms"""
def get_block_name(self):
return "group"
import os
import tempfile
from collections import namedtuple
from pathlib import Path
import gradio as gr
......@@ -12,10 +13,29 @@ from modules import shared
Savedfile = namedtuple("Savedfile", ["name"])
def register_tmp_file(gradio, filename):
if hasattr(gradio, 'temp_file_sets'): # gradio 3.15
gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)}
if hasattr(gradio, 'temp_dirs'): # gradio 3.9
gradio.temp_dirs = gradio.temp_dirs | {os.path.abspath(os.path.dirname(filename))}
def check_tmp_file(gradio, filename):
if hasattr(gradio, 'temp_file_sets'):
return any([filename in fileset for fileset in gradio.temp_file_sets])
if hasattr(gradio, 'temp_dirs'):
return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs)
return False
def save_pil_to_file(pil_image, dir=None):
already_saved_as = getattr(pil_image, 'already_saved_as', None)
if already_saved_as and os.path.isfile(already_saved_as):
shared.demo.temp_dirs = shared.demo.temp_dirs | {os.path.abspath(os.path.dirname(already_saved_as))}
register_tmp_file(shared.demo, already_saved_as)
file_obj = Savedfile(already_saved_as)
return file_obj
......@@ -44,7 +64,7 @@ def on_tmpdir_changed():
os.makedirs(shared.opts.temp_dir, exist_ok=True)
shared.demo.temp_dirs = shared.demo.temp_dirs | {os.path.abspath(shared.opts.temp_dir)}
register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x"))
def cleanup_tmpdr():
......
......@@ -53,10 +53,10 @@ class Upscaler:
def do_upscale(self, img: PIL.Image, selected_model: str):
return img
def upscale(self, img: PIL.Image, scale: int, selected_model: str = None):
def upscale(self, img: PIL.Image, scale, selected_model: str = None):
self.scale = scale
dest_w = img.width * scale
dest_h = img.height * scale
dest_w = int(img.width * scale)
dest_h = int(img.height * scale)
for i in range(3):
shape = (img.width, img.height)
......
from transformers import BertPreTrainedModel,BertModel,BertConfig
import torch.nn as nn
import torch
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
from transformers import XLMRobertaModel,XLMRobertaTokenizer
from typing import Optional
class BertSeriesConfig(BertConfig):
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
self.project_dim = project_dim
self.pooler_fn = pooler_fn
self.learn_encoder = learn_encoder
class RobertaSeriesConfig(XLMRobertaConfig):
def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.project_dim = project_dim
self.pooler_fn = pooler_fn
self.learn_encoder = learn_encoder
class BertSeriesModelWithTransformation(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
config_class = BertSeriesConfig
def __init__(self, config=None, **kargs):
# modify initialization for autoloading
if config is None:
config = XLMRobertaConfig()
config.attention_probs_dropout_prob= 0.1
config.bos_token_id=0
config.eos_token_id=2
config.hidden_act='gelu'
config.hidden_dropout_prob=0.1
config.hidden_size=1024
config.initializer_range=0.02
config.intermediate_size=4096
config.layer_norm_eps=1e-05
config.max_position_embeddings=514
config.num_attention_heads=16
config.num_hidden_layers=24
config.output_past=True
config.pad_token_id=1
config.position_embedding_type= "absolute"
config.type_vocab_size= 1
config.use_cache=True
config.vocab_size= 250002
config.project_dim = 768
config.learn_encoder = False
super().__init__(config)
self.roberta = XLMRobertaModel(config)
self.transformation = nn.Linear(config.hidden_size,config.project_dim)
self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
self.pooler = lambda x: x[:,0]
self.post_init()
def encode(self,c):
device = next(self.parameters()).device
text = self.tokenizer(c,
truncation=True,
max_length=77,
return_length=False,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt")
text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
text["attention_mask"] = torch.tensor(
text['attention_mask']).to(device)
features = self(**text)
return features['projection_state']
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) :
r"""
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.roberta(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=True,
return_dict=return_dict,
)
# last module outputs
sequence_output = outputs[0]
# project every module
sequence_output_ln = self.pre_LN(sequence_output)
# pooler
pooler_output = self.pooler(sequence_output_ln)
pooler_output = self.transformation(pooler_output)
projection_state = self.transformation(outputs.last_hidden_state)
return {
'pooler_output':pooler_output,
'last_hidden_state':outputs.last_hidden_state,
'hidden_states':outputs.hidden_states,
'attentions':outputs.attentions,
'projection_state':projection_state,
'sequence_out': sequence_output
}
class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
base_model_prefix = 'roberta'
config_class= RobertaSeriesConfig
\ No newline at end of file
......@@ -3,7 +3,7 @@ transformers==4.19.2
accelerate==0.12.0
basicsr==1.4.2
gfpgan==1.3.8
gradio==3.9
gradio==3.15.0
numpy==1.23.3
Pillow==9.2.0
realesrgan==0.3.0
......
......@@ -19,7 +19,7 @@ class Script(scripts.Script):
def ui(self, is_img2img):
info = gr.HTML("<p style=\"margin-bottom:0.75em\">Will upscale the image by the selected scale factor; use width and height sliders to set tile size</p>")
overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64)
scale_factor = gr.Slider(minimum=1, maximum=4, step=1, label='Scale Factor', value=2)
scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0)
upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
return [info, overlap, upscaler_index, scale_factor]
......
......@@ -202,7 +202,7 @@ axis_options = [
AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None),
AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None),
AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None),
AxisOption("Upscale latent space for hires.", str, apply_upscale_latent_space, format_value_add_label, None),
AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), format_value_add_label, None),
AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None),
AxisOption("VAE", str, apply_vae, format_value_add_label, None),
AxisOption("Styles", str, apply_styles, format_value_add_label, None),
......@@ -267,7 +267,6 @@ class SharedSettingsStackHelper(object):
self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
self.hypernetwork = opts.sd_hypernetwork
self.model = shared.sd_model
self.use_scale_latent_for_hires_fix = opts.use_scale_latent_for_hires_fix
self.vae = opts.sd_vae
def __exit__(self, exc_type, exc_value, tb):
......@@ -278,7 +277,6 @@ class SharedSettingsStackHelper(object):
hypernetwork.apply_strength()
opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers
opts.data["use_scale_latent_for_hires_fix"] = self.use_scale_latent_for_hires_fix
re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
......
......@@ -73,8 +73,9 @@
margin-right: auto;
}
#random_seed, #random_subseed, #reuse_seed, #reuse_subseed, #open_folder{
min-width: auto;
[id$=_random_seed], [id$=_random_subseed], [id$=_reuse_seed], [id$=_reuse_subseed], #open_folder{
min-width: 2.3em;
height: 2.5em;
flex-grow: 0;
padding-left: 0.25em;
padding-right: 0.25em;
......@@ -84,27 +85,28 @@
display: none;
}
#seed_row, #subseed_row{
[id$=_seed_row], [id$=_subseed_row]{
gap: 0.5rem;
padding: 0.6em;
}
#subseed_show_box{
[id$=_subseed_show_box]{
min-width: auto;
flex-grow: 0;
}
#subseed_show_box > div{
[id$=_subseed_show_box] > div{
border: 0;
height: 100%;
}
#subseed_show{
[id$=_subseed_show]{
min-width: auto;
flex-grow: 0;
padding: 0;
}
#subseed_show label{
[id$=_subseed_show] label{
height: 100%;
}
......@@ -206,24 +208,24 @@ button{
fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block span{
position: absolute;
top: -0.6em;
top: -0.7em;
line-height: 1.2em;
padding: 0 0.5em;
margin: 0;
padding: 0;
margin: 0 0.5em;
background-color: white;
border-top: 1px solid #eee;
border-left: 1px solid #eee;
border-right: 1px solid #eee;
box-shadow: 6px 0 6px 0px white, -6px 0 6px 0px white;
z-index: 300;
}
.dark fieldset span.text-gray-500, .dark .gr-block.gr-box span.text-gray-500, .dark label.block span{
background-color: rgb(31, 41, 55);
border-top: 1px solid rgb(55 65 81);
border-left: 1px solid rgb(55 65 81);
border-right: 1px solid rgb(55 65 81);
box-shadow: 6px 0 6px 0px rgb(31, 41, 55), -6px 0 6px 0px rgb(31, 41, 55);
}
#txt2img_column_batch, #img2img_column_batch{
min-width: min(13.5em, 100%) !important;
}
#settings fieldset span.text-gray-500, #settings .gr-block.gr-box span.text-gray-500, #settings label.block span{
......@@ -232,22 +234,40 @@ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block s
margin-right: 8em;
}
.gr-panel div.flex-col div.justify-between label span{
margin: 0;
}
#settings .gr-panel div.flex-col div.justify-between div{
position: relative;
z-index: 200;
}
input[type="range"]{
margin: 0.5em 0 -0.3em 0;
#settings{
display: block;
}
#txt2img_sampling label{
padding-left: 0.6em;
padding-right: 0.6em;
#settings > div{
border: none;
margin-left: 10em;
}
#settings > div.flex-wrap{
float: left;
display: block;
margin-left: 0;
width: 10em;
}
#settings > div.flex-wrap button{
display: block;
border: none;
text-align: left;
}
#settings_result{
height: 1.4em;
margin: 0 1.2em;
}
input[type="range"]{
margin: 0.5em 0 -0.3em 0;
}
#mask_bug_info {
......@@ -501,13 +521,6 @@ input[type="range"]{
padding: 0;
}
#refresh_sd_model_checkpoint, #refresh_sd_vae, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization{
max-width: 2.5em;
min-width: 2.5em;
height: 2.4em;
}
canvas[key="mask"] {
z-index: 12 !important;
filter: invert();
......@@ -521,7 +534,7 @@ canvas[key="mask"] {
position: absolute;
right: 0.5em;
top: -0.6em;
z-index: 200;
z-index: 400;
width: 8em;
}
#quicksettings .gr-box > div > div > input.gr-text-input {
......@@ -568,6 +581,53 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h
font-size: 95%;
}
#image_buttons_txt2img button, #image_buttons_img2img button, #image_buttons_extras button{
min-width: auto;
padding-left: 0.5em;
padding-right: 0.5em;
}
.gr-form{
background-color: white;
}
.dark .gr-form{
background-color: rgb(31 41 55 / var(--tw-bg-opacity));
}
.gr-button-tool{
max-width: 2.5em;
min-width: 2.5em !important;
height: 2.4em;
margin: 0.55em 0;
}
#quicksettings .gr-button-tool{
margin: 0;
}
#img2img_settings > div.gr-form, #txt2img_settings > div.gr-form {
padding-top: 0.9em;
}
#img2img_settings div.gr-form .gr-form, #txt2img_settings div.gr-form .gr-form{
border: none;
padding-bottom: 0.5em;
}
footer {
display: none !important;
}
#footer{
text-align: center;
}
#footer div{
display: inline-block;
}
/* The following handles localization for right-to-left (RTL) languages like Arabic.
The rtl media type will only be activated by the logic in javascript/localization.js.
If you change anything above, you need to make sure it is RTL compliant by just running
......
model:
base_learning_rate: 1.0e-4
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
parameterization: "v"
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False # we set this to false because this is an inference only config
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
use_fp16: True
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
#attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"
\ No newline at end of file
import os
import sys
import threading
import time
import importlib
......@@ -55,8 +56,8 @@ def initialize():
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
modelloader.list_builtin_upscalers()
modules.scripts.load_scripts()
modelloader.load_upscalers()
modules.sd_vae.refresh_vae_list()
......@@ -169,23 +170,22 @@ def webui():
modules.script_callbacks.app_started_callback(shared.demo, app)
wait_on_server(shared.demo)
print('Restarting UI...')
sd_samplers.set_samplers()
print('Reloading extensions')
extensions.list_extensions()
localization.list_localizations(cmd_opts.localizations_dir)
print('Reloading custom scripts')
modelloader.forbid_loaded_nonbuiltin_upscalers()
modules.scripts.reload_scripts()
modelloader.load_upscalers()
print('Reloading modules: modules.ui')
importlib.reload(modules.ui)
print('Refreshing Model List')
for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
importlib.reload(module)
modules.sd_models.list_models()
print('Restarting Gradio')
if __name__ == "__main__":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment