Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
ec5e0721
Unverified
Commit
ec5e0721
authored
Dec 10, 2022
by
AUTOMATIC1111
Committed by
GitHub
Dec 10, 2022
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #4841 from R-N/vae-fix-none
Fix None option of VAE selector
parents
e11d0d43
8662b5e5
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
20 deletions
+19
-20
sd_models.py
modules/sd_models.py
+2
-0
sd_vae.py
modules/sd_vae.py
+17
-20
No files found.
modules/sd_models.py
View file @
ec5e0721
...
...
@@ -227,6 +227,8 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
model
.
sd_model_checkpoint
=
checkpoint_file
model
.
sd_checkpoint_info
=
checkpoint_info
sd_vae
.
delete_base_vae
()
sd_vae
.
clear_loaded_vae
()
vae_file
=
sd_vae
.
resolve_vae
(
checkpoint_file
,
vae_file
=
vae_file
)
sd_vae
.
load_vae
(
model
,
vae_file
)
...
...
modules/sd_vae.py
View file @
ec5e0721
...
...
@@ -4,6 +4,7 @@ from collections import namedtuple
from
modules
import
shared
,
devices
,
script_callbacks
from
modules.paths
import
models_path
import
glob
from
copy
import
deepcopy
model_dir
=
"Stable-diffusion"
...
...
@@ -15,7 +16,7 @@ vae_path = os.path.abspath(os.path.join(models_path, vae_dir))
vae_ignore_keys
=
{
"model_ema.decay"
,
"model_ema.num_updates"
}
default_vae_dict
=
{
"auto"
:
"auto"
,
"None"
:
"None"
}
default_vae_dict
=
{
"auto"
:
"auto"
,
"None"
:
None
,
None
:
None
}
default_vae_list
=
[
"auto"
,
"None"
]
...
...
@@ -39,7 +40,8 @@ def get_base_vae(model):
def
store_base_vae
(
model
):
global
base_vae
,
checkpoint_info
if
checkpoint_info
!=
model
.
sd_checkpoint_info
:
base_vae
=
model
.
first_stage_model
.
state_dict
()
.
copy
()
assert
not
loaded_vae_file
,
"Trying to store non-base VAE!"
base_vae
=
deepcopy
(
model
.
first_stage_model
.
state_dict
())
checkpoint_info
=
model
.
sd_checkpoint_info
...
...
@@ -50,9 +52,11 @@ def delete_base_vae():
def
restore_base_vae
(
model
):
global
base_vae
,
checkpoint_info
global
loaded_vae_file
if
base_vae
is
not
None
and
checkpoint_info
==
model
.
sd_checkpoint_info
:
load_vae_dict
(
model
,
base_vae
)
print
(
"Restoring base VAE"
)
_load_vae_dict
(
model
,
base_vae
)
loaded_vae_file
=
None
delete_base_vae
()
...
...
@@ -148,9 +152,10 @@ def load_vae(model, vae_file=None):
if
vae_file
:
assert
os
.
path
.
isfile
(
vae_file
),
f
"VAE file doesn't exist: {vae_file}"
print
(
f
"Loading VAE weights from: {vae_file}"
)
store_base_vae
(
model
)
vae_ckpt
=
torch
.
load
(
vae_file
,
map_location
=
shared
.
weight_load_location
)
vae_dict_1
=
{
k
:
v
for
k
,
v
in
vae_ckpt
[
"state_dict"
]
.
items
()
if
k
[
0
:
4
]
!=
"loss"
and
k
not
in
vae_ignore_keys
}
load_vae_dict
(
model
,
vae_dict_1
)
_
load_vae_dict
(
model
,
vae_dict_1
)
# If vae used is not in dict, update it
# It will be removed on refresh though
...
...
@@ -158,30 +163,22 @@ def load_vae(model, vae_file=None):
if
vae_opt
not
in
vae_dict
:
vae_dict
[
vae_opt
]
=
vae_file
vae_list
.
append
(
vae_opt
)
elif
loaded_vae_file
:
restore_base_vae
(
model
)
loaded_vae_file
=
vae_file
"""
# Save current VAE to VAE settings, maybe? will it work?
if save_settings:
if vae_file is None:
vae_opt = "None"
# shared.opts.sd_vae = vae_opt
"""
first_load
=
False
# don't call this from outside
def
load_vae_dict
(
model
,
vae_dict_1
=
None
):
if
vae_dict_1
:
store_base_vae
(
model
)
model
.
first_stage_model
.
load_state_dict
(
vae_dict_1
)
else
:
restore_base_vae
()
def
_load_vae_dict
(
model
,
vae_dict_1
):
model
.
first_stage_model
.
load_state_dict
(
vae_dict_1
)
model
.
first_stage_model
.
to
(
devices
.
dtype_vae
)
def
clear_loaded_vae
():
global
loaded_vae_file
loaded_vae_file
=
None
def
reload_vae_weights
(
sd_model
=
None
,
vae_file
=
"auto"
):
from
modules
import
lowvram
,
devices
,
sd_hijack
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment