Unverified Commit 83ca8dd0 authored by Philpax's avatar Philpax Committed by GitHub

Merge branch 'AUTOMATIC1111:master' into fix-sd-arch-switch-in-override-settings

parents fa931733 5f4fa942
...@@ -127,6 +127,8 @@ Here's how to add code to this repo: [Contributing](https://github.com/AUTOMATIC ...@@ -127,6 +127,8 @@ Here's how to add code to this repo: [Contributing](https://github.com/AUTOMATIC
The documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki). The documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki).
## Credits ## Credits
Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
- Stable Diffusion - https://github.com/CompVis/stable-diffusion, https://github.com/CompVis/taming-transformers - Stable Diffusion - https://github.com/CompVis/stable-diffusion, https://github.com/CompVis/taming-transformers
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git - k-diffusion - https://github.com/crowsonkb/k-diffusion.git
- GFPGAN - https://github.com/TencentARC/GFPGAN.git - GFPGAN - https://github.com/TencentARC/GFPGAN.git
......
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: modules.xlmr.BertSeriesModelWithTransformation
params:
name: "XLMR-Large"
\ No newline at end of file
import random
from modules import script_callbacks, shared
import gradio as gr
art_symbol = '\U0001f3a8' # 🎨
global_prompt = None
related_ids = {"txt2img_prompt", "txt2img_clear_prompt", "img2img_prompt", "img2img_clear_prompt" }
def roll_artist(prompt):
allowed_cats = set([x for x in shared.artist_db.categories() if len(shared.opts.random_artist_categories)==0 or x in shared.opts.random_artist_categories])
artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats])
return prompt + ", " + artist.name if prompt != '' else artist.name
def add_roll_button(prompt):
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
roll.click(
fn=roll_artist,
_js="update_txt2img_tokens",
inputs=[
prompt,
],
outputs=[
prompt,
]
)
def after_component(component, **kwargs):
global global_prompt
elem_id = kwargs.get('elem_id', None)
if elem_id not in related_ids:
return
if elem_id == "txt2img_prompt":
global_prompt = component
elif elem_id == "txt2img_clear_prompt":
add_roll_button(global_prompt)
elif elem_id == "img2img_prompt":
global_prompt = component
elif elem_id == "img2img_clear_prompt":
add_roll_button(global_prompt)
script_callbacks.on_after_component(after_component)
<div>
<a href="/docs">API</a>
 • 
<a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui">Github</a>
 • 
<a href="https://gradio.app">Gradio</a>
 • 
<a href="/" onclick="javascript:gradioApp().getElementById('settings_restart_gradio').click(); return false">Reload UI</a>
</div>
This diff is collapsed.
...@@ -9,11 +9,19 @@ function dropReplaceImage( imgWrap, files ) { ...@@ -9,11 +9,19 @@ function dropReplaceImage( imgWrap, files ) {
return; return;
} }
const tmpFile = files[0];
imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click(); imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click();
const callback = () => { const callback = () => {
const fileInput = imgWrap.querySelector('input[type="file"]'); const fileInput = imgWrap.querySelector('input[type="file"]');
if ( fileInput ) { if ( fileInput ) {
fileInput.files = files; if ( files.length === 0 ) {
files = new DataTransfer();
files.items.add(tmpFile);
fileInput.files = files.files;
} else {
fileInput.files = files;
}
fileInput.dispatchEvent(new Event('change')); fileInput.dispatchEvent(new Event('change'));
} }
}; };
......
...@@ -81,9 +81,6 @@ titles = { ...@@ -81,9 +81,6 @@ titles = {
"vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).", "vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).",
"Highres. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition",
"Scale latent": "Uscale the image in latent space. Alternative is to produce the full image from latent representation, upscale that, and then move it back to latent space.",
"Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.", "Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.",
"Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.", "Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.",
...@@ -100,7 +97,13 @@ titles = { ...@@ -100,7 +97,13 @@ titles = {
"Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.", "Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.",
"Approx NN": "Cheap neural network approximation. Very fast compared to VAE, but produces pictures with 4 times smaller horizontal/vertical resoluton and lower quality.", "Approx NN": "Cheap neural network approximation. Very fast compared to VAE, but produces pictures with 4 times smaller horizontal/vertical resoluton and lower quality.",
"Approx cheap": "Very cheap approximation. Very fast compared to VAE, but produces pictures with 8 times smaller horizontal/vertical resoluton and extremely low quality." "Approx cheap": "Very cheap approximation. Very fast compared to VAE, but produces pictures with 8 times smaller horizontal/vertical resoluton and extremely low quality.",
"Hires. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition",
"Hires steps": "Number of sampling steps for upscaled picture. If 0, uses same as for original.",
"Upscale by": "Adjusts the size of the image by multiplying the original width and height by the selected value. Ignored if either Resize width to or Resize height to are non-zero.",
"Resize width to": "Resizes image to this width. If 0, width is inferred from either of two nearby sliders.",
"Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders."
} }
......
...@@ -148,8 +148,8 @@ function showGalleryImage() { ...@@ -148,8 +148,8 @@ function showGalleryImage() {
if(e && e.parentElement.tagName == 'DIV'){ if(e && e.parentElement.tagName == 'DIV'){
e.style.cursor='pointer' e.style.cursor='pointer'
e.style.userSelect='none' e.style.userSelect='none'
e.addEventListener('click', function (evt) { e.addEventListener('mousedown', function (evt) {
if(!opts.js_modal_lightbox) return; if(!opts.js_modal_lightbox || evt.button != 0) return;
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed) modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
showModal(evt) showModal(evt)
}, true); }, true);
......
// various functions for interation with ui.py not large enough to warrant putting them in separate files // various functions for interaction with ui.py not large enough to warrant putting them in separate files
function set_theme(theme){ function set_theme(theme){
gradioURL = window.location.href gradioURL = window.location.href
...@@ -19,7 +19,7 @@ function selected_gallery_index(){ ...@@ -19,7 +19,7 @@ function selected_gallery_index(){
function extract_image_from_gallery(gallery){ function extract_image_from_gallery(gallery){
if(gallery.length == 1){ if(gallery.length == 1){
return gallery[0] return [gallery[0]]
} }
index = selected_gallery_index() index = selected_gallery_index()
...@@ -28,7 +28,7 @@ function extract_image_from_gallery(gallery){ ...@@ -28,7 +28,7 @@ function extract_image_from_gallery(gallery){
return [null] return [null]
} }
return gallery[index]; return [gallery[index]];
} }
function args_to_array(args){ function args_to_array(args){
...@@ -188,6 +188,17 @@ onUiUpdate(function(){ ...@@ -188,6 +188,17 @@ onUiUpdate(function(){
img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea"); img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button")); img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
} }
show_all_pages = gradioApp().getElementById('settings_show_all_pages')
settings_tabs = gradioApp().querySelector('#settings div')
if(show_all_pages && settings_tabs){
settings_tabs.appendChild(show_all_pages)
show_all_pages.onclick = function(){
gradioApp().querySelectorAll('#settings > div').forEach(function(elem){
elem.style.display = "block";
})
}
}
}) })
let txt2img_textarea, img2img_textarea = undefined; let txt2img_textarea, img2img_textarea = undefined;
......
import base64 import base64
import io import io
import time import time
import datetime
import uvicorn import uvicorn
from threading import Lock from threading import Lock
from io import BytesIO from io import BytesIO
from gradio.processing_utils import decode_base64_to_file from gradio.processing_utils import decode_base64_to_file
from fastapi import APIRouter, Depends, FastAPI, HTTPException from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
from fastapi.security import HTTPBasic, HTTPBasicCredentials from fastapi.security import HTTPBasic, HTTPBasicCredentials
from secrets import compare_digest from secrets import compare_digest
...@@ -18,7 +19,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_ ...@@ -18,7 +19,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
from modules.textual_inversion.preprocess import preprocess from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image from PIL import PngImagePlugin,Image
from modules.sd_models import checkpoints_list from modules.sd_models import checkpoints_list, find_checkpoint_config
from modules.realesrgan_model import get_realesrgan_models from modules.realesrgan_model import get_realesrgan_models
from modules import devices from modules import devices
from typing import List from typing import List
...@@ -67,6 +68,27 @@ def encode_pil_to_base64(image): ...@@ -67,6 +68,27 @@ def encode_pil_to_base64(image):
bytes_data = output_bytes.getvalue() bytes_data = output_bytes.getvalue()
return base64.b64encode(bytes_data) return base64.b64encode(bytes_data)
def api_middleware(app: FastAPI):
@app.middleware("http")
async def log_and_time(req: Request, call_next):
ts = time.time()
res: Response = await call_next(req)
duration = str(round(time.time() - ts, 4))
res.headers["X-Process-Time"] = duration
endpoint = req.scope.get('path', 'err')
if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
code = res.status_code,
ver = req.scope.get('http_version', '0.0'),
cli = req.scope.get('client', ('0:0.0.0', 0))[0],
prot = req.scope.get('scheme', 'err'),
method = req.scope.get('method', 'err'),
endpoint = endpoint,
duration = duration,
))
return res
class Api: class Api:
def __init__(self, app: FastAPI, queue_lock: Lock): def __init__(self, app: FastAPI, queue_lock: Lock):
...@@ -79,6 +101,7 @@ class Api: ...@@ -79,6 +101,7 @@ class Api:
self.router = APIRouter() self.router = APIRouter()
self.app = app self.app = app
self.queue_lock = queue_lock self.queue_lock = queue_lock
api_middleware(self.app)
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse) self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse) self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
...@@ -100,6 +123,7 @@ class Api: ...@@ -100,6 +123,7 @@ class Api:
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem]) self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str]) self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem]) self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse) self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse) self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
...@@ -128,15 +152,14 @@ class Api: ...@@ -128,15 +152,14 @@ class Api:
) )
if populate.sampler_name: if populate.sampler_name:
populate.sampler_index = None # prevent a warning later on populate.sampler_index = None # prevent a warning later on
p = StableDiffusionProcessingTxt2Img(**vars(populate))
# Override object param
shared.state.begin()
with self.queue_lock: with self.queue_lock:
p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **vars(populate))
shared.state.begin()
processed = process_images(p) processed = process_images(p)
shared.state.end()
shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images)) b64images = list(map(encode_pil_to_base64, processed.images))
...@@ -163,16 +186,14 @@ class Api: ...@@ -163,16 +186,14 @@ class Api:
args = vars(populate) args = vars(populate)
args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine. args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
p = StableDiffusionProcessingImg2Img(**args)
p.init_images = [decode_base64_to_image(x) for x in init_images]
shared.state.begin()
with self.queue_lock: with self.queue_lock:
processed = process_images(p) p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
p.init_images = [decode_base64_to_image(x) for x in init_images]
shared.state.end() shared.state.begin()
processed = process_images(p)
shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images)) b64images = list(map(encode_pil_to_base64, processed.images))
...@@ -305,7 +326,7 @@ class Api: ...@@ -305,7 +326,7 @@ class Api:
return upscalers return upscalers
def get_sd_models(self): def get_sd_models(self):
return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": x.config} for x in checkpoints_list.values()] return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
def get_hypernetworks(self): def get_hypernetworks(self):
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks] return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
...@@ -330,6 +351,26 @@ class Api: ...@@ -330,6 +351,26 @@ class Api:
def get_artists(self): def get_artists(self):
return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists] return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
def get_embeddings(self):
db = sd_hijack.model_hijack.embedding_db
def convert_embedding(embedding):
return {
"step": embedding.step,
"sd_checkpoint": embedding.sd_checkpoint,
"sd_checkpoint_name": embedding.sd_checkpoint_name,
"shape": embedding.shape,
"vectors": embedding.vectors,
}
def convert_embeddings(embeddings):
return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
return {
"loaded": convert_embeddings(db.word_embeddings),
"skipped": convert_embeddings(db.skipped_embeddings),
}
def refresh_checkpoints(self): def refresh_checkpoints(self):
shared.refresh_checkpoints() shared.refresh_checkpoints()
......
...@@ -249,3 +249,13 @@ class ArtistItem(BaseModel): ...@@ -249,3 +249,13 @@ class ArtistItem(BaseModel):
score: float = Field(title="Score") score: float = Field(title="Score")
category: str = Field(title="Category") category: str = Field(title="Category")
class EmbeddingItem(BaseModel):
step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
shape: int = Field(title="Shape", description="The length of each individual vector in the embedding")
vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
class EmbeddingsResponse(BaseModel):
loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
\ No newline at end of file
...@@ -2,9 +2,30 @@ import sys ...@@ -2,9 +2,30 @@ import sys
import traceback import traceback
def print_error_explanation(message):
lines = message.strip().split("\n")
max_len = max([len(x) for x in lines])
print('=' * max_len, file=sys.stderr)
for line in lines:
print(line, file=sys.stderr)
print('=' * max_len, file=sys.stderr)
def display(e: Exception, task):
print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
message = str(e)
if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
print_error_explanation("""
The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its connfig file.
See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
""")
def run(code, task): def run(code, task):
try: try:
code() code()
except Exception as e: except Exception as e:
print(f"{task}: {type(e).__name__}", file=sys.stderr) display(task, e)
print(traceback.format_exc(), file=sys.stderr)
...@@ -19,8 +19,6 @@ from modules.shared import opts ...@@ -19,8 +19,6 @@ from modules.shared import opts
import modules.gfpgan_model import modules.gfpgan_model
from modules.ui import plaintext_to_html from modules.ui import plaintext_to_html
import modules.codeformer_model import modules.codeformer_model
import piexif
import piexif.helper
import gradio as gr import gradio as gr
import safetensors.torch import safetensors.torch
...@@ -58,6 +56,9 @@ cached_images: LruCache = LruCache(max_size=5) ...@@ -58,6 +56,9 @@ cached_images: LruCache = LruCache(max_size=5)
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
devices.torch_gc() devices.torch_gc()
shared.state.begin()
shared.state.job = 'extras'
imageArr = [] imageArr = []
# Also keep track of original file names # Also keep track of original file names
imageNameArr = [] imageNameArr = []
...@@ -94,6 +95,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -94,6 +95,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
# Extra operation definitions # Extra operation definitions
def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]: def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
shared.state.job = 'extras-gfpgan'
restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
res = Image.fromarray(restored_img) res = Image.fromarray(restored_img)
...@@ -104,6 +106,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -104,6 +106,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
return (res, info) return (res, info)
def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]: def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
shared.state.job = 'extras-codeformer'
restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight) restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
res = Image.fromarray(restored_img) res = Image.fromarray(restored_img)
...@@ -114,6 +117,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -114,6 +117,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
return (res, info) return (res, info)
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
shared.state.job = 'extras-upscale'
upscaler = shared.sd_upscalers[scaler_index] upscaler = shared.sd_upscalers[scaler_index]
res = upscaler.scaler.upscale(image, resize, upscaler.data_path) res = upscaler.scaler.upscale(image, resize, upscaler.data_path)
if mode == 1 and crop: if mode == 1 and crop:
...@@ -180,6 +184,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -180,6 +184,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
for image, image_name in zip(imageArr, imageNameArr): for image, image_name in zip(imageArr, imageNameArr):
if image is None: if image is None:
return outputs, "Please select an input image.", '' return outputs, "Please select an input image.", ''
shared.state.textinfo = f'Processing image {image_name}'
existing_pnginfo = image.info or {} existing_pnginfo = image.info or {}
image = image.convert("RGB") image = image.convert("RGB")
...@@ -193,6 +200,10 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -193,6 +200,10 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
else: else:
basename = '' basename = ''
if opts.enable_pnginfo: # append info before save
image.info = existing_pnginfo
image.info["extras"] = info
if save_output: if save_output:
# Add upscaler name as a suffix. # Add upscaler name as a suffix.
suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else "" suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else ""
...@@ -203,10 +214,6 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -203,10 +214,6 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix) no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix)
if opts.enable_pnginfo:
image.info = existing_pnginfo
image.info["extras"] = info
if extras_mode != 2 or show_extras_results : if extras_mode != 2 or show_extras_results :
outputs.append(image) outputs.append(image)
...@@ -242,6 +249,9 @@ def run_pnginfo(image): ...@@ -242,6 +249,9 @@ def run_pnginfo(image):
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format): def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
shared.state.begin()
shared.state.job = 'model-merge'
def weighted_sum(theta0, theta1, alpha): def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1) return ((1 - alpha) * theta0) + (alpha * theta1)
...@@ -263,8 +273,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam ...@@ -263,8 +273,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
theta_func1, theta_func2 = theta_funcs[interp_method] theta_func1, theta_func2 = theta_funcs[interp_method]
if theta_func1 and not tertiary_model_info: if theta_func1 and not tertiary_model_info:
shared.state.textinfo = "Failed: Interpolation method requires a tertiary model."
shared.state.end()
return ["Failed: Interpolation method requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)] return ["Failed: Interpolation method requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
shared.state.textinfo = f"Loading {secondary_model_info.filename}..."
print(f"Loading {secondary_model_info.filename}...") print(f"Loading {secondary_model_info.filename}...")
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
...@@ -281,6 +294,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam ...@@ -281,6 +294,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
theta_1[key] = torch.zeros_like(theta_1[key]) theta_1[key] = torch.zeros_like(theta_1[key])
del theta_2 del theta_2
shared.state.textinfo = f"Loading {primary_model_info.filename}..."
print(f"Loading {primary_model_info.filename}...") print(f"Loading {primary_model_info.filename}...")
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu') theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
...@@ -291,6 +305,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam ...@@ -291,6 +305,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
a = theta_0[key] a = theta_0[key]
b = theta_1[key] b = theta_1[key]
shared.state.textinfo = f'Merging layer {key}'
# this enables merging an inpainting model (A) with another one (B); # this enables merging an inpainting model (A) with another one (B);
# where normal model would have 4 channels, for latenst space, inpainting model would # where normal model would have 4 channels, for latenst space, inpainting model would
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9 # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
...@@ -330,6 +345,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam ...@@ -330,6 +345,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
output_modelname = os.path.join(ckpt_dir, filename) output_modelname = os.path.join(ckpt_dir, filename)
shared.state.textinfo = f"Saving to {output_modelname}..."
print(f"Saving to {output_modelname}...") print(f"Saving to {output_modelname}...")
_, extension = os.path.splitext(output_modelname) _, extension = os.path.splitext(output_modelname)
...@@ -341,4 +357,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam ...@@ -341,4 +357,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
sd_models.list_models() sd_models.list_models()
print("Checkpoint saved.") print("Checkpoint saved.")
shared.state.textinfo = "Checkpoint saved to " + output_modelname
shared.state.end()
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)] return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
import base64 import base64
import io import io
import math
import os import os
import re import re
from pathlib import Path from pathlib import Path
import gradio as gr import gradio as gr
from modules.shared import script_path from modules.shared import script_path
from modules import shared from modules import shared, ui_tempdir
import tempfile import tempfile
from PIL import Image from PIL import Image
...@@ -36,9 +37,12 @@ def quote(text): ...@@ -36,9 +37,12 @@ def quote(text):
def image_from_url_text(filedata): def image_from_url_text(filedata):
if type(filedata) == dict and filedata["is_file"]: if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False):
filedata = filedata[0]
if type(filedata) == dict and filedata.get("is_file", False):
filename = filedata["name"] filename = filedata["name"]
is_in_right_dir = any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in shared.demo.temp_dirs) is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
assert is_in_right_dir, 'trying to open image file outside of allowed directories' assert is_in_right_dir, 'trying to open image file outside of allowed directories'
return Image.open(filename) return Image.open(filename)
...@@ -93,7 +97,7 @@ def integrate_settings_paste_fields(component_dict): ...@@ -93,7 +97,7 @@ def integrate_settings_paste_fields(component_dict):
def create_buttons(tabs_list): def create_buttons(tabs_list):
buttons = {} buttons = {}
for tab in tabs_list: for tab in tabs_list:
buttons[tab] = gr.Button(f"Send to {tab}") buttons[tab] = gr.Button(f"Send to {tab}", elem_id=f"{tab}_tab")
return buttons return buttons
...@@ -102,35 +106,57 @@ def bind_buttons(buttons, send_image, send_generate_info): ...@@ -102,35 +106,57 @@ def bind_buttons(buttons, send_image, send_generate_info):
bind_list.append([buttons, send_image, send_generate_info]) bind_list.append([buttons, send_image, send_generate_info])
def send_image_and_dimensions(x):
if isinstance(x, Image.Image):
img = x
else:
img = image_from_url_text(x)
if shared.opts.send_size and isinstance(img, Image.Image):
w = img.width
h = img.height
else:
w = gr.update()
h = gr.update()
return img, w, h
def run_bind(): def run_bind():
for buttons, send_image, send_generate_info in bind_list: for buttons, source_image_component, send_generate_info in bind_list:
for tab in buttons: for tab in buttons:
button = buttons[tab] button = buttons[tab]
if send_image and paste_fields[tab]["init_img"]: destination_image_component = paste_fields[tab]["init_img"]
if type(send_image) == gr.Gallery: fields = paste_fields[tab]["fields"]
button.click(
fn=lambda x: image_from_url_text(x), destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
_js="extract_image_from_gallery", destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
inputs=[send_image],
outputs=[paste_fields[tab]["init_img"]], if source_image_component and destination_image_component:
) if isinstance(source_image_component, gr.Gallery):
func = send_image_and_dimensions if destination_width_component else image_from_url_text
jsfunc = "extract_image_from_gallery"
else: else:
button.click( func = send_image_and_dimensions if destination_width_component else lambda x: x
fn=lambda x: x, jsfunc = None
inputs=[send_image],
outputs=[paste_fields[tab]["init_img"]], button.click(
) fn=func,
_js=jsfunc,
inputs=[source_image_component],
outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
)
if send_generate_info and paste_fields[tab]["fields"] is not None: if send_generate_info and fields is not None:
if send_generate_info in paste_fields: if send_generate_info in paste_fields:
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (['Size-1', 'Size-2'] if shared.opts.send_size else []) + (["Seed"] if shared.opts.send_seed else []) paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
button.click( button.click(
fn=lambda *x: x, fn=lambda *x: x,
inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names], inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names],
outputs=[field for field, name in paste_fields[tab]["fields"] if name in paste_field_names], outputs=[field for field, name in fields if name in paste_field_names],
) )
else: else:
connect_paste(button, paste_fields[tab]["fields"], send_generate_info) connect_paste(button, fields, send_generate_info)
button.click( button.click(
fn=None, fn=None,
...@@ -164,6 +190,34 @@ def find_hypernetwork_key(hypernet_name, hypernet_hash=None): ...@@ -164,6 +190,34 @@ def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
return None return None
def restore_old_hires_fix_params(res):
"""for infotexts that specify old First pass size parameter, convert it into
width, height, and hr scale"""
firstpass_width = res.get('First pass size-1', None)
firstpass_height = res.get('First pass size-2', None)
if firstpass_width is None or firstpass_height is None:
return
firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height)
width = int(res.get("Size-1", 512))
height = int(res.get("Size-2", 512))
if firstpass_width == 0 or firstpass_height == 0:
# old algorithm for auto-calculating first pass size
desired_pixel_count = 512 * 512
actual_pixel_count = width * height
scale = math.sqrt(desired_pixel_count / actual_pixel_count)
firstpass_width = math.ceil(scale * width / 64) * 64
firstpass_height = math.ceil(scale * height / 64) * 64
res['Size-1'] = firstpass_width
res['Size-2'] = firstpass_height
res['Hires resize-1'] = width
res['Hires resize-2'] = height
def parse_generation_parameters(x: str): def parse_generation_parameters(x: str):
"""parses generation parameters string, the one you see in text field under the picture in UI: """parses generation parameters string, the one you see in text field under the picture in UI:
``` ```
...@@ -221,6 +275,12 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model ...@@ -221,6 +275,12 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
hypernet_hash = res.get("Hypernet hash", None) hypernet_hash = res.get("Hypernet hash", None)
res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash) res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash)
if "Hires resize-1" not in res:
res["Hires resize-1"] = 0
res["Hires resize-2"] = 0
restore_old_hires_fix_params(res)
return res return res
......
...@@ -402,10 +402,8 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, ...@@ -402,10 +402,8 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
shared.reload_hypernetworks() shared.reload_hypernetworks()
return fn
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem. # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images from modules import images
...@@ -417,6 +415,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, ...@@ -417,6 +415,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
shared.loaded_hypernetwork = Hypernetwork() shared.loaded_hypernetwork = Hypernetwork()
shared.loaded_hypernetwork.load(path) shared.loaded_hypernetwork.load(path)
shared.state.job = "train-hypernetwork"
shared.state.textinfo = "Initializing hypernetwork training..." shared.state.textinfo = "Initializing hypernetwork training..."
shared.state.job_count = steps shared.state.job_count = steps
...@@ -447,6 +446,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, ...@@ -447,6 +446,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
return hypernetwork, filename return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, initial_step) scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
if clip_grad:
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
# dataset loading may take a while, so input validations and early returns should be done before this # dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
...@@ -465,7 +468,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, ...@@ -465,7 +468,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
shared.parallel_processing_allowed = False shared.parallel_processing_allowed = False
shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu)
weights = hypernetwork.weights() weights = hypernetwork.weights()
hypernetwork.train_mode() hypernetwork.train_mode()
...@@ -524,6 +527,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, ...@@ -524,6 +527,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
if shared.state.interrupted: if shared.state.interrupted:
break break
if clip_grad:
clip_grad_sched.step(hypernetwork.step)
with devices.autocast(): with devices.autocast():
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
if tag_drop_out != 0 or shuffle_tags: if tag_drop_out != 0 or shuffle_tags:
...@@ -538,14 +544,14 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, ...@@ -538,14 +544,14 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
_loss_step += loss.item() _loss_step += loss.item()
scaler.scale(loss).backward() scaler.scale(loss).backward()
# go back until we reach gradient accumulation steps # go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0: if (j + 1) % gradient_step != 0:
continue continue
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}")
# scaler.unscale_(optimizer) if clip_grad:
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}") clip_grad(weights, clip_grad_sched.learn_rate)
# torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0)
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
hypernetwork.step += 1 hypernetwork.step += 1
......
...@@ -39,11 +39,14 @@ def image_grid(imgs, batch_size=1, rows=None): ...@@ -39,11 +39,14 @@ def image_grid(imgs, batch_size=1, rows=None):
cols = math.ceil(len(imgs) / rows) cols = math.ceil(len(imgs) / rows)
params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)
script_callbacks.image_grid_callback(params)
w, h = imgs[0].size w, h = imgs[0].size
grid = Image.new('RGB', size=(cols * w, rows * h), color='black') grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color='black')
for i, img in enumerate(imgs): for i, img in enumerate(params.imgs):
grid.paste(img, box=(i % cols * w, i // cols * h)) grid.paste(img, box=(i % params.cols * w, i // params.cols * h))
return grid return grid
...@@ -227,16 +230,32 @@ def draw_prompt_matrix(im, width, height, all_prompts): ...@@ -227,16 +230,32 @@ def draw_prompt_matrix(im, width, height, all_prompts):
return draw_grid_annotations(im, width, height, hor_texts, ver_texts) return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
def resize_image(resize_mode, im, width, height): def resize_image(resize_mode, im, width, height, upscaler_name=None):
"""
Resizes an image with the specified resize_mode, width, and height.
Args:
resize_mode: The mode to use when resizing the image.
0: Resize the image to the specified width and height.
1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
im: The image to resize.
width: The width to resize the image to.
height: The height to resize the image to.
upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img.
"""
upscaler_name = upscaler_name or opts.upscaler_for_img2img
def resize(im, w, h): def resize(im, w, h):
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L': if upscaler_name is None or upscaler_name == "None" or im.mode == 'L':
return im.resize((w, h), resample=LANCZOS) return im.resize((w, h), resample=LANCZOS)
scale = max(w / im.width, h / im.height) scale = max(w / im.width, h / im.height)
if scale > 1.0: if scale > 1.0:
upscalers = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img] upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]
assert len(upscalers) > 0, f"could not find upscaler named {opts.upscaler_for_img2img}" assert len(upscalers) > 0, f"could not find upscaler named {upscaler_name}"
upscaler = upscalers[0] upscaler = upscalers[0]
im = upscaler.scaler.upscale(im, scale, upscaler.data_path) im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
...@@ -525,6 +544,9 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i ...@@ -525,6 +544,9 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data) image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
elif extension.lower() in (".jpg", ".jpeg", ".webp"): elif extension.lower() in (".jpg", ".jpeg", ".webp"):
if image_to_save.mode == 'RGBA':
image_to_save = image_to_save.convert("RGB")
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality) image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
if opts.enable_pnginfo and info is not None: if opts.enable_pnginfo and info is not None:
......
...@@ -162,4 +162,4 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro ...@@ -162,4 +162,4 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
if opts.do_not_show_images: if opts.do_not_show_images:
processed.images = [] processed.images = []
return processed.images, generation_info_js, plaintext_to_html(processed.info) return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
...@@ -135,8 +135,9 @@ class InterrogateModels: ...@@ -135,8 +135,9 @@ class InterrogateModels:
return caption[0] return caption[0]
def interrogate(self, pil_image): def interrogate(self, pil_image):
res = None res = ""
shared.state.begin()
shared.state.job = 'interrogate'
try: try:
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
...@@ -177,5 +178,6 @@ class InterrogateModels: ...@@ -177,5 +178,6 @@ class InterrogateModels:
res += "<error>" res += "<error>"
self.unload() self.unload()
shared.state.end()
return res return res
...@@ -71,10 +71,13 @@ class MemUsageMonitor(threading.Thread): ...@@ -71,10 +71,13 @@ class MemUsageMonitor(threading.Thread):
def read(self): def read(self):
if not self.disabled: if not self.disabled:
free, total = torch.cuda.mem_get_info() free, total = torch.cuda.mem_get_info()
self.data["free"] = free
self.data["total"] = total self.data["total"] = total
torch_stats = torch.cuda.memory_stats(self.device) torch_stats = torch.cuda.memory_stats(self.device)
self.data["active"] = torch_stats["active.all.current"]
self.data["active_peak"] = torch_stats["active_bytes.all.peak"] self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
self.data["reserved"] = torch_stats["reserved_bytes.all.current"]
self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"] self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
self.data["system_peak"] = total - self.data["min_free"] self.data["system_peak"] = total - self.data["min_free"]
......
...@@ -123,6 +123,23 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None): ...@@ -123,6 +123,23 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
pass pass
builtin_upscaler_classes = []
forbidden_upscaler_classes = set()
def list_builtin_upscalers():
load_upscalers()
builtin_upscaler_classes.clear()
builtin_upscaler_classes.extend(Upscaler.__subclasses__())
def forbid_loaded_nonbuiltin_upscalers():
for cls in Upscaler.__subclasses__():
if cls not in builtin_upscaler_classes:
forbidden_upscaler_classes.add(cls)
def load_upscalers(): def load_upscalers():
# We can only do this 'magic' method to dynamically load upscalers if they are referenced, # We can only do this 'magic' method to dynamically load upscalers if they are referenced,
# so we'll try to import any _model.py files before looking in __subclasses__ # so we'll try to import any _model.py files before looking in __subclasses__
...@@ -139,6 +156,9 @@ def load_upscalers(): ...@@ -139,6 +156,9 @@ def load_upscalers():
datas = [] datas = []
commandline_options = vars(shared.cmd_opts) commandline_options = vars(shared.cmd_opts)
for cls in Upscaler.__subclasses__(): for cls in Upscaler.__subclasses__():
if cls in forbidden_upscaler_classes:
continue
name = cls.__name__ name = cls.__name__
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path" cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
scaler = cls(commandline_options.get(cmd_name, None)) scaler = cls(commandline_options.get(cmd_name, None))
......
This diff is collapsed.
...@@ -51,6 +51,13 @@ class UiTrainTabParams: ...@@ -51,6 +51,13 @@ class UiTrainTabParams:
self.txt2img_preview_params = txt2img_preview_params self.txt2img_preview_params = txt2img_preview_params
class ImageGridLoopParams:
def __init__(self, imgs, cols, rows):
self.imgs = imgs
self.cols = cols
self.rows = rows
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
callback_map = dict( callback_map = dict(
callbacks_app_started=[], callbacks_app_started=[],
...@@ -63,6 +70,7 @@ callback_map = dict( ...@@ -63,6 +70,7 @@ callback_map = dict(
callbacks_cfg_denoiser=[], callbacks_cfg_denoiser=[],
callbacks_before_component=[], callbacks_before_component=[],
callbacks_after_component=[], callbacks_after_component=[],
callbacks_image_grid=[],
) )
...@@ -155,6 +163,14 @@ def after_component_callback(component, **kwargs): ...@@ -155,6 +163,14 @@ def after_component_callback(component, **kwargs):
report_exception(c, 'after_component_callback') report_exception(c, 'after_component_callback')
def image_grid_callback(params: ImageGridLoopParams):
for c in callback_map['callbacks_image_grid']:
try:
c.callback(params)
except Exception:
report_exception(c, 'image_grid')
def add_callback(callbacks, fun): def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__] stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file' filename = stack[0].filename if len(stack) > 0 else 'unknown file'
...@@ -255,3 +271,11 @@ def on_before_component(callback): ...@@ -255,3 +271,11 @@ def on_before_component(callback):
def on_after_component(callback): def on_after_component(callback):
"""register a function to be called after a component is created. See on_before_component for more.""" """register a function to be called after a component is created. See on_before_component for more."""
add_callback(callback_map['callbacks_after_component'], callback) add_callback(callback_map['callbacks_after_component'], callback)
def on_image_grid(callback):
"""register a function to be called before making an image grid.
The callback is called with one argument:
- params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
"""
add_callback(callback_map['callbacks_image_grid'], callback)
...@@ -5,7 +5,7 @@ import modules.textual_inversion.textual_inversion ...@@ -5,7 +5,7 @@ import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
from modules.sd_hijack_optimizations import invokeAI_mps_available from modules.sd_hijack_optimizations import invokeAI_mps_available
...@@ -35,26 +35,35 @@ def apply_optimizations(): ...@@ -35,26 +35,35 @@ def apply_optimizations():
ldm.modules.diffusionmodules.model.nonlinearity = silu ldm.modules.diffusionmodules.model.nonlinearity = silu
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
optimization_method = None
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
print("Applying xformers cross attention optimization.") print("Applying xformers cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
optimization_method = 'xformers'
elif cmd_opts.opt_split_attention_v1: elif cmd_opts.opt_split_attention_v1:
print("Applying v1 cross attention optimization.") print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
optimization_method = 'V1'
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()): elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
if not invokeAI_mps_available and shared.device.type == 'mps': if not invokeAI_mps_available and shared.device.type == 'mps':
print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.") print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
print("Applying v1 cross attention optimization.") print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
optimization_method = 'V1'
else: else:
print("Applying cross attention optimization (InvokeAI).") print("Applying cross attention optimization (InvokeAI).")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
optimization_method = 'InvokeAI'
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
print("Applying cross attention optimization (Doggettx).") print("Applying cross attention optimization (Doggettx).")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
optimization_method = 'Doggettx'
return optimization_method
def undo_optimizations(): def undo_optimizations():
...@@ -68,27 +77,37 @@ def fix_checkpoint(): ...@@ -68,27 +77,37 @@ def fix_checkpoint():
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
class StableDiffusionModelHijack: class StableDiffusionModelHijack:
fixes = None fixes = None
comments = [] comments = []
layers = None layers = None
circular_enabled = False circular_enabled = False
clip = None clip = None
optimization_method = None
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir) embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
def hijack(self, m): def hijack(self, m):
if type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
model_embeddings = m.cond_stage_model.roberta.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder: elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self) m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
self.clip = m.cond_stage_model self.optimization_method = apply_optimizations()
apply_optimizations() self.clip = m.cond_stage_model
fix_checkpoint() fix_checkpoint()
def flatten(el): def flatten(el):
...@@ -101,7 +120,11 @@ class StableDiffusionModelHijack: ...@@ -101,7 +120,11 @@ class StableDiffusionModelHijack:
self.layers = flatten(m) self.layers = flatten(m)
def undo_hijack(self, m): def undo_hijack(self, m):
if type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
m.cond_stage_model = m.cond_stage_model.wrapped
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped m.cond_stage_model = m.cond_stage_model.wrapped
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
...@@ -129,8 +152,8 @@ class StableDiffusionModelHijack: ...@@ -129,8 +152,8 @@ class StableDiffusionModelHijack:
def tokenize(self, text): def tokenize(self, text):
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
class EmbeddingsWithFixes(torch.nn.Module): class EmbeddingsWithFixes(torch.nn.Module):
......
...@@ -5,7 +5,6 @@ import torch ...@@ -5,7 +5,6 @@ import torch
from modules import prompt_parser, devices from modules import prompt_parser, devices
from modules.shared import opts from modules.shared import opts
def get_target_prompt_token_count(token_count): def get_target_prompt_token_count(token_count):
return math.ceil(max(token_count, 1) / 75) * 75 return math.ceil(max(token_count, 1) / 75) * 75
...@@ -254,10 +253,13 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase): ...@@ -254,10 +253,13 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
def __init__(self, wrapped, hijack): def __init__(self, wrapped, hijack):
super().__init__(wrapped, hijack) super().__init__(wrapped, hijack)
self.tokenizer = wrapped.tokenizer self.tokenizer = wrapped.tokenizer
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
vocab = self.tokenizer.get_vocab()
self.comma_token = vocab.get(',</w>', None)
self.token_mults = {} self.token_mults = {}
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens: for text, ident in tokens_with_parens:
mult = 1.0 mult = 1.0
for c in text: for c in text:
...@@ -296,6 +298,6 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase): ...@@ -296,6 +298,6 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
def encode_embedding_init_text(self, init_text, nvpt): def encode_embedding_init_text(self, init_text, nvpt):
embedding_layer = self.wrapped.transformer.text_model.embeddings embedding_layer = self.wrapped.transformer.text_model.embeddings
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"] ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0) embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
return embedded return embedded
This diff is collapsed.
import open_clip.tokenizer
import torch
from modules import sd_hijack_clip, devices
from modules.shared import opts
class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
def __init__(self, wrapped, hijack):
super().__init__(wrapped, hijack)
self.id_start = wrapped.config.bos_token_id
self.id_end = wrapped.config.eos_token_id
self.id_pad = wrapped.config.pad_token_id
self.comma_token = self.tokenizer.get_vocab().get(',', None) # alt diffusion doesn't have </w> bits for comma
def encode_with_transformers(self, tokens):
# there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a
# trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer
# layer to work with - you have to use the last
attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64)
features = self.wrapped(input_ids=tokens, attention_mask=attention_mask)
z = features['projection_state']
return z
def encode_embedding_init_text(self, init_text, nvpt):
embedding_layer = self.wrapped.roberta.embeddings
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
return embedded
...@@ -20,7 +20,7 @@ from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inp ...@@ -20,7 +20,7 @@ from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inp
model_dir = "Stable-diffusion" model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir)) model_path = os.path.abspath(os.path.join(models_path, model_dir))
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config']) CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
checkpoints_list = {} checkpoints_list = {}
checkpoints_loaded = collections.OrderedDict() checkpoints_loaded = collections.OrderedDict()
...@@ -48,6 +48,14 @@ def checkpoint_tiles(): ...@@ -48,6 +48,14 @@ def checkpoint_tiles():
return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key) return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)
def find_checkpoint_config(info):
config = os.path.splitext(info.filename)[0] + ".yaml"
if os.path.exists(config):
return config
return shared.cmd_opts.config
def list_models(): def list_models():
checkpoints_list.clear() checkpoints_list.clear()
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"]) model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"])
...@@ -73,7 +81,7 @@ def list_models(): ...@@ -73,7 +81,7 @@ def list_models():
if os.path.exists(cmd_ckpt): if os.path.exists(cmd_ckpt):
h = model_hash(cmd_ckpt) h = model_hash(cmd_ckpt)
title, short_model_name = modeltitle(cmd_ckpt, h) title, short_model_name = modeltitle(cmd_ckpt, h)
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config) checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
shared.opts.data['sd_model_checkpoint'] = title shared.opts.data['sd_model_checkpoint'] = title
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
...@@ -81,12 +89,7 @@ def list_models(): ...@@ -81,12 +89,7 @@ def list_models():
h = model_hash(filename) h = model_hash(filename)
title, short_model_name = modeltitle(filename, h) title, short_model_name = modeltitle(filename, h)
basename, _ = os.path.splitext(filename) checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
config = basename + ".yaml"
if not os.path.exists(config):
config = shared.cmd_opts.config
checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config)
def get_closet_checkpoint_match(searchString): def get_closet_checkpoint_match(searchString):
...@@ -168,7 +171,10 @@ def get_state_dict_from_checkpoint(pl_sd): ...@@ -168,7 +171,10 @@ def get_state_dict_from_checkpoint(pl_sd):
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None): def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
_, extension = os.path.splitext(checkpoint_file) _, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors": if extension.lower() == ".safetensors":
pl_sd = safetensors.torch.load_file(checkpoint_file, device=map_location or shared.weight_load_location) device = map_location or shared.weight_load_location
if device is None:
device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu"
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
else: else:
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location) pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
...@@ -228,6 +234,8 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): ...@@ -228,6 +234,8 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
model.sd_model_checkpoint = checkpoint_file model.sd_model_checkpoint = checkpoint_file
model.sd_checkpoint_info = checkpoint_info model.sd_checkpoint_info = checkpoint_info
model.logvar = model.logvar.to(devices.device) # fix for training
sd_vae.delete_base_vae() sd_vae.delete_base_vae()
sd_vae.clear_loaded_vae() sd_vae.clear_loaded_vae()
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
...@@ -276,12 +284,14 @@ def enable_midas_autodownload(): ...@@ -276,12 +284,14 @@ def enable_midas_autodownload():
midas.api.load_model = load_model_wrapper midas.api.load_model = load_model_wrapper
def load_model(checkpoint_info=None): def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint() checkpoint_info = checkpoint_info or select_checkpoint()
checkpoint_config = find_checkpoint_config(checkpoint_info)
if checkpoint_info.config != shared.cmd_opts.config: if checkpoint_config != shared.cmd_opts.config:
print(f"Loading config from: {checkpoint_info.config}") print(f"Loading config from: {checkpoint_config}")
if shared.sd_model: if shared.sd_model:
sd_hijack.model_hijack.undo_hijack(shared.sd_model) sd_hijack.model_hijack.undo_hijack(shared.sd_model)
...@@ -289,7 +299,7 @@ def load_model(checkpoint_info=None): ...@@ -289,7 +299,7 @@ def load_model(checkpoint_info=None):
gc.collect() gc.collect()
devices.torch_gc() devices.torch_gc()
sd_config = OmegaConf.load(checkpoint_info.config) sd_config = OmegaConf.load(checkpoint_config)
if should_hijack_inpainting(checkpoint_info): if should_hijack_inpainting(checkpoint_info):
# Hardcoded config for now... # Hardcoded config for now...
...@@ -298,9 +308,6 @@ def load_model(checkpoint_info=None): ...@@ -298,9 +308,6 @@ def load_model(checkpoint_info=None):
sd_config.model.params.unet_config.params.in_channels = 9 sd_config.model.params.unet_config.params.in_channels = 9
sd_config.model.params.finetune_keys = None sd_config.model.params.finetune_keys = None
# Create a "fake" config with a different name so that we know to unload it when switching models.
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
if not hasattr(sd_config.model.params, "use_ema"): if not hasattr(sd_config.model.params, "use_ema"):
sd_config.model.params.use_ema = False sd_config.model.params.use_ema = False
...@@ -310,6 +317,7 @@ def load_model(checkpoint_info=None): ...@@ -310,6 +317,7 @@ def load_model(checkpoint_info=None):
sd_config.model.params.unet_config.params.use_fp16 = False sd_config.model.params.unet_config.params.use_fp16 = False
sd_model = instantiate_from_config(sd_config.model) sd_model = instantiate_from_config(sd_config.model)
load_model_weights(sd_model, checkpoint_info) load_model_weights(sd_model, checkpoint_info)
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
...@@ -322,23 +330,29 @@ def load_model(checkpoint_info=None): ...@@ -322,23 +330,29 @@ def load_model(checkpoint_info=None):
sd_model.eval() sd_model.eval()
shared.sd_model = sd_model shared.sd_model = sd_model
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
script_callbacks.model_loaded_callback(sd_model) script_callbacks.model_loaded_callback(sd_model)
print("Model loaded.") print("Model loaded.")
return sd_model return sd_model
def reload_model_weights(sd_model=None, info=None): def reload_model_weights(sd_model=None, info=None):
from modules import lowvram, devices, sd_hijack from modules import lowvram, devices, sd_hijack
checkpoint_info = info or select_checkpoint() checkpoint_info = info or select_checkpoint()
if not sd_model: if not sd_model:
sd_model = shared.sd_model sd_model = shared.sd_model
current_checkpoint_info = sd_model.sd_checkpoint_info
checkpoint_config = find_checkpoint_config(current_checkpoint_info)
if sd_model.sd_model_checkpoint == checkpoint_info.filename: if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return return
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): if checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
del sd_model del sd_model
checkpoints_loaded.clear() checkpoints_loaded.clear()
load_model(checkpoint_info) load_model(checkpoint_info)
...@@ -351,13 +365,19 @@ def reload_model_weights(sd_model=None, info=None): ...@@ -351,13 +365,19 @@ def reload_model_weights(sd_model=None, info=None):
sd_hijack.model_hijack.undo_hijack(sd_model) sd_hijack.model_hijack.undo_hijack(sd_model)
load_model_weights(sd_model, checkpoint_info) try:
load_model_weights(sd_model, checkpoint_info)
sd_hijack.model_hijack.hijack(sd_model) except Exception as e:
script_callbacks.model_loaded_callback(sd_model) print("Failed to load checkpoint, restoring previous")
load_model_weights(sd_model, current_checkpoint_info)
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: raise
sd_model.to(devices.device) finally:
sd_hijack.model_hijack.hijack(sd_model)
script_callbacks.model_loaded_callback(sd_model)
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device)
print("Weights loaded.") print("Weights loaded.")
return sd_model return sd_model
...@@ -97,8 +97,9 @@ sampler_extra_params = { ...@@ -97,8 +97,9 @@ sampler_extra_params = {
def setup_img2img_steps(p, steps=None): def setup_img2img_steps(p, steps=None):
if opts.img2img_fix_steps or steps is not None: if opts.img2img_fix_steps or steps is not None:
steps = int((steps or p.steps) / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0 requested_steps = (steps or p.steps)
t_enc = p.steps - 1 steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
t_enc = requested_steps - 1
else: else:
steps = p.steps steps = p.steps
t_enc = int(min(p.denoising_strength, 0.999) * steps) t_enc = int(min(p.denoising_strength, 0.999) * steps)
...@@ -465,7 +466,9 @@ class KDiffusionSampler: ...@@ -465,7 +466,9 @@ class KDiffusionSampler:
if p.sampler_noise_scheduler_override: if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps) sigmas = p.sampler_noise_scheduler_override(steps)
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras': elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device) sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
else: else:
sigmas = self.model_wrap.get_sigmas(steps) sigmas = self.model_wrap.get_sigmas(steps)
......
import torch import torch
import os import os
import collections
from collections import namedtuple from collections import namedtuple
from modules import shared, devices, script_callbacks from modules import shared, devices, script_callbacks
from modules.paths import models_path from modules.paths import models_path
...@@ -30,6 +31,7 @@ base_vae = None ...@@ -30,6 +31,7 @@ base_vae = None
loaded_vae_file = None loaded_vae_file = None
checkpoint_info = None checkpoint_info = None
checkpoints_loaded = collections.OrderedDict()
def get_base_vae(model): def get_base_vae(model):
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model: if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
...@@ -149,13 +151,30 @@ def load_vae(model, vae_file=None): ...@@ -149,13 +151,30 @@ def load_vae(model, vae_file=None):
global first_load, vae_dict, vae_list, loaded_vae_file global first_load, vae_dict, vae_list, loaded_vae_file
# save_settings = False # save_settings = False
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
if vae_file: if vae_file:
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}" if cache_enabled and vae_file in checkpoints_loaded:
print(f"Loading VAE weights from: {vae_file}") # use vae checkpoint cache
store_base_vae(model) print(f"Loading VAE weights [{get_filename(vae_file)}] from cache")
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) store_base_vae(model)
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} _load_vae_dict(model, checkpoints_loaded[vae_file])
_load_vae_dict(model, vae_dict_1) else:
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
print(f"Loading VAE weights from: {vae_file}")
store_base_vae(model)
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
_load_vae_dict(model, vae_dict_1)
if cache_enabled:
# cache newly loaded vae
checkpoints_loaded[vae_file] = vae_dict_1.copy()
# clean up cache if limit is reached
if cache_enabled:
while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model
checkpoints_loaded.popitem(last=False) # LRU
# If vae used is not in dict, update it # If vae used is not in dict, update it
# It will be removed on refresh though # It will be removed on refresh though
......
This diff is collapsed.
...@@ -58,14 +58,19 @@ class LearnRateScheduler: ...@@ -58,14 +58,19 @@ class LearnRateScheduler:
self.finished = False self.finished = False
def apply(self, optimizer, step_number): def step(self, step_number):
if step_number < self.end_step: if step_number < self.end_step:
return return False
try: try:
(self.learn_rate, self.end_step) = next(self.schedules) (self.learn_rate, self.end_step) = next(self.schedules)
except Exception: except StopIteration:
self.finished = True self.finished = True
return False
return True
def apply(self, optimizer, step_number):
if not self.step(step_number):
return return
if self.verbose: if self.verbose:
......
...@@ -124,6 +124,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre ...@@ -124,6 +124,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
files = listfiles(src) files = listfiles(src)
shared.state.job = "preprocess"
shared.state.textinfo = "Preprocessing..." shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files) shared.state.job_count = len(files)
......
...@@ -8,7 +8,7 @@ import modules.processing as processing ...@@ -8,7 +8,7 @@ import modules.processing as processing
from modules.ui import plaintext_to_html from modules.ui import plaintext_to_html
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args): def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args):
p = StableDiffusionProcessingTxt2Img( p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model, sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
...@@ -33,8 +33,11 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: ...@@ -33,8 +33,11 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
tiling=tiling, tiling=tiling,
enable_hr=enable_hr, enable_hr=enable_hr,
denoising_strength=denoising_strength if enable_hr else None, denoising_strength=denoising_strength if enable_hr else None,
firstphase_width=firstphase_width if enable_hr else None, hr_scale=hr_scale,
firstphase_height=firstphase_height if enable_hr else None, hr_upscaler=hr_upscaler,
hr_second_pass_steps=hr_second_pass_steps,
hr_resize_x=hr_resize_x,
hr_resize_y=hr_resize_y,
) )
p.scripts = modules.scripts.scripts_txt2img p.scripts = modules.scripts.scripts_txt2img
...@@ -59,4 +62,4 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: ...@@ -59,4 +62,4 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
if opts.do_not_show_images: if opts.do_not_show_images:
processed.images = [] processed.images = []
return processed.images, generation_info_js, plaintext_to_html(processed.info) return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
This diff is collapsed.
import gradio as gr
class ToolButton(gr.Button, gr.components.FormComponent):
"""Small button with single emoji as text, fits inside gradio forms"""
def __init__(self, **kwargs):
super().__init__(variant="tool", **kwargs)
def get_block_name(self):
return "button"
class FormRow(gr.Row, gr.components.FormComponent):
"""Same as gr.Row but fits inside gradio forms"""
def get_block_name(self):
return "row"
class FormGroup(gr.Group, gr.components.FormComponent):
"""Same as gr.Row but fits inside gradio forms"""
def get_block_name(self):
return "group"
import os import os
import tempfile import tempfile
from collections import namedtuple from collections import namedtuple
from pathlib import Path
import gradio as gr import gradio as gr
...@@ -12,10 +13,29 @@ from modules import shared ...@@ -12,10 +13,29 @@ from modules import shared
Savedfile = namedtuple("Savedfile", ["name"]) Savedfile = namedtuple("Savedfile", ["name"])
def register_tmp_file(gradio, filename):
if hasattr(gradio, 'temp_file_sets'): # gradio 3.15
gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)}
if hasattr(gradio, 'temp_dirs'): # gradio 3.9
gradio.temp_dirs = gradio.temp_dirs | {os.path.abspath(os.path.dirname(filename))}
def check_tmp_file(gradio, filename):
if hasattr(gradio, 'temp_file_sets'):
return any([filename in fileset for fileset in gradio.temp_file_sets])
if hasattr(gradio, 'temp_dirs'):
return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs)
return False
def save_pil_to_file(pil_image, dir=None): def save_pil_to_file(pil_image, dir=None):
already_saved_as = getattr(pil_image, 'already_saved_as', None) already_saved_as = getattr(pil_image, 'already_saved_as', None)
if already_saved_as and os.path.isfile(already_saved_as): if already_saved_as and os.path.isfile(already_saved_as):
shared.demo.temp_dirs = shared.demo.temp_dirs | {os.path.abspath(os.path.dirname(already_saved_as))} register_tmp_file(shared.demo, already_saved_as)
file_obj = Savedfile(already_saved_as) file_obj = Savedfile(already_saved_as)
return file_obj return file_obj
...@@ -44,7 +64,7 @@ def on_tmpdir_changed(): ...@@ -44,7 +64,7 @@ def on_tmpdir_changed():
os.makedirs(shared.opts.temp_dir, exist_ok=True) os.makedirs(shared.opts.temp_dir, exist_ok=True)
shared.demo.temp_dirs = shared.demo.temp_dirs | {os.path.abspath(shared.opts.temp_dir)} register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x"))
def cleanup_tmpdr(): def cleanup_tmpdr():
......
...@@ -53,10 +53,10 @@ class Upscaler: ...@@ -53,10 +53,10 @@ class Upscaler:
def do_upscale(self, img: PIL.Image, selected_model: str): def do_upscale(self, img: PIL.Image, selected_model: str):
return img return img
def upscale(self, img: PIL.Image, scale: int, selected_model: str = None): def upscale(self, img: PIL.Image, scale, selected_model: str = None):
self.scale = scale self.scale = scale
dest_w = img.width * scale dest_w = int(img.width * scale)
dest_h = img.height * scale dest_h = int(img.height * scale)
for i in range(3): for i in range(3):
shape = (img.width, img.height) shape = (img.width, img.height)
......
from transformers import BertPreTrainedModel,BertModel,BertConfig
import torch.nn as nn
import torch
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
from transformers import XLMRobertaModel,XLMRobertaTokenizer
from typing import Optional
class BertSeriesConfig(BertConfig):
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
self.project_dim = project_dim
self.pooler_fn = pooler_fn
self.learn_encoder = learn_encoder
class RobertaSeriesConfig(XLMRobertaConfig):
def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.project_dim = project_dim
self.pooler_fn = pooler_fn
self.learn_encoder = learn_encoder
class BertSeriesModelWithTransformation(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
config_class = BertSeriesConfig
def __init__(self, config=None, **kargs):
# modify initialization for autoloading
if config is None:
config = XLMRobertaConfig()
config.attention_probs_dropout_prob= 0.1
config.bos_token_id=0
config.eos_token_id=2
config.hidden_act='gelu'
config.hidden_dropout_prob=0.1
config.hidden_size=1024
config.initializer_range=0.02
config.intermediate_size=4096
config.layer_norm_eps=1e-05
config.max_position_embeddings=514
config.num_attention_heads=16
config.num_hidden_layers=24
config.output_past=True
config.pad_token_id=1
config.position_embedding_type= "absolute"
config.type_vocab_size= 1
config.use_cache=True
config.vocab_size= 250002
config.project_dim = 768
config.learn_encoder = False
super().__init__(config)
self.roberta = XLMRobertaModel(config)
self.transformation = nn.Linear(config.hidden_size,config.project_dim)
self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
self.pooler = lambda x: x[:,0]
self.post_init()
def encode(self,c):
device = next(self.parameters()).device
text = self.tokenizer(c,
truncation=True,
max_length=77,
return_length=False,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt")
text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
text["attention_mask"] = torch.tensor(
text['attention_mask']).to(device)
features = self(**text)
return features['projection_state']
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) :
r"""
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.roberta(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=True,
return_dict=return_dict,
)
# last module outputs
sequence_output = outputs[0]
# project every module
sequence_output_ln = self.pre_LN(sequence_output)
# pooler
pooler_output = self.pooler(sequence_output_ln)
pooler_output = self.transformation(pooler_output)
projection_state = self.transformation(outputs.last_hidden_state)
return {
'pooler_output':pooler_output,
'last_hidden_state':outputs.last_hidden_state,
'hidden_states':outputs.hidden_states,
'attentions':outputs.attentions,
'projection_state':projection_state,
'sequence_out': sequence_output
}
class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
base_model_prefix = 'roberta'
config_class= RobertaSeriesConfig
\ No newline at end of file
...@@ -3,9 +3,9 @@ transformers==4.19.2 ...@@ -3,9 +3,9 @@ transformers==4.19.2
accelerate==0.12.0 accelerate==0.12.0
basicsr==1.4.2 basicsr==1.4.2
gfpgan==1.3.8 gfpgan==1.3.8
gradio==3.9 gradio==3.15.0
numpy==1.23.3 numpy==1.23.3
Pillow==9.2.0 Pillow==9.4.0
realesrgan==0.3.0 realesrgan==0.3.0
torch torch
omegaconf==2.2.3 omegaconf==2.2.3
...@@ -26,5 +26,5 @@ lark==1.1.2 ...@@ -26,5 +26,5 @@ lark==1.1.2
inflection==0.5.1 inflection==0.5.1
GitPython==3.1.27 GitPython==3.1.27
torchsde==0.2.5 torchsde==0.2.5
safetensors==0.2.5 safetensors==0.2.7
httpcore<=0.15 httpcore<=0.15
...@@ -4,7 +4,7 @@ function gradioApp() { ...@@ -4,7 +4,7 @@ function gradioApp() {
} }
function get_uiCurrentTab() { function get_uiCurrentTab() {
return gradioApp().querySelector('.tabs button:not(.border-transparent)') return gradioApp().querySelector('#tabs button:not(.border-transparent)')
} }
function get_uiCurrentTabContent() { function get_uiCurrentTabContent() {
......
...@@ -19,7 +19,7 @@ class Script(scripts.Script): ...@@ -19,7 +19,7 @@ class Script(scripts.Script):
def ui(self, is_img2img): def ui(self, is_img2img):
info = gr.HTML("<p style=\"margin-bottom:0.75em\">Will upscale the image by the selected scale factor; use width and height sliders to set tile size</p>") info = gr.HTML("<p style=\"margin-bottom:0.75em\">Will upscale the image by the selected scale factor; use width and height sliders to set tile size</p>")
overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64) overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64)
scale_factor = gr.Slider(minimum=1, maximum=4, step=1, label='Scale Factor', value=2) scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0)
upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
return [info, overlap, upscaler_index, scale_factor] return [info, overlap, upscaler_index, scale_factor]
......
...@@ -10,7 +10,7 @@ import numpy as np ...@@ -10,7 +10,7 @@ import numpy as np
import modules.scripts as scripts import modules.scripts as scripts
import gradio as gr import gradio as gr
from modules import images, paths, sd_samplers from modules import images, paths, sd_samplers, processing
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
...@@ -202,7 +202,7 @@ axis_options = [ ...@@ -202,7 +202,7 @@ axis_options = [
AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None), AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None),
AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None), AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None),
AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None), AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None),
AxisOption("Upscale latent space for hires.", str, apply_upscale_latent_space, format_value_add_label, None), AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), format_value_add_label, None),
AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None), AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None),
AxisOption("VAE", str, apply_vae, format_value_add_label, None), AxisOption("VAE", str, apply_vae, format_value_add_label, None),
AxisOption("Styles", str, apply_styles, format_value_add_label, None), AxisOption("Styles", str, apply_styles, format_value_add_label, None),
...@@ -267,7 +267,6 @@ class SharedSettingsStackHelper(object): ...@@ -267,7 +267,6 @@ class SharedSettingsStackHelper(object):
self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
self.hypernetwork = opts.sd_hypernetwork self.hypernetwork = opts.sd_hypernetwork
self.model = shared.sd_model self.model = shared.sd_model
self.use_scale_latent_for_hires_fix = opts.use_scale_latent_for_hires_fix
self.vae = opts.sd_vae self.vae = opts.sd_vae
def __exit__(self, exc_type, exc_value, tb): def __exit__(self, exc_type, exc_value, tb):
...@@ -278,7 +277,6 @@ class SharedSettingsStackHelper(object): ...@@ -278,7 +277,6 @@ class SharedSettingsStackHelper(object):
hypernetwork.apply_strength() hypernetwork.apply_strength()
opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers
opts.data["use_scale_latent_for_hires_fix"] = self.use_scale_latent_for_hires_fix
re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
...@@ -287,6 +285,7 @@ re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d ...@@ -287,6 +285,7 @@ re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d
re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*") re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*")
re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*") re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*")
class Script(scripts.Script): class Script(scripts.Script):
def title(self): def title(self):
return "X/Y plot" return "X/Y plot"
...@@ -383,7 +382,7 @@ class Script(scripts.Script): ...@@ -383,7 +382,7 @@ class Script(scripts.Script):
ys = process_axis(y_opt, y_values) ys = process_axis(y_opt, y_values)
def fix_axis_seeds(axis_opt, axis_list): def fix_axis_seeds(axis_opt, axis_list):
if axis_opt.label in ['Seed','Var. seed']: if axis_opt.label in ['Seed', 'Var. seed']:
return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list] return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list]
else: else:
return axis_list return axis_list
...@@ -405,12 +404,33 @@ class Script(scripts.Script): ...@@ -405,12 +404,33 @@ class Script(scripts.Script):
print(f"X/Y plot will create {len(xs) * len(ys) * p.n_iter} images on a {len(xs)}x{len(ys)} grid. (Total steps to process: {total_steps * p.n_iter})") print(f"X/Y plot will create {len(xs) * len(ys) * p.n_iter} images on a {len(xs)}x{len(ys)} grid. (Total steps to process: {total_steps * p.n_iter})")
shared.total_tqdm.updateTotal(total_steps * p.n_iter) shared.total_tqdm.updateTotal(total_steps * p.n_iter)
grid_infotext = [None]
def cell(x, y): def cell(x, y):
pc = copy(p) pc = copy(p)
x_opt.apply(pc, x, xs) x_opt.apply(pc, x, xs)
y_opt.apply(pc, y, ys) y_opt.apply(pc, y, ys)
return process_images(pc) res = process_images(pc)
if grid_infotext[0] is None:
pc.extra_generation_params = copy(pc.extra_generation_params)
if x_opt.label != 'Nothing':
pc.extra_generation_params["X Type"] = x_opt.label
pc.extra_generation_params["X Values"] = x_values
if x_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
pc.extra_generation_params["Fixed X Values"] = ", ".join([str(x) for x in xs])
if y_opt.label != 'Nothing':
pc.extra_generation_params["Y Type"] = y_opt.label
pc.extra_generation_params["Y Values"] = y_values
if y_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
pc.extra_generation_params["Fixed Y Values"] = ", ".join([str(y) for y in ys])
grid_infotext[0] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds)
return res
with SharedSettingsStackHelper(): with SharedSettingsStackHelper():
processed = draw_xy_grid( processed = draw_xy_grid(
...@@ -425,6 +445,6 @@ class Script(scripts.Script): ...@@ -425,6 +445,6 @@ class Script(scripts.Script):
) )
if opts.grid_save: if opts.grid_save:
images.save_image(processed.images[0], p.outpath_grids, "xy_grid", extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p) images.save_image(processed.images[0], p.outpath_grids, "xy_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
return processed return processed
...@@ -73,8 +73,9 @@ ...@@ -73,8 +73,9 @@
margin-right: auto; margin-right: auto;
} }
#random_seed, #random_subseed, #reuse_seed, #reuse_subseed, #open_folder{ [id$=_random_seed], [id$=_random_subseed], [id$=_reuse_seed], [id$=_reuse_subseed], #open_folder{
min-width: auto; min-width: 2.3em;
height: 2.5em;
flex-grow: 0; flex-grow: 0;
padding-left: 0.25em; padding-left: 0.25em;
padding-right: 0.25em; padding-right: 0.25em;
...@@ -84,27 +85,28 @@ ...@@ -84,27 +85,28 @@
display: none; display: none;
} }
#seed_row, #subseed_row{ [id$=_seed_row], [id$=_subseed_row]{
gap: 0.5rem; gap: 0.5rem;
padding: 0.6em;
} }
#subseed_show_box{ [id$=_subseed_show_box]{
min-width: auto; min-width: auto;
flex-grow: 0; flex-grow: 0;
} }
#subseed_show_box > div{ [id$=_subseed_show_box] > div{
border: 0; border: 0;
height: 100%; height: 100%;
} }
#subseed_show{ [id$=_subseed_show]{
min-width: auto; min-width: auto;
flex-grow: 0; flex-grow: 0;
padding: 0; padding: 0;
} }
#subseed_show label{ [id$=_subseed_show] label{
height: 100%; height: 100%;
} }
...@@ -206,24 +208,24 @@ button{ ...@@ -206,24 +208,24 @@ button{
fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block span{ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block span{
position: absolute; position: absolute;
top: -0.6em; top: -0.7em;
line-height: 1.2em; line-height: 1.2em;
padding: 0 0.5em; padding: 0;
margin: 0; margin: 0 0.5em;
background-color: white; background-color: white;
border-top: 1px solid #eee; box-shadow: 6px 0 6px 0px white, -6px 0 6px 0px white;
border-left: 1px solid #eee;
border-right: 1px solid #eee;
z-index: 300; z-index: 300;
} }
.dark fieldset span.text-gray-500, .dark .gr-block.gr-box span.text-gray-500, .dark label.block span{ .dark fieldset span.text-gray-500, .dark .gr-block.gr-box span.text-gray-500, .dark label.block span{
background-color: rgb(31, 41, 55); background-color: rgb(31, 41, 55);
border-top: 1px solid rgb(55 65 81); box-shadow: 6px 0 6px 0px rgb(31, 41, 55), -6px 0 6px 0px rgb(31, 41, 55);
border-left: 1px solid rgb(55 65 81); }
border-right: 1px solid rgb(55 65 81);
#txt2img_column_batch, #img2img_column_batch{
min-width: min(13.5em, 100%) !important;
} }
#settings fieldset span.text-gray-500, #settings .gr-block.gr-box span.text-gray-500, #settings label.block span{ #settings fieldset span.text-gray-500, #settings .gr-block.gr-box span.text-gray-500, #settings label.block span{
...@@ -232,22 +234,40 @@ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block s ...@@ -232,22 +234,40 @@ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block s
margin-right: 8em; margin-right: 8em;
} }
.gr-panel div.flex-col div.justify-between label span{
margin: 0;
}
#settings .gr-panel div.flex-col div.justify-between div{ #settings .gr-panel div.flex-col div.justify-between div{
position: relative; position: relative;
z-index: 200; z-index: 200;
} }
input[type="range"]{ #settings{
margin: 0.5em 0 -0.3em 0; display: block;
} }
#txt2img_sampling label{ #settings > div{
padding-left: 0.6em; border: none;
padding-right: 0.6em; margin-left: 10em;
}
#settings > div.flex-wrap{
float: left;
display: block;
margin-left: 0;
width: 10em;
}
#settings > div.flex-wrap button{
display: block;
border: none;
text-align: left;
}
#settings_result{
height: 1.4em;
margin: 0 1.2em;
}
input[type="range"]{
margin: 0.5em 0 -0.3em 0;
} }
#mask_bug_info { #mask_bug_info {
...@@ -501,13 +521,6 @@ input[type="range"]{ ...@@ -501,13 +521,6 @@ input[type="range"]{
padding: 0; padding: 0;
} }
#refresh_sd_model_checkpoint, #refresh_sd_vae, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization{
max-width: 2.5em;
min-width: 2.5em;
height: 2.4em;
}
canvas[key="mask"] { canvas[key="mask"] {
z-index: 12 !important; z-index: 12 !important;
filter: invert(); filter: invert();
...@@ -521,7 +534,7 @@ canvas[key="mask"] { ...@@ -521,7 +534,7 @@ canvas[key="mask"] {
position: absolute; position: absolute;
right: 0.5em; right: 0.5em;
top: -0.6em; top: -0.6em;
z-index: 200; z-index: 400;
width: 8em; width: 8em;
} }
#quicksettings .gr-box > div > div > input.gr-text-input { #quicksettings .gr-box > div > div > input.gr-text-input {
...@@ -568,6 +581,53 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h ...@@ -568,6 +581,53 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h
font-size: 95%; font-size: 95%;
} }
#image_buttons_txt2img button, #image_buttons_img2img button, #image_buttons_extras button{
min-width: auto;
padding-left: 0.5em;
padding-right: 0.5em;
}
.gr-form{
background-color: white;
}
.dark .gr-form{
background-color: rgb(31 41 55 / var(--tw-bg-opacity));
}
.gr-button-tool{
max-width: 2.5em;
min-width: 2.5em !important;
height: 2.4em;
margin: 0.55em 0;
}
#quicksettings .gr-button-tool{
margin: 0;
}
#img2img_settings > div.gr-form, #txt2img_settings > div.gr-form {
padding-top: 0.9em;
}
#img2img_settings div.gr-form .gr-form, #txt2img_settings div.gr-form .gr-form, #train_tabs div.gr-form .gr-form{
border: none;
padding-bottom: 0.5em;
}
footer {
display: none !important;
}
#footer{
text-align: center;
}
#footer div{
display: inline-block;
}
/* The following handles localization for right-to-left (RTL) languages like Arabic. /* The following handles localization for right-to-left (RTL) languages like Arabic.
The rtl media type will only be activated by the logic in javascript/localization.js. The rtl media type will only be activated by the logic in javascript/localization.js.
If you change anything above, you need to make sure it is RTL compliant by just running If you change anything above, you need to make sure it is RTL compliant by just running
......
model:
base_learning_rate: 1.0e-4
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
parameterization: "v"
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False # we set this to false because this is an inference only config
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
use_fp16: True
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
#attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"
\ No newline at end of file
import os import os
import sys
import threading import threading
import time import time
import importlib import importlib
...@@ -8,7 +9,7 @@ from fastapi import FastAPI ...@@ -8,7 +9,7 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware from fastapi.middleware.gzip import GZipMiddleware
from modules import import_hook from modules import import_hook, errors
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
from modules.paths import script_path from modules.paths import script_path
...@@ -55,12 +56,20 @@ def initialize(): ...@@ -55,12 +56,20 @@ def initialize():
gfpgan.setup_model(cmd_opts.gfpgan_models_path) gfpgan.setup_model(cmd_opts.gfpgan_models_path)
shared.face_restorers.append(modules.face_restoration.FaceRestoration()) shared.face_restorers.append(modules.face_restoration.FaceRestoration())
modelloader.list_builtin_upscalers()
modules.scripts.load_scripts() modules.scripts.load_scripts()
modelloader.load_upscalers() modelloader.load_upscalers()
modules.sd_vae.refresh_vae_list() modules.sd_vae.refresh_vae_list()
modules.sd_models.load_model()
try:
modules.sd_models.load_model()
except Exception as e:
errors.display(e, "loading stable diffusion model")
print("", file=sys.stderr)
print("Stable diffusion model failed to load, exiting", file=sys.stderr)
exit(1)
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
...@@ -91,11 +100,11 @@ def initialize(): ...@@ -91,11 +100,11 @@ def initialize():
def setup_cors(app): def setup_cors(app):
if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex: if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex:
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*']) app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
elif cmd_opts.cors_allow_origins: elif cmd_opts.cors_allow_origins:
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*']) app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
elif cmd_opts.cors_allow_origins_regex: elif cmd_opts.cors_allow_origins_regex:
app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*']) app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
def create_api(app): def create_api(app):
...@@ -169,23 +178,22 @@ def webui(): ...@@ -169,23 +178,22 @@ def webui():
modules.script_callbacks.app_started_callback(shared.demo, app) modules.script_callbacks.app_started_callback(shared.demo, app)
wait_on_server(shared.demo) wait_on_server(shared.demo)
print('Restarting UI...')
sd_samplers.set_samplers() sd_samplers.set_samplers()
print('Reloading extensions')
extensions.list_extensions() extensions.list_extensions()
localization.list_localizations(cmd_opts.localizations_dir) localization.list_localizations(cmd_opts.localizations_dir)
print('Reloading custom scripts') modelloader.forbid_loaded_nonbuiltin_upscalers()
modules.scripts.reload_scripts() modules.scripts.reload_scripts()
modelloader.load_upscalers() modelloader.load_upscalers()
print('Reloading modules: modules.ui') for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
importlib.reload(modules.ui) importlib.reload(module)
print('Refreshing Model List')
modules.sd_models.list_models() modules.sd_models.list_models()
print('Restarting Gradio')
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment