Commit 8bcdd504 authored by wywywywy's avatar wywywywy

Add safetensors support to LDSR

parent 685f9631
import os
import gc import gc
import time import time
import warnings import warnings
...@@ -8,6 +9,7 @@ import torchvision ...@@ -8,6 +9,7 @@ import torchvision
from PIL import Image from PIL import Image
from einops import rearrange, repeat from einops import rearrange, repeat
from omegaconf import OmegaConf from omegaconf import OmegaConf
import safetensors.torch
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config, ismap from ldm.util import instantiate_from_config, ismap
...@@ -28,8 +30,12 @@ class LDSR: ...@@ -28,8 +30,12 @@ class LDSR:
model: torch.nn.Module = cached_ldsr_model model: torch.nn.Module = cached_ldsr_model
else: else:
print(f"Loading model from {self.modelPath}") print(f"Loading model from {self.modelPath}")
pl_sd = torch.load(self.modelPath, map_location="cpu") _, extension = os.path.splitext(self.modelPath)
sd = pl_sd["state_dict"] if extension.lower() == ".safetensors":
pl_sd = safetensors.torch.load_file(self.modelPath, device="cpu")
else:
pl_sd = torch.load(self.modelPath, map_location="cpu")
sd = pl_sd["state_dict"] if "state_dict" in pl_sd else pl_sd
config = OmegaConf.load(self.yamlPath) config = OmegaConf.load(self.yamlPath)
config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1" config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1"
model: torch.nn.Module = instantiate_from_config(config.model) model: torch.nn.Module = instantiate_from_config(config.model)
......
...@@ -25,6 +25,7 @@ class UpscalerLDSR(Upscaler): ...@@ -25,6 +25,7 @@ class UpscalerLDSR(Upscaler):
yaml_path = os.path.join(self.model_path, "project.yaml") yaml_path = os.path.join(self.model_path, "project.yaml")
old_model_path = os.path.join(self.model_path, "model.pth") old_model_path = os.path.join(self.model_path, "model.pth")
new_model_path = os.path.join(self.model_path, "model.ckpt") new_model_path = os.path.join(self.model_path, "model.ckpt")
safetensors_model_path = os.path.join(self.model_path, "model.safetensors")
if os.path.exists(yaml_path): if os.path.exists(yaml_path):
statinfo = os.stat(yaml_path) statinfo = os.stat(yaml_path)
if statinfo.st_size >= 10485760: if statinfo.st_size >= 10485760:
...@@ -33,8 +34,11 @@ class UpscalerLDSR(Upscaler): ...@@ -33,8 +34,11 @@ class UpscalerLDSR(Upscaler):
if os.path.exists(old_model_path): if os.path.exists(old_model_path):
print("Renaming model from model.pth to model.ckpt") print("Renaming model from model.pth to model.ckpt")
os.rename(old_model_path, new_model_path) os.rename(old_model_path, new_model_path)
model = load_file_from_url(url=self.model_url, model_dir=self.model_path, if os.path.exists(safetensors_model_path):
file_name="model.ckpt", progress=True) model = safetensors_model_path
else:
model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
file_name="model.ckpt", progress=True)
yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path, yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path,
file_name="project.yaml", progress=True) file_name="project.yaml", progress=True)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment