Unverified Commit 99da2c5a authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge pull request #6528 from PlasmaPower/vae-safetensors

Add support for loading VAEs from safetensors files
parents dd21af06 cb255fae
import torch import torch
import safetensors.torch
import os import os
import collections import collections
from collections import namedtuple from collections import namedtuple
...@@ -72,8 +73,10 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path): ...@@ -72,8 +73,10 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path):
candidates = [ candidates = [
*glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True), *glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True),
*glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True), *glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True),
*glob.iglob(os.path.join(model_path, '**/*.vae.safetensors'), recursive=True),
*glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True), *glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True),
*glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True) *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True),
*glob.iglob(os.path.join(vae_path, '**/*.safetensors'), recursive=True),
] ]
if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path): if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path):
candidates.append(shared.cmd_opts.vae_path) candidates.append(shared.cmd_opts.vae_path)
...@@ -137,6 +140,12 @@ def resolve_vae(checkpoint_file=None, vae_file="auto"): ...@@ -137,6 +140,12 @@ def resolve_vae(checkpoint_file=None, vae_file="auto"):
if os.path.isfile(vae_file_try): if os.path.isfile(vae_file_try):
vae_file = vae_file_try vae_file = vae_file_try
print(f"Using VAE found similar to selected model: {vae_file}") print(f"Using VAE found similar to selected model: {vae_file}")
# if still not found, try look for ".vae.safetensors" beside model
if vae_file == "auto":
vae_file_try = model_path + ".vae.safetensors"
if os.path.isfile(vae_file_try):
vae_file = vae_file_try
print(f"Using VAE found similar to selected model: {vae_file}")
# No more fallbacks for auto # No more fallbacks for auto
if vae_file == "auto": if vae_file == "auto":
vae_file = None vae_file = None
...@@ -163,8 +172,14 @@ def load_vae(model, vae_file=None): ...@@ -163,8 +172,14 @@ def load_vae(model, vae_file=None):
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}" assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
print(f"Loading VAE weights from: {vae_file}") print(f"Loading VAE weights from: {vae_file}")
store_base_vae(model) store_base_vae(model)
_, extension = os.path.splitext(vae_file)
if extension.lower() == ".safetensors":
vae_ckpt = safetensors.torch.load_file(vae_file, device=shared.weight_load_location)
else:
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} if "state_dict" in vae_ckpt:
vae_ckpt = vae_ckpt["state_dict"]
vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
_load_vae_dict(model, vae_dict_1) _load_vae_dict(model, vae_dict_1)
if cache_enabled: if cache_enabled:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment