Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
dac9b6f1
Commit
dac9b6f1
authored
Nov 27, 2022
by
AUTOMATIC
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add safetensors support for model merging #4869
parent
6074175f
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
35 additions
and
24 deletions
+35
-24
extras.py
modules/extras.py
+14
-12
sd_models.py
modules/sd_models.py
+15
-11
ui.py
modules/ui.py
+6
-1
No files found.
modules/extras.py
View file @
dac9b6f1
...
...
@@ -20,6 +20,7 @@ import modules.codeformer_model
import
piexif
import
piexif.helper
import
gradio
as
gr
import
safetensors.torch
class
LruCache
(
OrderedDict
):
...
...
@@ -249,7 +250,7 @@ def run_pnginfo(image):
return
''
,
geninfo
,
info
def
run_modelmerger
(
primary_model_name
,
secondary_model_name
,
teritary_model_name
,
interp_method
,
multiplier
,
save_as_half
,
custom_name
):
def
run_modelmerger
(
primary_model_name
,
secondary_model_name
,
teritary_model_name
,
interp_method
,
multiplier
,
save_as_half
,
custom_name
,
checkpoint_format
):
def
weighted_sum
(
theta0
,
theta1
,
alpha
):
return
((
1
-
alpha
)
*
theta0
)
+
(
alpha
*
theta1
)
...
...
@@ -264,19 +265,15 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
teritary_model_info
=
sd_models
.
checkpoints_list
.
get
(
teritary_model_name
,
None
)
print
(
f
"Loading {primary_model_info.filename}..."
)
primary_model
=
torch
.
load
(
primary_model_info
.
filename
,
map_location
=
'cpu'
)
theta_0
=
sd_models
.
get_state_dict_from_checkpoint
(
primary_model
)
theta_0
=
sd_models
.
read_state_dict
(
primary_model_info
.
filename
,
map_location
=
'cpu'
)
print
(
f
"Loading {secondary_model_info.filename}..."
)
secondary_model
=
torch
.
load
(
secondary_model_info
.
filename
,
map_location
=
'cpu'
)
theta_1
=
sd_models
.
get_state_dict_from_checkpoint
(
secondary_model
)
theta_1
=
sd_models
.
read_state_dict
(
secondary_model_info
.
filename
,
map_location
=
'cpu'
)
if
teritary_model_info
is
not
None
:
print
(
f
"Loading {teritary_model_info.filename}..."
)
teritary_model
=
torch
.
load
(
teritary_model_info
.
filename
,
map_location
=
'cpu'
)
theta_2
=
sd_models
.
get_state_dict_from_checkpoint
(
teritary_model
)
theta_2
=
sd_models
.
read_state_dict
(
teritary_model_info
.
filename
,
map_location
=
'cpu'
)
else
:
teritary_model
=
None
theta_2
=
None
theta_funcs
=
{
...
...
@@ -295,7 +292,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
theta_1
[
key
]
=
theta_func1
(
theta_1
[
key
],
t2
)
else
:
theta_1
[
key
]
=
torch
.
zeros_like
(
theta_1
[
key
])
del
theta_2
,
teritary_model
del
theta_2
for
key
in
tqdm
.
tqdm
(
theta_0
.
keys
()):
if
'model'
in
key
and
key
in
theta_1
:
...
...
@@ -314,12 +311,17 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
ckpt_dir
=
shared
.
cmd_opts
.
ckpt_dir
or
sd_models
.
model_path
filename
=
primary_model_info
.
model_name
+
'_'
+
str
(
round
(
1
-
multiplier
,
2
))
+
'-'
+
secondary_model_info
.
model_name
+
'_'
+
str
(
round
(
multiplier
,
2
))
+
'-'
+
interp_method
.
replace
(
" "
,
"_"
)
+
'-merged.
ckpt'
filename
=
filename
if
custom_name
==
''
else
(
custom_name
+
'.
ckpt'
)
filename
=
primary_model_info
.
model_name
+
'_'
+
str
(
round
(
1
-
multiplier
,
2
))
+
'-'
+
secondary_model_info
.
model_name
+
'_'
+
str
(
round
(
multiplier
,
2
))
+
'-'
+
interp_method
.
replace
(
" "
,
"_"
)
+
'-merged.
'
+
checkpoint_format
filename
=
filename
if
custom_name
==
''
else
(
custom_name
+
'.
'
+
checkpoint_format
)
output_modelname
=
os
.
path
.
join
(
ckpt_dir
,
filename
)
print
(
f
"Saving to {output_modelname}..."
)
torch
.
save
(
primary_model
,
output_modelname
)
_
,
extension
=
os
.
path
.
splitext
(
output_modelname
)
if
extension
.
lower
()
==
".safetensors"
:
safetensors
.
torch
.
save_file
(
theta_0
,
output_modelname
,
metadata
=
{
"format"
:
"pt"
})
else
:
torch
.
save
(
theta_0
,
output_modelname
)
sd_models
.
list_models
()
...
...
modules/sd_models.py
View file @
dac9b6f1
...
...
@@ -160,6 +160,20 @@ def get_state_dict_from_checkpoint(pl_sd):
return
pl_sd
def
read_state_dict
(
checkpoint_file
,
print_global_state
=
False
,
map_location
=
None
):
_
,
extension
=
os
.
path
.
splitext
(
checkpoint_file
)
if
extension
.
lower
()
==
".safetensors"
:
pl_sd
=
safetensors
.
torch
.
load_file
(
checkpoint_file
,
device
=
map_location
or
shared
.
weight_load_location
)
else
:
pl_sd
=
torch
.
load
(
checkpoint_file
,
map_location
=
map_location
or
shared
.
weight_load_location
)
if
print_global_state
and
"global_step"
in
pl_sd
:
print
(
f
"Global Step: {pl_sd['global_step']}"
)
sd
=
get_state_dict_from_checkpoint
(
pl_sd
)
return
sd
def
load_model_weights
(
model
,
checkpoint_info
,
vae_file
=
"auto"
):
checkpoint_file
=
checkpoint_info
.
filename
sd_model_hash
=
checkpoint_info
.
hash
...
...
@@ -174,17 +188,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
# load from file
print
(
f
"Loading weights [{sd_model_hash}] from {checkpoint_file}"
)
_
,
extension
=
os
.
path
.
splitext
(
checkpoint_file
)
if
extension
.
lower
()
==
".safetensors"
:
pl_sd
=
safetensors
.
torch
.
load_file
(
checkpoint_file
,
device
=
shared
.
weight_load_location
)
else
:
pl_sd
=
torch
.
load
(
checkpoint_file
,
map_location
=
shared
.
weight_load_location
)
if
"global_step"
in
pl_sd
:
print
(
f
"Global Step: {pl_sd['global_step']}"
)
sd
=
get_state_dict_from_checkpoint
(
pl_sd
)
del
pl_sd
sd
=
read_state_dict
(
checkpoint_file
)
model
.
load_state_dict
(
sd
,
strict
=
False
)
del
sd
...
...
modules/ui.py
View file @
dac9b6f1
...
...
@@ -1164,7 +1164,11 @@ def create_ui(wrap_gradio_gpu_call):
custom_name
=
gr
.
Textbox
(
label
=
"Custom Name (Optional)"
)
interp_amount
=
gr
.
Slider
(
minimum
=
0.0
,
maximum
=
1.0
,
step
=
0.05
,
label
=
'Multiplier (M) - set to 0 to get model A'
,
value
=
0.3
)
interp_method
=
gr
.
Radio
(
choices
=
[
"Weighted sum"
,
"Add difference"
],
value
=
"Weighted sum"
,
label
=
"Interpolation Method"
)
save_as_half
=
gr
.
Checkbox
(
value
=
False
,
label
=
"Save as float16"
)
with
gr
.
Row
():
checkpoint_format
=
gr
.
Radio
(
choices
=
[
"ckpt"
,
"safetensors"
],
value
=
"ckpt"
,
label
=
"Checkpoint format"
)
save_as_half
=
gr
.
Checkbox
(
value
=
False
,
label
=
"Save as float16"
)
modelmerger_merge
=
gr
.
Button
(
elem_id
=
"modelmerger_merge"
,
label
=
"Merge"
,
variant
=
'primary'
)
with
gr
.
Column
(
variant
=
'panel'
):
...
...
@@ -1692,6 +1696,7 @@ def create_ui(wrap_gradio_gpu_call):
interp_amount
,
save_as_half
,
custom_name
,
checkpoint_format
,
],
outputs
=
[
submit_result
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment