Commit bd68e35d authored by flamelaw's avatar flamelaw

Gradient accumulation, autocast fix, new latent sampling method, etc

parent 47a44c7e
This diff is collapsed.
...@@ -8,7 +8,7 @@ from torch import einsum ...@@ -8,7 +8,7 @@ from torch import einsum
from torch.nn.functional import silu from torch.nn.functional import silu
import modules.textual_inversion.textual_inversion import modules.textual_inversion.textual_inversion
from modules import prompt_parser, devices, sd_hijack_optimizations, shared from modules import prompt_parser, devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
from modules.shared import opts, device, cmd_opts from modules.shared import opts, device, cmd_opts
from modules.sd_hijack_optimizations import invokeAI_mps_available from modules.sd_hijack_optimizations import invokeAI_mps_available
...@@ -59,6 +59,10 @@ def undo_optimizations(): ...@@ -59,6 +59,10 @@ def undo_optimizations():
def get_target_prompt_token_count(token_count): def get_target_prompt_token_count(token_count):
return math.ceil(max(token_count, 1) / 75) * 75 return math.ceil(max(token_count, 1) / 75) * 75
def fix_checkpoint():
ldm.modules.attention.BasicTransformerBlock.forward = sd_hijack_checkpoint.BasicTransformerBlock_forward
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
class StableDiffusionModelHijack: class StableDiffusionModelHijack:
fixes = None fixes = None
...@@ -78,6 +82,7 @@ class StableDiffusionModelHijack: ...@@ -78,6 +82,7 @@ class StableDiffusionModelHijack:
self.clip = m.cond_stage_model self.clip = m.cond_stage_model
apply_optimizations() apply_optimizations()
fix_checkpoint()
def flatten(el): def flatten(el):
flattened = [flatten(children) for children in el.children()] flattened = [flatten(children) for children in el.children()]
...@@ -303,7 +308,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -303,7 +308,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
else: else:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
self.hijack.comments += hijack_comments self.hijack.comments += hijack_comments
if len(used_custom_terms) > 0: if len(used_custom_terms) > 0:
......
from torch.utils.checkpoint import checkpoint
def BasicTransformerBlock_forward(self, x, context=None):
return checkpoint(self._forward, x, context)
def AttentionBlock_forward(self, x):
return checkpoint(self._forward, x)
def ResBlock_forward(self, x, emb):
return checkpoint(self._forward, x, emb)
\ No newline at end of file
...@@ -322,8 +322,7 @@ options_templates.update(options_section(('system', "System"), { ...@@ -322,8 +322,7 @@ options_templates.update(options_section(('system', "System"), {
options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"shuffle_tags": OptionInfo(False, "Shuffleing tags by ',' when create texts."), "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
"tag_drop_out": OptionInfo(0, "Dropout tags when create texts", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.1}),
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."), "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
......
...@@ -3,7 +3,7 @@ import numpy as np ...@@ -3,7 +3,7 @@ import numpy as np
import PIL import PIL
import torch import torch
from PIL import Image from PIL import Image
from torch.utils.data import Dataset from torch.utils.data import Dataset, DataLoader
from torchvision import transforms from torchvision import transforms
import random import random
...@@ -11,25 +11,28 @@ import tqdm ...@@ -11,25 +11,28 @@ import tqdm
from modules import devices, shared from modules import devices, shared
import re import re
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
re_numbers_at_start = re.compile(r"^[-\d]+\s*") re_numbers_at_start = re.compile(r"^[-\d]+\s*")
class DatasetEntry: class DatasetEntry:
def __init__(self, filename=None, latent=None, filename_text=None): def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None):
self.filename = filename self.filename = filename
self.latent = latent
self.filename_text = filename_text self.filename_text = filename_text
self.cond = None self.latent_dist = latent_dist
self.cond_text = None self.latent_sample = latent_sample
self.cond = cond
self.cond_text = cond_text
self.pixel_values = pixel_values
class PersonalizedBase(Dataset): class PersonalizedBase(Dataset):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1): def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
self.placeholder_token = placeholder_token self.placeholder_token = placeholder_token
self.batch_size = batch_size
self.width = width self.width = width
self.height = height self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p) self.flip = transforms.RandomHorizontalFlip(p=flip_p)
...@@ -45,11 +48,16 @@ class PersonalizedBase(Dataset): ...@@ -45,11 +48,16 @@ class PersonalizedBase(Dataset):
assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty" assert os.listdir(data_root), "Dataset directory is empty"
cond_model = shared.sd_model.cond_stage_model
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
self.shuffle_tags = shuffle_tags
self.tag_drop_out = tag_drop_out
print("Preparing dataset...") print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths): for path in tqdm.tqdm(self.image_paths):
if shared.state.interrupted:
raise Exception("inturrupted")
try: try:
image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC) image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
except Exception: except Exception:
...@@ -71,37 +79,58 @@ class PersonalizedBase(Dataset): ...@@ -71,37 +79,58 @@ class PersonalizedBase(Dataset):
npimage = np.array(image).astype(np.uint8) npimage = np.array(image).astype(np.uint8)
npimage = (npimage / 127.5 - 1.0).astype(np.float32) npimage = (npimage / 127.5 - 1.0).astype(np.float32)
torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32) torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
torchdata = torch.moveaxis(torchdata, 2, 0) latent_sample = None
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze() with torch.autocast("cuda"):
init_latent = init_latent.to(devices.cpu) latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent) if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)):
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
if include_cond: latent_sampling_method = "once"
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
elif latent_sampling_method == "deterministic":
# Works only for DiagonalGaussianDistribution
latent_dist.std = 0
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
elif latent_sampling_method == "random":
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist)
if not (self.tag_drop_out != 0 or self.shuffle_tags):
entry.cond_text = self.create_text(filename_text) entry.cond_text = self.create_text(filename_text)
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
self.dataset.append(entry) if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
with torch.autocast("cuda"):
assert len(self.dataset) > 0, "No images have been found in the dataset." entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
self.length = len(self.dataset) * repeats // batch_size # elif not include_cond:
# _, _, _, _, hijack_fixes, token_count = cond_model.process_text([entry.cond_text])
# max_n = token_count // 75
# index_list = [ [] for _ in range(max_n + 1) ]
# for n, (z, _) in hijack_fixes[0]:
# index_list[n].append(z)
# with torch.autocast("cuda"):
# entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
# entry.emb_index = index_list
self.dataset_length = len(self.dataset) self.dataset.append(entry)
self.indexes = None del torchdata
self.shuffle() del latent_dist
del latent_sample
def shuffle(self): self.length = len(self.dataset)
self.indexes = np.random.permutation(self.dataset_length) assert self.length > 0, "No images have been found in the dataset."
self.batch_size = min(batch_size, self.length)
self.gradient_step = min(gradient_step, self.length // self.batch_size)
self.latent_sampling_method = latent_sampling_method
def create_text(self, filename_text): def create_text(self, filename_text):
text = random.choice(self.lines) text = random.choice(self.lines)
text = text.replace("[name]", self.placeholder_token) text = text.replace("[name]", self.placeholder_token)
tags = filename_text.split(',') tags = filename_text.split(',')
if shared.opts.tag_drop_out != 0: if self.tag_drop_out != 0:
tags = [t for t in tags if random.random() > shared.opts.tag_drop_out] tags = [t for t in tags if random.random() > self.tag_drop_out]
if shared.opts.shuffle_tags: if self.shuffle_tags:
random.shuffle(tags) random.shuffle(tags)
text = text.replace("[filewords]", ','.join(tags)) text = text.replace("[filewords]", ','.join(tags))
return text return text
...@@ -110,19 +139,28 @@ class PersonalizedBase(Dataset): ...@@ -110,19 +139,28 @@ class PersonalizedBase(Dataset):
return self.length return self.length
def __getitem__(self, i): def __getitem__(self, i):
res = [] entry = self.dataset[i]
if self.tag_drop_out != 0 or self.shuffle_tags:
for j in range(self.batch_size): entry.cond_text = self.create_text(entry.filename_text)
position = i * self.batch_size + j if self.latent_sampling_method == "random":
if position % len(self.indexes) == 0: entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist)
self.shuffle() return entry
index = self.indexes[position % len(self.indexes)] class PersonalizedDataLoader(DataLoader):
entry = self.dataset[index] def __init__(self, *args, **kwargs):
super(PersonalizedDataLoader, self).__init__(shuffle=True, drop_last=True, *args, **kwargs)
if entry.cond is None: self.collate_fn = collate_wrapper
entry.cond_text = self.create_text(entry.filename_text)
res.append(entry) class BatchLoader:
def __init__(self, data):
return res self.cond_text = [entry.cond_text for entry in data]
self.cond = [entry.cond for entry in data]
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
def pin_memory(self):
self.latent_sample = self.latent_sample.pin_memory()
return self
def collate_wrapper(batch):
return BatchLoader(batch)
\ No newline at end of file
...@@ -1262,7 +1262,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1262,7 +1262,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Column(): with gr.Column():
with gr.Row(): with gr.Row():
interrupt_preprocessing = gr.Button("Interrupt") interrupt_preprocessing = gr.Button("Interrupt")
run_preprocess = gr.Button(value="Preprocess", variant='primary') run_preprocess = gr.Button(value="Preprocess", variant='primary')
process_split.change( process_split.change(
fn=lambda show: gr_show(show), fn=lambda show: gr_show(show),
...@@ -1289,6 +1289,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1289,6 +1289,7 @@ def create_ui(wrap_gradio_gpu_call):
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001") hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001")
batch_size = gr.Number(label='Batch size', value=1, precision=0) batch_size = gr.Number(label='Batch size', value=1, precision=0)
gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
...@@ -1299,6 +1300,11 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1299,6 +1300,11 @@ def create_ui(wrap_gradio_gpu_call):
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False) preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
with gr.Row():
shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False)
tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0)
with gr.Row():
latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'])
with gr.Row(): with gr.Row():
interrupt_training = gr.Button(value="Interrupt") interrupt_training = gr.Button(value="Interrupt")
...@@ -1387,11 +1393,15 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1387,11 +1393,15 @@ def create_ui(wrap_gradio_gpu_call):
train_embedding_name, train_embedding_name,
embedding_learn_rate, embedding_learn_rate,
batch_size, batch_size,
gradient_step,
dataset_directory, dataset_directory,
log_directory, log_directory,
training_width, training_width,
training_height, training_height,
steps, steps,
shuffle_tags,
tag_drop_out,
latent_sampling_method,
create_image_every, create_image_every,
save_embedding_every, save_embedding_every,
template_file, template_file,
...@@ -1412,11 +1422,15 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1412,11 +1422,15 @@ def create_ui(wrap_gradio_gpu_call):
train_hypernetwork_name, train_hypernetwork_name,
hypernetwork_learn_rate, hypernetwork_learn_rate,
batch_size, batch_size,
gradient_step,
dataset_directory, dataset_directory,
log_directory, log_directory,
training_width, training_width,
training_height, training_height,
steps, steps,
shuffle_tags,
tag_drop_out,
latent_sampling_method,
create_image_every, create_image_every,
save_embedding_every, save_embedding_every,
template_file, template_file,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment