Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
63781563
Commit
63781563
authored
Nov 20, 2022
by
Tim Patton
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Generalize SD torch load/save to implement safetensor merging compat
parent
ac7ecd2d
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
1840 additions
and
1826 deletions
+1840
-1826
extras.py
modules/extras.py
+8
-7
sd_models.py
modules/sd_models.py
+18
-7
ui.py
modules/ui.py
+1814
-1812
No files found.
modules/extras.py
View file @
63781563
...
@@ -249,7 +249,7 @@ def run_pnginfo(image):
...
@@ -249,7 +249,7 @@ def run_pnginfo(image):
return
''
,
geninfo
,
info
return
''
,
geninfo
,
info
def
run_modelmerger
(
primary_model_name
,
secondary_model_name
,
teritary_model_name
,
interp_method
,
multiplier
,
save_as_half
,
custom_name
):
def
run_modelmerger
(
primary_model_name
,
secondary_model_name
,
teritary_model_name
,
interp_method
,
multiplier
,
save_as_half
,
save_as_safetensors
,
custom_name
):
def
weighted_sum
(
theta0
,
theta1
,
alpha
):
def
weighted_sum
(
theta0
,
theta1
,
alpha
):
return
((
1
-
alpha
)
*
theta0
)
+
(
alpha
*
theta1
)
return
((
1
-
alpha
)
*
theta0
)
+
(
alpha
*
theta1
)
...
@@ -264,16 +264,16 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
...
@@ -264,16 +264,16 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
teritary_model_info
=
sd_models
.
checkpoints_list
.
get
(
teritary_model_name
,
None
)
teritary_model_info
=
sd_models
.
checkpoints_list
.
get
(
teritary_model_name
,
None
)
print
(
f
"Loading {primary_model_info.filename}..."
)
print
(
f
"Loading {primary_model_info.filename}..."
)
primary_model
=
torch
.
load
(
primary_model_info
.
filename
,
map_location
=
'cpu'
)
primary_model
=
sd_models
.
torch_load
(
primary_model_info
.
filename
,
primary_model_info
,
map_override
=
'cpu'
)
theta_0
=
sd_models
.
get_state_dict_from_checkpoint
(
primary_model
)
theta_0
=
sd_models
.
get_state_dict_from_checkpoint
(
primary_model
)
print
(
f
"Loading {secondary_model_info.filename}..."
)
print
(
f
"Loading {secondary_model_info.filename}..."
)
secondary_model
=
torch
.
load
(
secondary_model_info
.
filename
,
map_location
=
'cpu'
)
secondary_model
=
sd_models
.
torch_load
(
secondary_model_info
.
filename
,
primary_model_info
,
map_override
=
'cpu'
)
theta_1
=
sd_models
.
get_state_dict_from_checkpoint
(
secondary_model
)
theta_1
=
sd_models
.
get_state_dict_from_checkpoint
(
secondary_model
)
if
teritary_model_info
is
not
None
:
if
teritary_model_info
is
not
None
:
print
(
f
"Loading {teritary_model_info.filename}..."
)
print
(
f
"Loading {teritary_model_info.filename}..."
)
teritary_model
=
torch
.
load
(
teritary_model_info
.
filename
,
map_location
=
'cpu'
)
teritary_model
=
sd_models
.
torch_load
(
teritary_model_info
.
filename
,
teritary_model_info
,
map_override
=
'cpu'
)
theta_2
=
sd_models
.
get_state_dict_from_checkpoint
(
teritary_model
)
theta_2
=
sd_models
.
get_state_dict_from_checkpoint
(
teritary_model
)
else
:
else
:
teritary_model
=
None
teritary_model
=
None
...
@@ -314,12 +314,13 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
...
@@ -314,12 +314,13 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
ckpt_dir
=
shared
.
cmd_opts
.
ckpt_dir
or
sd_models
.
model_path
ckpt_dir
=
shared
.
cmd_opts
.
ckpt_dir
or
sd_models
.
model_path
filename
=
primary_model_info
.
model_name
+
'_'
+
str
(
round
(
1
-
multiplier
,
2
))
+
'-'
+
secondary_model_info
.
model_name
+
'_'
+
str
(
round
(
multiplier
,
2
))
+
'-'
+
interp_method
.
replace
(
" "
,
"_"
)
+
'-merged.ckpt'
output_exttype
=
'.safetensors'
if
save_as_safetensors
else
'.ckpt'
filename
=
filename
if
custom_name
==
''
else
(
custom_name
+
'.ckpt'
)
filename
=
primary_model_info
.
model_name
+
'_'
+
str
(
round
(
1
-
multiplier
,
2
))
+
'-'
+
secondary_model_info
.
model_name
+
'_'
+
str
(
round
(
multiplier
,
2
))
+
'-'
+
interp_method
.
replace
(
" "
,
"_"
)
+
'-merged'
+
output_exttype
filename
=
filename
if
custom_name
==
''
else
(
custom_name
+
output_exttype
)
output_modelname
=
os
.
path
.
join
(
ckpt_dir
,
filename
)
output_modelname
=
os
.
path
.
join
(
ckpt_dir
,
filename
)
print
(
f
"Saving to {output_modelname}..."
)
print
(
f
"Saving to {output_modelname}..."
)
torch
.
save
(
primary_model
,
output_modelname
)
sd_models
.
torch_
save
(
primary_model
,
output_modelname
)
sd_models
.
list_models
()
sd_models
.
list_models
()
...
...
modules/sd_models.py
View file @
63781563
...
@@ -4,7 +4,7 @@ import sys
...
@@ -4,7 +4,7 @@ import sys
import
gc
import
gc
from
collections
import
namedtuple
from
collections
import
namedtuple
import
torch
import
torch
from
safetensors.torch
import
load_file
from
safetensors.torch
import
load_file
,
save_file
import
re
import
re
from
omegaconf
import
OmegaConf
from
omegaconf
import
OmegaConf
...
@@ -143,6 +143,22 @@ def transform_checkpoint_dict_key(k):
...
@@ -143,6 +143,22 @@ def transform_checkpoint_dict_key(k):
return
k
return
k
def
torch_load
(
model_filename
,
model_info
,
map_override
=
None
):
map_override
=
shared
.
weight_load_location
if
not
map_override
else
map_override
if
(
checkpoint_types
[
model_info
.
exttype
]
==
'safetensors'
):
# safely load weights
# TODO: safetensors supports zero copy fast load to gpu, see issue #684
return
load_file
(
model_filename
,
device
=
map_override
)
else
:
return
torch
.
load
(
model_filename
,
map_location
=
map_override
)
def
torch_save
(
model
,
output_filename
):
basename
,
exttype
=
os
.
path
.
splitext
(
output_filename
)
if
(
checkpoint_types
[
exttype
]
==
'safetensors'
):
# [===== >] Reticulating brines...
save_file
(
model
,
output_filename
,
metadata
=
{
"format"
:
"pt"
})
else
:
torch
.
save
(
model
,
output_filename
)
def
get_state_dict_from_checkpoint
(
pl_sd
):
def
get_state_dict_from_checkpoint
(
pl_sd
):
if
"state_dict"
in
pl_sd
:
if
"state_dict"
in
pl_sd
:
...
@@ -175,12 +191,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
...
@@ -175,12 +191,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
# load from file
# load from file
print
(
f
"Loading weights [{sd_model_hash}] from {checkpoint_file}"
)
print
(
f
"Loading weights [{sd_model_hash}] from {checkpoint_file}"
)
if
(
checkpoint_types
[
checkpoint_info
.
exttype
]
==
'safetensors'
):
pl_sd
=
torch_load
(
checkpoint_file
,
checkpoint_info
)
# safely load weights
# TODO: safetensors supports zero copy fast load to gpu, see issue #684
pl_sd
=
load_file
(
checkpoint_file
,
device
=
shared
.
weight_load_location
)
else
:
pl_sd
=
torch
.
load
(
checkpoint_file
,
map_location
=
shared
.
weight_load_location
)
if
"global_step"
in
pl_sd
:
if
"global_step"
in
pl_sd
:
print
(
f
"Global Step: {pl_sd['global_step']}"
)
print
(
f
"Global Step: {pl_sd['global_step']}"
)
...
...
modules/ui.py
View file @
63781563
This source diff could not be displayed because it is too large. You can
view the blob
instead.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment