Commit 29e74d6e authored by Melan's avatar Melan

Add support for Tensorboard for training embeddings

parent 7f8ab1ee
...@@ -254,6 +254,10 @@ options_templates.update(options_section(('training', "Training"), { ...@@ -254,6 +254,10 @@ options_templates.update(options_section(('training', "Training"), {
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"), "training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
"training_enable_tensorboard": OptionInfo(False, "Enable tensorboard logging."),
"training_tensorboard_save_images": OptionInfo(False, "Save generated images within tensorboard."),
"training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."),
})) }))
options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('sd', "Stable Diffusion"), {
......
...@@ -7,9 +7,11 @@ import tqdm ...@@ -7,9 +7,11 @@ import tqdm
import html import html
import datetime import datetime
import csv import csv
import numpy as np
import torchvision.transforms
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
from torch.utils.tensorboard import SummaryWriter
from modules import shared, devices, sd_hijack, processing, sd_models from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.learn_schedule import LearnRateScheduler
...@@ -199,6 +201,19 @@ def write_loss(log_directory, filename, step, epoch_len, values): ...@@ -199,6 +201,19 @@ def write_loss(log_directory, filename, step, epoch_len, values):
**values, **values,
}) })
def tensorboard_add_scaler(tensorboard_writer, tag, value, step):
if shared.opts.training_enable_tensorboard:
tensorboard_writer.add_scalar(tag=tag,
scalar_value=value, global_step=step)
def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
if shared.opts.training_enable_tensorboard:
# Convert a pil image to a torch tensor
img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0], len(pil_image.getbands()))
img_tensor = img_tensor.permute((2, 0, 1))
tensorboard_writer.add_image(tag, img_tensor, global_step=step)
def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert embedding_name, 'embedding not selected' assert embedding_name, 'embedding not selected'
...@@ -252,6 +267,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -252,6 +267,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
if shared.opts.training_enable_tensorboard:
os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
tensorboard_writer = SummaryWriter(
log_dir=os.path.join(log_directory, "tensorboard"),
flush_secs=shared.opts.training_tensorboard_flush_every)
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, entries in pbar: for i, entries in pbar:
embedding.step = i + ititial_step embedding.step = i + ititial_step
...@@ -271,6 +292,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -271,6 +292,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
losses[embedding.step % losses.shape[0]] = loss.item() losses[embedding.step % losses.shape[0]] = loss.item()
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
...@@ -285,6 +307,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -285,6 +307,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
embedding.save(last_saved_file) embedding.save(last_saved_file)
embedding_yet_to_be_embedded = True embedding_yet_to_be_embedded = True
if shared.opts.training_enable_tensorboard:
tensorboard_add_scaler(tensorboard_writer, "Loss/train", losses.mean(), embedding.step)
tensorboard_add_scaler(tensorboard_writer, f"Loss/train/epoch-{epoch_num}", losses.mean(), epoch_step)
tensorboard_add_scaler(tensorboard_writer, "Learn rate/train", scheduler.learn_rate, embedding.step)
tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", scheduler.learn_rate, epoch_step)
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
"loss": f"{losses.mean():.7f}", "loss": f"{losses.mean():.7f}",
"learn_rate": scheduler.learn_rate "learn_rate": scheduler.learn_rate
...@@ -349,6 +377,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -349,6 +377,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
embedding_yet_to_be_embedded = False embedding_yet_to_be_embedded = False
image.save(last_saved_image) image.save(last_saved_image)
tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)
last_saved_image += f", prompt: {preview_text}" last_saved_image += f", prompt: {preview_text}"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment