Commit 509fd145 authored by Roy Shilkrot's avatar Roy Shilkrot

Merge remote-tracking branch 'upstream/master' into roy.add_simple_interrogate_api

parents bdc90837 dc7425a5
......@@ -29,4 +29,5 @@ notification.mp3
/textual_inversion
.vscode
/extensions
/test/stdout.txt
/test/stderr.txt
* @AUTOMATIC1111
/localizations/ar_AR.json @xmodar @blackneoo
/localizations/de_DE.json @LunixWasTaken
/localizations/es_ES.json @innovaciones
/localizations/fr_FR.json @tumbly
/localizations/it_IT.json @EugenioBuffo
/localizations/ja_JP.json @yuuki76
/localizations/ko_KR.json @36DB
/localizations/pt_BR.json @M-art-ucci
/localizations/ru_RU.json @kabachuha
/localizations/tr_TR.json @camenduru
/localizations/zh_CN.json @dtlnor @bgluminous
/localizations/zh_TW.json @benlisquare
function extensions_apply(_, _){
disable = []
update = []
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
if(x.name.startsWith("enable_") && ! x.checked)
disable.push(x.name.substr(7))
if(x.name.startsWith("update_") && x.checked)
update.push(x.name.substr(7))
})
restart_reload()
return [JSON.stringify(disable), JSON.stringify(update)]
}
function extensions_check(){
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
x.innerHTML = "Loading..."
})
return []
}
\ No newline at end of file
......@@ -75,6 +75,7 @@ titles = {
"Create style": "Save current prompts as a style. If you add the token {prompt} to the text, the style use that as placeholder for your prompt when you use the style in the future.",
"Checkpoint name": "Loads weights from checkpoint before making images. You can either use hash or a part of filename (as seen in settings) for checkpoint name. Recommended to use with Y axis for less switching.",
"Inpainting conditioning mask strength": "Only applies to inpainting models. Determines how strongly to mask off the original image for inpainting and img2img. 1.0 means fully masked, which is the default behaviour. 0.0 means a fully unmasked conditioning. Lower values will help preserve the overall composition of the image, but will struggle with large changes.",
"vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).",
......
......@@ -13,6 +13,15 @@ function showModal(event) {
}
lb.style.display = "block";
lb.focus()
const tabTxt2Img = gradioApp().getElementById("tab_txt2img")
const tabImg2Img = gradioApp().getElementById("tab_img2img")
// show the save button in modal only on txt2img or img2img tabs
if (tabTxt2Img.style.display != "none" || tabImg2Img.style.display != "none") {
gradioApp().getElementById("modal_save").style.display = "inline"
} else {
gradioApp().getElementById("modal_save").style.display = "none"
}
event.stopPropagation()
}
......@@ -81,6 +90,25 @@ function modalImageSwitch(offset) {
}
}
function saveImage(){
const tabTxt2Img = gradioApp().getElementById("tab_txt2img")
const tabImg2Img = gradioApp().getElementById("tab_img2img")
const saveTxt2Img = "save_txt2img"
const saveImg2Img = "save_img2img"
if (tabTxt2Img.style.display != "none") {
gradioApp().getElementById(saveTxt2Img).click()
} else if (tabImg2Img.style.display != "none") {
gradioApp().getElementById(saveImg2Img).click()
} else {
console.error("missing implementation for saving modal of this type")
}
}
function modalSaveImage(event) {
saveImage()
event.stopPropagation()
}
function modalNextImage(event) {
modalImageSwitch(1)
event.stopPropagation()
......@@ -93,6 +121,9 @@ function modalPrevImage(event) {
function modalKeyHandler(event) {
switch (event.key) {
case "s":
saveImage()
break;
case "ArrowLeft":
modalPrevImage(event)
break;
......@@ -198,6 +229,14 @@ document.addEventListener("DOMContentLoaded", function() {
modalTileImage.title = "Preview tiling";
modalControls.appendChild(modalTileImage)
const modalSave = document.createElement("span")
modalSave.className = "modalSave cursor"
modalSave.id = "modal_save"
modalSave.innerHTML = "🖫"
modalSave.addEventListener("click", modalSaveImage, true)
modalSave.title = "Save Image(s)"
modalControls.appendChild(modalSave)
const modalClose = document.createElement('span')
modalClose.className = 'modalClose cursor';
modalClose.innerHTML = '×'
......
......@@ -45,14 +45,14 @@ function switch_to_txt2img(){
return args_to_array(arguments);
}
function switch_to_img2img_img2img(){
function switch_to_img2img(){
gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
gradioApp().getElementById('mode_img2img').querySelectorAll('button')[0].click();
return args_to_array(arguments);
}
function switch_to_img2img_inpaint(){
function switch_to_inpaint(){
gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
gradioApp().getElementById('mode_img2img').querySelectorAll('button')[1].click();
......@@ -65,26 +65,6 @@ function switch_to_extras(){
return args_to_array(arguments);
}
function extract_image_from_gallery_txt2img(gallery){
switch_to_txt2img()
return extract_image_from_gallery(gallery);
}
function extract_image_from_gallery_img2img(gallery){
switch_to_img2img_img2img()
return extract_image_from_gallery(gallery);
}
function extract_image_from_gallery_inpaint(gallery){
switch_to_img2img_inpaint()
return extract_image_from_gallery(gallery);
}
function extract_image_from_gallery_extras(gallery){
switch_to_extras()
return extract_image_from_gallery(gallery);
}
function get_tab_index(tabId){
var res = 0
......
......@@ -128,10 +128,12 @@ def prepare_enviroment():
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
sys.argv += shlex.split(commandline_args)
test_argv = [x for x in sys.argv if x != '--tests']
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
sys.argv, update_check = extract_arg(sys.argv, '--update-check')
sys.argv, run_tests = extract_arg(sys.argv, '--tests')
xformers = '--xformers' in sys.argv
deepdanbooru = '--deepdanbooru' in sys.argv
ngrok = '--ngrok' in sys.argv
......@@ -194,6 +196,26 @@ def prepare_enviroment():
print("Exiting because of --exit argument")
exit(0)
if run_tests:
tests(test_argv)
exit(0)
def tests(argv):
if "--api" not in argv:
argv.append("--api")
print(f"Launching Web UI in another process for testing with arguments: {' '.join(argv[1:])}")
with open('test/stdout.txt', "w", encoding="utf8") as stdout, open('test/stderr.txt', "w", encoding="utf8") as stderr:
proc = subprocess.Popen([sys.executable, *argv], stdout=stdout, stderr=stderr)
import test.server_poll
test.server_poll.run_tests()
print(f"Stopping Web UI process with id {proc.pid}")
proc.kill()
def start_webui():
print(f"Launching Web UI with arguments: {' '.join(sys.argv[1:])}")
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI, InterrogateAPI
import time
import uvicorn
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
from fastapi import APIRouter, Depends, HTTPException
import modules.shared as shared
from modules import devices
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers
from modules.extras import run_pnginfo
import modules.shared as shared
import uvicorn
from fastapi import Body, APIRouter, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, Json
from typing import List
import json
import io
import base64
from PIL import Image
from modules.extras import run_extras, run_pnginfo
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
class TextToImageResponse(BaseModel):
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: Json
info: Json
def upscaler_to_index(name: str):
try:
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
except:
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
class ImageToImageResponse(BaseModel):
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: Json
info: Json
class InterrogateResponse(BaseModel):
caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
parameters: Json
info: Json
def setUpscalers(req: dict):
reqDict = vars(req)
reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
reqDict['extras_upscaler_2'] = upscaler_to_index(req.upscaler_2)
reqDict.pop('upscaler_1')
reqDict.pop('upscaler_2')
return reqDict
class Api:
......@@ -36,18 +34,14 @@ class Api:
self.router = APIRouter()
self.app = app
self.queue_lock = queue_lock
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"])
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
self.app.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
def __base64_to_image(self, base64_string):
# if has a comma, deal with prefix
if "," in base64_string:
base64_string = base64_string.split(",")[1]
imgdata = base64.b64decode(base64_string)
# convert base64 to PIL image
return Image.open(io.BytesIO(imgdata))
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
......@@ -63,18 +57,17 @@ class Api:
)
p = StableDiffusionProcessingTxt2Img(**vars(populate))
# Override object param
shared.state.begin()
with self.queue_lock:
processed = process_images(p)
b64images = []
for i in processed.images:
buffer = io.BytesIO()
i.save(buffer, format="png")
b64images.append(base64.b64encode(buffer.getvalue()))
return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.js())
shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images))
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
sampler_index = sampler_to_index(img2imgreq.sampler_index)
......@@ -89,7 +82,7 @@ class Api:
mask = img2imgreq.mask
if mask:
mask = self.__base64_to_image(mask)
mask = decode_base64_to_image(mask)
populate = img2imgreq.copy(update={ # Override __init__ params
......@@ -104,27 +97,89 @@ class Api:
imgs = []
for img in init_images:
img = self.__base64_to_image(img)
img = decode_base64_to_image(img)
imgs = [img] * p.batch_size
p.init_images = imgs
# Override object param
shared.state.begin()
with self.queue_lock:
processed = process_images(p)
b64images = []
for i in processed.images:
buffer = io.BytesIO()
i.save(buffer, format="png")
b64images.append(base64.b64encode(buffer.getvalue()))
shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images))
if (not img2imgreq.include_init_images):
img2imgreq.init_images = None
img2imgreq.mask = None
return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=processed.js())
return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
def extrasapi(self):
raise NotImplementedError
def extras_single_image_api(self, req: ExtrasSingleImageRequest):
reqDict = setUpscalers(req)
reqDict['image'] = decode_base64_to_image(reqDict['image'])
with self.queue_lock:
result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", **reqDict)
return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
def interrogateapi(self, interrogatereq: InterrogateAPI):
def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
reqDict = setUpscalers(req)
def prepareFiles(file):
file = decode_base64_to_file(file.data, file_path=file.name)
file.orig_name = file.name
return file
reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList']))
reqDict.pop('imageList')
with self.queue_lock:
result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", **reqDict)
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
def pnginfoapi(self, req: PNGInfoRequest):
if(not req.image.strip()):
return PNGInfoResponse(info="")
result = run_pnginfo(decode_base64_to_image(req.image.strip()))
return PNGInfoResponse(info=result[1])
def progressapi(self, req: ProgressRequest = Depends()):
# copy from check_progress_call of ui.py
if shared.state.job_count == 0:
return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict())
# avoid dividing zero
progress = 0.01
if shared.state.job_count > 0:
progress += shared.state.job_no / shared.state.job_count
if shared.state.sampling_steps > 0:
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
time_since_start = time.time() - shared.state.time_start
eta = (time_since_start/progress)
eta_relative = eta-time_since_start
progress = min(progress, 1)
current_image = None
if shared.state.current_image and not req.skip_current_image:
current_image = encode_pil_to_base64(shared.state.current_image)
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
def interrogateapi(self, interrogatereq: InterrogateRequest):
image_b64 = interrogatereq.image
if image_b64 is None:
raise HTTPException(status_code=404, detail="Image not found")
......@@ -139,13 +194,7 @@ class Api:
with self.queue_lock:
processed = shared.interrogator.interrogate(img)
return InterrogateResponse(caption=processed, parameters=json.dumps(vars(interrogatereq)), info=None)
def extrasapi(self):
raise NotImplementedError
def pnginfoapi(self):
raise NotImplementedError
return InterrogateResponse(caption=processed)
def launch(self, server_name, port):
self.app.include_router(self.router)
......
from array import array
from inflection import underscore
from typing import Any, Dict, Optional
import inspect
from click import prompt
from pydantic import BaseModel, Field, create_model
from typing import Any, Optional
from typing_extensions import Literal
from inflection import underscore
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
import inspect
from modules.shared import sd_upscalers
API_NOT_ALLOWED = [
"self",
......@@ -112,8 +113,66 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
).generate_model()
InterrogateAPI = PydanticModelGenerator(
"Interrogate",
None,
[{"key": "image", "type": str, "default": None}]
).generate_model()
\ No newline at end of file
class TextToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict
info: str
class ImageToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict
info: str
class ExtrasBaseRequest(BaseModel):
resize_mode: Literal[0, 1] = Field(default=0, title="Resize Mode", description="Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.")
show_extras_results: bool = Field(default=True, title="Show results", description="Should the backend return the generated image?")
gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.")
codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.")
codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.")
upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.")
upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the choosen size?")
upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
class ExtraBaseResponse(BaseModel):
html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.")
class ExtrasSingleImageRequest(ExtrasBaseRequest):
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
class ExtrasSingleImageResponse(ExtraBaseResponse):
image: str = Field(default=None, title="Image", description="The generated image in base64 format.")
class FileData(BaseModel):
data: str = Field(title="File data", description="Base64 representation of the file")
name: str = Field(title="File name")
class ExtrasBatchImagesRequest(ExtrasBaseRequest):
imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
class ExtrasBatchImagesResponse(ExtraBaseResponse):
images: list[str] = Field(title="Images", description="The generated images in base64 format.")
class PNGInfoRequest(BaseModel):
image: str = Field(title="Image", description="The base64 encoded PNG image")
class PNGInfoResponse(BaseModel):
info: str = Field(title="Image info", description="A string with all the info the image had")
class ProgressRequest(BaseModel):
skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
class ProgressResponse(BaseModel):
progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
eta_relative: float = Field(title="ETA in secs")
state: dict = Field(title="State", description="The current state snapshot")
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
class InterrogateRequest(BaseModel):
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
class InterrogateResponse(BaseModel):
caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
import os
import sys
import traceback
import git
from modules import paths, shared
extensions = []
extensions_dir = os.path.join(paths.script_path, "extensions")
def active():
return [x for x in extensions if x.enabled]
class Extension:
def __init__(self, name, path, enabled=True):
self.name = name
self.path = path
self.enabled = enabled
self.status = ''
self.can_update = False
repo = None
try:
if os.path.exists(os.path.join(path, ".git")):
repo = git.Repo(path)
except Exception:
print(f"Error reading github repository info from {path}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
if repo is None or repo.bare:
self.remote = None
else:
self.remote = next(repo.remote().urls, None)
self.status = 'unknown'
def list_files(self, subdir, extension):
from modules import scripts
dirpath = os.path.join(self.path, subdir)
if not os.path.isdir(dirpath):
return []
res = []
for filename in sorted(os.listdir(dirpath)):
res.append(scripts.ScriptFile(dirpath, filename, os.path.join(dirpath, filename)))
res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
return res
def check_updates(self):
repo = git.Repo(self.path)
for fetch in repo.remote().fetch("--dry-run"):
if fetch.flags != fetch.HEAD_UPTODATE:
self.can_update = True
self.status = "behind"
return
self.can_update = False
self.status = "latest"
def pull(self):
repo = git.Repo(self.path)
repo.remotes.origin.pull()
def list_extensions():
extensions.clear()
if not os.path.isdir(extensions_dir):
return
for dirname in sorted(os.listdir(extensions_dir)):
path = os.path.join(extensions_dir, dirname)
if not os.path.isdir(path):
continue
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions)
extensions.append(extension)
This diff is collapsed.
import base64
import io
import os
import re
import gradio as gr
from modules.shared import script_path
from modules import shared
import tempfile
from PIL import Image
re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
type_of_gr_update = type(gr.update())
paste_fields = {}
bind_list = []
def reset():
paste_fields.clear()
bind_list.clear()
def quote(text):
......@@ -20,6 +31,111 @@ def quote(text):
text = text.replace('"', '\\"')
return f'"{text}"'
def image_from_url_text(filedata):
if type(filedata) == dict and filedata["is_file"]:
filename = filedata["name"]
tempdir = os.path.normpath(tempfile.gettempdir())
normfn = os.path.normpath(filename)
assert normfn.startswith(tempdir), 'trying to open image file not in temporary directory'
return Image.open(filename)
if type(filedata) == list:
if len(filedata) == 0:
return None
filedata = filedata[0]
if filedata.startswith("data:image/png;base64,"):
filedata = filedata[len("data:image/png;base64,"):]
filedata = base64.decodebytes(filedata.encode('utf-8'))
image = Image.open(io.BytesIO(filedata))
return image
def add_paste_fields(tabname, init_img, fields):
paste_fields[tabname] = {"init_img": init_img, "fields": fields}
# backwards compatibility for existing extensions
import modules.ui
if tabname == 'txt2img':
modules.ui.txt2img_paste_fields = fields
elif tabname == 'img2img':
modules.ui.img2img_paste_fields = fields
def integrate_settings_paste_fields(component_dict):
from modules import ui
settings_map = {
'sd_hypernetwork': 'Hypernet',
'sd_hypernetwork_strength': 'Hypernet strength',
'CLIP_stop_at_last_layers': 'Clip skip',
'sd_model_checkpoint': 'Model hash',
}
settings_paste_fields = [
(component_dict[k], lambda d, k=k, v=v: ui.apply_setting(k, d.get(v, None)))
for k, v in settings_map.items()
]
for tabname, info in paste_fields.items():
if info["fields"] is not None:
info["fields"] += settings_paste_fields
def create_buttons(tabs_list):
buttons = {}
for tab in tabs_list:
buttons[tab] = gr.Button(f"Send to {tab}")
return buttons
#if send_generate_info is a tab name, mean generate_info comes from the params fields of the tab
def bind_buttons(buttons, send_image, send_generate_info):
bind_list.append([buttons, send_image, send_generate_info])
def run_bind():
for buttons, send_image, send_generate_info in bind_list:
for tab in buttons:
button = buttons[tab]
if send_image and paste_fields[tab]["init_img"]:
if type(send_image) == gr.Gallery:
button.click(
fn=lambda x: image_from_url_text(x),
_js="extract_image_from_gallery",
inputs=[send_image],
outputs=[paste_fields[tab]["init_img"]],
)
else:
button.click(
fn=lambda x: x,
inputs=[send_image],
outputs=[paste_fields[tab]["init_img"]],
)
if send_generate_info and paste_fields[tab]["fields"] is not None:
if send_generate_info in paste_fields:
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Size-1', 'Size-2'] + (["Seed"] if shared.opts.send_seed else [])
button.click(
fn=lambda *x: x,
inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names],
outputs=[field for field, name in paste_fields[tab]["fields"] if name in paste_field_names],
)
else:
connect_paste(button, paste_fields[tab]["fields"], send_generate_info)
button.click(
fn=None,
_js=f"switch_to_{tab}",
inputs=None,
outputs=None,
)
def parse_generation_parameters(x: str):
"""parses generation parameters string, the one you see in text field under the picture in UI:
```
......@@ -68,7 +184,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
return res
def connect_paste(button, paste_fields, input_comp, js=None):
def connect_paste(button, paste_fields, input_comp, jsfunc=None):
def paste_func(prompt):
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
filename = os.path.join(script_path, "params.txt")
......@@ -106,7 +222,9 @@ def connect_paste(button, paste_fields, input_comp, js=None):
button.click(
fn=paste_func,
_js=js,
_js=jsfunc,
inputs=[input_comp],
outputs=[x[0] for x in paste_fields],
)
......@@ -25,6 +25,7 @@ from statistics import stdev, mean
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
activation_dict = {
"linear": torch.nn.Identity,
"relu": torch.nn.ReLU,
"leakyrelu": torch.nn.LeakyReLU,
"elu": torch.nn.ELU,
......@@ -208,13 +209,16 @@ def list_hypernetworks(path):
res = {}
for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
name = os.path.splitext(os.path.basename(filename))[0]
# Prevent a hypothetical "None.pt" from being listed.
if name != "None":
res[name] = filename
return res
def load_hypernetwork(filename):
path = shared.hypernetworks.get(filename, None)
if path is not None:
# Prevent any file named "None.pt" from being loaded.
if path is not None and filename != "None":
print(f"Loading hypernetwork {filename}")
try:
shared.loaded_hypernetwork = Hypernetwork()
......@@ -331,7 +335,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
assert hypernetwork_name, 'hypernetwork not selected'
save_hypernetwork_every = save_hypernetwork_every or 0
create_image_every = create_image_every or 0
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
path = shared.hypernetworks.get(hypernetwork_name, None)
shared.loaded_hypernetwork = Hypernetwork()
......@@ -357,18 +363,25 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
else:
images_dir = None
hypernetwork = shared.loaded_hypernetwork
checkpoint = sd_models.select_checkpoint()
ititial_step = hypernetwork.step or 0
if ititial_step >= steps:
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
hypernetwork = shared.loaded_hypernetwork
weights = hypernetwork.weights()
for weight in weights:
weight.requires_grad = True
size = len(ds.indexes)
loss_dict = defaultdict(lambda : deque(maxlen = 1024))
losses = torch.zeros((size,))
......@@ -376,20 +389,18 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
previous_mean_loss = 0
print("Mean loss of {} elements".format(size))
last_saved_file = "<none>"
last_saved_image = "<none>"
forced_filename = "<none>"
ititial_step = hypernetwork.step or 0
if ititial_step > steps:
return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
weights = hypernetwork.weights()
for weight in weights:
weight.requires_grad = True
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
steps_without_grad = 0
last_saved_file = "<none>"
last_saved_image = "<none>"
forced_filename = "<none>"
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, entries in pbar:
hypernetwork.step = i + ititial_step
......@@ -428,6 +439,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
optimizer.step()
steps_done = hypernetwork.step + 1
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
raise RuntimeError("Loss diverged.")
......@@ -438,19 +451,19 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
pbar.set_description(dataset_loss_info)
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
# Before saving, change name to match current checkpoint.
hypernetwork.name = f'{hypernetwork_name}-{hypernetwork.step}'
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
hypernetwork.save(last_saved_file)
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
"loss": f"{previous_mean_loss:.7f}",
"learn_rate": scheduler.learn_rate
})
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
forced_filename = f'{hypernetwork_name}-{hypernetwork.step}'
if images_dir is not None and steps_done % create_image_every == 0:
forced_filename = f'{hypernetwork_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename)
optimizer.zero_grad()
......@@ -503,13 +516,23 @@ Last saved image: {html.escape(last_saved_image)}<br/>
"""
report_statistics(loss_dict)
checkpoint = sd_models.select_checkpoint()
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
return hypernetwork, filename
def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
old_hypernetwork_name = hypernetwork.name
old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
try:
hypernetwork.sd_checkpoint = checkpoint.hash
hypernetwork.sd_checkpoint_name = checkpoint.model_name
# Before saving for the last time, change name back to the base name (as opposed to the save_hypernetwork_every step-suffixed naming convention).
hypernetwork.name = hypernetwork_name
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork.name}.pt')
hypernetwork.save(filename)
return hypernetwork, filename
except:
hypernetwork.sd_checkpoint = old_sd_checkpoint
hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
hypernetwork.name = old_hypernetwork_name
raise
......@@ -8,7 +8,8 @@ import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack, shared
from modules.hypernetworks import hypernetwork
keys = list(hypernetwork.HypernetworkModule.activation_dict.keys())
not_available = ["hardswish", "multiheadattention"]
keys = ["linear"] + list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
# Remove illegal characters from name.
......
......@@ -300,8 +300,8 @@ class FilenameGenerator:
'seed': lambda self: self.seed if self.seed is not None else '',
'steps': lambda self: self.p and self.p.steps,
'cfg': lambda self: self.p and self.p.cfg_scale,
'width': lambda self: self.p and self.p.width,
'height': lambda self: self.p and self.p.height,
'width': lambda self: self.image.width,
'height': lambda self: self.image.height,
'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
'sampler': lambda self: self.p and sanitize_filename_part(sd_samplers.samplers[self.p.sampler_index].name, replace_spaces=False),
'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
......@@ -315,10 +315,11 @@ class FilenameGenerator:
}
default_time_format = '%Y%m%d%H%M%S'
def __init__(self, p, seed, prompt):
def __init__(self, p, seed, prompt, image):
self.p = p
self.seed = seed
self.prompt = prompt
self.image = image
def prompt_no_style(self):
if self.p is None or self.prompt is None:
......@@ -449,7 +450,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
txt_fullfn (`str` or None):
If a text file is saved for this image, this will be its full path. Otherwise None.
"""
namegen = FilenameGenerator(p, seed, prompt)
namegen = FilenameGenerator(p, seed, prompt, image)
if save_to_dirs is None:
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
......
......@@ -19,7 +19,7 @@ import modules.scripts
def process_batch(p, input_dir, output_dir, args):
processing.fix_seed(p)
images = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)]
images = shared.listfiles(input_dir)
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
......
This diff is collapsed.
This diff is collapsed.
......@@ -3,6 +3,7 @@ import os.path
import sys
from collections import namedtuple
import torch
import re
from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
......@@ -36,7 +37,9 @@ def setup_model():
def checkpoint_tiles():
return sorted([x.title for x in checkpoints_list.values()])
convert = lambda name: int(name) if name.isdigit() else name.lower()
alphanumeric_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)
def list_models():
......@@ -170,7 +173,9 @@ def load_model_weights(model, checkpoint_info):
print(f"Global Step: {pl_sd['global_step']}")
sd = get_state_dict_from_checkpoint(pl_sd)
missing, extra = model.load_state_dict(sd, strict=False)
del pl_sd
model.load_state_dict(sd, strict=False)
del sd
if shared.cmd_opts.opt_channelslast:
model.to(memory_format=torch.channels_last)
......@@ -194,6 +199,7 @@ def load_model_weights(model, checkpoint_info):
model.first_stage_model.to(devices.dtype_vae)
if shared.opts.sd_checkpoint_cache > 0:
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
checkpoints_loaded.popitem(last=False) # LRU
......
This diff is collapsed.
......@@ -42,6 +42,8 @@ class PersonalizedBase(Dataset):
self.lines = lines
assert data_root, 'dataset directory not specified'
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty"
cond_model = shared.sd_model.cond_stage_model
......@@ -86,12 +88,12 @@ class PersonalizedBase(Dataset):
assert len(self.dataset) > 0, "No images have been found in the dataset."
self.length = len(self.dataset) * repeats // batch_size
self.initial_indexes = np.arange(len(self.dataset))
self.dataset_length = len(self.dataset)
self.indexes = None
self.shuffle()
def shuffle(self):
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0]).numpy()]
self.indexes = np.random.permutation(self.dataset_length)
def create_text(self, filename_text):
text = random.choice(self.lines)
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -23,3 +23,4 @@ torchdiffeq==0.2.3
kornia==0.6.7
lark==1.1.2
inflection==0.5.1
GitPython==3.1.27
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment