Commit bddebe09 authored by Shondoit's avatar Shondoit

Save Optimizer next to TI embedding

Also add check to load only .PT and .BIN files as embeddings. (since we add .optim files in the same directory)
parent c0ee1488
...@@ -355,7 +355,7 @@ options_templates.update(options_section(('system', "System"), { ...@@ -355,7 +355,7 @@ options_templates.update(options_section(('system', "System"), {
options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."), "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."), "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
......
...@@ -28,6 +28,7 @@ class Embedding: ...@@ -28,6 +28,7 @@ class Embedding:
self.cached_checksum = None self.cached_checksum = None
self.sd_checkpoint = None self.sd_checkpoint = None
self.sd_checkpoint_name = None self.sd_checkpoint_name = None
self.optimizer_state_dict = None
def save(self, filename): def save(self, filename):
embedding_data = { embedding_data = {
...@@ -41,6 +42,13 @@ class Embedding: ...@@ -41,6 +42,13 @@ class Embedding:
torch.save(embedding_data, filename) torch.save(embedding_data, filename)
if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None:
optimizer_saved_dict = {
'hash': self.checksum(),
'optimizer_state_dict': self.optimizer_state_dict,
}
torch.save(optimizer_saved_dict, filename + '.optim')
def checksum(self): def checksum(self):
if self.cached_checksum is not None: if self.cached_checksum is not None:
return self.cached_checksum return self.cached_checksum
...@@ -95,9 +103,10 @@ class EmbeddingDatabase: ...@@ -95,9 +103,10 @@ class EmbeddingDatabase:
self.expected_shape = self.get_expected_shape() self.expected_shape = self.get_expected_shape()
def process_file(path, filename): def process_file(path, filename):
name = os.path.splitext(filename)[0] name, ext = os.path.splitext(filename)
ext = ext.upper()
if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']: if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
embed_image = Image.open(path) embed_image = Image.open(path)
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text: if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
data = embedding_from_b64(embed_image.text['sd-ti-embedding']) data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
...@@ -105,8 +114,10 @@ class EmbeddingDatabase: ...@@ -105,8 +114,10 @@ class EmbeddingDatabase:
else: else:
data = extract_image_data_embed(embed_image) data = extract_image_data_embed(embed_image)
name = data.get('name', name) name = data.get('name', name)
else: elif ext in ['.BIN', '.PT']:
data = torch.load(path, map_location="cpu") data = torch.load(path, map_location="cpu")
else:
return
# textual inversion embeddings # textual inversion embeddings
if 'string_to_param' in data: if 'string_to_param' in data:
...@@ -300,6 +311,20 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ ...@@ -300,6 +311,20 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
embedding.vec.requires_grad = True embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0) optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
if shared.opts.save_optimizer_state:
optimizer_state_dict = None
if os.path.exists(filename + '.optim'):
optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu')
if embedding.checksum() == optimizer_saved_dict.get('hash', None):
optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
if optimizer_state_dict is not None:
optimizer.load_state_dict(optimizer_state_dict)
print("Loaded existing optimizer from checkpoint")
else:
print("No saved optimizer exists in checkpoint")
scaler = torch.cuda.amp.GradScaler() scaler = torch.cuda.amp.GradScaler()
batch_size = ds.batch_size batch_size = ds.batch_size
...@@ -366,9 +391,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ ...@@ -366,9 +391,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
# Before saving, change name to match current checkpoint. # Before saving, change name to match current checkpoint.
embedding_name_every = f'{embedding_name}-{steps_done}' embedding_name_every = f'{embedding_name}-{steps_done}'
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt') last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
#if shared.opts.save_optimizer_state: save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
#embedding.optimizer_state_dict = optimizer.state_dict()
save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
embedding_yet_to_be_embedded = True embedding_yet_to_be_embedded = True
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, { write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
...@@ -458,7 +481,7 @@ Last saved image: {html.escape(last_saved_image)}<br/> ...@@ -458,7 +481,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>
""" """
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True) save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
except Exception: except Exception:
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
pass pass
...@@ -470,7 +493,7 @@ Last saved image: {html.escape(last_saved_image)}<br/> ...@@ -470,7 +493,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
return embedding, filename return embedding, filename
def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True): def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
old_embedding_name = embedding.name old_embedding_name = embedding.name
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
...@@ -481,6 +504,7 @@ def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cache ...@@ -481,6 +504,7 @@ def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cache
if remove_cached_checksum: if remove_cached_checksum:
embedding.cached_checksum = None embedding.cached_checksum = None
embedding.name = embedding_name embedding.name = embedding_name
embedding.optimizer_state_dict = optimizer.state_dict()
embedding.save(filename) embedding.save(filename)
except: except:
embedding.sd_checkpoint = old_sd_checkpoint embedding.sd_checkpoint = old_sd_checkpoint
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment