Unverified Commit 649d79a8 authored by guaneec's avatar guaneec Committed by GitHub

Merge branch 'master' into hn-activation

parents 877d94f9 757264c4
This diff is collapsed.
...@@ -7,6 +7,7 @@ import uvicorn ...@@ -7,6 +7,7 @@ import uvicorn
from fastapi import Body, APIRouter, HTTPException from fastapi import Body, APIRouter, HTTPException
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, Json from pydantic import BaseModel, Field, Json
from typing import List
import json import json
import io import io
import base64 import base64
...@@ -15,12 +16,12 @@ from PIL import Image ...@@ -15,12 +16,12 @@ from PIL import Image
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
class TextToImageResponse(BaseModel): class TextToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: Json parameters: Json
info: Json info: Json
class ImageToImageResponse(BaseModel): class ImageToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: Json parameters: Json
info: Json info: Json
...@@ -65,7 +66,7 @@ class Api: ...@@ -65,7 +66,7 @@ class Api:
i.save(buffer, format="png") i.save(buffer, format="png")
b64images.append(base64.b64encode(buffer.getvalue())) b64images.append(base64.b64encode(buffer.getvalue()))
return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info)) return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.js())
...@@ -111,7 +112,11 @@ class Api: ...@@ -111,7 +112,11 @@ class Api:
i.save(buffer, format="png") i.save(buffer, format="png")
b64images.append(base64.b64encode(buffer.getvalue())) b64images.append(base64.b64encode(buffer.getvalue()))
return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=json.dumps(processed.info)) if (not img2imgreq.include_init_images):
img2imgreq.init_images = None
img2imgreq.mask = None
return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=processed.js())
def extrasapi(self): def extrasapi(self):
raise NotImplementedError raise NotImplementedError
......
...@@ -31,6 +31,7 @@ class ModelDef(BaseModel): ...@@ -31,6 +31,7 @@ class ModelDef(BaseModel):
field_alias: str field_alias: str
field_type: Any field_type: Any
field_value: Any field_value: Any
field_exclude: bool = False
class PydanticModelGenerator: class PydanticModelGenerator:
...@@ -78,7 +79,8 @@ class PydanticModelGenerator: ...@@ -78,7 +79,8 @@ class PydanticModelGenerator:
field=underscore(fields["key"]), field=underscore(fields["key"]),
field_alias=fields["key"], field_alias=fields["key"],
field_type=fields["type"], field_type=fields["type"],
field_value=fields["default"])) field_value=fields["default"],
field_exclude=fields["exclude"] if "exclude" in fields else False))
def generate_model(self): def generate_model(self):
""" """
...@@ -86,7 +88,7 @@ class PydanticModelGenerator: ...@@ -86,7 +88,7 @@ class PydanticModelGenerator:
from the json and overrides provided at initialization from the json and overrides provided at initialization
""" """
fields = { fields = {
d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def
} }
DynamicModel = create_model(self._model_name, **fields) DynamicModel = create_model(self._model_name, **fields)
DynamicModel.__config__.allow_population_by_field_name = True DynamicModel.__config__.allow_population_by_field_name = True
...@@ -102,5 +104,5 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( ...@@ -102,5 +104,5 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingImg2Img", "StableDiffusionProcessingImg2Img",
StableDiffusionProcessingImg2Img, StableDiffusionProcessingImg2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}] [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
).generate_model() ).generate_model()
\ No newline at end of file
...@@ -5,6 +5,7 @@ import html ...@@ -5,6 +5,7 @@ import html
import os import os
import sys import sys
import traceback import traceback
import inspect
import modules.textual_inversion.dataset import modules.textual_inversion.dataset
import torch import torch
...@@ -15,10 +16,12 @@ from modules import devices, processing, sd_models, shared ...@@ -15,10 +16,12 @@ from modules import devices, processing, sd_models, shared
from modules.textual_inversion import textual_inversion from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum from torch import einsum
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
from collections import defaultdict, deque from collections import defaultdict, deque
from statistics import stdev, mean from statistics import stdev, mean
class HypernetworkModule(torch.nn.Module): class HypernetworkModule(torch.nn.Module):
multiplier = 1.0 multiplier = 1.0
activation_dict = { activation_dict = {
...@@ -26,9 +29,12 @@ class HypernetworkModule(torch.nn.Module): ...@@ -26,9 +29,12 @@ class HypernetworkModule(torch.nn.Module):
"leakyrelu": torch.nn.LeakyReLU, "leakyrelu": torch.nn.LeakyReLU,
"elu": torch.nn.ELU, "elu": torch.nn.ELU,
"swish": torch.nn.Hardswish, "swish": torch.nn.Hardswish,
"tanh": torch.nn.Tanh,
"sigmoid": torch.nn.Sigmoid,
} }
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False, activate_output=False): def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False, activate_output=False):
super().__init__() super().__init__()
assert layer_structure is not None, "layer_structure must not be None" assert layer_structure is not None, "layer_structure must not be None"
...@@ -65,9 +71,24 @@ class HypernetworkModule(torch.nn.Module): ...@@ -65,9 +71,24 @@ class HypernetworkModule(torch.nn.Module):
else: else:
for layer in self.linear: for layer in self.linear:
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm: if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
layer.weight.data.normal_(mean=0.0, std=0.01) w, b = layer.weight.data, layer.bias.data
layer.bias.data.zero_() if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
normal_(w, mean=0.0, std=0.01)
normal_(b, mean=0.0, std=0.005)
elif weight_init == 'XavierUniform':
xavier_uniform_(w)
zeros_(b)
elif weight_init == 'XavierNormal':
xavier_normal_(w)
zeros_(b)
elif weight_init == 'KaimingUniform':
kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
zeros_(b)
elif weight_init == 'KaimingNormal':
kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
zeros_(b)
else:
raise KeyError(f"Key {weight_init} is not defined as initialization!")
self.to(devices.device) self.to(devices.device)
def fix_old_state_dict(self, state_dict): def fix_old_state_dict(self, state_dict):
...@@ -105,7 +126,7 @@ class Hypernetwork: ...@@ -105,7 +126,7 @@ class Hypernetwork:
filename = None filename = None
name = None name = None
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False, activate_output=False): def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False)
self.filename = None self.filename = None
self.name = name self.name = name
self.layers = {} self.layers = {}
...@@ -114,14 +135,15 @@ class Hypernetwork: ...@@ -114,14 +135,15 @@ class Hypernetwork:
self.sd_checkpoint_name = None self.sd_checkpoint_name = None
self.layer_structure = layer_structure self.layer_structure = layer_structure
self.activation_func = activation_func self.activation_func = activation_func
self.weight_init = weight_init
self.add_layer_norm = add_layer_norm self.add_layer_norm = add_layer_norm
self.use_dropout = use_dropout self.use_dropout = use_dropout
self.activate_output = activate_output self.activate_output = activate_output
for size in enable_sizes or []: for size in enable_sizes or []:
self.layers[size] = ( self.layers[size] = (
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output), HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output), HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
) )
def weights(self): def weights(self):
...@@ -145,6 +167,7 @@ class Hypernetwork: ...@@ -145,6 +167,7 @@ class Hypernetwork:
state_dict['layer_structure'] = self.layer_structure state_dict['layer_structure'] = self.layer_structure
state_dict['activation_func'] = self.activation_func state_dict['activation_func'] = self.activation_func
state_dict['is_layer_norm'] = self.add_layer_norm state_dict['is_layer_norm'] = self.add_layer_norm
state_dict['weight_initialization'] = self.weight_init
state_dict['use_dropout'] = self.use_dropout state_dict['use_dropout'] = self.use_dropout
state_dict['sd_checkpoint'] = self.sd_checkpoint state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
...@@ -160,16 +183,22 @@ class Hypernetwork: ...@@ -160,16 +183,22 @@ class Hypernetwork:
state_dict = torch.load(filename, map_location='cpu') state_dict = torch.load(filename, map_location='cpu')
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1]) self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
print(self.layer_structure)
self.activation_func = state_dict.get('activation_func', None) self.activation_func = state_dict.get('activation_func', None)
print(f"Activation function is {self.activation_func}")
self.weight_init = state_dict.get('weight_initialization', 'Normal')
print(f"Weight initialization is {self.weight_init}")
self.add_layer_norm = state_dict.get('is_layer_norm', False) self.add_layer_norm = state_dict.get('is_layer_norm', False)
self.use_dropout = state_dict.get('use_dropout', False) print(f"Layer norm is set to {self.add_layer_norm}")
self.use_dropout = state_dict.get('use_dropout', False
print(f"Dropout usage is set to {self.use_dropout}" )
self.activate_output = state_dict.get('activate_output', True) self.activate_output = state_dict.get('activate_output', True)
for size, sd in state_dict.items(): for size, sd in state_dict.items():
if type(size) == int: if type(size) == int:
self.layers[size] = ( self.layers[size] = (
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output), HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output), HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
) )
self.name = state_dict.get('name', self.name) self.name = state_dict.get('name', self.name)
......
...@@ -8,8 +8,9 @@ import modules.textual_inversion.textual_inversion ...@@ -8,8 +8,9 @@ import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack, shared from modules import devices, sd_hijack, shared
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
keys = list(hypernetwork.HypernetworkModule.activation_dict.keys())
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False): def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
# Remove illegal characters from name. # Remove illegal characters from name.
name = "".join( x for x in name if (x.isalnum() or x in "._- ")) name = "".join( x for x in name if (x.isalnum() or x in "._- "))
...@@ -25,6 +26,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, ...@@ -25,6 +26,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
enable_sizes=[int(x) for x in enable_sizes], enable_sizes=[int(x) for x in enable_sizes],
layer_structure=layer_structure, layer_structure=layer_structure,
activation_func=activation_func, activation_func=activation_func,
weight_init=weight_init,
add_layer_norm=add_layer_norm, add_layer_norm=add_layer_norm,
use_dropout=use_dropout, use_dropout=use_dropout,
) )
......
...@@ -277,7 +277,7 @@ invalid_filename_chars = '<>:"/\\|?*\n' ...@@ -277,7 +277,7 @@ invalid_filename_chars = '<>:"/\\|?*\n'
invalid_filename_prefix = ' ' invalid_filename_prefix = ' '
invalid_filename_postfix = ' .' invalid_filename_postfix = ' .'
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+') re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
re_pattern = re.compile(r"([^\[\]]+|\[([^]]+)]|[\[\]]*)") re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
re_pattern_arg = re.compile(r"(.*)<([^>]*)>$") re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
max_filename_part_length = 128 max_filename_part_length = 128
...@@ -343,7 +343,7 @@ class FilenameGenerator: ...@@ -343,7 +343,7 @@ class FilenameGenerator:
def datetime(self, *args): def datetime(self, *args):
time_datetime = datetime.datetime.now() time_datetime = datetime.datetime.now()
time_format = args[0] if len(args) > 0 else self.default_time_format time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
try: try:
time_zone = pytz.timezone(args[1]) if len(args) > 1 else None time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
except pytz.exceptions.UnknownTimeZoneError as _: except pytz.exceptions.UnknownTimeZoneError as _:
...@@ -362,9 +362,9 @@ class FilenameGenerator: ...@@ -362,9 +362,9 @@ class FilenameGenerator:
for m in re_pattern.finditer(x): for m in re_pattern.finditer(x):
text, pattern = m.groups() text, pattern = m.groups()
res += text
if pattern is None: if pattern is None:
res += text
continue continue
pattern_args = [] pattern_args = []
...@@ -385,11 +385,8 @@ class FilenameGenerator: ...@@ -385,11 +385,8 @@ class FilenameGenerator:
print(f"Error adding [{pattern}] to filename", file=sys.stderr) print(f"Error adding [{pattern}] to filename", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
if replacement is None: if replacement is not None:
res += f'[{pattern}]'
else:
res += str(replacement) res += str(replacement)
continue continue
res += f'[{pattern}]' res += f'[{pattern}]'
......
This diff is collapsed.
...@@ -7,12 +7,14 @@ import tqdm ...@@ -7,12 +7,14 @@ import tqdm
import time import time
from modules import shared, images from modules import shared, images
from modules.paths import models_path
from modules.shared import opts, cmd_opts from modules.shared import opts, cmd_opts
from modules.textual_inversion import autocrop
if cmd_opts.deepdanbooru: if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru import modules.deepbooru as deepbooru
def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2): def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
try: try:
if process_caption: if process_caption:
shared.interrogator.load() shared.interrogator.load()
...@@ -22,7 +24,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce ...@@ -22,7 +24,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce
db_opts[deepbooru.OPT_INCLUDE_RANKS] = False db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts) deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio) preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug)
finally: finally:
...@@ -34,7 +36,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce ...@@ -34,7 +36,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2): def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
width = process_width width = process_width
height = process_height height = process_height
src = os.path.abspath(process_src) src = os.path.abspath(process_src)
...@@ -113,6 +115,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre ...@@ -113,6 +115,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
splitted = image.crop((0, y, to_w, y + to_h)) splitted = image.crop((0, y, to_w, y + to_h))
yield splitted yield splitted
for index, imagefile in enumerate(tqdm.tqdm(files)): for index, imagefile in enumerate(tqdm.tqdm(files)):
subindex = [0] subindex = [0]
filename = os.path.join(src, imagefile) filename = os.path.join(src, imagefile)
...@@ -137,10 +140,35 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre ...@@ -137,10 +140,35 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
ratio = (img.height * width) / (img.width * height) ratio = (img.height * width) / (img.width * height)
inverse_xy = True inverse_xy = True
process_default_resize = True
if process_split and ratio < 1.0 and ratio <= split_threshold: if process_split and ratio < 1.0 and ratio <= split_threshold:
for splitted in split_pic(img, inverse_xy): for splitted in split_pic(img, inverse_xy):
save_pic(splitted, index, existing_caption=existing_caption) save_pic(splitted, index, existing_caption=existing_caption)
else: process_default_resize = False
if process_focal_crop and img.height != img.width:
dnn_model_path = None
try:
dnn_model_path = autocrop.download_and_cache_models(os.path.join(models_path, "opencv"))
except Exception as e:
print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e)
autocrop_settings = autocrop.Settings(
crop_width = width,
crop_height = height,
face_points_weight = process_focal_crop_face_weight,
entropy_points_weight = process_focal_crop_entropy_weight,
corner_points_weight = process_focal_crop_edges_weight,
annotate_image = process_focal_crop_debug,
dnn_model_path = dnn_model_path,
)
for focal in autocrop.crop_image(img, autocrop_settings):
save_pic(focal, index, existing_caption=existing_caption)
process_default_resize = False
if process_default_resize:
img = images.resize_image(1, img, width, height) img = images.resize_image(1, img, width, height)
save_pic(img, index, existing_caption=existing_caption) save_pic(img, index, existing_caption=existing_caption)
......
...@@ -157,6 +157,9 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): ...@@ -157,6 +157,9 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
cond_model = shared.sd_model.cond_stage_model cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
with devices.autocast():
cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"] ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0) embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device) vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
......
...@@ -1238,7 +1238,8 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1238,7 +1238,8 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_name = gr.Textbox(label="Name") new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'") new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu", "elu", "swish"]) new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=modules.hypernetworks.ui.keys)
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. relu-like - Kaiming, sigmoid-like - Xavier is recommended", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout") new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork") overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
...@@ -1260,6 +1261,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1260,6 +1261,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Row(): with gr.Row():
process_flip = gr.Checkbox(label='Create flipped copies') process_flip = gr.Checkbox(label='Create flipped copies')
process_split = gr.Checkbox(label='Split oversized images') process_split = gr.Checkbox(label='Split oversized images')
process_focal_crop = gr.Checkbox(label='Auto focal point crop')
process_caption = gr.Checkbox(label='Use BLIP for caption') process_caption = gr.Checkbox(label='Use BLIP for caption')
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False) process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False)
...@@ -1267,6 +1269,12 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1267,6 +1269,12 @@ def create_ui(wrap_gradio_gpu_call):
process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05) process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05)
process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05) process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05)
with gr.Row(visible=False) as process_focal_crop_row:
process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05)
process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05)
process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05)
process_focal_crop_debug = gr.Checkbox(label='Create debug image')
with gr.Row(): with gr.Row():
with gr.Column(scale=3): with gr.Column(scale=3):
gr.HTML(value="") gr.HTML(value="")
...@@ -1280,6 +1288,12 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1280,6 +1288,12 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[process_split_extra_row], outputs=[process_split_extra_row],
) )
process_focal_crop.change(
fn=lambda show: gr_show(show),
inputs=[process_focal_crop],
outputs=[process_focal_crop_row],
)
with gr.Tab(label="Train"): with gr.Tab(label="Train"):
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>") gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
with gr.Row(): with gr.Row():
...@@ -1342,6 +1356,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1342,6 +1356,7 @@ def create_ui(wrap_gradio_gpu_call):
overwrite_old_hypernetwork, overwrite_old_hypernetwork,
new_hypernetwork_layer_structure, new_hypernetwork_layer_structure,
new_hypernetwork_activation_func, new_hypernetwork_activation_func,
new_hypernetwork_initialization_option,
new_hypernetwork_add_layer_norm, new_hypernetwork_add_layer_norm,
new_hypernetwork_use_dropout new_hypernetwork_use_dropout
], ],
...@@ -1367,6 +1382,11 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1367,6 +1382,11 @@ def create_ui(wrap_gradio_gpu_call):
process_caption_deepbooru, process_caption_deepbooru,
process_split_threshold, process_split_threshold,
process_overlap_ratio, process_overlap_ratio,
process_focal_crop,
process_focal_crop_face_weight,
process_focal_crop_entropy_weight,
process_focal_crop_edges_weight,
process_focal_crop_debug,
], ],
outputs=[ outputs=[
ti_output, ti_output,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment