Commit d85c2cb2 authored by Muhammad Rizqi Nur's avatar Muhammad Rizqi Nur

Merge branch 'master' into gradient-clipping

parents cabd4e3b ac085628
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -7,6 +7,7 @@ from typing import Optional ...@@ -7,6 +7,7 @@ from typing import Optional
from fastapi import FastAPI from fastapi import FastAPI
from gradio import Blocks from gradio import Blocks
def report_exception(c, job): def report_exception(c, job):
print(f"Error executing callback {job} for {c.script}", file=sys.stderr) print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
...@@ -45,15 +46,21 @@ class CFGDenoiserParams: ...@@ -45,15 +46,21 @@ class CFGDenoiserParams:
"""Total number of sampling steps planned""" """Total number of sampling steps planned"""
class UiTrainTabParams:
def __init__(self, txt2img_preview_params):
self.txt2img_preview_params = txt2img_preview_params
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
callback_map = dict( callback_map = dict(
callbacks_app_started=[], callbacks_app_started=[],
callbacks_model_loaded=[], callbacks_model_loaded=[],
callbacks_ui_tabs=[], callbacks_ui_tabs=[],
callbacks_ui_train_tabs=[],
callbacks_ui_settings=[], callbacks_ui_settings=[],
callbacks_before_image_saved=[], callbacks_before_image_saved=[],
callbacks_image_saved=[], callbacks_image_saved=[],
callbacks_cfg_denoiser=[] callbacks_cfg_denoiser=[],
) )
...@@ -61,6 +68,7 @@ def clear_callbacks(): ...@@ -61,6 +68,7 @@ def clear_callbacks():
for callback_list in callback_map.values(): for callback_list in callback_map.values():
callback_list.clear() callback_list.clear()
def app_started_callback(demo: Optional[Blocks], app: FastAPI): def app_started_callback(demo: Optional[Blocks], app: FastAPI):
for c in callback_map['callbacks_app_started']: for c in callback_map['callbacks_app_started']:
try: try:
...@@ -79,7 +87,7 @@ def model_loaded_callback(sd_model): ...@@ -79,7 +87,7 @@ def model_loaded_callback(sd_model):
def ui_tabs_callback(): def ui_tabs_callback():
res = [] res = []
for c in callback_map['callbacks_ui_tabs']: for c in callback_map['callbacks_ui_tabs']:
try: try:
res += c.callback() or [] res += c.callback() or []
...@@ -89,6 +97,14 @@ def ui_tabs_callback(): ...@@ -89,6 +97,14 @@ def ui_tabs_callback():
return res return res
def ui_train_tabs_callback(params: UiTrainTabParams):
for c in callback_map['callbacks_ui_train_tabs']:
try:
c.callback(params)
except Exception:
report_exception(c, 'callbacks_ui_train_tabs')
def ui_settings_callback(): def ui_settings_callback():
for c in callback_map['callbacks_ui_settings']: for c in callback_map['callbacks_ui_settings']:
try: try:
...@@ -169,6 +185,13 @@ def on_ui_tabs(callback): ...@@ -169,6 +185,13 @@ def on_ui_tabs(callback):
add_callback(callback_map['callbacks_ui_tabs'], callback) add_callback(callback_map['callbacks_ui_tabs'], callback)
def on_ui_train_tabs(callback):
"""register a function to be called when the UI is creating new tabs for the train tab.
Create your new tabs with gr.Tab.
"""
add_callback(callback_map['callbacks_ui_train_tabs'], callback)
def on_ui_settings(callback): def on_ui_settings(callback):
"""register a function to be called before UI settings are populated; add your settings """register a function to be called before UI settings are populated; add your settings
by using shared.opts.add_option(shared.OptionInfo(...)) """ by using shared.opts.add_option(shared.OptionInfo(...)) """
......
...@@ -35,6 +35,84 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce ...@@ -35,6 +35,84 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce
deepbooru.release_process() deepbooru.release_process()
def listfiles(dirname):
return os.listdir(dirname)
class PreprocessParams:
src = None
dstdir = None
subindex = 0
flip = False
process_caption = False
process_caption_deepbooru = False
preprocess_txt_action = None
def save_pic_with_caption(image, index, params: PreprocessParams, existing_caption=None):
caption = ""
if params.process_caption:
caption += shared.interrogator.generate_caption(image)
if params.process_caption_deepbooru:
if len(caption) > 0:
caption += ", "
caption += deepbooru.get_tags_from_process(image)
filename_part = params.src
filename_part = os.path.splitext(filename_part)[0]
filename_part = os.path.basename(filename_part)
basename = f"{index:05}-{params.subindex}-{filename_part}"
image.save(os.path.join(params.dstdir, f"{basename}.png"))
if params.preprocess_txt_action == 'prepend' and existing_caption:
caption = existing_caption + ' ' + caption
elif params.preprocess_txt_action == 'append' and existing_caption:
caption = caption + ' ' + existing_caption
elif params.preprocess_txt_action == 'copy' and existing_caption:
caption = existing_caption
caption = caption.strip()
if len(caption) > 0:
with open(os.path.join(params.dstdir, f"{basename}.txt"), "w", encoding="utf8") as file:
file.write(caption)
params.subindex += 1
def save_pic(image, index, params, existing_caption=None):
save_pic_with_caption(image, index, params, existing_caption=existing_caption)
if params.flip:
save_pic_with_caption(ImageOps.mirror(image), index, params, existing_caption=existing_caption)
def split_pic(image, inverse_xy, width, height, overlap_ratio):
if inverse_xy:
from_w, from_h = image.height, image.width
to_w, to_h = height, width
else:
from_w, from_h = image.width, image.height
to_w, to_h = width, height
h = from_h * to_w // from_w
if inverse_xy:
image = image.resize((h, to_w))
else:
image = image.resize((to_w, h))
split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
y_step = (h - to_h) / (split_count - 1)
for i in range(split_count):
y = int(y_step * i)
if inverse_xy:
splitted = image.crop((y, 0, y + to_h, to_w))
else:
splitted = image.crop((0, y, to_w, y + to_h))
yield splitted
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False): def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
width = process_width width = process_width
...@@ -48,82 +126,28 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre ...@@ -48,82 +126,28 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
os.makedirs(dst, exist_ok=True) os.makedirs(dst, exist_ok=True)
files = os.listdir(src) files = listfiles(src)
shared.state.textinfo = "Preprocessing..." shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files) shared.state.job_count = len(files)
def save_pic_with_caption(image, index, existing_caption=None): params = PreprocessParams()
caption = "" params.dstdir = dst
params.flip = process_flip
if process_caption: params.process_caption = process_caption
caption += shared.interrogator.generate_caption(image) params.process_caption_deepbooru = process_caption_deepbooru
params.preprocess_txt_action = preprocess_txt_action
if process_caption_deepbooru:
if len(caption) > 0:
caption += ", "
caption += deepbooru.get_tags_from_process(image)
filename_part = filename
filename_part = os.path.splitext(filename_part)[0]
filename_part = os.path.basename(filename_part)
basename = f"{index:05}-{subindex[0]}-{filename_part}"
image.save(os.path.join(dst, f"{basename}.png"))
if preprocess_txt_action == 'prepend' and existing_caption:
caption = existing_caption + ' ' + caption
elif preprocess_txt_action == 'append' and existing_caption:
caption = caption + ' ' + existing_caption
elif preprocess_txt_action == 'copy' and existing_caption:
caption = existing_caption
caption = caption.strip()
if len(caption) > 0:
with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
file.write(caption)
subindex[0] += 1
def save_pic(image, index, existing_caption=None):
save_pic_with_caption(image, index, existing_caption=existing_caption)
if process_flip:
save_pic_with_caption(ImageOps.mirror(image), index, existing_caption=existing_caption)
def split_pic(image, inverse_xy):
if inverse_xy:
from_w, from_h = image.height, image.width
to_w, to_h = height, width
else:
from_w, from_h = image.width, image.height
to_w, to_h = width, height
h = from_h * to_w // from_w
if inverse_xy:
image = image.resize((h, to_w))
else:
image = image.resize((to_w, h))
split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
y_step = (h - to_h) / (split_count - 1)
for i in range(split_count):
y = int(y_step * i)
if inverse_xy:
splitted = image.crop((y, 0, y + to_h, to_w))
else:
splitted = image.crop((0, y, to_w, y + to_h))
yield splitted
for index, imagefile in enumerate(tqdm.tqdm(files)): for index, imagefile in enumerate(tqdm.tqdm(files)):
subindex = [0] params.subindex = 0
filename = os.path.join(src, imagefile) filename = os.path.join(src, imagefile)
try: try:
img = Image.open(filename).convert("RGB") img = Image.open(filename).convert("RGB")
except Exception: except Exception:
continue continue
params.src = filename
existing_caption = None existing_caption = None
existing_caption_filename = os.path.splitext(filename)[0] + '.txt' existing_caption_filename = os.path.splitext(filename)[0] + '.txt'
if os.path.exists(existing_caption_filename): if os.path.exists(existing_caption_filename):
...@@ -143,8 +167,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre ...@@ -143,8 +167,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
process_default_resize = True process_default_resize = True
if process_split and ratio < 1.0 and ratio <= split_threshold: if process_split and ratio < 1.0 and ratio <= split_threshold:
for splitted in split_pic(img, inverse_xy): for splitted in split_pic(img, inverse_xy, width, height, overlap_ratio):
save_pic(splitted, index, existing_caption=existing_caption) save_pic(splitted, index, params, existing_caption=existing_caption)
process_default_resize = False process_default_resize = False
if process_focal_crop and img.height != img.width: if process_focal_crop and img.height != img.width:
...@@ -165,11 +189,11 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre ...@@ -165,11 +189,11 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
dnn_model_path = dnn_model_path, dnn_model_path = dnn_model_path,
) )
for focal in autocrop.crop_image(img, autocrop_settings): for focal in autocrop.crop_image(img, autocrop_settings):
save_pic(focal, index, existing_caption=existing_caption) save_pic(focal, index, params, existing_caption=existing_caption)
process_default_resize = False process_default_resize = False
if process_default_resize: if process_default_resize:
img = images.resize_image(1, img, width, height) img = images.resize_image(1, img, width, height)
save_pic(img, index, existing_caption=existing_caption) save_pic(img, index, params, existing_caption=existing_caption)
shared.state.nextjob() shared.state.nextjob()
\ No newline at end of file
...@@ -1272,6 +1272,10 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1272,6 +1272,10 @@ def create_ui(wrap_gradio_gpu_call):
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary') train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
train_embedding = gr.Button(value="Train Embedding", variant='primary') train_embedding = gr.Button(value="Train Embedding", variant='primary')
params = script_callbacks.UiTrainTabParams(txt2img_preview_params)
script_callbacks.ui_train_tabs_callback(params)
with gr.Column(): with gr.Column():
progressbar = gr.HTML(elem_id="ti_progressbar") progressbar = gr.HTML(elem_id="ti_progressbar")
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
...@@ -1758,7 +1762,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1758,7 +1762,7 @@ def create_ui(wrap_gradio_gpu_call):
return demo return demo
def load_javascript(raw_response): def reload_javascript():
with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
javascript = f'<script>{jsfile.read()}</script>' javascript = f'<script>{jsfile.read()}</script>'
...@@ -1774,7 +1778,7 @@ def load_javascript(raw_response): ...@@ -1774,7 +1778,7 @@ def load_javascript(raw_response):
javascript += f"\n<script>{localization.localization_js(shared.opts.localization)}</script>" javascript += f"\n<script>{localization.localization_js(shared.opts.localization)}</script>"
def template_response(*args, **kwargs): def template_response(*args, **kwargs):
res = raw_response(*args, **kwargs) res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
res.body = res.body.replace( res.body = res.body.replace(
b'</head>', f'{javascript}</head>'.encode("utf8")) b'</head>', f'{javascript}</head>'.encode("utf8"))
res.init_headers() res.init_headers()
...@@ -1783,4 +1787,5 @@ def load_javascript(raw_response): ...@@ -1783,4 +1787,5 @@ def load_javascript(raw_response):
gradio.routes.templates.TemplateResponse = template_response gradio.routes.templates.TemplateResponse = template_response
reload_javascript = partial(load_javascript, gradio.routes.templates.TemplateResponse) if not hasattr(shared, 'GradioTemplateResponseOriginal'):
shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment