Commit 6074175f authored by AUTOMATIC's avatar AUTOMATIC

add safetensors to requirements

parent f108782e
...@@ -5,6 +5,7 @@ import gc ...@@ -5,6 +5,7 @@ import gc
from collections import namedtuple from collections import namedtuple
import torch import torch
import re import re
import safetensors.torch
from omegaconf import OmegaConf from omegaconf import OmegaConf
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
...@@ -173,14 +174,12 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): ...@@ -173,14 +174,12 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
# load from file # load from file
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
if checkpoint_file.endswith(".safetensors"): _, extension = os.path.splitext(checkpoint_file)
try: if extension.lower() == ".safetensors":
from safetensors.torch import load_file pl_sd = safetensors.torch.load_file(checkpoint_file, device=shared.weight_load_location)
except ImportError as e:
raise ImportError(f"The model is in safetensors format and it is not installed, use `pip install safetensors`: {e}")
pl_sd = load_file(checkpoint_file, device=shared.weight_load_location)
else: else:
pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
if "global_step" in pl_sd: if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}") print(f"Global Step: {pl_sd['global_step']}")
......
...@@ -26,3 +26,4 @@ lark==1.1.2 ...@@ -26,3 +26,4 @@ lark==1.1.2
inflection==0.5.1 inflection==0.5.1
GitPython==3.1.27 GitPython==3.1.27
torchsde==0.2.5 torchsde==0.2.5
safetensors==0.2.5
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment