Commit fd66f669 authored by Bruno Seoane's avatar Bruno Seoane
parents 31db25ec 89722fb5
......@@ -155,14 +155,15 @@ The documentation was moved from this README over to the project's [wiki](https:
- Swin2SR - https://github.com/mv-lab/swin2sr
- LDSR - https://github.com/Hafiidz/latent-diffusion
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion
- Doggettx - Cross Attention layer optimization - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
- InvokeAI, lstein - Cross Attention layer optimization - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion)
- Rinon Gal - Textual Inversion - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas).
- Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
- Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion)
- Textual Inversion - Rinon Gal - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas).
- Idea for SD upscale - https://github.com/jquesnelle/txt2imghd
- Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot
- CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator
- Idea for Composable Diffusion - https://github.com/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch
- xformers - https://github.com/facebookresearch/xformers
- DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru
- Security advice - RyotaK
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- (You)
addEventListener('keydown', (event) => {
let target = event.originalTarget || event.composedPath()[0];
if (!target.hasAttribute("placeholder")) return;
if (!target.placeholder.toLowerCase().includes("prompt")) return;
if (!target.matches("#toprow textarea.gr-text-input[placeholder]")) return;
if (! (event.metaKey || event.ctrlKey)) return;
......
......@@ -3,8 +3,21 @@ global_progressbars = {}
galleries = {}
galleryObservers = {}
// this tracks laumnches of window.setTimeout for progressbar to prevent starting a new timeout when the previous is still running
timeoutIds = {}
function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){
var progressbar = gradioApp().getElementById(id_progressbar)
// gradio 3.8's enlightened approach allows them to create two nested div elements inside each other with same id
// every time you use gr.HTML(elem_id='xxx'), so we handle this here
var progressbar = gradioApp().querySelector("#"+id_progressbar+" #"+id_progressbar)
var progressbarParent
if(progressbar){
progressbarParent = gradioApp().querySelector("#"+id_progressbar)
} else{
progressbar = gradioApp().getElementById(id_progressbar)
progressbarParent = null
}
var skip = id_skip ? gradioApp().getElementById(id_skip) : null
var interrupt = gradioApp().getElementById(id_interrupt)
......@@ -26,18 +39,26 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip
global_progressbars[id_progressbar] = progressbar
var mutationObserver = new MutationObserver(function(m){
if(timeoutIds[id_part]) return;
preview = gradioApp().getElementById(id_preview)
gallery = gradioApp().getElementById(id_gallery)
if(preview != null && gallery != null){
preview.style.width = gallery.clientWidth + "px"
preview.style.height = gallery.clientHeight + "px"
if(progressbarParent) progressbar.style.width = progressbarParent.clientWidth + "px"
//only watch gallery if there is a generation process going on
check_gallery(id_gallery);
var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
if(!progressDiv){
if(progressDiv){
timeoutIds[id_part] = window.setTimeout(function() {
timeoutIds[id_part] = null
requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt)
}, 500)
} else{
if (skip) {
skip.style.display = "none"
}
......@@ -47,13 +68,10 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip
if (galleryObservers[id_gallery]) {
galleryObservers[id_gallery].disconnect();
galleries[id_gallery] = null;
}
}
}
}
window.setTimeout(function() { requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt) }, 500)
});
mutationObserver.observe( progressbar, { childList:true, subtree:true })
}
......
......@@ -238,12 +238,15 @@ def tests(argv):
proc.kill()
def start_webui():
print(f"Launching Web UI with arguments: {' '.join(sys.argv[1:])}")
def start():
print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}")
import webui
webui.webui()
if '--nowebui' in sys.argv:
webui.api_only()
else:
webui.webui()
if __name__ == "__main__":
prepare_enviroment()
start_webui()
start()
......@@ -70,7 +70,7 @@
"None": "Nichts",
"Prompt matrix": "Promptmatrix",
"Prompts from file or textbox": "Prompts aus Datei oder Textfeld",
"X/Y plot": "X/Y Graf",
"X/Y plot": "X/Y Graph",
"Put variable parts at start of prompt": "Variable teile am start des Prompt setzen",
"Iterate seed every line": "Iterate seed every line",
"List of prompt inputs": "List of prompt inputs",
......@@ -455,4 +455,4 @@
"Only applies to inpainting models. Determines how strongly to mask off the original image for inpainting and img2img. 1.0 means fully masked, which is the default behaviour. 0.0 means a fully unmasked conditioning. Lower values will help preserve the overall composition of the image, but will struggle with large changes.": "Gilt nur für Inpainting-Modelle. Legt fest, wie stark das Originalbild für Inpainting und img2img maskiert werden soll. 1.0 bedeutet vollständig maskiert, was das Standardverhalten ist. 0.0 bedeutet eine vollständig unmaskierte Konditionierung. Niedrigere Werte tragen dazu bei, die Gesamtkomposition des Bildes zu erhalten, sind aber bei großen Änderungen problematisch.",
"List of setting names, separated by commas, for settings that should go to the quick access bar at the top, rather than the usual setting tab. See modules/shared.py for setting names. Requires restarting to apply.": "Liste von Einstellungsnamen, getrennt durch Kommas, für Einstellungen, die in der Schnellzugriffsleiste oben erscheinen sollen, anstatt in dem üblichen Einstellungs-Tab. Siehe modules/shared.py für Einstellungsnamen. Erfordert einen Neustart zur Anwendung.",
"If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.": "Wenn dieser Wert ungleich Null ist, wird er zum Seed addiert und zur Initialisierung des RNG für Noise bei der Verwendung von Samplern mit Eta verwendet. Dies kann verwendet werden, um noch mehr Variationen von Bildern zu erzeugen, oder um Bilder von anderer Software zu erzeugen, wenn Sie wissen, was Sie tun."
}
\ No newline at end of file
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import base64
import io
import time
import uvicorn
from threading import Lock
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, FastAPI, HTTPException
import modules.shared as shared
from modules import devices
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers
from modules.extras import run_extras, run_pnginfo
from modules.sd_models import checkpoints_list
from modules.realesrgan_model import get_realesrgan_models
from typing import List
def upscaler_to_index(name: str):
try:
......@@ -29,8 +33,14 @@ def setUpscalers(req: dict):
return reqDict
def encode_pil_to_base64(image):
buffer = io.BytesIO()
image.save(buffer, format="png")
return base64.b64encode(buffer.getvalue())
class Api:
def __init__(self, app, queue_lock):
def __init__(self, app: FastAPI, queue_lock: Lock):
self.router = APIRouter()
self.app = app
self.queue_lock = queue_lock
......@@ -40,6 +50,19 @@ class Api:
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
self.app.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
self.app.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
self.app.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
self.app.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
self.app.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
self.app.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
self.app.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
self.app.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
self.app.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
self.app.add_api_route("/sdapi/v1/prompt-styles", self.get_promp_styles, methods=["GET"], response_model=List[PromptStyleItem])
self.app.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
self.app.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
......@@ -170,12 +193,79 @@ class Api:
progress = min(progress, 1)
shared.state.set_current_image()
current_image = None
if shared.state.current_image and not req.skip_current_image:
current_image = encode_pil_to_base64(shared.state.current_image)
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
def interruptapi(self):
shared.state.interrupt()
return {}
def get_config(self):
options = {}
for key in shared.opts.data.keys():
metadata = shared.opts.data_labels.get(key)
if(metadata is not None):
options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})
else:
options.update({key: shared.opts.data.get(key, None)})
return options
def set_config(self, req: OptionsModel):
reqDict = vars(req)
for o in reqDict:
setattr(shared.opts, o, reqDict[o])
shared.opts.save(shared.config_filename)
return
def get_cmd_flags(self):
return vars(shared.cmd_opts)
def get_samplers(self):
return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in all_samplers]
def get_upscalers(self):
upscalers = []
for upscaler in shared.sd_upscalers:
u = upscaler.scaler
upscalers.append({"name":u.name, "model_name":u.model_name, "model_path":u.model_path, "model_url":u.model_url})
return upscalers
def get_sd_models(self):
return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": x.config} for x in checkpoints_list.values()]
def get_hypernetworks(self):
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
def get_face_restorers(self):
return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]
def get_realesrgan_models(self):
return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
def get_promp_styles(self):
styleList = []
for k in shared.prompt_styles.styles:
style = shared.prompt_styles.styles[k]
styleList.append({"name":style[0], "prompt": style[1], "negative_prompr": style[2]})
return styleList
def get_artists_categories(self):
return shared.artist_db.cats
def get_artists(self):
return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
def launch(self, server_name, port):
self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port)
import inspect
from click import prompt
from pydantic import BaseModel, Field, create_model
from typing import Any, Optional
from typing import Any, Optional, Union
from typing_extensions import Literal
from inflection import underscore
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
from modules.shared import sd_upscalers
from modules.shared import sd_upscalers, opts, parser
from typing import List
API_NOT_ALLOWED = [
"self",
......@@ -109,12 +109,12 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
).generate_model()
class TextToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict
info: str
class ImageToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict
info: str
......@@ -131,6 +131,7 @@ class ExtrasBaseRequest(BaseModel):
upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
upscale_first: bool = Field(default=False, title="Upscale first", description="Should the upscaler run before restoring faces?")
class ExtraBaseResponse(BaseModel):
html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.")
......@@ -146,10 +147,10 @@ class FileData(BaseModel):
name: str = Field(title="File name")
class ExtrasBatchImagesRequest(ExtrasBaseRequest):
imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
class ExtrasBatchImagesResponse(ExtraBaseResponse):
images: list[str] = Field(title="Images", description="The generated images in base64 format.")
images: List[str] = Field(title="Images", description="The generated images in base64 format.")
class PNGInfoRequest(BaseModel):
image: str = Field(title="Image", description="The base64 encoded PNG image")
......@@ -165,3 +166,68 @@ class ProgressResponse(BaseModel):
eta_relative: float = Field(title="ETA in secs")
state: dict = Field(title="State", description="The current state snapshot")
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
fields = {}
for key, value in opts.data.items():
metadata = opts.data_labels.get(key)
optType = opts.typemap.get(type(value), type(value))
if (metadata is not None):
fields.update({key: (Optional[optType], Field(
default=metadata.default ,description=metadata.label))})
else:
fields.update({key: (Optional[optType], Field())})
OptionsModel = create_model("Options", **fields)
flags = {}
_options = vars(parser)['_option_string_actions']
for key in _options:
if(_options[key].dest != 'help'):
flag = _options[key]
_type = str
if(_options[key].default != None): _type = type(_options[key].default)
flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))})
FlagsModel = create_model("Flags", **flags)
class SamplerItem(BaseModel):
name: str = Field(title="Name")
aliases: list[str] = Field(title="Aliases")
options: dict[str, str] = Field(title="Options")
class UpscalerItem(BaseModel):
name: str = Field(title="Name")
model_name: str | None = Field(title="Model Name")
model_path: str | None = Field(title="Path")
model_url: str | None = Field(title="URL")
class SDModelItem(BaseModel):
title: str = Field(title="Title")
model_name: str = Field(title="Model Name")
hash: str = Field(title="Hash")
filename: str = Field(title="Filename")
config: str = Field(title="Config file")
class HypernetworkItem(BaseModel):
name: str = Field(title="Name")
path: str | None = Field(title="Path")
class FaceRestorerItem(BaseModel):
name: str = Field(title="Name")
cmd_dir: str | None = Field(title="Path")
class RealesrganItem(BaseModel):
name: str = Field(title="Name")
path: str | None = Field(title="Path")
scale: int | None = Field(title="Scale")
class PromptStyleItem(BaseModel):
name: str = Field(title="Name")
prompt: str | None = Field(title="Prompt")
negative_prompt: str | None = Field(title="Negative Prompt")
class ArtistItem(BaseModel):
name: str = Field(title="Name")
score: float = Field(title="Score")
category: str = Field(title="Category")
\ No newline at end of file
......@@ -50,6 +50,7 @@ def mod2normal(state_dict):
def resrgan2normal(state_dict, nb=23):
# this code is copied from https://github.com/victorca25/iNNfer
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
re8x = 0
crt_net = {}
items = []
for k, v in state_dict.items():
......@@ -75,10 +76,18 @@ def resrgan2normal(state_dict, nb=23):
crt_net['model.3.bias'] = state_dict['conv_up1.bias']
crt_net['model.6.weight'] = state_dict['conv_up2.weight']
crt_net['model.6.bias'] = state_dict['conv_up2.bias']
crt_net['model.8.weight'] = state_dict['conv_hr.weight']
crt_net['model.8.bias'] = state_dict['conv_hr.bias']
crt_net['model.10.weight'] = state_dict['conv_last.weight']
crt_net['model.10.bias'] = state_dict['conv_last.bias']
if 'conv_up3.weight' in state_dict:
# modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
re8x = 3
crt_net['model.9.weight'] = state_dict['conv_up3.weight']
crt_net['model.9.bias'] = state_dict['conv_up3.bias']
crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
state_dict = crt_net
return state_dict
......
......@@ -136,10 +136,11 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]:
blended_result: Image.Image = None
image_hash: str = hash(np.array(image.getdata()).tobytes())
for upscaler in params:
upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode,
upscaling_resize_w, upscaling_resize_h, upscaling_crop)
cache_key = LruCache.Key(image_hash=hash(np.array(image.getdata()).tobytes()),
cache_key = LruCache.Key(image_hash=image_hash,
info_hash=hash(info),
args_hash=hash(upscale_args))
cached_entry = cached_images.get(cache_key)
......
......@@ -35,7 +35,8 @@ class HypernetworkModule(torch.nn.Module):
}
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False):
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=True):
super().__init__()
assert layer_structure is not None, "layer_structure must not be None"
......@@ -48,8 +49,8 @@ class HypernetworkModule(torch.nn.Module):
# Add a fully-connected layer
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
# Add an activation func
if activation_func == "linear" or activation_func is None:
# Add an activation func except last layer
if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
pass
elif activation_func in self.activation_dict:
linears.append(self.activation_dict[activation_func]())
......@@ -60,8 +61,8 @@ class HypernetworkModule(torch.nn.Module):
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
# Add dropout expect last layer
if use_dropout and i < len(layer_structure) - 3:
# Add dropout except last layer
if use_dropout and (i < len(layer_structure) - 3 or last_layer_dropout and i < len(layer_structure) - 2):
linears.append(torch.nn.Dropout(p=0.3))
self.linear = torch.nn.Sequential(*linears)
......@@ -75,7 +76,7 @@ class HypernetworkModule(torch.nn.Module):
w, b = layer.weight.data, layer.bias.data
if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
normal_(w, mean=0.0, std=0.01)
normal_(b, mean=0.0, std=0.005)
normal_(b, mean=0.0, std=0)
elif weight_init == 'XavierUniform':
xavier_uniform_(w)
zeros_(b)
......@@ -127,7 +128,7 @@ class Hypernetwork:
filename = None
name = None
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
self.filename = None
self.name = name
self.layers = {}
......@@ -139,11 +140,15 @@ class Hypernetwork:
self.weight_init = weight_init
self.add_layer_norm = add_layer_norm
self.use_dropout = use_dropout
self.activate_output = activate_output
self.last_layer_dropout = kwargs['last_layer_dropout'] if 'last_layer_dropout' in kwargs else True
for size in enable_sizes or []:
self.layers[size] = (
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
)
def weights(self):
......@@ -171,7 +176,9 @@ class Hypernetwork:
state_dict['use_dropout'] = self.use_dropout
state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
state_dict['activate_output'] = self.activate_output
state_dict['last_layer_dropout'] = self.last_layer_dropout
torch.save(state_dict, filename)
def load(self, filename):
......@@ -191,12 +198,17 @@ class Hypernetwork:
print(f"Layer norm is set to {self.add_layer_norm}")
self.use_dropout = state_dict.get('use_dropout', False)
print(f"Dropout usage is set to {self.use_dropout}" )
self.activate_output = state_dict.get('activate_output', True)
print(f"Activate last layer is set to {self.activate_output}")
self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
for size, sd in state_dict.items():
if type(size) == int:
self.layers[size] = (
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
)
self.name = state_dict.get('name', self.name)
......
......@@ -510,8 +510,9 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
if extension.lower() == '.png':
pnginfo_data = PngImagePlugin.PngInfo()
for k, v in params.pnginfo.items():
pnginfo_data.add_text(k, str(v))
if opts.enable_pnginfo:
for k, v in params.pnginfo.items():
pnginfo_data.add_text(k, str(v))
image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
......
......@@ -81,7 +81,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
mask = None
# Use the EXIF orientation of photos taken by smartphones.
image = ImageOps.exif_transpose(image)
if image is not None:
image = ImageOps.exif_transpose(image)
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
......@@ -137,6 +138,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
if processed is None:
processed = process_images(p)
p.close()
shared.total_tqdm.clear()
generation_info_js = processed.js()
......
......@@ -56,9 +56,9 @@ class InterrogateModels:
import clip
if self.running_on_cpu:
model, preprocess = clip.load(clip_model_name, device="cpu")
model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path)
else:
model, preprocess = clip.load(clip_model_name)
model, preprocess = clip.load(clip_model_name, download_root=shared.cmd_opts.clip_models_path)
model.eval()
model = model.to(devices.device_interrogate)
......
......@@ -49,7 +49,7 @@ def expand_crop_region(crop_region, processing_width, processing_height, image_w
ratio_processing = processing_width / processing_height
if ratio_crop_region > ratio_processing:
desired_height = (x2 - x1) * ratio_processing
desired_height = (x2 - x1) / ratio_processing
desired_height_diff = int(desired_height - (y2-y1))
y1 -= desired_height_diff//2
y2 += desired_height_diff - desired_height_diff//2
......
......@@ -85,6 +85,9 @@ def cleanup_models():
src_path = os.path.join(root_path, "ESRGAN")
dest_path = os.path.join(models_path, "ESRGAN")
move_files(src_path, dest_path)
src_path = os.path.join(models_path, "BSRGAN")
dest_path = os.path.join(models_path, "ESRGAN")
move_files(src_path, dest_path, ".pth")
src_path = os.path.join(root_path, "gfpgan")
dest_path = os.path.join(models_path, "GFPGAN")
move_files(src_path, dest_path)
......
......@@ -134,11 +134,7 @@ class StableDiffusionProcessing():
# Dummy zero conditioning if we're not using inpainting model.
# Still takes up a bit of memory, but no encoder call.
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
return torch.zeros(
x.shape[0], 5, 1, 1,
dtype=x.dtype,
device=x.device
)
return x.new_zeros(x.shape[0], 5, 1, 1)
height = height or self.height
width = width or self.width
......@@ -156,11 +152,7 @@ class StableDiffusionProcessing():
def img2img_image_conditioning(self, source_image, latent_image, image_mask = None):
if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
# Dummy zero conditioning if we're not using inpainting model.
return torch.zeros(
latent_image.shape[0], 5, 1, 1,
dtype=latent_image.dtype,
device=latent_image.device
)
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
# Handle the different mask inputs
if image_mask is not None:
......@@ -174,11 +166,11 @@ class StableDiffusionProcessing():
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
conditioning_mask = torch.round(conditioning_mask)
else:
conditioning_mask = torch.ones(1, 1, *source_image.shape[-2:])
conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
# Create another latent image, this time with a masked version of the original input.
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
conditioning_mask = conditioning_mask.to(source_image.device)
conditioning_mask = conditioning_mask.to(source_image.device).to(source_image.dtype)
conditioning_image = torch.lerp(
source_image,
source_image * (1.0 - conditioning_mask),
......@@ -199,9 +191,13 @@ class StableDiffusionProcessing():
def init(self, all_prompts, all_seeds, all_subseeds):
pass
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
raise NotImplementedError()
def close(self):
self.sd_model = None
self.sampler = None
class Processed:
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
......@@ -422,13 +418,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
try:
for k, v in p.override_settings.items():
opts.data[k] = v # we don't call onchange for simplicity which makes changing model, hypernet impossible
setattr(opts, k, v) # we don't call onchange for simplicity which makes changing model, hypernet impossible
res = process_images_inner(p)
finally:
for k, v in stored_opts.items():
opts.data[k] = v
setattr(opts, k, v)
return res
......@@ -505,6 +501,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if len(prompts) == 0:
break
if p.scripts is not None:
p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
with devices.autocast():
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
......@@ -517,7 +516,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
with devices.autocast():
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
samples_ddim = samples_ddim.to(devices.dtype_vae)
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
......@@ -597,9 +596,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None:
p.scripts.postprocess(p, res)
p.sd_model = None
p.sampler = None
return res
......@@ -648,7 +644,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
if not self.enable_hr:
......@@ -661,9 +657,28 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
"""saves image before applying hires fix, if enabled in options; takes as an arguyment either an image or batch with latent space images"""
def save_intermediate(image, index):
if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
return
if not isinstance(image, Image.Image):
image = sd_samplers.sample_to_image(image, index)
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")
if opts.use_scale_latent_for_hires_fix:
for i in range(samples.shape[0]):
save_intermediate(samples, i)
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
# Avoid making the inpainting conditioning unless necessary as
# this does need some extra compute to decode / encode the image again.
if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
else:
image_conditioning = self.txt2img_image_conditioning(samples)
else:
decoded_samples = decode_first_stage(self.sd_model, samples)
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
......@@ -673,6 +688,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
image = Image.fromarray(x_sample)
save_intermediate(image, i)
image = images.resize_image(0, image, self.width, self.height)
image = np.array(image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0)
......@@ -684,14 +702,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
shared.state.nextjob()
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
image_conditioning = self.txt2img_image_conditioning(x)
# GC now before running the next img2img to prevent running out of memory
x = None
devices.torch_gc()
......@@ -830,8 +848,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
......@@ -842,4 +859,4 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
del x
devices.torch_gc()
return samples
\ No newline at end of file
return samples
......@@ -2,6 +2,7 @@ import sys
import traceback
from collections import namedtuple
import inspect
from typing import Optional
from fastapi import FastAPI
from gradio import Blocks
......@@ -26,25 +27,42 @@ class ImageSaveParams:
"""dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
class CFGDenoiserParams:
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps):
self.x = x
"""Latent image representation in the process of being denoised"""
self.image_cond = image_cond
"""Conditioning image"""
self.sigma = sigma
"""Current sigma noise step value"""
self.sampling_step = sampling_step
"""Current Sampling step number"""
self.total_sampling_steps = total_sampling_steps
"""Total number of sampling steps planned"""
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
callbacks_app_started = []
callbacks_model_loaded = []
callbacks_ui_tabs = []
callbacks_ui_settings = []
callbacks_before_image_saved = []
callbacks_image_saved = []
callback_map = dict(
callbacks_app_started=[],
callbacks_model_loaded=[],
callbacks_ui_tabs=[],
callbacks_ui_settings=[],
callbacks_before_image_saved=[],
callbacks_image_saved=[],
callbacks_cfg_denoiser=[]
)
def clear_callbacks():
callbacks_model_loaded.clear()
callbacks_ui_tabs.clear()
callbacks_ui_settings.clear()
callbacks_before_image_saved.clear()
callbacks_image_saved.clear()
for callback_list in callback_map.values():
callback_list.clear()
def app_started_callback(demo: Blocks, app: FastAPI):
for c in callbacks_app_started:
def app_started_callback(demo: Optional[Blocks], app: FastAPI):
for c in callback_map['callbacks_app_started']:
try:
c.callback(demo, app)
except Exception:
......@@ -52,7 +70,7 @@ def app_started_callback(demo: Blocks, app: FastAPI):
def model_loaded_callback(sd_model):
for c in callbacks_model_loaded:
for c in callback_map['callbacks_model_loaded']:
try:
c.callback(sd_model)
except Exception:
......@@ -62,7 +80,7 @@ def model_loaded_callback(sd_model):
def ui_tabs_callback():
res = []
for c in callbacks_ui_tabs:
for c in callback_map['callbacks_ui_tabs']:
try:
res += c.callback() or []
except Exception:
......@@ -72,7 +90,7 @@ def ui_tabs_callback():
def ui_settings_callback():
for c in callbacks_ui_settings:
for c in callback_map['callbacks_ui_settings']:
try:
c.callback()
except Exception:
......@@ -80,7 +98,7 @@ def ui_settings_callback():
def before_image_saved_callback(params: ImageSaveParams):
for c in callbacks_before_image_saved:
for c in callback_map['callbacks_before_image_saved']:
try:
c.callback(params)
except Exception:
......@@ -88,30 +106,54 @@ def before_image_saved_callback(params: ImageSaveParams):
def image_saved_callback(params: ImageSaveParams):
for c in callbacks_image_saved:
for c in callback_map['callbacks_image_saved']:
try:
c.callback(params)
except Exception:
report_exception(c, 'image_saved_callback')
def cfg_denoiser_callback(params: CFGDenoiserParams):
for c in callback_map['callbacks_cfg_denoiser']:
try:
c.callback(params)
except Exception:
report_exception(c, 'cfg_denoiser_callback')
def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
callbacks.append(ScriptCallback(filename, fun))
def remove_current_script_callbacks():
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
if filename == 'unknown file':
return
for callback_list in callback_map.values():
for callback_to_remove in [cb for cb in callback_list if cb.script == filename]:
callback_list.remove(callback_to_remove)
def remove_callbacks_for_function(callback_func):
for callback_list in callback_map.values():
for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]:
callback_list.remove(callback_to_remove)
def on_app_started(callback):
"""register a function to be called when the webui started, the gradio `Block` component and
fastapi `FastAPI` object are passed as the arguments"""
add_callback(callbacks_app_started, callback)
add_callback(callback_map['callbacks_app_started'], callback)
def on_model_loaded(callback):
"""register a function to be called when the stable diffusion model is created; the model is
passed as an argument"""
add_callback(callbacks_model_loaded, callback)
add_callback(callback_map['callbacks_model_loaded'], callback)
def on_ui_tabs(callback):
......@@ -124,13 +166,13 @@ def on_ui_tabs(callback):
title is tab text displayed to user in the UI
elem_id is HTML id for the tab
"""
add_callback(callbacks_ui_tabs, callback)
add_callback(callback_map['callbacks_ui_tabs'], callback)
def on_ui_settings(callback):
"""register a function to be called before UI settings are populated; add your settings
by using shared.opts.add_option(shared.OptionInfo(...)) """
add_callback(callbacks_ui_settings, callback)
add_callback(callback_map['callbacks_ui_settings'], callback)
def on_before_image_saved(callback):
......@@ -138,7 +180,7 @@ def on_before_image_saved(callback):
The callback is called with one argument:
- params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
"""
add_callback(callbacks_before_image_saved, callback)
add_callback(callback_map['callbacks_before_image_saved'], callback)
def on_image_saved(callback):
......@@ -146,4 +188,12 @@ def on_image_saved(callback):
The callback is called with one argument:
- params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
"""
add_callback(callbacks_image_saved, callback)
add_callback(callback_map['callbacks_image_saved'], callback)
def on_cfg_denoiser(callback):
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
The callback is called with one argument:
- params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
"""
add_callback(callback_map['callbacks_cfg_denoiser'], callback)
......@@ -18,6 +18,9 @@ class Script:
args_to = None
alwayson = False
"""A gr.Group component that has all script's UI inside it"""
group = None
infotext_fields = None
"""if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
......@@ -70,6 +73,19 @@ class Script:
pass
def process_batch(self, p, *args, **kwargs):
"""
Same as process(), but called for every batch.
**kwargs will have those items:
- batch_number - index of current batch, from 0 to number of batches-1
- prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
- seeds - list of seeds for current batch
- subseeds - list of subseeds for current batch
"""
pass
def postprocess(self, p, processed, *args):
"""
This function is called after processing ends for AlwaysVisible scripts.
......@@ -218,8 +234,6 @@ class ScriptRunner:
for control in controls:
control.custom_script_source = os.path.basename(script.filename)
if not script.alwayson:
control.visible = False
if script.infotext_fields is not None:
self.infotext_fields += script.infotext_fields
......@@ -229,40 +243,41 @@ class ScriptRunner:
script.args_to = len(inputs)
for script in self.alwayson_scripts:
with gr.Group():
with gr.Group() as group:
create_script_ui(script, inputs, inputs_alwayson)
script.group = group
dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
dropdown.save_to_config = True
inputs[0] = dropdown
for script in self.selectable_scripts:
create_script_ui(script, inputs, inputs_alwayson)
with gr.Group(visible=False) as group:
create_script_ui(script, inputs, inputs_alwayson)
script.group = group
def select_script(script_index):
if 0 < script_index <= len(self.selectable_scripts):
script = self.selectable_scripts[script_index-1]
args_from = script.args_from
args_to = script.args_to
else:
args_from = 0
args_to = 0
selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
return [ui.gr_show(True if i == 0 else args_from <= i < args_to or is_alwayson) for i, is_alwayson in enumerate(inputs_alwayson)]
return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
def init_field(title):
"""called when an initial value is set from ui-config.json to show script's UI components"""
if title == 'None':
return
script_index = self.titles.index(title)
script = self.selectable_scripts[script_index]
for i in range(script.args_from, script.args_to):
inputs[i].visible = True
self.selectable_scripts[script_index].group.visible = True
dropdown.init_field = init_field
dropdown.change(
fn=select_script,
inputs=[dropdown],
outputs=inputs
outputs=[script.group for script in self.selectable_scripts]
)
return inputs
......@@ -294,6 +309,15 @@ class ScriptRunner:
print(f"Error running process: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def process_batch(self, p, **kwargs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.process_batch(p, *script_args, **kwargs)
except Exception:
print(f"Error running process_batch: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def postprocess(self, p, processed):
for script in self.alwayson_scripts:
try:
......
......@@ -9,7 +9,7 @@ from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
from modules import shared, modelloader, devices, script_callbacks
from modules import shared, modelloader, devices, script_callbacks, sd_vae
from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
......@@ -159,13 +159,16 @@ def get_state_dict_from_checkpoint(pl_sd):
return pl_sd
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
def load_model_weights(model, checkpoint_info):
def load_model_weights(model, checkpoint_info, vae_file="auto"):
checkpoint_file = checkpoint_info.filename
sd_model_hash = checkpoint_info.hash
if shared.opts.sd_checkpoint_cache > 0 and hasattr(model, "sd_checkpoint_info"):
sd_vae.restore_base_vae(model)
checkpoints_loaded[model.sd_checkpoint_info] = model.state_dict().copy()
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
if checkpoint_info not in checkpoints_loaded:
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
......@@ -182,37 +185,36 @@ def load_model_weights(model, checkpoint_info):
model.to(memory_format=torch.channels_last)
if not shared.cmd_opts.no_half:
vae = model.first_stage_model
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
if shared.cmd_opts.no_half_vae:
model.first_stage_model = None
model.half()
model.first_stage_model = vae
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt"
if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None:
vae_file = shared.cmd_opts.vae_path
if os.path.exists(vae_file):
print(f"Loading VAE weights from: {vae_file}")
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
model.first_stage_model.load_state_dict(vae_dict)
model.first_stage_model.to(devices.dtype_vae)
if shared.opts.sd_checkpoint_cache > 0:
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
checkpoints_loaded.popitem(last=False) # LRU
else:
print(f"Loading weights [{sd_model_hash}] from cache")
checkpoints_loaded.move_to_end(checkpoint_info)
vae_name = sd_vae.get_filename(vae_file) if vae_file else None
vae_message = f" with {vae_name} VAE" if vae_name else ""
print(f"Loading weights [{sd_model_hash}]{vae_message} from cache")
model.load_state_dict(checkpoints_loaded[checkpoint_info])
if shared.opts.sd_checkpoint_cache > 0:
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
checkpoints_loaded.popitem(last=False) # LRU
model.sd_model_hash = sd_model_hash
model.sd_model_checkpoint = checkpoint_file
model.sd_checkpoint_info = checkpoint_info
sd_vae.load_vae(model, vae_file)
def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack
......@@ -263,7 +265,7 @@ def load_model(checkpoint_info=None):
def reload_model_weights(sd_model=None, info=None):
from modules import lowvram, devices, sd_hijack
checkpoint_info = info or select_checkpoint()
if not sd_model:
sd_model = shared.sd_model
......
......@@ -12,6 +12,7 @@ from modules import prompt_parser, devices, processing, images
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
......@@ -92,8 +93,8 @@ def single_sample_to_image(sample):
return Image.fromarray(x_sample)
def sample_to_image(samples):
return single_sample_to_image(samples[0])
def sample_to_image(samples, index=0):
return single_sample_to_image(samples[index])
def samples_to_image_grid(samples):
......@@ -280,6 +281,12 @@ class CFGDenoiser(torch.nn.Module):
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
cfg_denoiser_callback(denoiser_params)
x_in = denoiser_params.x
image_cond_in = denoiser_params.image_cond
sigma_in = denoiser_params.sigma
if tensor.shape[1] == uncond.shape[1]:
cond_in = torch.cat([tensor, uncond])
......
import torch
import os
from collections import namedtuple
from modules import shared, devices, script_callbacks
from modules.paths import models_path
import glob
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
vae_dir = "VAE"
vae_path = os.path.abspath(os.path.join(models_path, vae_dir))
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
default_vae_dict = {"auto": "auto", "None": "None"}
default_vae_list = ["auto", "None"]
default_vae_values = [default_vae_dict[x] for x in default_vae_list]
vae_dict = dict(default_vae_dict)
vae_list = list(default_vae_list)
first_load = True
base_vae = None
loaded_vae_file = None
checkpoint_info = None
def get_base_vae(model):
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
return base_vae
return None
def store_base_vae(model):
global base_vae, checkpoint_info
if checkpoint_info != model.sd_checkpoint_info:
base_vae = model.first_stage_model.state_dict().copy()
checkpoint_info = model.sd_checkpoint_info
def delete_base_vae():
global base_vae, checkpoint_info
base_vae = None
checkpoint_info = None
def restore_base_vae(model):
global base_vae, checkpoint_info
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
load_vae_dict(model, base_vae)
delete_base_vae()
def get_filename(filepath):
return os.path.splitext(os.path.basename(filepath))[0]
def refresh_vae_list(vae_path=vae_path, model_path=model_path):
global vae_dict, vae_list
res = {}
candidates = [
*glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True),
*glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True),
*glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True),
*glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True)
]
if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path):
candidates.append(shared.cmd_opts.vae_path)
for filepath in candidates:
name = get_filename(filepath)
res[name] = filepath
vae_list.clear()
vae_list.extend(default_vae_list)
vae_list.extend(list(res.keys()))
vae_dict.clear()
vae_dict.update(res)
vae_dict.update(default_vae_dict)
return vae_list
def resolve_vae(checkpoint_file, vae_file="auto"):
global first_load, vae_dict, vae_list
# if vae_file argument is provided, it takes priority, but not saved
if vae_file and vae_file not in default_vae_list:
if not os.path.isfile(vae_file):
vae_file = "auto"
print("VAE provided as function argument doesn't exist")
# for the first load, if vae-path is provided, it takes priority, saved, and failure is reported
if first_load and shared.cmd_opts.vae_path is not None:
if os.path.isfile(shared.cmd_opts.vae_path):
vae_file = shared.cmd_opts.vae_path
shared.opts.data['sd_vae'] = get_filename(vae_file)
else:
print("VAE provided as command line argument doesn't exist")
# else, we load from settings
if vae_file == "auto" and shared.opts.sd_vae is not None:
# if saved VAE settings isn't recognized, fallback to auto
vae_file = vae_dict.get(shared.opts.sd_vae, "auto")
# if VAE selected but not found, fallback to auto
if vae_file not in default_vae_values and not os.path.isfile(vae_file):
vae_file = "auto"
print("Selected VAE doesn't exist")
# vae-path cmd arg takes priority for auto
if vae_file == "auto" and shared.cmd_opts.vae_path is not None:
if os.path.isfile(shared.cmd_opts.vae_path):
vae_file = shared.cmd_opts.vae_path
print("Using VAE provided as command line argument")
# if still not found, try look for ".vae.pt" beside model
model_path = os.path.splitext(checkpoint_file)[0]
if vae_file == "auto":
vae_file_try = model_path + ".vae.pt"
if os.path.isfile(vae_file_try):
vae_file = vae_file_try
print("Using VAE found beside selected model")
# if still not found, try look for ".vae.ckpt" beside model
if vae_file == "auto":
vae_file_try = model_path + ".vae.ckpt"
if os.path.isfile(vae_file_try):
vae_file = vae_file_try
print("Using VAE found beside selected model")
# No more fallbacks for auto
if vae_file == "auto":
vae_file = None
# Last check, just because
if vae_file and not os.path.exists(vae_file):
vae_file = None
return vae_file
def load_vae(model, vae_file=None):
global first_load, vae_dict, vae_list, loaded_vae_file
# save_settings = False
if vae_file:
print(f"Loading VAE weights from: {vae_file}")
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
load_vae_dict(model, vae_dict_1)
# If vae used is not in dict, update it
# It will be removed on refresh though
vae_opt = get_filename(vae_file)
if vae_opt not in vae_dict:
vae_dict[vae_opt] = vae_file
vae_list.append(vae_opt)
loaded_vae_file = vae_file
"""
# Save current VAE to VAE settings, maybe? will it work?
if save_settings:
if vae_file is None:
vae_opt = "None"
# shared.opts.sd_vae = vae_opt
"""
first_load = False
# don't call this from outside
def load_vae_dict(model, vae_dict_1=None):
if vae_dict_1:
store_base_vae(model)
model.first_stage_model.load_state_dict(vae_dict_1)
else:
restore_base_vae()
model.first_stage_model.to(devices.dtype_vae)
def reload_vae_weights(sd_model=None, vae_file="auto"):
from modules import lowvram, devices, sd_hijack
if not sd_model:
sd_model = shared.sd_model
checkpoint_info = sd_model.sd_checkpoint_info
checkpoint_file = checkpoint_info.filename
vae_file = resolve_vae(checkpoint_file, vae_file=vae_file)
if loaded_vae_file == vae_file:
return
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
else:
sd_model.to(devices.cpu)
sd_hijack.model_hijack.undo_hijack(sd_model)
load_vae(sd_model, vae_file)
sd_hijack.model_hijack.hijack(sd_model)
script_callbacks.model_loaded_callback(sd_model)
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device)
print(f"VAE Weights loaded.")
return sd_model
This diff is collapsed.
......@@ -235,6 +235,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
unload = shared.opts.unload_models_when_training
if save_embedding_every > 0:
embedding_dir = os.path.join(log_directory, "embeddings")
......@@ -272,6 +273,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
if unload:
shared.sd_model.first_stage_model.to(devices.cpu)
embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
......@@ -328,6 +331,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
if images_dir is not None and steps_done % create_image_every == 0:
forced_filename = f'{embedding_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename)
shared.sd_model.first_stage_model.to(devices.device)
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
do_not_save_grid=True,
......@@ -355,6 +361,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
processed = processing.process_images(p)
image = processed.images[0]
if unload:
shared.sd_model.first_stage_model.to(devices.cpu)
shared.state.current_image = image
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
......@@ -400,6 +409,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
shared.sd_model.first_stage_model.to(devices.device)
return embedding, filename
......
......@@ -25,8 +25,10 @@ def train_embedding(*args):
assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
apply_optimizations = shared.opts.training_xattention_optimizations
try:
sd_hijack.undo_optimizations()
if not apply_optimizations:
sd_hijack.undo_optimizations()
embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
......@@ -38,5 +40,6 @@ Embedding saved to {html.escape(filename)}
except Exception:
raise
finally:
sd_hijack.apply_optimizations()
if not apply_optimizations:
sd_hijack.apply_optimizations()
......@@ -47,6 +47,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
if processed is None:
processed = process_images(p)
p.close()
shared.total_tqdm.clear()
generation_info_js = processed.js()
......
......@@ -276,16 +276,8 @@ def check_progress_call(id_part):
image = gr_show(False)
preview_visibility = gr_show(False)
if opts.show_progress_every_n_steps > 0:
if shared.parallel_processing_allowed:
if shared.state.sampling_step - shared.state.current_image_sampling_step >= opts.show_progress_every_n_steps and shared.state.current_latent is not None:
if opts.show_progress_grid:
shared.state.current_image = modules.sd_samplers.samples_to_image_grid(shared.state.current_latent)
else:
shared.state.current_image = modules.sd_samplers.sample_to_image(shared.state.current_latent)
shared.state.current_image_sampling_step = shared.state.sampling_step
if opts.show_progress_every_n_steps != 0:
shared.state.set_current_image()
image = shared.state.current_image
if image is None:
......@@ -671,6 +663,8 @@ def create_ui(wrap_gradio_gpu_call):
import modules.img2img
import modules.txt2img
reload_javascript()
parameters_copypaste.reset()
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
......@@ -1058,9 +1052,11 @@ def create_ui(wrap_gradio_gpu_call):
extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.")
show_extras_results = gr.Checkbox(label='Show result images', value=True)
submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
with gr.Tabs(elem_id="extras_resize_mode"):
with gr.TabItem('Scale by'):
upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2)
upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4)
with gr.TabItem('Scale to'):
with gr.Group():
with gr.Row():
......@@ -1085,8 +1081,6 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Group():
upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False)
submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
result_images, html_info_x, html_info = create_output_panel("extras", opts.outdir_extras_samples)
submit.click(
......@@ -1188,8 +1182,8 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork", choices=modules.hypernetworks.ui.keys)
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. relu-like - Kaiming, sigmoid-like - Xavier is recommended", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys)
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
......@@ -1444,25 +1438,16 @@ def create_ui(wrap_gradio_gpu_call):
def run_settings(*args):
changed = 0
assert not shared.cmd_opts.freeze_settings, "changing settings is disabled"
for key, value, comp in zip(opts.data_labels.keys(), args, components):
if comp != dummy_component and not opts.same_type(value, opts.data_labels[key].default):
return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}", opts.dumpjson()
assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
for key, value, comp in zip(opts.data_labels.keys(), args, components):
if comp == dummy_component:
continue
comp_args = opts.data_labels[key].component_args
if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
continue
if cmd_opts.hide_ui_dir_config and key in restricted_opts:
continue
oldval = opts.data.get(key, None)
opts.data[key] = value
setattr(opts, key, value)
if oldval != value:
if opts.data_labels[key].onchange is not None:
......@@ -1472,20 +1457,18 @@ def create_ui(wrap_gradio_gpu_call):
opts.save(shared.config_filename)
return f'{changed} settings changed.', opts.dumpjson()
return opts.dumpjson(), f'{changed} settings changed.'
def run_settings_single(value, key):
assert not shared.cmd_opts.freeze_settings, "changing settings is disabled"
if not opts.same_type(value, opts.data_labels[key].default):
return gr.update(visible=True), opts.dumpjson()
oldval = opts.data.get(key, None)
if cmd_opts.hide_ui_dir_config and key in restricted_opts:
try:
setattr(opts, key, value)
except Exception:
return gr.update(value=oldval), opts.dumpjson()
opts.data[key] = value
if oldval != value:
if opts.data_labels[key].onchange is not None:
opts.data_labels[key].onchange()
......@@ -1570,8 +1553,7 @@ def create_ui(wrap_gradio_gpu_call):
reload_script_bodies.click(
fn=reload_scripts,
inputs=[],
outputs=[],
_js='function(){}'
outputs=[]
)
def request_restart():
......@@ -1583,7 +1565,7 @@ def create_ui(wrap_gradio_gpu_call):
fn=request_restart,
inputs=[],
outputs=[],
_js='function(){restart_reload()}'
_js='restart_reload'
)
if column is not None:
......@@ -1639,9 +1621,9 @@ def create_ui(wrap_gradio_gpu_call):
text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
settings_submit.click(
fn=run_settings,
fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]),
inputs=components,
outputs=[result, text_settings],
outputs=[text_settings, result],
)
for i, k, item in quicksettings_list:
......@@ -1782,4 +1764,3 @@ def load_javascript(raw_response):
reload_javascript = partial(load_javascript, gradio.routes.templates.TemplateResponse)
reload_javascript()
......@@ -86,7 +86,7 @@ def extension_table():
code += f"""
<tr>
<td><label><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
<td><a href="{html.escape(ext.remote or '')}">{html.escape(ext.remote or '')}</a></td>
<td><a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape(ext.remote or '')}</a></td>
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
</tr>
"""
......
......@@ -10,6 +10,7 @@ import modules.shared
from modules import modelloader, shared
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST)
from modules.paths import models_path
......@@ -57,7 +58,7 @@ class Upscaler:
dest_w = img.width * scale
dest_h = img.height * scale
for i in range(3):
if img.width >= dest_w and img.height >= dest_h:
if img.width > dest_w and img.height > dest_h:
break
img = self.do_upscale(img, selected_model)
if img.width != dest_w or img.height != dest_h:
......@@ -120,3 +121,17 @@ class UpscalerLanczos(Upscaler):
self.name = "Lanczos"
self.scalers = [UpscalerData("Lanczos", None, self)]
class UpscalerNearest(Upscaler):
scalers = []
def do_upscale(self, img, selected_model=None):
return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=NEAREST)
def load_model(self, _):
pass
def __init__(self, dirname=None):
super().__init__(False)
self.name = "Nearest"
self.scalers = [UpscalerData("Nearest", None, self)]
\ No newline at end of file
......@@ -2,7 +2,7 @@ transformers==4.19.2
diffusers==0.3.0
basicsr==1.4.2
gfpgan==1.3.8
gradio==3.5
gradio==3.8
numpy==1.23.3
Pillow==9.2.0
realesrgan==0.3.0
......
......@@ -14,7 +14,7 @@ class Script(scripts.Script):
return cmd_opts.allow_code
def ui(self, is_img2img):
code = gr.Textbox(label="Python code", visible=False, lines=1)
code = gr.Textbox(label="Python code", lines=1)
return [code]
......
......@@ -166,8 +166,7 @@ class Script(scripts.Script):
if override_strength:
p.denoising_strength = 1.0
def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
lat = (p.init_latent.cpu().numpy() * 10).astype(int)
same_params = self.cache is not None and self.cache.cfg_scale == cfg and self.cache.steps == st \
......
......@@ -132,7 +132,7 @@ class Script(scripts.Script):
info = gr.HTML("<p style=\"margin-bottom:0.75em\">Recommended settings: Sampling Steps: 80-100, Sampler: Euler a, Denoising strength: 0.8</p>")
pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128)
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, visible=False)
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8)
direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'])
noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0)
color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05)
......
......@@ -22,8 +22,8 @@ class Script(scripts.Script):
return None
pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128)
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, visible=False)
inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", visible=False)
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4)
inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index")
direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'])
return [pixels, mask_blur, inpainting_fill, direction]
......
......@@ -83,19 +83,21 @@ def cmdargs(line):
def load_prompt_file(file):
if (file is None):
if file is None:
lines = []
else:
lines = [x.strip() for x in file.decode('utf8', errors='ignore').split("\n")]
return None, "\n".join(lines), gr.update(lines=7)
class Script(scripts.Script):
def title(self):
return "Prompts from file or textbox"
def ui(self, is_img2img):
checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False)
checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False)
prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1)
file = gr.File(label="Upload prompt inputs", type='bytes')
......@@ -106,9 +108,9 @@ class Script(scripts.Script):
# We don't shrink back to 1, because that causes the control to ignore [enter], and it may
# be unclear to the user that shift-enter is needed.
prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt])
return [checkbox_iterate, file, prompt_txt]
return [checkbox_iterate, checkbox_iterate_batch, prompt_txt]
def run(self, p, checkbox_iterate, file, prompt_txt: str):
def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_txt: str):
lines = [x.strip() for x in prompt_txt.splitlines()]
lines = [x for x in lines if len(x) > 0]
......@@ -137,7 +139,7 @@ class Script(scripts.Script):
jobs.append(args)
print(f"Will process {len(lines)} lines in {job_count} jobs.")
if (checkbox_iterate and p.seed == -1):
if (checkbox_iterate or checkbox_iterate_batch) and p.seed == -1:
p.seed = int(random.randrange(4294967294))
state.job_count = job_count
......@@ -153,8 +155,7 @@ class Script(scripts.Script):
proc = process_images(copy_p)
images += proc.images
if (checkbox_iterate):
if checkbox_iterate:
p.seed = p.seed + (p.batch_size * p.n_iter)
return Processed(p, images, p.seed, "")
\ No newline at end of file
return Processed(p, images, p.seed, "")
......@@ -18,8 +18,8 @@ class Script(scripts.Script):
def ui(self, is_img2img):
info = gr.HTML("<p style=\"margin-bottom:0.75em\">Will upscale the image to twice the dimensions; use width and height sliders to set tile size</p>")
overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, visible=False)
upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index", visible=False)
overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64)
upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
return [info, overlap, upscaler_index]
......
......@@ -263,12 +263,12 @@ class Script(scripts.Script):
current_axis_options = [x for x in axis_options if type(x) == AxisOption or type(x) == AxisOptionImg2Img and is_img2img]
with gr.Row():
x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, visible=False, type="index", elem_id="x_type")
x_values = gr.Textbox(label="X values", visible=False, lines=1)
x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, type="index", elem_id="x_type")
x_values = gr.Textbox(label="X values", lines=1)
with gr.Row():
y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, visible=False, type="index", elem_id="y_type")
y_values = gr.Textbox(label="Y values", visible=False, lines=1)
y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, type="index", elem_id="y_type")
y_values = gr.Textbox(label="Y values", lines=1)
draw_legend = gr.Checkbox(label='Draw legend', value=True)
include_lone_images = gr.Checkbox(label='Include Separate Images', value=False)
......
......@@ -260,6 +260,16 @@ input[type="range"]{
#txt2img_negative_prompt, #img2img_negative_prompt{
}
/* gradio 3.8 adds opacity to progressbar which makes it blink; disable it here */
.transition.opacity-20 {
opacity: 1 !important;
}
/* more gradio's garbage cleanup */
.min-h-\[4rem\] {
min-height: unset !important;
}
#txt2img_progressbar, #img2img_progressbar, #ti_progressbar{
position: absolute;
z-index: 1000;
......@@ -491,7 +501,7 @@ input[type="range"]{
padding: 0;
}
#refresh_sd_model_checkpoint, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization{
#refresh_sd_model_checkpoint, #refresh_sd_vae, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization{
max-width: 2.5em;
min-width: 2.5em;
height: 2.4em;
......
import unittest
import requests
class UtilsTests(unittest.TestCase):
def setUp(self):
self.url_options = "http://localhost:7860/sdapi/v1/options"
self.url_cmd_flags = "http://localhost:7860/sdapi/v1/cmd-flags"
self.url_samplers = "http://localhost:7860/sdapi/v1/samplers"
self.url_upscalers = "http://localhost:7860/sdapi/v1/upscalers"
self.url_sd_models = "http://localhost:7860/sdapi/v1/sd-models"
self.url_hypernetworks = "http://localhost:7860/sdapi/v1/hypernetworks"
self.url_face_restorers = "http://localhost:7860/sdapi/v1/face-restorers"
self.url_realesrgan_models = "http://localhost:7860/sdapi/v1/realesrgan-models"
self.url_prompt_styles = "http://localhost:7860/sdapi/v1/prompt-styles"
self.url_artist_categories = "http://localhost:7860/sdapi/v1/artist-categories"
self.url_artists = "http://localhost:7860/sdapi/v1/artists"
def test_options_get(self):
self.assertEqual(requests.get(self.url_options).status_code, 200)
def test_options_write(self):
response = requests.get(self.url_options)
self.assertEqual(response.status_code, 200)
pre_value = response.json()["send_seed"]
self.assertEqual(requests.post(self.url_options, json={"send_seed":not pre_value}).status_code, 200)
response = requests.get(self.url_options)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json()["send_seed"], not pre_value)
requests.post(self.url_options, json={"send_seed": pre_value})
def test_cmd_flags(self):
self.assertEqual(requests.get(self.url_cmd_flags).status_code, 200)
def test_samplers(self):
self.assertEqual(requests.get(self.url_samplers).status_code, 200)
def test_upscalers(self):
self.assertEqual(requests.get(self.url_upscalers).status_code, 200)
def test_sd_models(self):
self.assertEqual(requests.get(self.url_sd_models).status_code, 200)
def test_hypernetworks(self):
self.assertEqual(requests.get(self.url_hypernetworks).status_code, 200)
def test_face_restorers(self):
self.assertEqual(requests.get(self.url_face_restorers).status_code, 200)
def test_realesrgan_models(self):
self.assertEqual(requests.get(self.url_realesrgan_models).status_code, 200)
def test_prompt_styles(self):
self.assertEqual(requests.get(self.url_prompt_styles).status_code, 200)
def test_artist_categories(self):
self.assertEqual(requests.get(self.url_artist_categories).status_code, 200)
def test_artists(self):
self.assertEqual(requests.get(self.url_artists).status_code, 200)
\ No newline at end of file
......@@ -21,6 +21,7 @@ import modules.paths
import modules.scripts
import modules.sd_hijack
import modules.sd_models
import modules.sd_vae
import modules.shared as shared
import modules.txt2img
import modules.script_callbacks
......@@ -77,8 +78,10 @@ def initialize():
modules.scripts.load_scripts()
modules.sd_vae.refresh_vae_list()
modules.sd_models.load_model()
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
......@@ -114,6 +117,8 @@ def api_only():
app.add_middleware(GZipMiddleware, minimum_size=1000)
api = create_api(app)
modules.script_callbacks.app_started_callback(None, app)
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
......@@ -136,6 +141,12 @@ def webui():
# after initial launch, disable --autolaunch for subsequent restarts
cmd_opts.autolaunch = False
# gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
# an attacker to trick the user into opening a malicious HTML page, which makes a request to the
# running web ui and do whatever the attcker wants, including installing an extension and
# runnnig its code. We disable this here. Suggested by RyotaK.
app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
app.add_middleware(GZipMiddleware, minimum_size=1000)
if launch_api:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment