Commit d85c2cb2 authored by Muhammad Rizqi Nur's avatar Muhammad Rizqi Nur

Merge branch 'master' into gradient-clipping

parents cabd4e3b ac085628
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -7,6 +7,7 @@ from typing import Optional ...@@ -7,6 +7,7 @@ from typing import Optional
from fastapi import FastAPI from fastapi import FastAPI
from gradio import Blocks from gradio import Blocks
def report_exception(c, job): def report_exception(c, job):
print(f"Error executing callback {job} for {c.script}", file=sys.stderr) print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
...@@ -45,15 +46,21 @@ class CFGDenoiserParams: ...@@ -45,15 +46,21 @@ class CFGDenoiserParams:
"""Total number of sampling steps planned""" """Total number of sampling steps planned"""
class UiTrainTabParams:
def __init__(self, txt2img_preview_params):
self.txt2img_preview_params = txt2img_preview_params
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
callback_map = dict( callback_map = dict(
callbacks_app_started=[], callbacks_app_started=[],
callbacks_model_loaded=[], callbacks_model_loaded=[],
callbacks_ui_tabs=[], callbacks_ui_tabs=[],
callbacks_ui_train_tabs=[],
callbacks_ui_settings=[], callbacks_ui_settings=[],
callbacks_before_image_saved=[], callbacks_before_image_saved=[],
callbacks_image_saved=[], callbacks_image_saved=[],
callbacks_cfg_denoiser=[] callbacks_cfg_denoiser=[],
) )
...@@ -61,6 +68,7 @@ def clear_callbacks(): ...@@ -61,6 +68,7 @@ def clear_callbacks():
for callback_list in callback_map.values(): for callback_list in callback_map.values():
callback_list.clear() callback_list.clear()
def app_started_callback(demo: Optional[Blocks], app: FastAPI): def app_started_callback(demo: Optional[Blocks], app: FastAPI):
for c in callback_map['callbacks_app_started']: for c in callback_map['callbacks_app_started']:
try: try:
...@@ -89,6 +97,14 @@ def ui_tabs_callback(): ...@@ -89,6 +97,14 @@ def ui_tabs_callback():
return res return res
def ui_train_tabs_callback(params: UiTrainTabParams):
for c in callback_map['callbacks_ui_train_tabs']:
try:
c.callback(params)
except Exception:
report_exception(c, 'callbacks_ui_train_tabs')
def ui_settings_callback(): def ui_settings_callback():
for c in callback_map['callbacks_ui_settings']: for c in callback_map['callbacks_ui_settings']:
try: try:
...@@ -169,6 +185,13 @@ def on_ui_tabs(callback): ...@@ -169,6 +185,13 @@ def on_ui_tabs(callback):
add_callback(callback_map['callbacks_ui_tabs'], callback) add_callback(callback_map['callbacks_ui_tabs'], callback)
def on_ui_train_tabs(callback):
"""register a function to be called when the UI is creating new tabs for the train tab.
Create your new tabs with gr.Tab.
"""
add_callback(callback_map['callbacks_ui_train_tabs'], callback)
def on_ui_settings(callback): def on_ui_settings(callback):
"""register a function to be called before UI settings are populated; add your settings """register a function to be called before UI settings are populated; add your settings
by using shared.opts.add_option(shared.OptionInfo(...)) """ by using shared.opts.add_option(shared.OptionInfo(...)) """
......
...@@ -35,64 +35,62 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce ...@@ -35,64 +35,62 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce
deepbooru.release_process() deepbooru.release_process()
def listfiles(dirname):
return os.listdir(dirname)
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
width = process_width
height = process_height
src = os.path.abspath(process_src)
dst = os.path.abspath(process_dst)
split_threshold = max(0.0, min(1.0, split_threshold))
overlap_ratio = max(0.0, min(0.9, overlap_ratio))
assert src != dst, 'same directory specified as source and destination'
os.makedirs(dst, exist_ok=True) class PreprocessParams:
src = None
dstdir = None
subindex = 0
flip = False
process_caption = False
process_caption_deepbooru = False
preprocess_txt_action = None
files = os.listdir(src)
shared.state.textinfo = "Preprocessing..." def save_pic_with_caption(image, index, params: PreprocessParams, existing_caption=None):
shared.state.job_count = len(files)
def save_pic_with_caption(image, index, existing_caption=None):
caption = "" caption = ""
if process_caption: if params.process_caption:
caption += shared.interrogator.generate_caption(image) caption += shared.interrogator.generate_caption(image)
if process_caption_deepbooru: if params.process_caption_deepbooru:
if len(caption) > 0: if len(caption) > 0:
caption += ", " caption += ", "
caption += deepbooru.get_tags_from_process(image) caption += deepbooru.get_tags_from_process(image)
filename_part = filename filename_part = params.src
filename_part = os.path.splitext(filename_part)[0] filename_part = os.path.splitext(filename_part)[0]
filename_part = os.path.basename(filename_part) filename_part = os.path.basename(filename_part)
basename = f"{index:05}-{subindex[0]}-{filename_part}" basename = f"{index:05}-{params.subindex}-{filename_part}"
image.save(os.path.join(dst, f"{basename}.png")) image.save(os.path.join(params.dstdir, f"{basename}.png"))
if preprocess_txt_action == 'prepend' and existing_caption: if params.preprocess_txt_action == 'prepend' and existing_caption:
caption = existing_caption + ' ' + caption caption = existing_caption + ' ' + caption
elif preprocess_txt_action == 'append' and existing_caption: elif params.preprocess_txt_action == 'append' and existing_caption:
caption = caption + ' ' + existing_caption caption = caption + ' ' + existing_caption
elif preprocess_txt_action == 'copy' and existing_caption: elif params.preprocess_txt_action == 'copy' and existing_caption:
caption = existing_caption caption = existing_caption
caption = caption.strip() caption = caption.strip()
if len(caption) > 0: if len(caption) > 0:
with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file: with open(os.path.join(params.dstdir, f"{basename}.txt"), "w", encoding="utf8") as file:
file.write(caption) file.write(caption)
subindex[0] += 1 params.subindex += 1
def save_pic(image, index, existing_caption=None): def save_pic(image, index, params, existing_caption=None):
save_pic_with_caption(image, index, existing_caption=existing_caption) save_pic_with_caption(image, index, params, existing_caption=existing_caption)
if process_flip: if params.flip:
save_pic_with_caption(ImageOps.mirror(image), index, existing_caption=existing_caption) save_pic_with_caption(ImageOps.mirror(image), index, params, existing_caption=existing_caption)
def split_pic(image, inverse_xy):
def split_pic(image, inverse_xy, width, height, overlap_ratio):
if inverse_xy: if inverse_xy:
from_w, from_h = image.height, image.width from_w, from_h = image.height, image.width
to_w, to_h = height, width to_w, to_h = height, width
...@@ -116,14 +114,40 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre ...@@ -116,14 +114,40 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
yield splitted yield splitted
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
width = process_width
height = process_height
src = os.path.abspath(process_src)
dst = os.path.abspath(process_dst)
split_threshold = max(0.0, min(1.0, split_threshold))
overlap_ratio = max(0.0, min(0.9, overlap_ratio))
assert src != dst, 'same directory specified as source and destination'
os.makedirs(dst, exist_ok=True)
files = listfiles(src)
shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files)
params = PreprocessParams()
params.dstdir = dst
params.flip = process_flip
params.process_caption = process_caption
params.process_caption_deepbooru = process_caption_deepbooru
params.preprocess_txt_action = preprocess_txt_action
for index, imagefile in enumerate(tqdm.tqdm(files)): for index, imagefile in enumerate(tqdm.tqdm(files)):
subindex = [0] params.subindex = 0
filename = os.path.join(src, imagefile) filename = os.path.join(src, imagefile)
try: try:
img = Image.open(filename).convert("RGB") img = Image.open(filename).convert("RGB")
except Exception: except Exception:
continue continue
params.src = filename
existing_caption = None existing_caption = None
existing_caption_filename = os.path.splitext(filename)[0] + '.txt' existing_caption_filename = os.path.splitext(filename)[0] + '.txt'
if os.path.exists(existing_caption_filename): if os.path.exists(existing_caption_filename):
...@@ -143,8 +167,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre ...@@ -143,8 +167,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
process_default_resize = True process_default_resize = True
if process_split and ratio < 1.0 and ratio <= split_threshold: if process_split and ratio < 1.0 and ratio <= split_threshold:
for splitted in split_pic(img, inverse_xy): for splitted in split_pic(img, inverse_xy, width, height, overlap_ratio):
save_pic(splitted, index, existing_caption=existing_caption) save_pic(splitted, index, params, existing_caption=existing_caption)
process_default_resize = False process_default_resize = False
if process_focal_crop and img.height != img.width: if process_focal_crop and img.height != img.width:
...@@ -165,11 +189,11 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre ...@@ -165,11 +189,11 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
dnn_model_path = dnn_model_path, dnn_model_path = dnn_model_path,
) )
for focal in autocrop.crop_image(img, autocrop_settings): for focal in autocrop.crop_image(img, autocrop_settings):
save_pic(focal, index, existing_caption=existing_caption) save_pic(focal, index, params, existing_caption=existing_caption)
process_default_resize = False process_default_resize = False
if process_default_resize: if process_default_resize:
img = images.resize_image(1, img, width, height) img = images.resize_image(1, img, width, height)
save_pic(img, index, existing_caption=existing_caption) save_pic(img, index, params, existing_caption=existing_caption)
shared.state.nextjob() shared.state.nextjob()
...@@ -1272,6 +1272,10 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1272,6 +1272,10 @@ def create_ui(wrap_gradio_gpu_call):
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary') train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
train_embedding = gr.Button(value="Train Embedding", variant='primary') train_embedding = gr.Button(value="Train Embedding", variant='primary')
params = script_callbacks.UiTrainTabParams(txt2img_preview_params)
script_callbacks.ui_train_tabs_callback(params)
with gr.Column(): with gr.Column():
progressbar = gr.HTML(elem_id="ti_progressbar") progressbar = gr.HTML(elem_id="ti_progressbar")
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
...@@ -1758,7 +1762,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1758,7 +1762,7 @@ def create_ui(wrap_gradio_gpu_call):
return demo return demo
def load_javascript(raw_response): def reload_javascript():
with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
javascript = f'<script>{jsfile.read()}</script>' javascript = f'<script>{jsfile.read()}</script>'
...@@ -1774,7 +1778,7 @@ def load_javascript(raw_response): ...@@ -1774,7 +1778,7 @@ def load_javascript(raw_response):
javascript += f"\n<script>{localization.localization_js(shared.opts.localization)}</script>" javascript += f"\n<script>{localization.localization_js(shared.opts.localization)}</script>"
def template_response(*args, **kwargs): def template_response(*args, **kwargs):
res = raw_response(*args, **kwargs) res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
res.body = res.body.replace( res.body = res.body.replace(
b'</head>', f'{javascript}</head>'.encode("utf8")) b'</head>', f'{javascript}</head>'.encode("utf8"))
res.init_headers() res.init_headers()
...@@ -1783,4 +1787,5 @@ def load_javascript(raw_response): ...@@ -1783,4 +1787,5 @@ def load_javascript(raw_response):
gradio.routes.templates.TemplateResponse = template_response gradio.routes.templates.TemplateResponse = template_response
reload_javascript = partial(load_javascript, gradio.routes.templates.TemplateResponse) if not hasattr(shared, 'GradioTemplateResponseOriginal'):
shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment