Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
8d8a05a3
Commit
8d8a05a3
authored
Jan 04, 2023
by
AUTOMATIC
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
find configs for models at runtime rather than when starting
parent
02d7abf5
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
14 deletions
+22
-14
sd_hijack_inpainting.py
modules/sd_hijack_inpainting.py
+4
-1
sd_models.py
modules/sd_models.py
+18
-13
No files found.
modules/sd_hijack_inpainting.py
View file @
8d8a05a3
...
@@ -97,8 +97,11 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
...
@@ -97,8 +97,11 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
def
should_hijack_inpainting
(
checkpoint_info
):
def
should_hijack_inpainting
(
checkpoint_info
):
from
modules
import
sd_models
ckpt_basename
=
os
.
path
.
basename
(
checkpoint_info
.
filename
)
.
lower
()
ckpt_basename
=
os
.
path
.
basename
(
checkpoint_info
.
filename
)
.
lower
()
cfg_basename
=
os
.
path
.
basename
(
checkpoint_info
.
config
)
.
lower
()
cfg_basename
=
os
.
path
.
basename
(
sd_models
.
find_checkpoint_config
(
checkpoint_info
))
.
lower
()
return
"inpainting"
in
ckpt_basename
and
not
"inpainting"
in
cfg_basename
return
"inpainting"
in
ckpt_basename
and
not
"inpainting"
in
cfg_basename
...
...
modules/sd_models.py
View file @
8d8a05a3
...
@@ -20,7 +20,7 @@ from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inp
...
@@ -20,7 +20,7 @@ from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inp
model_dir
=
"Stable-diffusion"
model_dir
=
"Stable-diffusion"
model_path
=
os
.
path
.
abspath
(
os
.
path
.
join
(
models_path
,
model_dir
))
model_path
=
os
.
path
.
abspath
(
os
.
path
.
join
(
models_path
,
model_dir
))
CheckpointInfo
=
namedtuple
(
"CheckpointInfo"
,
[
'filename'
,
'title'
,
'hash'
,
'model_name'
,
'config'
])
CheckpointInfo
=
namedtuple
(
"CheckpointInfo"
,
[
'filename'
,
'title'
,
'hash'
,
'model_name'
])
checkpoints_list
=
{}
checkpoints_list
=
{}
checkpoints_loaded
=
collections
.
OrderedDict
()
checkpoints_loaded
=
collections
.
OrderedDict
()
...
@@ -48,6 +48,14 @@ def checkpoint_tiles():
...
@@ -48,6 +48,14 @@ def checkpoint_tiles():
return
sorted
([
x
.
title
for
x
in
checkpoints_list
.
values
()],
key
=
alphanumeric_key
)
return
sorted
([
x
.
title
for
x
in
checkpoints_list
.
values
()],
key
=
alphanumeric_key
)
def
find_checkpoint_config
(
info
):
config
=
os
.
path
.
splitext
(
info
.
filename
)[
0
]
+
".yaml"
if
os
.
path
.
exists
(
config
):
return
config
return
shared
.
cmd_opts
.
config
def
list_models
():
def
list_models
():
checkpoints_list
.
clear
()
checkpoints_list
.
clear
()
model_list
=
modelloader
.
load_models
(
model_path
=
model_path
,
command_path
=
shared
.
cmd_opts
.
ckpt_dir
,
ext_filter
=
[
".ckpt"
,
".safetensors"
])
model_list
=
modelloader
.
load_models
(
model_path
=
model_path
,
command_path
=
shared
.
cmd_opts
.
ckpt_dir
,
ext_filter
=
[
".ckpt"
,
".safetensors"
])
...
@@ -73,7 +81,7 @@ def list_models():
...
@@ -73,7 +81,7 @@ def list_models():
if
os
.
path
.
exists
(
cmd_ckpt
):
if
os
.
path
.
exists
(
cmd_ckpt
):
h
=
model_hash
(
cmd_ckpt
)
h
=
model_hash
(
cmd_ckpt
)
title
,
short_model_name
=
modeltitle
(
cmd_ckpt
,
h
)
title
,
short_model_name
=
modeltitle
(
cmd_ckpt
,
h
)
checkpoints_list
[
title
]
=
CheckpointInfo
(
cmd_ckpt
,
title
,
h
,
short_model_name
,
shared
.
cmd_opts
.
config
)
checkpoints_list
[
title
]
=
CheckpointInfo
(
cmd_ckpt
,
title
,
h
,
short_model_name
)
shared
.
opts
.
data
[
'sd_model_checkpoint'
]
=
title
shared
.
opts
.
data
[
'sd_model_checkpoint'
]
=
title
elif
cmd_ckpt
is
not
None
and
cmd_ckpt
!=
shared
.
default_sd_model_file
:
elif
cmd_ckpt
is
not
None
and
cmd_ckpt
!=
shared
.
default_sd_model_file
:
print
(
f
"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}"
,
file
=
sys
.
stderr
)
print
(
f
"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}"
,
file
=
sys
.
stderr
)
...
@@ -81,12 +89,7 @@ def list_models():
...
@@ -81,12 +89,7 @@ def list_models():
h
=
model_hash
(
filename
)
h
=
model_hash
(
filename
)
title
,
short_model_name
=
modeltitle
(
filename
,
h
)
title
,
short_model_name
=
modeltitle
(
filename
,
h
)
basename
,
_
=
os
.
path
.
splitext
(
filename
)
checkpoints_list
[
title
]
=
CheckpointInfo
(
filename
,
title
,
h
,
short_model_name
)
config
=
basename
+
".yaml"
if
not
os
.
path
.
exists
(
config
):
config
=
shared
.
cmd_opts
.
config
checkpoints_list
[
title
]
=
CheckpointInfo
(
filename
,
title
,
h
,
short_model_name
,
config
)
def
get_closet_checkpoint_match
(
searchString
):
def
get_closet_checkpoint_match
(
searchString
):
...
@@ -282,9 +285,10 @@ def enable_midas_autodownload():
...
@@ -282,9 +285,10 @@ def enable_midas_autodownload():
def
load_model
(
checkpoint_info
=
None
):
def
load_model
(
checkpoint_info
=
None
):
from
modules
import
lowvram
,
sd_hijack
from
modules
import
lowvram
,
sd_hijack
checkpoint_info
=
checkpoint_info
or
select_checkpoint
()
checkpoint_info
=
checkpoint_info
or
select_checkpoint
()
checkpoint_config
=
find_checkpoint_config
(
checkpoint_info
)
if
checkpoint_
info
.
config
!=
shared
.
cmd_opts
.
config
:
if
checkpoint_config
!=
shared
.
cmd_opts
.
config
:
print
(
f
"Loading config from: {checkpoint_
info.
config}"
)
print
(
f
"Loading config from: {checkpoint_config}"
)
if
shared
.
sd_model
:
if
shared
.
sd_model
:
sd_hijack
.
model_hijack
.
undo_hijack
(
shared
.
sd_model
)
sd_hijack
.
model_hijack
.
undo_hijack
(
shared
.
sd_model
)
...
@@ -292,7 +296,7 @@ def load_model(checkpoint_info=None):
...
@@ -292,7 +296,7 @@ def load_model(checkpoint_info=None):
gc
.
collect
()
gc
.
collect
()
devices
.
torch_gc
()
devices
.
torch_gc
()
sd_config
=
OmegaConf
.
load
(
checkpoint_
info
.
config
)
sd_config
=
OmegaConf
.
load
(
checkpoint_config
)
if
should_hijack_inpainting
(
checkpoint_info
):
if
should_hijack_inpainting
(
checkpoint_info
):
# Hardcoded config for now...
# Hardcoded config for now...
...
@@ -302,7 +306,7 @@ def load_model(checkpoint_info=None):
...
@@ -302,7 +306,7 @@ def load_model(checkpoint_info=None):
sd_config
.
model
.
params
.
finetune_keys
=
None
sd_config
.
model
.
params
.
finetune_keys
=
None
# Create a "fake" config with a different name so that we know to unload it when switching models.
# Create a "fake" config with a different name so that we know to unload it when switching models.
checkpoint_info
=
checkpoint_info
.
_replace
(
config
=
checkpoint_
info
.
config
.
replace
(
".yaml"
,
"-inpainting.yaml"
))
checkpoint_info
=
checkpoint_info
.
_replace
(
config
=
checkpoint_config
.
replace
(
".yaml"
,
"-inpainting.yaml"
))
if
not
hasattr
(
sd_config
.
model
.
params
,
"use_ema"
):
if
not
hasattr
(
sd_config
.
model
.
params
,
"use_ema"
):
sd_config
.
model
.
params
.
use_ema
=
False
sd_config
.
model
.
params
.
use_ema
=
False
...
@@ -343,11 +347,12 @@ def reload_model_weights(sd_model=None, info=None):
...
@@ -343,11 +347,12 @@ def reload_model_weights(sd_model=None, info=None):
sd_model
=
shared
.
sd_model
sd_model
=
shared
.
sd_model
current_checkpoint_info
=
sd_model
.
sd_checkpoint_info
current_checkpoint_info
=
sd_model
.
sd_checkpoint_info
checkpoint_config
=
find_checkpoint_config
(
current_checkpoint_info
)
if
sd_model
.
sd_model_checkpoint
==
checkpoint_info
.
filename
:
if
sd_model
.
sd_model_checkpoint
==
checkpoint_info
.
filename
:
return
return
if
sd_model
.
sd_checkpoint_info
.
config
!=
checkpoint_info
.
config
or
should_hijack_inpainting
(
checkpoint_info
)
!=
should_hijack_inpainting
(
sd_model
.
sd_checkpoint_info
):
if
checkpoint_config
!=
find_checkpoint_config
(
checkpoint_info
)
or
should_hijack_inpainting
(
checkpoint_info
)
!=
should_hijack_inpainting
(
sd_model
.
sd_checkpoint_info
):
del
sd_model
del
sd_model
checkpoints_loaded
.
clear
()
checkpoints_loaded
.
clear
()
load_model
(
checkpoint_info
)
load_model
(
checkpoint_info
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment