Commit ac90cf38 authored by Tim Patton's avatar Tim Patton

safetensors optional for now

parent 210cb4c1
...@@ -4,7 +4,6 @@ import sys ...@@ -4,7 +4,6 @@ import sys
import gc import gc
from collections import namedtuple from collections import namedtuple
import torch import torch
from safetensors.torch import load_file, save_file
import re import re
from omegaconf import OmegaConf from omegaconf import OmegaConf
...@@ -149,6 +148,10 @@ def torch_load(model_filename, model_info, map_override=None): ...@@ -149,6 +148,10 @@ def torch_load(model_filename, model_info, map_override=None):
# safely load weights # safely load weights
# TODO: safetensors supports zero copy fast load to gpu, see issue #684. # TODO: safetensors supports zero copy fast load to gpu, see issue #684.
# GPU only for now, see https://github.com/huggingface/safetensors/issues/95 # GPU only for now, see https://github.com/huggingface/safetensors/issues/95
try:
from safetensors.torch import load_file
except ImportError as e:
raise ImportError(f"The model is in safetensors format and it is not installed, use `pip install safetensors`: {e}")
return load_file(model_filename, device='cuda') return load_file(model_filename, device='cuda')
else: else:
return torch.load(model_filename, map_location=map_override) return torch.load(model_filename, map_location=map_override)
...@@ -157,6 +160,10 @@ def torch_save(model, output_filename): ...@@ -157,6 +160,10 @@ def torch_save(model, output_filename):
basename, exttype = os.path.splitext(output_filename) basename, exttype = os.path.splitext(output_filename)
if(checkpoint_types[exttype] == 'safetensors'): if(checkpoint_types[exttype] == 'safetensors'):
# [===== >] Reticulating brines... # [===== >] Reticulating brines...
try:
from safetensors.torch import save_file
except ImportError as e:
raise ImportError(f"Export as safetensors selected, yet it is not installed, use `pip install safetensors`: {e}")
save_file(model, output_filename, metadata={"format": "pt"}) save_file(model, output_filename, metadata={"format": "pt"})
else: else:
torch.save(model, output_filename) torch.save(model, output_filename)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment