Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
d13ce89e
Unverified
Commit
d13ce89e
authored
Oct 15, 2022
by
AUTOMATIC1111
Committed by
GitHub
Oct 15, 2022
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #2573 from raefu/ckpt-cache
add --ckpt-cache option for faster model switching
parents
6a4e8467
af144ebd
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
26 deletions
+33
-26
sd_models.py
modules/sd_models.py
+32
-26
shared.py
modules/shared.py
+1
-0
No files found.
modules/sd_models.py
View file @
d13ce89e
import
glob
import
collections
import
os.path
import
os.path
import
sys
import
sys
from
collections
import
namedtuple
from
collections
import
namedtuple
...
@@ -15,6 +15,7 @@ model_path = os.path.abspath(os.path.join(models_path, model_dir))
...
@@ -15,6 +15,7 @@ model_path = os.path.abspath(os.path.join(models_path, model_dir))
CheckpointInfo
=
namedtuple
(
"CheckpointInfo"
,
[
'filename'
,
'title'
,
'hash'
,
'model_name'
,
'config'
])
CheckpointInfo
=
namedtuple
(
"CheckpointInfo"
,
[
'filename'
,
'title'
,
'hash'
,
'model_name'
,
'config'
])
checkpoints_list
=
{}
checkpoints_list
=
{}
checkpoints_loaded
=
collections
.
OrderedDict
()
try
:
try
:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
...
@@ -132,41 +133,45 @@ def load_model_weights(model, checkpoint_info):
...
@@ -132,41 +133,45 @@ def load_model_weights(model, checkpoint_info):
checkpoint_file
=
checkpoint_info
.
filename
checkpoint_file
=
checkpoint_info
.
filename
sd_model_hash
=
checkpoint_info
.
hash
sd_model_hash
=
checkpoint_info
.
hash
print
(
f
"Loading weights [{sd_model_hash}] from {checkpoint_file}"
)
if
checkpoint_info
not
in
checkpoints_loaded
:
print
(
f
"Loading weights [{sd_model_hash}] from {checkpoint_file}"
)
pl_sd
=
torch
.
load
(
checkpoint_file
,
map_location
=
shared
.
weight_load_location
)
pl_sd
=
torch
.
load
(
checkpoint_file
,
map_location
=
shared
.
weight_load_location
)
if
"global_step"
in
pl_sd
:
print
(
f
"Global Step: {pl_sd['global_step']}"
)
if
"global_step"
in
pl_sd
:
sd
=
get_state_dict_from_checkpoint
(
pl_sd
)
print
(
f
"Global Step: {pl_sd['global_step']}"
)
model
.
load_state_dict
(
sd
,
strict
=
False
)
sd
=
get_state_dict_from_checkpoint
(
pl_sd
)
if
shared
.
cmd_opts
.
opt_channelslast
:
model
.
to
(
memory_format
=
torch
.
channels_last
)
model
.
load_state_dict
(
sd
,
strict
=
False
)
if
not
shared
.
cmd_opts
.
no_half
:
model
.
half
()
if
shared
.
cmd_opts
.
opt_channelslast
:
devices
.
dtype
=
torch
.
float32
if
shared
.
cmd_opts
.
no_half
else
torch
.
float16
model
.
to
(
memory_format
=
torch
.
channels_last
)
devices
.
dtype_vae
=
torch
.
float32
if
shared
.
cmd_opts
.
no_half
or
shared
.
cmd_opts
.
no_half_vae
else
torch
.
float16
if
not
shared
.
cmd_opts
.
no_half
:
vae_file
=
os
.
path
.
splitext
(
checkpoint_file
)[
0
]
+
".vae.pt"
model
.
half
()
devices
.
dtype
=
torch
.
float32
if
shared
.
cmd_opts
.
no_half
else
torch
.
float16
if
not
os
.
path
.
exists
(
vae_file
)
and
shared
.
cmd_opts
.
vae_path
is
not
None
:
devices
.
dtype_vae
=
torch
.
float32
if
shared
.
cmd_opts
.
no_half
or
shared
.
cmd_opts
.
no_half_vae
else
torch
.
float16
vae_file
=
shared
.
cmd_opts
.
vae_path
vae_file
=
os
.
path
.
splitext
(
checkpoint_file
)[
0
]
+
".vae.pt"
if
os
.
path
.
exists
(
vae_file
):
print
(
f
"Loading VAE weights from: {vae_file}"
)
vae_ckpt
=
torch
.
load
(
vae_file
,
map_location
=
shared
.
weight_load_location
)
vae_dict
=
{
k
:
v
for
k
,
v
in
vae_ckpt
[
"state_dict"
]
.
items
()
if
k
[
0
:
4
]
!=
"loss"
}
model
.
first_stage_model
.
load_state_dict
(
vae_dict
)
if
not
os
.
path
.
exists
(
vae_file
)
and
shared
.
cmd_opts
.
vae_path
is
not
None
:
model
.
first_stage_model
.
to
(
devices
.
dtype_vae
)
vae_file
=
shared
.
cmd_opts
.
vae_path
if
os
.
path
.
exists
(
vae_file
):
checkpoints_loaded
[
checkpoint_info
]
=
model
.
state_dict
()
.
copy
()
print
(
f
"Loading VAE weights from: {vae_file}"
)
while
len
(
checkpoints_loaded
)
>
shared
.
opts
.
sd_checkpoint_cache
:
checkpoints_loaded
.
popitem
(
last
=
False
)
# LRU
vae_ckpt
=
torch
.
load
(
vae_file
,
map_location
=
shared
.
weight_load_location
)
else
:
print
(
f
"Loading weights [{sd_model_hash}] from cache"
)
vae_dict
=
{
k
:
v
for
k
,
v
in
vae_ckpt
[
"state_dict"
]
.
items
()
if
k
[
0
:
4
]
!=
"loss"
}
checkpoints_loaded
.
move_to_end
(
checkpoint_info
)
model
.
load_state_dict
(
checkpoints_loaded
[
checkpoint_info
])
model
.
first_stage_model
.
load_state_dict
(
vae_dict
)
model
.
first_stage_model
.
to
(
devices
.
dtype_vae
)
model
.
sd_model_hash
=
sd_model_hash
model
.
sd_model_hash
=
sd_model_hash
model
.
sd_model_checkpoint
=
checkpoint_file
model
.
sd_model_checkpoint
=
checkpoint_file
...
@@ -205,6 +210,7 @@ def reload_model_weights(sd_model, info=None):
...
@@ -205,6 +210,7 @@ def reload_model_weights(sd_model, info=None):
return
return
if
sd_model
.
sd_checkpoint_info
.
config
!=
checkpoint_info
.
config
:
if
sd_model
.
sd_checkpoint_info
.
config
!=
checkpoint_info
.
config
:
checkpoints_loaded
.
clear
()
shared
.
sd_model
=
load_model
()
shared
.
sd_model
=
load_model
()
return
shared
.
sd_model
return
shared
.
sd_model
...
...
modules/shared.py
View file @
d13ce89e
...
@@ -242,6 +242,7 @@ options_templates.update(options_section(('training', "Training"), {
...
@@ -242,6 +242,7 @@ options_templates.update(options_section(('training', "Training"), {
options_templates
.
update
(
options_section
((
'sd'
,
"Stable Diffusion"
),
{
options_templates
.
update
(
options_section
((
'sd'
,
"Stable Diffusion"
),
{
"sd_model_checkpoint"
:
OptionInfo
(
None
,
"Stable Diffusion checkpoint"
,
gr
.
Dropdown
,
lambda
:
{
"choices"
:
modules
.
sd_models
.
checkpoint_tiles
()},
refresh
=
sd_models
.
list_models
),
"sd_model_checkpoint"
:
OptionInfo
(
None
,
"Stable Diffusion checkpoint"
,
gr
.
Dropdown
,
lambda
:
{
"choices"
:
modules
.
sd_models
.
checkpoint_tiles
()},
refresh
=
sd_models
.
list_models
),
"sd_checkpoint_cache"
:
OptionInfo
(
0
,
"Checkpoints to cache in RAM"
,
gr
.
Slider
,
{
"minimum"
:
0
,
"maximum"
:
10
,
"step"
:
1
}),
"sd_hypernetwork"
:
OptionInfo
(
"None"
,
"Hypernetwork"
,
gr
.
Dropdown
,
lambda
:
{
"choices"
:
[
"None"
]
+
[
x
for
x
in
hypernetworks
.
keys
()]},
refresh
=
reload_hypernetworks
),
"sd_hypernetwork"
:
OptionInfo
(
"None"
,
"Hypernetwork"
,
gr
.
Dropdown
,
lambda
:
{
"choices"
:
[
"None"
]
+
[
x
for
x
in
hypernetworks
.
keys
()]},
refresh
=
reload_hypernetworks
),
"sd_hypernetwork_strength"
:
OptionInfo
(
1.0
,
"Hypernetwork strength"
,
gr
.
Slider
,
{
"minimum"
:
0.0
,
"maximum"
:
1.0
,
"step"
:
0.001
}),
"sd_hypernetwork_strength"
:
OptionInfo
(
1.0
,
"Hypernetwork strength"
,
gr
.
Slider
,
{
"minimum"
:
0.0
,
"maximum"
:
1.0
,
"step"
:
0.001
}),
"img2img_color_correction"
:
OptionInfo
(
False
,
"Apply color correction to img2img results to match original colors."
),
"img2img_color_correction"
:
OptionInfo
(
False
,
"Apply color correction to img2img results to match original colors."
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment