Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
8bcdd504
Commit
8bcdd504
authored
Dec 10, 2022
by
wywywywy
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add safetensors support to LDSR
parent
685f9631
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
4 deletions
+14
-4
ldsr_model_arch.py
extensions-builtin/LDSR/ldsr_model_arch.py
+8
-2
ldsr_model.py
extensions-builtin/LDSR/scripts/ldsr_model.py
+6
-2
No files found.
extensions-builtin/LDSR/ldsr_model_arch.py
View file @
8bcdd504
import
os
import
gc
import
gc
import
time
import
time
import
warnings
import
warnings
...
@@ -8,6 +9,7 @@ import torchvision
...
@@ -8,6 +9,7 @@ import torchvision
from
PIL
import
Image
from
PIL
import
Image
from
einops
import
rearrange
,
repeat
from
einops
import
rearrange
,
repeat
from
omegaconf
import
OmegaConf
from
omegaconf
import
OmegaConf
import
safetensors.torch
from
ldm.models.diffusion.ddim
import
DDIMSampler
from
ldm.models.diffusion.ddim
import
DDIMSampler
from
ldm.util
import
instantiate_from_config
,
ismap
from
ldm.util
import
instantiate_from_config
,
ismap
...
@@ -28,8 +30,12 @@ class LDSR:
...
@@ -28,8 +30,12 @@ class LDSR:
model
:
torch
.
nn
.
Module
=
cached_ldsr_model
model
:
torch
.
nn
.
Module
=
cached_ldsr_model
else
:
else
:
print
(
f
"Loading model from {self.modelPath}"
)
print
(
f
"Loading model from {self.modelPath}"
)
pl_sd
=
torch
.
load
(
self
.
modelPath
,
map_location
=
"cpu"
)
_
,
extension
=
os
.
path
.
splitext
(
self
.
modelPath
)
sd
=
pl_sd
[
"state_dict"
]
if
extension
.
lower
()
==
".safetensors"
:
pl_sd
=
safetensors
.
torch
.
load_file
(
self
.
modelPath
,
device
=
"cpu"
)
else
:
pl_sd
=
torch
.
load
(
self
.
modelPath
,
map_location
=
"cpu"
)
sd
=
pl_sd
[
"state_dict"
]
if
"state_dict"
in
pl_sd
else
pl_sd
config
=
OmegaConf
.
load
(
self
.
yamlPath
)
config
=
OmegaConf
.
load
(
self
.
yamlPath
)
config
.
model
.
target
=
"ldm.models.diffusion.ddpm.LatentDiffusionV1"
config
.
model
.
target
=
"ldm.models.diffusion.ddpm.LatentDiffusionV1"
model
:
torch
.
nn
.
Module
=
instantiate_from_config
(
config
.
model
)
model
:
torch
.
nn
.
Module
=
instantiate_from_config
(
config
.
model
)
...
...
extensions-builtin/LDSR/scripts/ldsr_model.py
View file @
8bcdd504
...
@@ -25,6 +25,7 @@ class UpscalerLDSR(Upscaler):
...
@@ -25,6 +25,7 @@ class UpscalerLDSR(Upscaler):
yaml_path
=
os
.
path
.
join
(
self
.
model_path
,
"project.yaml"
)
yaml_path
=
os
.
path
.
join
(
self
.
model_path
,
"project.yaml"
)
old_model_path
=
os
.
path
.
join
(
self
.
model_path
,
"model.pth"
)
old_model_path
=
os
.
path
.
join
(
self
.
model_path
,
"model.pth"
)
new_model_path
=
os
.
path
.
join
(
self
.
model_path
,
"model.ckpt"
)
new_model_path
=
os
.
path
.
join
(
self
.
model_path
,
"model.ckpt"
)
safetensors_model_path
=
os
.
path
.
join
(
self
.
model_path
,
"model.safetensors"
)
if
os
.
path
.
exists
(
yaml_path
):
if
os
.
path
.
exists
(
yaml_path
):
statinfo
=
os
.
stat
(
yaml_path
)
statinfo
=
os
.
stat
(
yaml_path
)
if
statinfo
.
st_size
>=
10485760
:
if
statinfo
.
st_size
>=
10485760
:
...
@@ -33,8 +34,11 @@ class UpscalerLDSR(Upscaler):
...
@@ -33,8 +34,11 @@ class UpscalerLDSR(Upscaler):
if
os
.
path
.
exists
(
old_model_path
):
if
os
.
path
.
exists
(
old_model_path
):
print
(
"Renaming model from model.pth to model.ckpt"
)
print
(
"Renaming model from model.pth to model.ckpt"
)
os
.
rename
(
old_model_path
,
new_model_path
)
os
.
rename
(
old_model_path
,
new_model_path
)
model
=
load_file_from_url
(
url
=
self
.
model_url
,
model_dir
=
self
.
model_path
,
if
os
.
path
.
exists
(
safetensors_model_path
):
file_name
=
"model.ckpt"
,
progress
=
True
)
model
=
safetensors_model_path
else
:
model
=
load_file_from_url
(
url
=
self
.
model_url
,
model_dir
=
self
.
model_path
,
file_name
=
"model.ckpt"
,
progress
=
True
)
yaml
=
load_file_from_url
(
url
=
self
.
yaml_url
,
model_dir
=
self
.
model_path
,
yaml
=
load_file_from_url
(
url
=
self
.
yaml_url
,
model_dir
=
self
.
model_path
,
file_name
=
"project.yaml"
,
progress
=
True
)
file_name
=
"project.yaml"
,
progress
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment