Commit b48b7999 authored by AUTOMATIC's avatar AUTOMATIC

Merge remote-tracking branch 'flamelaw/master'

parents b0063827 755df94b
...@@ -38,7 +38,7 @@ class HypernetworkModule(torch.nn.Module): ...@@ -38,7 +38,7 @@ class HypernetworkModule(torch.nn.Module):
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'}) activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=True): add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=False):
super().__init__() super().__init__()
assert layer_structure is not None, "layer_structure must not be None" assert layer_structure is not None, "layer_structure must not be None"
...@@ -154,16 +154,28 @@ class Hypernetwork: ...@@ -154,16 +154,28 @@ class Hypernetwork:
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout), self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
) )
self.eval_mode()
def weights(self): def weights(self):
res = [] res = []
for k, layers in self.layers.items():
for layer in layers:
res += layer.parameters()
return res
def train_mode(self):
for k, layers in self.layers.items(): for k, layers in self.layers.items():
for layer in layers: for layer in layers:
layer.train() layer.train()
res += layer.trainables() for param in layer.parameters():
param.requires_grad = True
return res def eval_mode(self):
for k, layers in self.layers.items():
for layer in layers:
layer.eval()
for param in layer.parameters():
param.requires_grad = False
def save(self, filename): def save(self, filename):
state_dict = {} state_dict = {}
...@@ -367,13 +379,13 @@ def report_statistics(loss_info:dict): ...@@ -367,13 +379,13 @@ def report_statistics(loss_info:dict):
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem. # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images from modules import images
save_hypernetwork_every = save_hypernetwork_every or 0 save_hypernetwork_every = save_hypernetwork_every or 0
create_image_every = create_image_every or 0 create_image_every = create_image_every or 0
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork") textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
path = shared.hypernetworks.get(hypernetwork_name, None) path = shared.hypernetworks.get(hypernetwork_name, None)
shared.loaded_hypernetwork = Hypernetwork() shared.loaded_hypernetwork = Hypernetwork()
...@@ -403,32 +415,30 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log ...@@ -403,32 +415,30 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
hypernetwork = shared.loaded_hypernetwork hypernetwork = shared.loaded_hypernetwork
checkpoint = sd_models.select_checkpoint() checkpoint = sd_models.select_checkpoint()
ititial_step = hypernetwork.step or 0 initial_step = hypernetwork.step or 0
if ititial_step >= steps: if initial_step >= steps:
shared.state.textinfo = f"Model has already been trained beyond specified max steps" shared.state.textinfo = f"Model has already been trained beyond specified max steps"
return hypernetwork, filename return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
# dataset loading may take a while, so input validations and early returns should be done before this # dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size) pin_memory = shared.opts.pin_memory
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
latent_sampling_method = ds.latent_sampling_method
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
if unload: if unload:
shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu)
size = len(ds.indexes)
loss_dict = defaultdict(lambda : deque(maxlen = 1024))
losses = torch.zeros((size,))
previous_mean_losses = [0]
previous_mean_loss = 0
print("Mean loss of {} elements".format(size))
weights = hypernetwork.weights() weights = hypernetwork.weights()
for weight in weights: hypernetwork.train_mode()
weight.requires_grad = True
# Here we use optimizer from saved HN, or we can specify as UI option. # Here we use optimizer from saved HN, or we can specify as UI option.
if hypernetwork.optimizer_name in optimizer_dict: if hypernetwork.optimizer_name in optimizer_dict:
...@@ -446,131 +456,156 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log ...@@ -446,131 +456,156 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
print("Cannot resume from saved optimizer!") print("Cannot resume from saved optimizer!")
print(e) print(e)
scaler = torch.cuda.amp.GradScaler()
batch_size = ds.batch_size
gradient_step = ds.gradient_step
# n steps = batch_size * gradient_step * n image processed
steps_per_epoch = len(ds) // batch_size // gradient_step
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
loss_step = 0
_loss_step = 0 #internal
# size = len(ds.indexes)
# loss_dict = defaultdict(lambda : deque(maxlen = 1024))
# losses = torch.zeros((size,))
# previous_mean_losses = [0]
# previous_mean_loss = 0
# print("Mean loss of {} elements".format(size))
steps_without_grad = 0 steps_without_grad = 0
last_saved_file = "<none>" last_saved_file = "<none>"
last_saved_image = "<none>" last_saved_image = "<none>"
forced_filename = "<none>" forced_filename = "<none>"
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) pbar = tqdm.tqdm(total=steps - initial_step)
for i, entries in pbar: try:
hypernetwork.step = i + ititial_step for i in range((steps-initial_step) * gradient_step):
if len(loss_dict) > 0: if scheduler.finished:
previous_mean_losses = [i[-1] for i in loss_dict.values()] break
previous_mean_loss = mean(previous_mean_losses) if shared.state.interrupted:
break
scheduler.apply(optimizer, hypernetwork.step) for j, batch in enumerate(dl):
if scheduler.finished: # works as a drop_last=True for gradient accumulation
break if j == max_steps_per_epoch:
break
if shared.state.interrupted: scheduler.apply(optimizer, hypernetwork.step)
break if scheduler.finished:
break
with torch.autocast("cuda"): if shared.state.interrupted:
c = stack_conds([entry.cond for entry in entries]).to(devices.device) break
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
x = torch.stack([entry.latent for entry in entries]).to(devices.device) with torch.autocast("cuda"):
loss = shared.sd_model(x, c)[0] x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
del x if tag_drop_out != 0 or shuffle_tags:
del c shared.sd_model.cond_stage_model.to(devices.device)
c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
losses[hypernetwork.step % losses.shape[0]] = loss.item() shared.sd_model.cond_stage_model.to(devices.cpu)
for entry in entries: else:
loss_dict[entry.filename].append(loss.item()) c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
loss = shared.sd_model(x, c)[0] / gradient_step
del x
del c
_loss_step += loss.item()
scaler.scale(loss).backward()
# go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0:
continue
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}")
# scaler.unscale_(optimizer)
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
# torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0)
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
scaler.step(optimizer)
scaler.update()
hypernetwork.step += 1
pbar.update()
optimizer.zero_grad(set_to_none=True)
loss_step = _loss_step
_loss_step = 0
steps_done = hypernetwork.step + 1
optimizer.zero_grad() epoch_num = hypernetwork.step // steps_per_epoch
weights[0].grad = None epoch_step = hypernetwork.step % steps_per_epoch
loss.backward()
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
if weights[0].grad is None: if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
steps_without_grad += 1 # Before saving, change name to match current checkpoint.
else: hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
steps_without_grad = 0 last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue' hypernetwork.optimizer_name = optimizer_name
if shared.opts.save_optimizer_state:
optimizer.step() hypernetwork.optimizer_state_dict = optimizer.state_dict()
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
steps_done = hypernetwork.step + 1 hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]): textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
raise RuntimeError("Loss diverged.") "loss": f"{loss_step:.7f}",
"learn_rate": scheduler.learn_rate
if len(previous_mean_losses) > 1: })
std = stdev(previous_mean_losses)
else: if images_dir is not None and steps_done % create_image_every == 0:
std = 0 forced_filename = f'{hypernetwork_name}-{steps_done}'
dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})" last_saved_image = os.path.join(images_dir, forced_filename)
pbar.set_description(dataset_loss_info) hypernetwork.eval_mode()
shared.sd_model.cond_stage_model.to(devices.device)
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0: shared.sd_model.first_stage_model.to(devices.device)
# Before saving, change name to match current checkpoint.
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}' p = processing.StableDiffusionProcessingTxt2Img(
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt') sd_model=shared.sd_model,
hypernetwork.optimizer_name = optimizer_name do_not_save_grid=True,
if shared.opts.save_optimizer_state: do_not_save_samples=True,
hypernetwork.optimizer_state_dict = optimizer.state_dict() )
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory. if preview_from_txt2img:
p.prompt = preview_prompt
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { p.negative_prompt = preview_negative_prompt
"loss": f"{previous_mean_loss:.7f}", p.steps = preview_steps
"learn_rate": scheduler.learn_rate p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
}) p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
if images_dir is not None and steps_done % create_image_every == 0: p.width = preview_width
forced_filename = f'{hypernetwork_name}-{steps_done}' p.height = preview_height
last_saved_image = os.path.join(images_dir, forced_filename) else:
p.prompt = batch.cond_text[0]
optimizer.zero_grad() p.steps = 20
shared.sd_model.cond_stage_model.to(devices.device) p.width = training_width
shared.sd_model.first_stage_model.to(devices.device) p.height = training_height
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
do_not_save_grid=True,
do_not_save_samples=True,
)
if preview_from_txt2img:
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width
p.height = preview_height
else:
p.prompt = entries[0].cond_text
p.steps = 20
preview_text = p.prompt
processed = processing.process_images(p) preview_text = p.prompt
image = processed.images[0] if len(processed.images)>0 else None
if unload: processed = processing.process_images(p)
shared.sd_model.cond_stage_model.to(devices.cpu) image = processed.images[0] if len(processed.images) > 0 else None
shared.sd_model.first_stage_model.to(devices.cpu)
if image is not None: if unload:
shared.state.current_image = image shared.sd_model.cond_stage_model.to(devices.cpu)
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) shared.sd_model.first_stage_model.to(devices.cpu)
last_saved_image += f", prompt: {preview_text}" hypernetwork.train_mode()
if image is not None:
shared.state.current_image = image
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = hypernetwork.step shared.state.job_no = hypernetwork.step
shared.state.textinfo = f""" shared.state.textinfo = f"""
<p> <p>
Loss: {previous_mean_loss:.7f}<br/> Loss: {loss_step:.7f}<br/>
Step: {hypernetwork.step}<br/> Step: {steps_done}<br/>
Last prompt: {html.escape(entries[0].cond_text)}<br/> Last prompt: {html.escape(batch.cond_text[0])}<br/>
Last saved hypernetwork: {html.escape(last_saved_file)}<br/> Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/> Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>
""" """
except Exception:
report_statistics(loss_dict) print(traceback.format_exc(), file=sys.stderr)
finally:
pbar.leave = False
pbar.close()
hypernetwork.eval_mode()
#report_statistics(loss_dict)
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
hypernetwork.optimizer_name = optimizer_name hypernetwork.optimizer_name = optimizer_name
...@@ -579,6 +614,9 @@ Last saved image: {html.escape(last_saved_image)}<br/> ...@@ -579,6 +614,9 @@ Last saved image: {html.escape(last_saved_image)}<br/>
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename) save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
del optimizer del optimizer
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory. hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
return hypernetwork, filename return hypernetwork, filename
def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename): def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
......
...@@ -8,9 +8,9 @@ from torch import einsum ...@@ -8,9 +8,9 @@ from torch import einsum
from torch.nn.functional import silu from torch.nn.functional import silu
import modules.textual_inversion.textual_inversion import modules.textual_inversion.textual_inversion
from modules import prompt_parser, devices, sd_hijack_optimizations, shared from modules import prompt_parser, devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts from modules.shared import opts, device, cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip from modules import sd_hijack_clip, sd_hijack_open_clip
from modules.sd_hijack_optimizations import invokeAI_mps_available from modules.sd_hijack_optimizations import invokeAI_mps_available
...@@ -66,6 +66,10 @@ def undo_optimizations(): ...@@ -66,6 +66,10 @@ def undo_optimizations():
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
def fix_checkpoint():
ldm.modules.attention.BasicTransformerBlock.forward = sd_hijack_checkpoint.BasicTransformerBlock_forward
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
class StableDiffusionModelHijack: class StableDiffusionModelHijack:
fixes = None fixes = None
...@@ -88,6 +92,7 @@ class StableDiffusionModelHijack: ...@@ -88,6 +92,7 @@ class StableDiffusionModelHijack:
self.clip = m.cond_stage_model self.clip = m.cond_stage_model
apply_optimizations() apply_optimizations()
fix_checkpoint()
def flatten(el): def flatten(el):
flattened = [flatten(children) for children in el.children()] flattened = [flatten(children) for children in el.children()]
......
from torch.utils.checkpoint import checkpoint
def BasicTransformerBlock_forward(self, x, context=None):
return checkpoint(self._forward, x, context)
def AttentionBlock_forward(self, x):
return checkpoint(self._forward, x)
def ResBlock_forward(self, x, emb):
return checkpoint(self._forward, x, emb)
\ No newline at end of file
...@@ -345,8 +345,7 @@ options_templates.update(options_section(('system', "System"), { ...@@ -345,8 +345,7 @@ options_templates.update(options_section(('system', "System"), {
options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"shuffle_tags": OptionInfo(False, "Shuffleing tags by ',' when create texts."), "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
"tag_drop_out": OptionInfo(0, "Dropout tags when create texts", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.1}),
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."), "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
......
...@@ -3,7 +3,7 @@ import numpy as np ...@@ -3,7 +3,7 @@ import numpy as np
import PIL import PIL
import torch import torch
from PIL import Image from PIL import Image
from torch.utils.data import Dataset from torch.utils.data import Dataset, DataLoader
from torchvision import transforms from torchvision import transforms
import random import random
...@@ -11,25 +11,28 @@ import tqdm ...@@ -11,25 +11,28 @@ import tqdm
from modules import devices, shared from modules import devices, shared
import re import re
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
re_numbers_at_start = re.compile(r"^[-\d]+\s*") re_numbers_at_start = re.compile(r"^[-\d]+\s*")
class DatasetEntry: class DatasetEntry:
def __init__(self, filename=None, latent=None, filename_text=None): def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None):
self.filename = filename self.filename = filename
self.latent = latent
self.filename_text = filename_text self.filename_text = filename_text
self.cond = None self.latent_dist = latent_dist
self.cond_text = None self.latent_sample = latent_sample
self.cond = cond
self.cond_text = cond_text
self.pixel_values = pixel_values
class PersonalizedBase(Dataset): class PersonalizedBase(Dataset):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1): def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
self.placeholder_token = placeholder_token self.placeholder_token = placeholder_token
self.batch_size = batch_size
self.width = width self.width = width
self.height = height self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p) self.flip = transforms.RandomHorizontalFlip(p=flip_p)
...@@ -45,11 +48,16 @@ class PersonalizedBase(Dataset): ...@@ -45,11 +48,16 @@ class PersonalizedBase(Dataset):
assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty" assert os.listdir(data_root), "Dataset directory is empty"
cond_model = shared.sd_model.cond_stage_model
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
self.shuffle_tags = shuffle_tags
self.tag_drop_out = tag_drop_out
print("Preparing dataset...") print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths): for path in tqdm.tqdm(self.image_paths):
if shared.state.interrupted:
raise Exception("inturrupted")
try: try:
image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC) image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
except Exception: except Exception:
...@@ -71,37 +79,49 @@ class PersonalizedBase(Dataset): ...@@ -71,37 +79,49 @@ class PersonalizedBase(Dataset):
npimage = np.array(image).astype(np.uint8) npimage = np.array(image).astype(np.uint8)
npimage = (npimage / 127.5 - 1.0).astype(np.float32) npimage = (npimage / 127.5 - 1.0).astype(np.float32)
torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32) torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
torchdata = torch.moveaxis(torchdata, 2, 0) latent_sample = None
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze() with torch.autocast("cuda"):
init_latent = init_latent.to(devices.cpu) latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent) if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)):
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
if include_cond: latent_sampling_method = "once"
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
elif latent_sampling_method == "deterministic":
# Works only for DiagonalGaussianDistribution
latent_dist.std = 0
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
elif latent_sampling_method == "random":
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist)
if not (self.tag_drop_out != 0 or self.shuffle_tags):
entry.cond_text = self.create_text(filename_text) entry.cond_text = self.create_text(filename_text)
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
self.dataset.append(entry) if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
with torch.autocast("cuda"):
assert len(self.dataset) > 0, "No images have been found in the dataset." entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
self.length = len(self.dataset) * repeats // batch_size
self.dataset_length = len(self.dataset) self.dataset.append(entry)
self.indexes = None del torchdata
self.shuffle() del latent_dist
del latent_sample
def shuffle(self): self.length = len(self.dataset)
self.indexes = np.random.permutation(self.dataset_length) assert self.length > 0, "No images have been found in the dataset."
self.batch_size = min(batch_size, self.length)
self.gradient_step = min(gradient_step, self.length // self.batch_size)
self.latent_sampling_method = latent_sampling_method
def create_text(self, filename_text): def create_text(self, filename_text):
text = random.choice(self.lines) text = random.choice(self.lines)
text = text.replace("[name]", self.placeholder_token) text = text.replace("[name]", self.placeholder_token)
tags = filename_text.split(',') tags = filename_text.split(',')
if shared.opts.tag_drop_out != 0: if self.tag_drop_out != 0:
tags = [t for t in tags if random.random() > shared.opts.tag_drop_out] tags = [t for t in tags if random.random() > self.tag_drop_out]
if shared.opts.shuffle_tags: if self.shuffle_tags:
random.shuffle(tags) random.shuffle(tags)
text = text.replace("[filewords]", ','.join(tags)) text = text.replace("[filewords]", ','.join(tags))
return text return text
...@@ -110,19 +130,43 @@ class PersonalizedBase(Dataset): ...@@ -110,19 +130,43 @@ class PersonalizedBase(Dataset):
return self.length return self.length
def __getitem__(self, i): def __getitem__(self, i):
res = [] entry = self.dataset[i]
if self.tag_drop_out != 0 or self.shuffle_tags:
for j in range(self.batch_size): entry.cond_text = self.create_text(entry.filename_text)
position = i * self.batch_size + j if self.latent_sampling_method == "random":
if position % len(self.indexes) == 0: entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
self.shuffle() return entry
index = self.indexes[position % len(self.indexes)] class PersonalizedDataLoader(DataLoader):
entry = self.dataset[index] def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
super(PersonalizedDataLoader, self).__init__(dataset, shuffle=True, drop_last=True, batch_size=batch_size, pin_memory=pin_memory)
if entry.cond is None: if latent_sampling_method == "random":
entry.cond_text = self.create_text(entry.filename_text) self.collate_fn = collate_wrapper_random
else:
res.append(entry) self.collate_fn = collate_wrapper
return res
class BatchLoader:
def __init__(self, data):
self.cond_text = [entry.cond_text for entry in data]
self.cond = [entry.cond for entry in data]
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
#self.emb_index = [entry.emb_index for entry in data]
#print(self.latent_sample.device)
def pin_memory(self):
self.latent_sample = self.latent_sample.pin_memory()
return self
def collate_wrapper(batch):
return BatchLoader(batch)
class BatchLoaderRandom(BatchLoader):
def __init__(self, data):
super().__init__(data)
def pin_memory(self):
return self
def collate_wrapper_random(batch):
return BatchLoaderRandom(batch)
\ No newline at end of file
...@@ -183,7 +183,7 @@ def write_loss(log_directory, filename, step, epoch_len, values): ...@@ -183,7 +183,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
if shared.opts.training_write_csv_every == 0: if shared.opts.training_write_csv_every == 0:
return return
if (step + 1) % shared.opts.training_write_csv_every != 0: if step % shared.opts.training_write_csv_every != 0:
return return
write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
...@@ -193,21 +193,23 @@ def write_loss(log_directory, filename, step, epoch_len, values): ...@@ -193,21 +193,23 @@ def write_loss(log_directory, filename, step, epoch_len, values):
if write_csv_header: if write_csv_header:
csv_writer.writeheader() csv_writer.writeheader()
epoch = step // epoch_len epoch = (step - 1) // epoch_len
epoch_step = step % epoch_len epoch_step = (step - 1) % epoch_len
csv_writer.writerow({ csv_writer.writerow({
"step": step + 1, "step": step,
"epoch": epoch, "epoch": epoch,
"epoch_step": epoch_step + 1, "epoch_step": epoch_step,
**values, **values,
}) })
def validate_train_inputs(model_name, learn_rate, batch_size, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"): def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
assert model_name, f"{name} not selected" assert model_name, f"{name} not selected"
assert learn_rate, "Learning rate is empty or 0" assert learn_rate, "Learning rate is empty or 0"
assert isinstance(batch_size, int), "Batch size must be integer" assert isinstance(batch_size, int), "Batch size must be integer"
assert batch_size > 0, "Batch size must be positive" assert batch_size > 0, "Batch size must be positive"
assert isinstance(gradient_step, int), "Gradient accumulation step must be integer"
assert gradient_step > 0, "Gradient accumulation step must be positive"
assert data_root, "Dataset directory is empty" assert data_root, "Dataset directory is empty"
assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty" assert os.listdir(data_root), "Dataset directory is empty"
...@@ -223,10 +225,10 @@ def validate_train_inputs(model_name, learn_rate, batch_size, data_root, templat ...@@ -223,10 +225,10 @@ def validate_train_inputs(model_name, learn_rate, batch_size, data_root, templat
if save_model_every or create_image_every: if save_model_every or create_image_every:
assert log_directory, "Log directory is empty" assert log_directory, "Log directory is empty"
def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0 save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0 create_image_every = create_image_every or 0
validate_train_inputs(embedding_name, learn_rate, batch_size, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding") validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
shared.state.textinfo = "Initializing textual inversion training..." shared.state.textinfo = "Initializing textual inversion training..."
shared.state.job_count = steps shared.state.job_count = steps
...@@ -254,161 +256,200 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -254,161 +256,200 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
else: else:
images_embeds_dir = None images_embeds_dir = None
cond_model = shared.sd_model.cond_stage_model
hijack = sd_hijack.model_hijack hijack = sd_hijack.model_hijack
embedding = hijack.embedding_db.word_embeddings[embedding_name] embedding = hijack.embedding_db.word_embeddings[embedding_name]
checkpoint = sd_models.select_checkpoint() checkpoint = sd_models.select_checkpoint()
ititial_step = embedding.step or 0 initial_step = embedding.step or 0
if ititial_step >= steps: if initial_step >= steps:
shared.state.textinfo = f"Model has already been trained beyond specified max steps" shared.state.textinfo = f"Model has already been trained beyond specified max steps"
return embedding, filename return embedding, filename
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) # dataset loading may take a while, so input validations and early returns should be done before this
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size) pin_memory = shared.opts.pin_memory
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
latent_sampling_method = ds.latent_sampling_method
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
if unload: if unload:
shared.sd_model.first_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu)
embedding.vec.requires_grad = True embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
scaler = torch.cuda.amp.GradScaler()
losses = torch.zeros((32,))
batch_size = ds.batch_size
gradient_step = ds.gradient_step
# n steps = batch_size * gradient_step * n image processed
steps_per_epoch = len(ds) // batch_size // gradient_step
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
loss_step = 0
_loss_step = 0 #internal
last_saved_file = "<none>" last_saved_file = "<none>"
last_saved_image = "<none>" last_saved_image = "<none>"
forced_filename = "<none>" forced_filename = "<none>"
embedding_yet_to_be_embedded = False embedding_yet_to_be_embedded = False
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) pbar = tqdm.tqdm(total=steps - initial_step)
for i, entries in pbar: try:
embedding.step = i + ititial_step for i in range((steps-initial_step) * gradient_step):
if scheduler.finished:
scheduler.apply(optimizer, embedding.step) break
if scheduler.finished: if shared.state.interrupted:
break break
for j, batch in enumerate(dl):
if shared.state.interrupted: # works as a drop_last=True for gradient accumulation
break if j == max_steps_per_epoch:
break
with torch.autocast("cuda"): scheduler.apply(optimizer, embedding.step)
c = cond_model([entry.cond_text for entry in entries]) if scheduler.finished:
x = torch.stack([entry.latent for entry in entries]).to(devices.device) break
loss = shared.sd_model(x, c)[0] if shared.state.interrupted:
del x break
losses[embedding.step % losses.shape[0]] = loss.item() with torch.autocast("cuda"):
# c = stack_conds(batch.cond).to(devices.device)
optimizer.zero_grad() # mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory)
loss.backward() # print(mask)
optimizer.step() # c[:, 1:1+embedding.vec.shape[0]] = embedding.vec.to(devices.device, non_blocking=pin_memory)
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
steps_done = embedding.step + 1 c = shared.sd_model.cond_stage_model(batch.cond_text)
loss = shared.sd_model(x, c)[0] / gradient_step
epoch_num = embedding.step // len(ds) del x
epoch_step = embedding.step % len(ds)
_loss_step += loss.item()
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}") scaler.scale(loss).backward()
if embedding_dir is not None and steps_done % save_embedding_every == 0: # go back until we reach gradient accumulation steps
# Before saving, change name to match current checkpoint. if (j + 1) % gradient_step != 0:
embedding_name_every = f'{embedding_name}-{steps_done}' continue
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt') scaler.step(optimizer)
save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True) scaler.update()
embedding_yet_to_be_embedded = True embedding.step += 1
pbar.update()
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { optimizer.zero_grad(set_to_none=True)
"loss": f"{losses.mean():.7f}", loss_step = _loss_step
"learn_rate": scheduler.learn_rate _loss_step = 0
})
steps_done = embedding.step + 1
if images_dir is not None and steps_done % create_image_every == 0:
forced_filename = f'{embedding_name}-{steps_done}' epoch_num = embedding.step // steps_per_epoch
last_saved_image = os.path.join(images_dir, forced_filename) epoch_step = embedding.step % steps_per_epoch
shared.sd_model.first_stage_model.to(devices.device) pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
if embedding_dir is not None and steps_done % save_embedding_every == 0:
p = processing.StableDiffusionProcessingTxt2Img( # Before saving, change name to match current checkpoint.
sd_model=shared.sd_model, embedding_name_every = f'{embedding_name}-{steps_done}'
do_not_save_grid=True, last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
do_not_save_samples=True, #if shared.opts.save_optimizer_state:
do_not_reload_embeddings=True, #embedding.optimizer_state_dict = optimizer.state_dict()
) save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
embedding_yet_to_be_embedded = True
if preview_from_txt2img:
p.prompt = preview_prompt write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
p.negative_prompt = preview_negative_prompt "loss": f"{loss_step:.7f}",
p.steps = preview_steps "learn_rate": scheduler.learn_rate
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name })
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed if images_dir is not None and steps_done % create_image_every == 0:
p.width = preview_width forced_filename = f'{embedding_name}-{steps_done}'
p.height = preview_height last_saved_image = os.path.join(images_dir, forced_filename)
else:
p.prompt = entries[0].cond_text shared.sd_model.first_stage_model.to(devices.device)
p.steps = 20
p.width = training_width p = processing.StableDiffusionProcessingTxt2Img(
p.height = training_height sd_model=shared.sd_model,
do_not_save_grid=True,
preview_text = p.prompt do_not_save_samples=True,
do_not_reload_embeddings=True,
processed = processing.process_images(p) )
image = processed.images[0]
if preview_from_txt2img:
if unload: p.prompt = preview_prompt
shared.sd_model.first_stage_model.to(devices.cpu) p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
shared.state.current_image = image p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
p.cfg_scale = preview_cfg_scale
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: p.seed = preview_seed
p.width = preview_width
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png') p.height = preview_height
else:
info = PngImagePlugin.PngInfo() p.prompt = batch.cond_text[0]
data = torch.load(last_saved_file) p.steps = 20
info.add_text("sd-ti-embedding", embedding_to_b64(data)) p.width = training_width
p.height = training_height
title = "<{}>".format(data.get('name', '???'))
preview_text = p.prompt
try:
vectorSize = list(data['string_to_param'].values())[0].shape[0] processed = processing.process_images(p)
except Exception as e: image = processed.images[0] if len(processed.images) > 0 else None
vectorSize = '?'
if unload:
checkpoint = sd_models.select_checkpoint() shared.sd_model.first_stage_model.to(devices.cpu)
footer_left = checkpoint.model_name
footer_mid = '[{}]'.format(checkpoint.hash) if image is not None:
footer_right = '{}v {}s'.format(vectorSize, steps_done) shared.state.current_image = image
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) last_saved_image += f", prompt: {preview_text}"
captioned_image = insert_image_data_embed(captioned_image, data)
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
embedding_yet_to_be_embedded = False last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) info = PngImagePlugin.PngInfo()
last_saved_image += f", prompt: {preview_text}" data = torch.load(last_saved_file)
info.add_text("sd-ti-embedding", embedding_to_b64(data))
shared.state.job_no = embedding.step
title = "<{}>".format(data.get('name', '???'))
shared.state.textinfo = f"""
try:
vectorSize = list(data['string_to_param'].values())[0].shape[0]
except Exception as e:
vectorSize = '?'
checkpoint = sd_models.select_checkpoint()
footer_left = checkpoint.model_name
footer_mid = '[{}]'.format(checkpoint.hash)
footer_right = '{}v {}s'.format(vectorSize, steps_done)
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
captioned_image = insert_image_data_embed(captioned_image, data)
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
embedding_yet_to_be_embedded = False
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = embedding.step
shared.state.textinfo = f"""
<p> <p>
Loss: {losses.mean():.7f}<br/> Loss: {loss_step:.7f}<br/>
Step: {embedding.step}<br/> Step: {steps_done}<br/>
Last prompt: {html.escape(entries[0].cond_text)}<br/> Last prompt: {html.escape(batch.cond_text[0])}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/> Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/> Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>
""" """
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True) except Exception:
shared.sd_model.first_stage_model.to(devices.device) print(traceback.format_exc(), file=sys.stderr)
pass
finally:
pbar.leave = False
pbar.close()
shared.sd_model.first_stage_model.to(devices.device)
return embedding, filename return embedding, filename
......
...@@ -1240,7 +1240,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1240,7 +1240,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Column(): with gr.Column():
with gr.Row(): with gr.Row():
interrupt_preprocessing = gr.Button("Interrupt") interrupt_preprocessing = gr.Button("Interrupt")
run_preprocess = gr.Button(value="Preprocess", variant='primary') run_preprocess = gr.Button(value="Preprocess", variant='primary')
process_split.change( process_split.change(
fn=lambda show: gr_show(show), fn=lambda show: gr_show(show),
...@@ -1267,6 +1267,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1267,6 +1267,7 @@ def create_ui(wrap_gradio_gpu_call):
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001") hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001")
batch_size = gr.Number(label='Batch size', value=1, precision=0) batch_size = gr.Number(label='Batch size', value=1, precision=0)
gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
...@@ -1277,6 +1278,11 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1277,6 +1278,11 @@ def create_ui(wrap_gradio_gpu_call):
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False) preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
with gr.Row():
shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False)
tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0)
with gr.Row():
latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'])
with gr.Row(): with gr.Row():
interrupt_training = gr.Button(value="Interrupt") interrupt_training = gr.Button(value="Interrupt")
...@@ -1365,11 +1371,15 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1365,11 +1371,15 @@ def create_ui(wrap_gradio_gpu_call):
train_embedding_name, train_embedding_name,
embedding_learn_rate, embedding_learn_rate,
batch_size, batch_size,
gradient_step,
dataset_directory, dataset_directory,
log_directory, log_directory,
training_width, training_width,
training_height, training_height,
steps, steps,
shuffle_tags,
tag_drop_out,
latent_sampling_method,
create_image_every, create_image_every,
save_embedding_every, save_embedding_every,
template_file, template_file,
...@@ -1390,11 +1400,15 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1390,11 +1400,15 @@ def create_ui(wrap_gradio_gpu_call):
train_hypernetwork_name, train_hypernetwork_name,
hypernetwork_learn_rate, hypernetwork_learn_rate,
batch_size, batch_size,
gradient_step,
dataset_directory, dataset_directory,
log_directory, log_directory,
training_width, training_width,
training_height, training_height,
steps, steps,
shuffle_tags,
tag_drop_out,
latent_sampling_method,
create_image_every, create_image_every,
save_embedding_every, save_embedding_every,
template_file, template_file,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment