Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
675b51eb
Unverified
Commit
675b51eb
authored
Nov 02, 2022
by
AUTOMATIC1111
Committed by
GitHub
Nov 02, 2022
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #3986 from R-N/vae-picker
VAE Selector
parents
e359268b
a5409a6e
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
233 additions
and
28 deletions
+233
-28
Put VAE here.txt
models/VAE/Put VAE here.txt
+0
-0
sd_models.py
modules/sd_models.py
+17
-24
sd_vae.py
modules/sd_vae.py
+207
-0
shared.py
modules/shared.py
+5
-3
style.css
style.css
+1
-1
webui.py
webui.py
+3
-0
No files found.
models/VAE/Put VAE here.txt
0 → 100644
View file @
675b51eb
modules/sd_models.py
View file @
675b51eb
...
...
@@ -9,7 +9,7 @@ from omegaconf import OmegaConf
from
ldm.util
import
instantiate_from_config
from
modules
import
shared
,
modelloader
,
devices
,
script_callbacks
from
modules
import
shared
,
modelloader
,
devices
,
script_callbacks
,
sd_vae
from
modules.paths
import
models_path
from
modules.sd_hijack_inpainting
import
do_inpainting_hijack
,
should_hijack_inpainting
...
...
@@ -159,14 +159,15 @@ def get_state_dict_from_checkpoint(pl_sd):
return
pl_sd
vae_ignore_keys
=
{
"model_ema.decay"
,
"model_ema.num_updates"
}
def
load_model_weights
(
model
,
checkpoint_info
):
def
load_model_weights
(
model
,
checkpoint_info
,
vae_file
=
"auto"
):
checkpoint_file
=
checkpoint_info
.
filename
sd_model_hash
=
checkpoint_info
.
hash
if
checkpoint_info
not
in
checkpoints_loaded
:
vae_file
=
sd_vae
.
resolve_vae
(
checkpoint_file
,
vae_file
=
vae_file
)
checkpoint_key
=
checkpoint_info
if
checkpoint_key
not
in
checkpoints_loaded
:
print
(
f
"Loading weights [{sd_model_hash}] from {checkpoint_file}"
)
pl_sd
=
torch
.
load
(
checkpoint_file
,
map_location
=
shared
.
weight_load_location
)
...
...
@@ -187,32 +188,24 @@ def load_model_weights(model, checkpoint_info):
devices
.
dtype
=
torch
.
float32
if
shared
.
cmd_opts
.
no_half
else
torch
.
float16
devices
.
dtype_vae
=
torch
.
float32
if
shared
.
cmd_opts
.
no_half
or
shared
.
cmd_opts
.
no_half_vae
else
torch
.
float16
vae_file
=
os
.
path
.
splitext
(
checkpoint_file
)[
0
]
+
".vae.pt"
if
not
os
.
path
.
exists
(
vae_file
)
and
shared
.
cmd_opts
.
vae_path
is
not
None
:
vae_file
=
shared
.
cmd_opts
.
vae_path
if
os
.
path
.
exists
(
vae_file
):
print
(
f
"Loading VAE weights from: {vae_file}"
)
vae_ckpt
=
torch
.
load
(
vae_file
,
map_location
=
shared
.
weight_load_location
)
vae_dict
=
{
k
:
v
for
k
,
v
in
vae_ckpt
[
"state_dict"
]
.
items
()
if
k
[
0
:
4
]
!=
"loss"
and
k
not
in
vae_ignore_keys
}
model
.
first_stage_model
.
load_state_dict
(
vae_dict
)
model
.
first_stage_model
.
to
(
devices
.
dtype_vae
)
if
shared
.
opts
.
sd_checkpoint_cache
>
0
:
checkpoints_loaded
[
checkpoint_info
]
=
model
.
state_dict
()
.
copy
()
# if PR #4035 were to get merged, restore base VAE first before caching
checkpoints_loaded
[
checkpoint_key
]
=
model
.
state_dict
()
.
copy
()
while
len
(
checkpoints_loaded
)
>
shared
.
opts
.
sd_checkpoint_cache
:
checkpoints_loaded
.
popitem
(
last
=
False
)
# LRU
else
:
print
(
f
"Loading weights [{sd_model_hash}] from cache"
)
checkpoints_loaded
.
move_to_end
(
checkpoint_info
)
model
.
load_state_dict
(
checkpoints_loaded
[
checkpoint_info
])
vae_name
=
sd_vae
.
get_filename
(
vae_file
)
print
(
f
"Loading weights [{sd_model_hash}] with {vae_name} VAE from cache"
)
checkpoints_loaded
.
move_to_end
(
checkpoint_key
)
model
.
load_state_dict
(
checkpoints_loaded
[
checkpoint_key
])
model
.
sd_model_hash
=
sd_model_hash
model
.
sd_model_checkpoint
=
checkpoint_file
model
.
sd_checkpoint_info
=
checkpoint_info
sd_vae
.
load_vae
(
model
,
vae_file
)
def
load_model
(
checkpoint_info
=
None
):
from
modules
import
lowvram
,
sd_hijack
...
...
@@ -263,7 +256,7 @@ def load_model(checkpoint_info=None):
def
reload_model_weights
(
sd_model
=
None
,
info
=
None
):
from
modules
import
lowvram
,
devices
,
sd_hijack
checkpoint_info
=
info
or
select_checkpoint
()
if
not
sd_model
:
sd_model
=
shared
.
sd_model
...
...
modules/sd_vae.py
0 → 100644
View file @
675b51eb
import
torch
import
os
from
collections
import
namedtuple
from
modules
import
shared
,
devices
,
script_callbacks
from
modules.paths
import
models_path
import
glob
model_dir
=
"Stable-diffusion"
model_path
=
os
.
path
.
abspath
(
os
.
path
.
join
(
models_path
,
model_dir
))
vae_dir
=
"VAE"
vae_path
=
os
.
path
.
abspath
(
os
.
path
.
join
(
models_path
,
vae_dir
))
vae_ignore_keys
=
{
"model_ema.decay"
,
"model_ema.num_updates"
}
default_vae_dict
=
{
"auto"
:
"auto"
,
"None"
:
"None"
}
default_vae_list
=
[
"auto"
,
"None"
]
default_vae_values
=
[
default_vae_dict
[
x
]
for
x
in
default_vae_list
]
vae_dict
=
dict
(
default_vae_dict
)
vae_list
=
list
(
default_vae_list
)
first_load
=
True
base_vae
=
None
loaded_vae_file
=
None
checkpoint_info
=
None
def
get_base_vae
(
model
):
if
base_vae
is
not
None
and
checkpoint_info
==
model
.
sd_checkpoint_info
and
model
:
return
base_vae
return
None
def
store_base_vae
(
model
):
global
base_vae
,
checkpoint_info
if
checkpoint_info
!=
model
.
sd_checkpoint_info
:
base_vae
=
model
.
first_stage_model
.
state_dict
()
.
copy
()
checkpoint_info
=
model
.
sd_checkpoint_info
def
delete_base_vae
():
global
base_vae
,
checkpoint_info
base_vae
=
None
checkpoint_info
=
None
def
restore_base_vae
(
model
):
global
base_vae
,
checkpoint_info
if
base_vae
is
not
None
and
checkpoint_info
==
model
.
sd_checkpoint_info
:
load_vae_dict
(
model
,
base_vae
)
delete_base_vae
()
def
get_filename
(
filepath
):
return
os
.
path
.
splitext
(
os
.
path
.
basename
(
filepath
))[
0
]
def
refresh_vae_list
(
vae_path
=
vae_path
,
model_path
=
model_path
):
global
vae_dict
,
vae_list
res
=
{}
candidates
=
[
*
glob
.
iglob
(
os
.
path
.
join
(
model_path
,
'**/*.vae.ckpt'
),
recursive
=
True
),
*
glob
.
iglob
(
os
.
path
.
join
(
model_path
,
'**/*.vae.pt'
),
recursive
=
True
),
*
glob
.
iglob
(
os
.
path
.
join
(
vae_path
,
'**/*.ckpt'
),
recursive
=
True
),
*
glob
.
iglob
(
os
.
path
.
join
(
vae_path
,
'**/*.pt'
),
recursive
=
True
)
]
if
shared
.
cmd_opts
.
vae_path
is
not
None
and
os
.
path
.
isfile
(
shared
.
cmd_opts
.
vae_path
):
candidates
.
append
(
shared
.
cmd_opts
.
vae_path
)
for
filepath
in
candidates
:
name
=
get_filename
(
filepath
)
res
[
name
]
=
filepath
vae_list
.
clear
()
vae_list
.
extend
(
default_vae_list
)
vae_list
.
extend
(
list
(
res
.
keys
()))
vae_dict
.
clear
()
vae_dict
.
update
(
res
)
vae_dict
.
update
(
default_vae_dict
)
return
vae_list
def
resolve_vae
(
checkpoint_file
,
vae_file
=
"auto"
):
global
first_load
,
vae_dict
,
vae_list
# if vae_file argument is provided, it takes priority, but not saved
if
vae_file
and
vae_file
not
in
default_vae_list
:
if
not
os
.
path
.
isfile
(
vae_file
):
vae_file
=
"auto"
print
(
"VAE provided as function argument doesn't exist"
)
# for the first load, if vae-path is provided, it takes priority, saved, and failure is reported
if
first_load
and
shared
.
cmd_opts
.
vae_path
is
not
None
:
if
os
.
path
.
isfile
(
shared
.
cmd_opts
.
vae_path
):
vae_file
=
shared
.
cmd_opts
.
vae_path
shared
.
opts
.
data
[
'sd_vae'
]
=
get_filename
(
vae_file
)
else
:
print
(
"VAE provided as command line argument doesn't exist"
)
# else, we load from settings
if
vae_file
==
"auto"
and
shared
.
opts
.
sd_vae
is
not
None
:
# if saved VAE settings isn't recognized, fallback to auto
vae_file
=
vae_dict
.
get
(
shared
.
opts
.
sd_vae
,
"auto"
)
# if VAE selected but not found, fallback to auto
if
vae_file
not
in
default_vae_values
and
not
os
.
path
.
isfile
(
vae_file
):
vae_file
=
"auto"
print
(
"Selected VAE doesn't exist"
)
# vae-path cmd arg takes priority for auto
if
vae_file
==
"auto"
and
shared
.
cmd_opts
.
vae_path
is
not
None
:
if
os
.
path
.
isfile
(
shared
.
cmd_opts
.
vae_path
):
vae_file
=
shared
.
cmd_opts
.
vae_path
print
(
"Using VAE provided as command line argument"
)
# if still not found, try look for ".vae.pt" beside model
model_path
=
os
.
path
.
splitext
(
checkpoint_file
)[
0
]
if
vae_file
==
"auto"
:
vae_file_try
=
model_path
+
".vae.pt"
if
os
.
path
.
isfile
(
vae_file_try
):
vae_file
=
vae_file_try
print
(
"Using VAE found beside selected model"
)
# if still not found, try look for ".vae.ckpt" beside model
if
vae_file
==
"auto"
:
vae_file_try
=
model_path
+
".vae.ckpt"
if
os
.
path
.
isfile
(
vae_file_try
):
vae_file
=
vae_file_try
print
(
"Using VAE found beside selected model"
)
# No more fallbacks for auto
if
vae_file
==
"auto"
:
vae_file
=
None
# Last check, just because
if
vae_file
and
not
os
.
path
.
exists
(
vae_file
):
vae_file
=
None
return
vae_file
def
load_vae
(
model
,
vae_file
=
None
):
global
first_load
,
vae_dict
,
vae_list
,
loaded_vae_file
# save_settings = False
if
vae_file
:
print
(
f
"Loading VAE weights from: {vae_file}"
)
vae_ckpt
=
torch
.
load
(
vae_file
,
map_location
=
shared
.
weight_load_location
)
vae_dict_1
=
{
k
:
v
for
k
,
v
in
vae_ckpt
[
"state_dict"
]
.
items
()
if
k
[
0
:
4
]
!=
"loss"
and
k
not
in
vae_ignore_keys
}
load_vae_dict
(
model
,
vae_dict_1
)
# If vae used is not in dict, update it
# It will be removed on refresh though
vae_opt
=
get_filename
(
vae_file
)
if
vae_opt
not
in
vae_dict
:
vae_dict
[
vae_opt
]
=
vae_file
vae_list
.
append
(
vae_opt
)
loaded_vae_file
=
vae_file
"""
# Save current VAE to VAE settings, maybe? will it work?
if save_settings:
if vae_file is None:
vae_opt = "None"
# shared.opts.sd_vae = vae_opt
"""
first_load
=
False
# don't call this from outside
def
load_vae_dict
(
model
,
vae_dict_1
=
None
):
if
vae_dict_1
:
store_base_vae
(
model
)
model
.
first_stage_model
.
load_state_dict
(
vae_dict_1
)
else
:
restore_base_vae
()
model
.
first_stage_model
.
to
(
devices
.
dtype_vae
)
def
reload_vae_weights
(
sd_model
=
None
,
vae_file
=
"auto"
):
from
modules
import
lowvram
,
devices
,
sd_hijack
if
not
sd_model
:
sd_model
=
shared
.
sd_model
checkpoint_info
=
sd_model
.
sd_checkpoint_info
checkpoint_file
=
checkpoint_info
.
filename
vae_file
=
resolve_vae
(
checkpoint_file
,
vae_file
=
vae_file
)
if
loaded_vae_file
==
vae_file
:
return
if
shared
.
cmd_opts
.
lowvram
or
shared
.
cmd_opts
.
medvram
:
lowvram
.
send_everything_to_cpu
()
else
:
sd_model
.
to
(
devices
.
cpu
)
sd_hijack
.
model_hijack
.
undo_hijack
(
sd_model
)
load_vae
(
sd_model
,
vae_file
)
sd_hijack
.
model_hijack
.
hijack
(
sd_model
)
script_callbacks
.
model_loaded_callback
(
sd_model
)
if
not
shared
.
cmd_opts
.
lowvram
and
not
shared
.
cmd_opts
.
medvram
:
sd_model
.
to
(
devices
.
device
)
print
(
f
"VAE Weights loaded."
)
return
sd_model
modules/shared.py
View file @
675b51eb
...
...
@@ -15,7 +15,7 @@ import modules.memmon
import
modules.sd_models
import
modules.styles
import
modules.devices
as
devices
from
modules
import
sd_samplers
,
sd_models
,
localization
from
modules
import
sd_samplers
,
sd_models
,
localization
,
sd_vae
from
modules.hypernetworks
import
hypernetwork
from
modules.paths
import
models_path
,
script_path
,
sd_path
...
...
@@ -319,6 +319,7 @@ options_templates.update(options_section(('training', "Training"), {
options_templates
.
update
(
options_section
((
'sd'
,
"Stable Diffusion"
),
{
"sd_model_checkpoint"
:
OptionInfo
(
None
,
"Stable Diffusion checkpoint"
,
gr
.
Dropdown
,
lambda
:
{
"choices"
:
modules
.
sd_models
.
checkpoint_tiles
()},
refresh
=
sd_models
.
list_models
),
"sd_checkpoint_cache"
:
OptionInfo
(
0
,
"Checkpoints to cache in RAM"
,
gr
.
Slider
,
{
"minimum"
:
0
,
"maximum"
:
10
,
"step"
:
1
}),
"sd_vae"
:
OptionInfo
(
"auto"
,
"SD VAE"
,
gr
.
Dropdown
,
lambda
:
{
"choices"
:
list
(
sd_vae
.
vae_list
)},
refresh
=
sd_vae
.
refresh_vae_list
),
"sd_hypernetwork"
:
OptionInfo
(
"None"
,
"Hypernetwork"
,
gr
.
Dropdown
,
lambda
:
{
"choices"
:
[
"None"
]
+
[
x
for
x
in
hypernetworks
.
keys
()]},
refresh
=
reload_hypernetworks
),
"sd_hypernetwork_strength"
:
OptionInfo
(
1.0
,
"Hypernetwork strength"
,
gr
.
Slider
,
{
"minimum"
:
0.0
,
"maximum"
:
1.0
,
"step"
:
0.001
}),
"inpainting_mask_weight"
:
OptionInfo
(
1.0
,
"Inpainting conditioning mask strength"
,
gr
.
Slider
,
{
"minimum"
:
0.0
,
"maximum"
:
1.0
,
"step"
:
0.01
}),
...
...
@@ -437,11 +438,12 @@ class Options:
if
bad_settings
>
0
:
print
(
f
"The program is likely to not work with bad settings.
\n
Settings file: {filename}
\n
Either fix the file, or delete it and restart."
,
file
=
sys
.
stderr
)
def
onchange
(
self
,
key
,
func
):
def
onchange
(
self
,
key
,
func
,
call
=
True
):
item
=
self
.
data_labels
.
get
(
key
)
item
.
onchange
=
func
func
()
if
call
:
func
()
def
dumpjson
(
self
):
d
=
{
k
:
self
.
data
.
get
(
k
,
self
.
data_labels
.
get
(
k
)
.
default
)
for
k
in
self
.
data_labels
.
keys
()}
...
...
style.css
View file @
675b51eb
...
...
@@ -501,7 +501,7 @@ input[type="range"]{
padding
:
0
;
}
#refresh_sd_model_checkpoint
,
#refresh_sd_hypernetwork
,
#refresh_train_hypernetwork_name
,
#refresh_train_embedding_name
,
#refresh_localization
{
#refresh_sd_model_checkpoint
,
#refresh_sd_
vae
,
#refresh_sd_
hypernetwork
,
#refresh_train_hypernetwork_name
,
#refresh_train_embedding_name
,
#refresh_localization
{
max-width
:
2.5em
;
min-width
:
2.5em
;
height
:
2.4em
;
...
...
webui.py
View file @
675b51eb
...
...
@@ -21,6 +21,7 @@ import modules.paths
import
modules.scripts
import
modules.sd_hijack
import
modules.sd_models
import
modules.sd_vae
import
modules.shared
as
shared
import
modules.txt2img
import
modules.script_callbacks
...
...
@@ -77,8 +78,10 @@ def initialize():
modules
.
scripts
.
load_scripts
()
modules
.
sd_vae
.
refresh_vae_list
()
modules
.
sd_models
.
load_model
()
shared
.
opts
.
onchange
(
"sd_model_checkpoint"
,
wrap_queued_call
(
lambda
:
modules
.
sd_models
.
reload_model_weights
()))
shared
.
opts
.
onchange
(
"sd_vae"
,
wrap_queued_call
(
lambda
:
modules
.
sd_vae
.
reload_vae_weights
()),
call
=
False
)
shared
.
opts
.
onchange
(
"sd_hypernetwork"
,
wrap_queued_call
(
lambda
:
modules
.
hypernetworks
.
hypernetwork
.
load_hypernetwork
(
shared
.
opts
.
sd_hypernetwork
)))
shared
.
opts
.
onchange
(
"sd_hypernetwork_strength"
,
modules
.
hypernetworks
.
hypernetwork
.
apply_strength
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment