Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
d04e3e92
Commit
d04e3e92
authored
Jan 28, 2023
by
AUTOMATIC
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
automatically detect v-parameterization for SD2 checkpoints
parent
4aa7f5b5
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
48 additions
and
5 deletions
+48
-5
sd_hijack.py
modules/sd_hijack.py
+2
-0
sd_models_config.py
modules/sd_models_config.py
+46
-5
No files found.
modules/sd_hijack.py
View file @
d04e3e92
...
@@ -131,6 +131,8 @@ class StableDiffusionModelHijack:
...
@@ -131,6 +131,8 @@ class StableDiffusionModelHijack:
m
.
cond_stage_model
.
wrapped
.
model
.
token_embedding
=
m
.
cond_stage_model
.
wrapped
.
model
.
token_embedding
.
wrapped
m
.
cond_stage_model
.
wrapped
.
model
.
token_embedding
=
m
.
cond_stage_model
.
wrapped
.
model
.
token_embedding
.
wrapped
m
.
cond_stage_model
=
m
.
cond_stage_model
.
wrapped
m
.
cond_stage_model
=
m
.
cond_stage_model
.
wrapped
undo_optimizations
()
self
.
apply_circular
(
False
)
self
.
apply_circular
(
False
)
self
.
layers
=
None
self
.
layers
=
None
self
.
clip
=
None
self
.
clip
=
None
...
...
modules/sd_models_config.py
View file @
d04e3e92
import
re
import
re
import
os
import
os
from
modules
import
shared
,
paths
import
torch
from
modules
import
shared
,
paths
,
sd_disable_initialization
sd_configs_path
=
shared
.
sd_configs_path
sd_configs_path
=
shared
.
sd_configs_path
sd_repo_configs_path
=
os
.
path
.
join
(
paths
.
paths
[
'Stable Diffusion'
],
"configs"
,
"stable-diffusion"
)
sd_repo_configs_path
=
os
.
path
.
join
(
paths
.
paths
[
'Stable Diffusion'
],
"configs"
,
"stable-diffusion"
)
...
@@ -16,12 +18,51 @@ config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml"
...
@@ -16,12 +18,51 @@ config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml"
config_instruct_pix2pix
=
os
.
path
.
join
(
sd_configs_path
,
"instruct-pix2pix.yaml"
)
config_instruct_pix2pix
=
os
.
path
.
join
(
sd_configs_path
,
"instruct-pix2pix.yaml"
)
config_alt_diffusion
=
os
.
path
.
join
(
sd_configs_path
,
"alt-diffusion-inference.yaml"
)
config_alt_diffusion
=
os
.
path
.
join
(
sd_configs_path
,
"alt-diffusion-inference.yaml"
)
re_parametrization_v
=
re
.
compile
(
r'-v\b'
)
def
is_using_v_parameterization_for_sd2
(
state_dict
):
"""
Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome.
"""
def
guess_model_config_from_state_dict
(
sd
,
filename
):
import
ldm.modules.diffusionmodules.openaimodel
fn
=
os
.
path
.
basename
(
filename
)
from
modules
import
devices
device
=
devices
.
cpu
with
sd_disable_initialization
.
DisableInitialization
():
unet
=
ldm
.
modules
.
diffusionmodules
.
openaimodel
.
UNetModel
(
use_checkpoint
=
True
,
use_fp16
=
False
,
image_size
=
32
,
in_channels
=
4
,
out_channels
=
4
,
model_channels
=
320
,
attention_resolutions
=
[
4
,
2
,
1
],
num_res_blocks
=
2
,
channel_mult
=
[
1
,
2
,
4
,
4
],
num_head_channels
=
64
,
use_spatial_transformer
=
True
,
use_linear_in_transformer
=
True
,
transformer_depth
=
1
,
context_dim
=
1024
,
legacy
=
False
)
unet
.
eval
()
with
torch
.
no_grad
():
unet_sd
=
{
k
.
replace
(
"model.diffusion_model."
,
""
):
v
for
k
,
v
in
state_dict
.
items
()
if
"model.diffusion_model."
in
k
}
unet
.
load_state_dict
(
unet_sd
,
strict
=
True
)
unet
.
to
(
device
=
device
,
dtype
=
torch
.
float
)
test_cond
=
torch
.
ones
((
1
,
2
,
1024
),
device
=
device
)
*
0.5
x_test
=
torch
.
ones
((
1
,
4
,
8
,
8
),
device
=
device
)
*
0.5
out
=
(
unet
(
x_test
,
torch
.
asarray
([
999
],
device
=
device
),
context
=
test_cond
)
-
x_test
)
.
mean
()
.
item
()
return
out
<
-
1
def
guess_model_config_from_state_dict
(
sd
,
filename
):
sd2_cond_proj_weight
=
sd
.
get
(
'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
,
None
)
sd2_cond_proj_weight
=
sd
.
get
(
'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
,
None
)
diffusion_model_input
=
sd
.
get
(
'model.diffusion_model.input_blocks.0.0.weight'
,
None
)
diffusion_model_input
=
sd
.
get
(
'model.diffusion_model.input_blocks.0.0.weight'
,
None
)
...
@@ -31,7 +72,7 @@ def guess_model_config_from_state_dict(sd, filename):
...
@@ -31,7 +72,7 @@ def guess_model_config_from_state_dict(sd, filename):
if
sd2_cond_proj_weight
is
not
None
and
sd2_cond_proj_weight
.
shape
[
1
]
==
1024
:
if
sd2_cond_proj_weight
is
not
None
and
sd2_cond_proj_weight
.
shape
[
1
]
==
1024
:
if
diffusion_model_input
.
shape
[
1
]
==
9
:
if
diffusion_model_input
.
shape
[
1
]
==
9
:
return
config_sd2_inpainting
return
config_sd2_inpainting
elif
re
.
search
(
re_parametrization_v
,
fn
):
elif
is_using_v_parameterization_for_sd2
(
sd
):
return
config_sd2v
return
config_sd2v
else
:
else
:
return
config_sd2
return
config_sd2
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment