Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
893933e0
Commit
893933e0
authored
Dec 25, 2022
by
hitomi
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add memory cache for VAE weights
parent
c6f347b8
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
26 additions
and
6 deletions
+26
-6
sd_vae.py
modules/sd_vae.py
+25
-6
shared.py
modules/shared.py
+1
-0
No files found.
modules/sd_vae.py
View file @
893933e0
import
torch
import
torch
import
os
import
os
import
collections
from
collections
import
namedtuple
from
collections
import
namedtuple
from
modules
import
shared
,
devices
,
script_callbacks
from
modules
import
shared
,
devices
,
script_callbacks
from
modules.paths
import
models_path
from
modules.paths
import
models_path
...
@@ -30,6 +31,7 @@ base_vae = None
...
@@ -30,6 +31,7 @@ base_vae = None
loaded_vae_file
=
None
loaded_vae_file
=
None
checkpoint_info
=
None
checkpoint_info
=
None
checkpoints_loaded
=
collections
.
OrderedDict
()
def
get_base_vae
(
model
):
def
get_base_vae
(
model
):
if
base_vae
is
not
None
and
checkpoint_info
==
model
.
sd_checkpoint_info
and
model
:
if
base_vae
is
not
None
and
checkpoint_info
==
model
.
sd_checkpoint_info
and
model
:
...
@@ -149,13 +151,30 @@ def load_vae(model, vae_file=None):
...
@@ -149,13 +151,30 @@ def load_vae(model, vae_file=None):
global
first_load
,
vae_dict
,
vae_list
,
loaded_vae_file
global
first_load
,
vae_dict
,
vae_list
,
loaded_vae_file
# save_settings = False
# save_settings = False
cache_enabled
=
shared
.
opts
.
sd_vae_checkpoint_cache
>
0
if
vae_file
:
if
vae_file
:
assert
os
.
path
.
isfile
(
vae_file
),
f
"VAE file doesn't exist: {vae_file}"
if
cache_enabled
and
vae_file
in
checkpoints_loaded
:
print
(
f
"Loading VAE weights from: {vae_file}"
)
# use vae checkpoint cache
store_base_vae
(
model
)
print
(
f
"Loading VAE weights [{get_filename(vae_file)}] from cache"
)
vae_ckpt
=
torch
.
load
(
vae_file
,
map_location
=
shared
.
weight_load_location
)
store_base_vae
(
model
)
vae_dict_1
=
{
k
:
v
for
k
,
v
in
vae_ckpt
[
"state_dict"
]
.
items
()
if
k
[
0
:
4
]
!=
"loss"
and
k
not
in
vae_ignore_keys
}
_load_vae_dict
(
model
,
checkpoints_loaded
[
vae_file
])
_load_vae_dict
(
model
,
vae_dict_1
)
else
:
assert
os
.
path
.
isfile
(
vae_file
),
f
"VAE file doesn't exist: {vae_file}"
print
(
f
"Loading VAE weights from: {vae_file}"
)
store_base_vae
(
model
)
vae_ckpt
=
torch
.
load
(
vae_file
,
map_location
=
shared
.
weight_load_location
)
vae_dict_1
=
{
k
:
v
for
k
,
v
in
vae_ckpt
[
"state_dict"
]
.
items
()
if
k
[
0
:
4
]
!=
"loss"
and
k
not
in
vae_ignore_keys
}
_load_vae_dict
(
model
,
vae_dict_1
)
if
cache_enabled
:
# cache newly loaded vae
checkpoints_loaded
[
vae_file
]
=
vae_dict_1
.
copy
()
# clean up cache if limit is reached
if
cache_enabled
:
while
len
(
checkpoints_loaded
)
>
shared
.
opts
.
sd_vae_checkpoint_cache
+
1
:
# we need to count the current model
checkpoints_loaded
.
popitem
(
last
=
False
)
# LRU
# If vae used is not in dict, update it
# If vae used is not in dict, update it
# It will be removed on refresh though
# It will be removed on refresh though
...
...
modules/shared.py
View file @
893933e0
...
@@ -356,6 +356,7 @@ options_templates.update(options_section(('training', "Training"), {
...
@@ -356,6 +356,7 @@ options_templates.update(options_section(('training', "Training"), {
options_templates
.
update
(
options_section
((
'sd'
,
"Stable Diffusion"
),
{
options_templates
.
update
(
options_section
((
'sd'
,
"Stable Diffusion"
),
{
"sd_model_checkpoint"
:
OptionInfo
(
None
,
"Stable Diffusion checkpoint"
,
gr
.
Dropdown
,
lambda
:
{
"choices"
:
list_checkpoint_tiles
()},
refresh
=
refresh_checkpoints
),
"sd_model_checkpoint"
:
OptionInfo
(
None
,
"Stable Diffusion checkpoint"
,
gr
.
Dropdown
,
lambda
:
{
"choices"
:
list_checkpoint_tiles
()},
refresh
=
refresh_checkpoints
),
"sd_checkpoint_cache"
:
OptionInfo
(
0
,
"Checkpoints to cache in RAM"
,
gr
.
Slider
,
{
"minimum"
:
0
,
"maximum"
:
10
,
"step"
:
1
}),
"sd_checkpoint_cache"
:
OptionInfo
(
0
,
"Checkpoints to cache in RAM"
,
gr
.
Slider
,
{
"minimum"
:
0
,
"maximum"
:
10
,
"step"
:
1
}),
"sd_vae_checkpoint_cache"
:
OptionInfo
(
0
,
"VAE Checkpoints to cache in RAM"
,
gr
.
Slider
,
{
"minimum"
:
0
,
"maximum"
:
10
,
"step"
:
1
}),
"sd_vae"
:
OptionInfo
(
"auto"
,
"SD VAE"
,
gr
.
Dropdown
,
lambda
:
{
"choices"
:
sd_vae
.
vae_list
},
refresh
=
sd_vae
.
refresh_vae_list
),
"sd_vae"
:
OptionInfo
(
"auto"
,
"SD VAE"
,
gr
.
Dropdown
,
lambda
:
{
"choices"
:
sd_vae
.
vae_list
},
refresh
=
sd_vae
.
refresh_vae_list
),
"sd_vae_as_default"
:
OptionInfo
(
False
,
"Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"
),
"sd_vae_as_default"
:
OptionInfo
(
False
,
"Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"
),
"sd_hypernetwork"
:
OptionInfo
(
"None"
,
"Hypernetwork"
,
gr
.
Dropdown
,
lambda
:
{
"choices"
:
[
"None"
]
+
[
x
for
x
in
hypernetworks
.
keys
()]},
refresh
=
reload_hypernetworks
),
"sd_hypernetwork"
:
OptionInfo
(
"None"
,
"Hypernetwork"
,
gr
.
Dropdown
,
lambda
:
{
"choices"
:
[
"None"
]
+
[
x
for
x
in
hypernetworks
.
keys
()]},
refresh
=
reload_hypernetworks
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment