Unverified Commit 3d8256e4 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge pull request #6017 from hitomi/master

Add memory cache for VAE weights
parents d81636a0 893933e0
import torch import torch
import os import os
import collections
from collections import namedtuple from collections import namedtuple
from modules import shared, devices, script_callbacks from modules import shared, devices, script_callbacks
from modules.paths import models_path from modules.paths import models_path
...@@ -30,6 +31,7 @@ base_vae = None ...@@ -30,6 +31,7 @@ base_vae = None
loaded_vae_file = None loaded_vae_file = None
checkpoint_info = None checkpoint_info = None
checkpoints_loaded = collections.OrderedDict()
def get_base_vae(model): def get_base_vae(model):
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model: if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
...@@ -149,13 +151,30 @@ def load_vae(model, vae_file=None): ...@@ -149,13 +151,30 @@ def load_vae(model, vae_file=None):
global first_load, vae_dict, vae_list, loaded_vae_file global first_load, vae_dict, vae_list, loaded_vae_file
# save_settings = False # save_settings = False
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
if vae_file: if vae_file:
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}" if cache_enabled and vae_file in checkpoints_loaded:
print(f"Loading VAE weights from: {vae_file}") # use vae checkpoint cache
store_base_vae(model) print(f"Loading VAE weights [{get_filename(vae_file)}] from cache")
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) store_base_vae(model)
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} _load_vae_dict(model, checkpoints_loaded[vae_file])
_load_vae_dict(model, vae_dict_1) else:
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
print(f"Loading VAE weights from: {vae_file}")
store_base_vae(model)
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
_load_vae_dict(model, vae_dict_1)
if cache_enabled:
# cache newly loaded vae
checkpoints_loaded[vae_file] = vae_dict_1.copy()
# clean up cache if limit is reached
if cache_enabled:
while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model
checkpoints_loaded.popitem(last=False) # LRU
# If vae used is not in dict, update it # If vae used is not in dict, update it
# It will be removed on refresh though # It will be removed on refresh though
......
...@@ -356,6 +356,7 @@ options_templates.update(options_section(('training', "Training"), { ...@@ -356,6 +356,7 @@ options_templates.update(options_section(('training', "Training"), {
options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints), "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list), "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
"sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), "sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment