Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
d2ac95fa
Commit
d2ac95fa
authored
Jan 27, 2023
by
AUTOMATIC
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
remove the need to place configs near models
parent
7a14c8ab
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
360 additions
and
151 deletions
+360
-151
instruct-pix2pix.yaml
configs/instruct-pix2pix.yaml
+99
-0
v1-inpainting-inference.yaml
configs/v1-inpainting-inference.yaml
+19
-17
api.py
modules/api/api.py
+3
-2
devices.py
modules/devices.py
+8
-4
sd_hijack_inpainting.py
modules/sd_hijack_inpainting.py
+0
-9
sd_models.py
modules/sd_models.py
+113
-115
sd_models_config.py
modules/sd_models_config.py
+65
-0
shared.py
modules/shared.py
+4
-3
shared_items.py
modules/shared_items.py
+14
-1
timer.py
modules/timer.py
+35
-0
No files found.
configs/instruct-pix2pix.yaml
0 → 100644
View file @
d2ac95fa
# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
# See more details in LICENSE.
model
:
base_learning_rate
:
1.0e-04
target
:
modules.models.diffusion.ddpm_edit.LatentDiffusion
params
:
linear_start
:
0.00085
linear_end
:
0.0120
num_timesteps_cond
:
1
log_every_t
:
200
timesteps
:
1000
first_stage_key
:
edited
cond_stage_key
:
edit
# image_size: 64
# image_size: 32
image_size
:
16
channels
:
4
cond_stage_trainable
:
false
# Note: different from the one we trained before
conditioning_key
:
hybrid
monitor
:
val/loss_simple_ema
scale_factor
:
0.18215
use_ema
:
true
load_ema
:
true
scheduler_config
:
# 10000 warmup steps
target
:
ldm.lr_scheduler.LambdaLinearScheduler
params
:
warm_up_steps
:
[
0
]
cycle_lengths
:
[
10000000000000
]
# incredibly large number to prevent corner cases
f_start
:
[
1.e-6
]
f_max
:
[
1.
]
f_min
:
[
1.
]
unet_config
:
target
:
ldm.modules.diffusionmodules.openaimodel.UNetModel
params
:
image_size
:
32
# unused
in_channels
:
8
out_channels
:
4
model_channels
:
320
attention_resolutions
:
[
4
,
2
,
1
]
num_res_blocks
:
2
channel_mult
:
[
1
,
2
,
4
,
4
]
num_heads
:
8
use_spatial_transformer
:
True
transformer_depth
:
1
context_dim
:
768
use_checkpoint
:
True
legacy
:
False
first_stage_config
:
target
:
ldm.models.autoencoder.AutoencoderKL
params
:
embed_dim
:
4
monitor
:
val/rec_loss
ddconfig
:
double_z
:
true
z_channels
:
4
resolution
:
256
in_channels
:
3
out_ch
:
3
ch
:
128
ch_mult
:
-
1
-
2
-
4
-
4
num_res_blocks
:
2
attn_resolutions
:
[]
dropout
:
0.0
lossconfig
:
target
:
torch.nn.Identity
cond_stage_config
:
target
:
ldm.modules.encoders.modules.FrozenCLIPEmbedder
data
:
target
:
main.DataModuleFromConfig
params
:
batch_size
:
128
num_workers
:
1
wrap
:
false
validation
:
target
:
edit_dataset.EditDataset
params
:
path
:
data/clip-filtered-dataset
cache_dir
:
data/
cache_name
:
data_10k
split
:
val
min_text_sim
:
0.2
min_image_sim
:
0.75
min_direction_sim
:
0.2
max_samples_per_prompt
:
1
min_resize_res
:
512
max_resize_res
:
512
crop_res
:
512
output_as_edit
:
False
real_input
:
True
v2-inference-v
.yaml
→
configs/v1-inpainting-inference
.yaml
View file @
d2ac95fa
model
:
base_learning_rate
:
1.0e-4
target
:
ldm.models.diffusion.ddpm.LatentDiffusion
base_learning_rate
:
7.5e-05
target
:
ldm.models.diffusion.ddpm.Latent
Inpaint
Diffusion
params
:
parameterization
:
"
v"
linear_start
:
0.00085
linear_end
:
0.0120
num_timesteps_cond
:
1
...
...
@@ -12,29 +11,36 @@ model:
cond_stage_key
:
"
txt"
image_size
:
64
channels
:
4
cond_stage_trainable
:
false
conditioning_key
:
crossattn
cond_stage_trainable
:
false
# Note: different from the one we trained before
conditioning_key
:
hybrid
# important
monitor
:
val/loss_simple_ema
scale_factor
:
0.18215
use_ema
:
False
# we set this to false because this is an inference only config
finetune_keys
:
null
scheduler_config
:
# 10000 warmup steps
target
:
ldm.lr_scheduler.LambdaLinearScheduler
params
:
warm_up_steps
:
[
2500
]
# NOTE for resuming. use 10000 if starting from scratch
cycle_lengths
:
[
10000000000000
]
# incredibly large number to prevent corner cases
f_start
:
[
1.e-6
]
f_max
:
[
1.
]
f_min
:
[
1.
]
unet_config
:
target
:
ldm.modules.diffusionmodules.openaimodel.UNetModel
params
:
use_checkpoint
:
True
use_fp16
:
True
image_size
:
32
# unused
in_channels
:
4
in_channels
:
9
# 4 data + 4 downscaled image + 1 mask
out_channels
:
4
model_channels
:
320
attention_resolutions
:
[
4
,
2
,
1
]
num_res_blocks
:
2
channel_mult
:
[
1
,
2
,
4
,
4
]
num_head
_channels
:
64
# need to fix for flash-attn
num_head
s
:
8
use_spatial_transformer
:
True
use_linear_in_transformer
:
True
transformer_depth
:
1
context_dim
:
1024
context_dim
:
768
use_checkpoint
:
True
legacy
:
False
first_stage_config
:
...
...
@@ -43,7 +49,6 @@ model:
embed_dim
:
4
monitor
:
val/rec_loss
ddconfig
:
#attn_type: "vanilla-xformers"
double_z
:
true
z_channels
:
4
resolution
:
256
...
...
@@ -62,7 +67,4 @@ model:
target
:
torch.nn.Identity
cond_stage_config
:
target
:
ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params
:
freeze
:
True
layer
:
"
penultimate"
\ No newline at end of file
target
:
ldm.modules.encoders.modules.FrozenCLIPEmbedder
modules/api/api.py
View file @
d2ac95fa
...
...
@@ -18,7 +18,8 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
from
modules.textual_inversion.preprocess
import
preprocess
from
modules.hypernetworks.hypernetwork
import
create_hypernetwork
,
train_hypernetwork
from
PIL
import
PngImagePlugin
,
Image
from
modules.sd_models
import
checkpoints_list
,
find_checkpoint_config
from
modules.sd_models
import
checkpoints_list
from
modules.sd_models_config
import
find_checkpoint_config_near_filename
from
modules.realesrgan_model
import
get_realesrgan_models
from
modules
import
devices
from
typing
import
List
...
...
@@ -387,7 +388,7 @@ class Api:
]
def
get_sd_models
(
self
):
return
[{
"title"
:
x
.
title
,
"model_name"
:
x
.
model_name
,
"hash"
:
x
.
shorthash
,
"sha256"
:
x
.
sha256
,
"filename"
:
x
.
filename
,
"config"
:
find_checkpoint_config
(
x
)}
for
x
in
checkpoints_list
.
values
()]
return
[{
"title"
:
x
.
title
,
"model_name"
:
x
.
model_name
,
"hash"
:
x
.
shorthash
,
"sha256"
:
x
.
sha256
,
"filename"
:
x
.
filename
,
"config"
:
find_checkpoint_config
_near_filename
(
x
)}
for
x
in
checkpoints_list
.
values
()]
def
get_hypernetworks
(
self
):
return
[{
"name"
:
name
,
"path"
:
shared
.
hypernetworks
[
name
]}
for
name
in
shared
.
hypernetworks
]
...
...
modules/devices.py
View file @
d2ac95fa
...
...
@@ -34,14 +34,18 @@ def get_cuda_device_string():
return
"cuda"
def
get_optimal_device
():
def
get_optimal_device
_name
():
if
torch
.
cuda
.
is_available
():
return
torch
.
device
(
get_cuda_device_string
()
)
return
get_cuda_device_string
(
)
if
has_mps
():
return
torch
.
device
(
"mps"
)
return
"mps"
return
"cpu"
return
cpu
def
get_optimal_device
():
return
torch
.
device
(
get_optimal_device_name
())
def
get_device_for
(
task
):
...
...
modules/sd_hijack_inpainting.py
View file @
d2ac95fa
...
...
@@ -96,15 +96,6 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
return
x_prev
,
pred_x0
,
e_t
def
should_hijack_inpainting
(
checkpoint_info
):
from
modules
import
sd_models
ckpt_basename
=
os
.
path
.
basename
(
checkpoint_info
.
filename
)
.
lower
()
cfg_basename
=
os
.
path
.
basename
(
sd_models
.
find_checkpoint_config
(
checkpoint_info
))
.
lower
()
return
"inpainting"
in
ckpt_basename
and
not
"inpainting"
in
cfg_basename
def
do_inpainting_hijack
():
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
...
...
modules/sd_models.py
View file @
d2ac95fa
...
...
@@ -2,8 +2,6 @@ import collections
import
os.path
import
sys
import
gc
import
time
from
collections
import
namedtuple
import
torch
import
re
import
safetensors.torch
...
...
@@ -14,10 +12,10 @@ import ldm.modules.midas as midas
from
ldm.util
import
instantiate_from_config
from
modules
import
shared
,
modelloader
,
devices
,
script_callbacks
,
sd_vae
,
sd_disable_initialization
,
errors
,
hashes
from
modules
import
shared
,
modelloader
,
devices
,
script_callbacks
,
sd_vae
,
sd_disable_initialization
,
errors
,
hashes
,
sd_models_config
from
modules.paths
import
models_path
from
modules.sd_hijack_inpainting
import
do_inpainting_hijack
,
should_hijack_inpainting
from
modules.
sd_hijack_ip2p
import
should_hijack_ip2p
from
modules.sd_hijack_inpainting
import
do_inpainting_hijack
from
modules.
timer
import
Timer
model_dir
=
"Stable-diffusion"
model_path
=
os
.
path
.
abspath
(
os
.
path
.
join
(
models_path
,
model_dir
))
...
...
@@ -99,17 +97,6 @@ def checkpoint_tiles():
return
sorted
([
x
.
title
for
x
in
checkpoints_list
.
values
()],
key
=
alphanumeric_key
)
def
find_checkpoint_config
(
info
):
if
info
is
None
:
return
shared
.
cmd_opts
.
config
config
=
os
.
path
.
splitext
(
info
.
filename
)[
0
]
+
".yaml"
if
os
.
path
.
exists
(
config
):
return
config
return
shared
.
cmd_opts
.
config
def
list_models
():
checkpoints_list
.
clear
()
checkpoint_alisases
.
clear
()
...
...
@@ -215,9 +202,7 @@ def get_state_dict_from_checkpoint(pl_sd):
def
read_state_dict
(
checkpoint_file
,
print_global_state
=
False
,
map_location
=
None
):
_
,
extension
=
os
.
path
.
splitext
(
checkpoint_file
)
if
extension
.
lower
()
==
".safetensors"
:
device
=
map_location
or
shared
.
weight_load_location
if
device
is
None
:
device
=
devices
.
get_cuda_device_string
()
if
torch
.
cuda
.
is_available
()
else
"cpu"
device
=
map_location
or
shared
.
weight_load_location
or
devices
.
get_optimal_device_name
()
pl_sd
=
safetensors
.
torch
.
load_file
(
checkpoint_file
,
device
=
device
)
else
:
pl_sd
=
torch
.
load
(
checkpoint_file
,
map_location
=
map_location
or
shared
.
weight_load_location
)
...
...
@@ -229,60 +214,74 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
return
sd
def
load_model_weights
(
model
,
checkpoint_info
:
CheckpointInfo
):
def
get_checkpoint_state_dict
(
checkpoint_info
:
CheckpointInfo
,
timer
):
sd_model_hash
=
checkpoint_info
.
calculate_shorthash
()
timer
.
record
(
"calculate hash"
)
if
checkpoint_info
in
checkpoints_loaded
:
# use checkpoint cache
print
(
f
"Loading weights [{sd_model_hash}] from cache"
)
return
checkpoints_loaded
[
checkpoint_info
]
print
(
f
"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}"
)
res
=
read_state_dict
(
checkpoint_info
.
filename
)
timer
.
record
(
"load weights from disk"
)
return
res
def
load_model_weights
(
model
,
checkpoint_info
:
CheckpointInfo
,
state_dict
,
timer
):
title
=
checkpoint_info
.
title
sd_model_hash
=
checkpoint_info
.
calculate_shorthash
()
timer
.
record
(
"calculate hash"
)
if
checkpoint_info
.
title
!=
title
:
shared
.
opts
.
data
[
"sd_model_checkpoint"
]
=
checkpoint_info
.
title
cache_enabled
=
shared
.
opts
.
sd_checkpoint_cache
>
0
if
state_dict
is
None
:
state_dict
=
get_checkpoint_state_dict
(
checkpoint_info
,
timer
)
if
cache_enabled
and
checkpoint_info
in
checkpoints_loaded
:
# use checkpoint cache
print
(
f
"Loading weights [{sd_model_hash}] from cache"
)
model
.
load_state_dict
(
checkpoints_loaded
[
checkpoint_info
])
else
:
# load from file
print
(
f
"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}"
)
model
.
load_state_dict
(
state_dict
,
strict
=
False
)
del
state_dict
timer
.
record
(
"apply weights to model"
)
sd
=
read_state_dict
(
checkpoint_info
.
filename
)
model
.
load_state_dict
(
sd
,
strict
=
False
)
del
sd
if
cache_enabled
:
# cache newly loaded model
checkpoints_loaded
[
checkpoint_info
]
=
model
.
state_dict
()
.
copy
(
)
if
shared
.
opts
.
sd_checkpoint_cache
>
0
:
# cache newly loaded model
checkpoints_loaded
[
checkpoint_info
]
=
model
.
state_dict
()
.
copy
()
if
shared
.
cmd_opts
.
opt_channelslast
:
model
.
to
(
memory_format
=
torch
.
channels_last
)
timer
.
record
(
"apply channels_last"
)
if
shared
.
cmd_opts
.
opt_channelslast
:
model
.
to
(
memory_format
=
torch
.
channels_last
)
if
not
shared
.
cmd_opts
.
no_half
:
vae
=
model
.
first_stage_model
depth_model
=
getattr
(
model
,
'depth_model'
,
None
)
if
not
shared
.
cmd_opts
.
no_half
:
vae
=
model
.
first_stage_model
depth_model
=
getattr
(
model
,
'depth_model'
,
None
)
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
if
shared
.
cmd_opts
.
no_half_vae
:
model
.
first_stage_model
=
None
# with --upcast-sampling, don't convert the depth model weights to float16
if
shared
.
cmd_opts
.
upcast_sampling
and
depth_model
:
model
.
depth_model
=
None
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
if
shared
.
cmd_opts
.
no_half_vae
:
model
.
first_stage_model
=
None
# with --upcast-sampling, don't convert the depth model weights to float16
if
shared
.
cmd_opts
.
upcast_sampling
and
depth_model
:
model
.
depth_model
=
None
model
.
half
()
model
.
first_stage_model
=
vae
if
depth_model
:
model
.
depth_model
=
depth_model
model
.
half
()
model
.
first_stage_model
=
vae
if
depth_model
:
model
.
depth_model
=
depth_model
timer
.
record
(
"apply half()"
)
devices
.
dtype
=
torch
.
float32
if
shared
.
cmd_opts
.
no_half
else
torch
.
float16
devices
.
dtype_vae
=
torch
.
float32
if
shared
.
cmd_opts
.
no_half
or
shared
.
cmd_opts
.
no_half_vae
else
torch
.
float16
devices
.
dtype_unet
=
model
.
model
.
diffusion_model
.
dtype
devices
.
unet_needs_upcast
=
shared
.
cmd_opts
.
upcast_sampling
and
devices
.
dtype
==
torch
.
float16
and
devices
.
dtype_unet
==
torch
.
float16
devices
.
dtype
=
torch
.
float32
if
shared
.
cmd_opts
.
no_half
else
torch
.
float16
devices
.
dtype_vae
=
torch
.
float32
if
shared
.
cmd_opts
.
no_half
or
shared
.
cmd_opts
.
no_half_vae
else
torch
.
float16
devices
.
dtype_unet
=
model
.
model
.
diffusion_model
.
dtype
devices
.
unet_needs_upcast
=
shared
.
cmd_opts
.
upcast_sampling
and
devices
.
dtype
==
torch
.
float16
and
devices
.
dtype_unet
==
torch
.
float16
model
.
first_stage_model
.
to
(
devices
.
dtype_vae
)
model
.
first_stage_model
.
to
(
devices
.
dtype_vae
)
timer
.
record
(
"apply dtype to VAE"
)
# clean up cache if limit is reached
if
cache_enabled
:
while
len
(
checkpoints_loaded
)
>
shared
.
opts
.
sd_checkpoint_cache
+
1
:
# we need to count the current model
checkpoints_loaded
.
popitem
(
last
=
False
)
# LRU
while
len
(
checkpoints_loaded
)
>
shared
.
opts
.
sd_checkpoint_cache
:
checkpoints_loaded
.
popitem
(
last
=
False
)
model
.
sd_model_hash
=
sd_model_hash
model
.
sd_model_checkpoint
=
checkpoint_info
.
filename
...
...
@@ -295,6 +294,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo):
sd_vae
.
clear_loaded_vae
()
vae_file
,
vae_source
=
sd_vae
.
resolve_vae
(
checkpoint_info
.
filename
)
sd_vae
.
load_vae
(
model
,
vae_file
,
vae_source
)
timer
.
record
(
"load VAE"
)
def
enable_midas_autodownload
():
...
...
@@ -340,24 +340,20 @@ def enable_midas_autodownload():
midas
.
api
.
load_model
=
load_model_wrapper
class
Timer
:
def
__init__
(
self
):
self
.
start
=
time
.
time
()
def
repair_config
(
sd_config
):
def
elapsed
(
self
):
end
=
time
.
time
()
res
=
end
-
self
.
start
self
.
start
=
end
return
res
if
not
hasattr
(
sd_config
.
model
.
params
,
"use_ema"
):
sd_config
.
model
.
params
.
use_ema
=
False
if
shared
.
cmd_opts
.
no_half
:
sd_config
.
model
.
params
.
unet_config
.
params
.
use_fp16
=
False
elif
shared
.
cmd_opts
.
upcast_sampling
:
sd_config
.
model
.
params
.
unet_config
.
params
.
use_fp16
=
True
def
load_model
(
checkpoint_info
=
None
):
def
load_model
(
checkpoint_info
=
None
,
already_loaded_state_dict
=
None
,
time_taken_to_load_state_dict
=
None
):
from
modules
import
lowvram
,
sd_hijack
checkpoint_info
=
checkpoint_info
or
select_checkpoint
()
checkpoint_config
=
find_checkpoint_config
(
checkpoint_info
)
if
checkpoint_config
!=
shared
.
cmd_opts
.
config
:
print
(
f
"Loading config from: {checkpoint_config}"
)
if
shared
.
sd_model
:
sd_hijack
.
model_hijack
.
undo_hijack
(
shared
.
sd_model
)
...
...
@@ -365,38 +361,27 @@ def load_model(checkpoint_info=None):
gc
.
collect
()
devices
.
torch_gc
()
sd_config
=
OmegaConf
.
load
(
checkpoint_config
)
if
should_hijack_inpainting
(
checkpoint_info
):
# Hardcoded config for now...
sd_config
.
model
.
target
=
"ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
sd_config
.
model
.
params
.
conditioning_key
=
"hybrid"
sd_config
.
model
.
params
.
unet_config
.
params
.
in_channels
=
9
sd_config
.
model
.
params
.
finetune_keys
=
None
if
should_hijack_ip2p
(
checkpoint_info
):
sd_config
.
model
.
target
=
"modules.models.diffusion.ddpm_edit.LatentDiffusion"
sd_config
.
model
.
params
.
conditioning_key
=
"hybrid"
sd_config
.
model
.
params
.
first_stage_key
=
"edited"
sd_config
.
model
.
params
.
cond_stage_key
=
"edit"
sd_config
.
model
.
params
.
image_size
=
16
sd_config
.
model
.
params
.
unet_config
.
params
.
in_channels
=
8
sd_config
.
model
.
params
.
unet_config
.
params
.
out_channels
=
4
do_inpainting_hijack
()
if
not
hasattr
(
sd_config
.
model
.
params
,
"use_ema"
):
sd_config
.
model
.
params
.
use_ema
=
False
timer
=
Timer
()
do_inpainting_hijack
()
if
already_loaded_state_dict
is
not
None
:
state_dict
=
already_loaded_state_dict
else
:
state_dict
=
get_checkpoint_state_dict
(
checkpoint_info
,
timer
)
if
shared
.
cmd_opts
.
no_half
:
sd_config
.
model
.
params
.
unet_config
.
params
.
use_fp16
=
False
elif
shared
.
cmd_opts
.
upcast_sampling
:
sd_config
.
model
.
params
.
unet_config
.
params
.
use_fp16
=
True
checkpoint_config
=
sd_models_config
.
find_checkpoint_config
(
state_dict
,
checkpoint_info
)
timer
=
Timer
(
)
timer
.
record
(
"find config"
)
sd_model
=
None
sd_config
=
OmegaConf
.
load
(
checkpoint_config
)
repair_config
(
sd_config
)
timer
.
record
(
"load config"
)
print
(
f
"Creating model from config: {checkpoint_config}"
)
sd_model
=
None
try
:
with
sd_disable_initialization
.
DisableInitialization
():
sd_model
=
instantiate_from_config
(
sd_config
.
model
)
...
...
@@ -407,29 +392,35 @@ def load_model(checkpoint_info=None):
print
(
'Failed to create model quickly; will retry using slow method.'
,
file
=
sys
.
stderr
)
sd_model
=
instantiate_from_config
(
sd_config
.
model
)
elapsed_create
=
timer
.
elapsed
()
sd_model
.
used_config
=
checkpoint_config
load_model_weights
(
sd_model
,
checkpoint_info
)
timer
.
record
(
"create model"
)
elapsed_load_weights
=
timer
.
elapsed
(
)
load_model_weights
(
sd_model
,
checkpoint_info
,
state_dict
,
timer
)
if
shared
.
cmd_opts
.
lowvram
or
shared
.
cmd_opts
.
medvram
:
lowvram
.
setup_for_low_vram
(
sd_model
,
shared
.
cmd_opts
.
medvram
)
else
:
sd_model
.
to
(
shared
.
device
)
timer
.
record
(
"move model to device"
)
sd_hijack
.
model_hijack
.
hijack
(
sd_model
)
timer
.
record
(
"hijack"
)
sd_model
.
eval
()
shared
.
sd_model
=
sd_model
sd_hijack
.
model_hijack
.
embedding_db
.
load_textual_inversion_embeddings
(
force_reload
=
True
)
# Reload embeddings after model load as they may or may not fit the model
timer
.
record
(
"load textual inversion embeddings"
)
script_callbacks
.
model_loaded_callback
(
sd_model
)
elapsed_the_rest
=
timer
.
elapsed
(
)
timer
.
record
(
"scripts callbacks"
)
print
(
f
"Model loaded in {
elapsed_create + elapsed_load_weights + elapsed_the_rest:.1f}s ({elapsed_create:.1f}s create model, {elapsed_load_weights:.1f}s load weights)
."
)
print
(
f
"Model loaded in {
timer.summary()}
."
)
return
sd_model
...
...
@@ -440,6 +431,7 @@ def reload_model_weights(sd_model=None, info=None):
if
not
sd_model
:
sd_model
=
shared
.
sd_model
if
sd_model
is
None
:
# previous model load failed
current_checkpoint_info
=
None
else
:
...
...
@@ -447,14 +439,6 @@ def reload_model_weights(sd_model=None, info=None):
if
sd_model
.
sd_model_checkpoint
==
checkpoint_info
.
filename
:
return
checkpoint_config
=
find_checkpoint_config
(
current_checkpoint_info
)
if
current_checkpoint_info
is
None
or
checkpoint_config
!=
find_checkpoint_config
(
checkpoint_info
)
or
should_hijack_inpainting
(
checkpoint_info
)
!=
should_hijack_inpainting
(
sd_model
.
sd_checkpoint_info
)
or
should_hijack_ip2p
(
checkpoint_info
)
!=
should_hijack_ip2p
(
sd_model
.
sd_checkpoint_info
):
del
sd_model
checkpoints_loaded
.
clear
()
load_model
(
checkpoint_info
)
return
shared
.
sd_model
if
shared
.
cmd_opts
.
lowvram
or
shared
.
cmd_opts
.
medvram
:
lowvram
.
send_everything_to_cpu
()
else
:
...
...
@@ -464,21 +448,35 @@ def reload_model_weights(sd_model=None, info=None):
timer
=
Timer
()
state_dict
=
get_checkpoint_state_dict
(
checkpoint_info
,
timer
)
checkpoint_config
=
sd_models_config
.
find_checkpoint_config
(
state_dict
,
checkpoint_info
)
timer
.
record
(
"find config"
)
if
sd_model
is
None
or
checkpoint_config
!=
sd_model
.
used_config
:
del
sd_model
checkpoints_loaded
.
clear
()
load_model
(
checkpoint_info
,
already_loaded_state_dict
=
state_dict
,
time_taken_to_load_state_dict
=
timer
.
records
[
"load weights from disk"
])
return
shared
.
sd_model
try
:
load_model_weights
(
sd_model
,
checkpoint_info
)
load_model_weights
(
sd_model
,
checkpoint_info
,
state_dict
,
timer
)
except
Exception
as
e
:
print
(
"Failed to load checkpoint, restoring previous"
)
load_model_weights
(
sd_model
,
current_checkpoint_info
)
load_model_weights
(
sd_model
,
current_checkpoint_info
,
None
,
timer
)
raise
finally
:
sd_hijack
.
model_hijack
.
hijack
(
sd_model
)
timer
.
record
(
"hijack"
)
script_callbacks
.
model_loaded_callback
(
sd_model
)
timer
.
record
(
"script callbacks"
)
if
not
shared
.
cmd_opts
.
lowvram
and
not
shared
.
cmd_opts
.
medvram
:
sd_model
.
to
(
devices
.
device
)
timer
.
record
(
"move model to device"
)
elapsed
=
timer
.
elapsed
()
print
(
f
"Weights loaded in {elapsed:.1f}s."
)
print
(
f
"Weights loaded in {timer.summary()}."
)
return
sd_model
modules/sd_models_config.py
0 → 100644
View file @
d2ac95fa
import
re
import
os
from
modules
import
shared
,
paths
sd_configs_path
=
shared
.
sd_configs_path
sd_repo_configs_path
=
os
.
path
.
join
(
paths
.
paths
[
'Stable Diffusion'
],
"configs"
,
"stable-diffusion"
)
config_default
=
shared
.
sd_default_config
config_sd2
=
os
.
path
.
join
(
sd_repo_configs_path
,
"v2-inference.yaml"
)
config_sd2v
=
os
.
path
.
join
(
sd_repo_configs_path
,
"v2-inference-v.yaml"
)
config_inpainting
=
os
.
path
.
join
(
sd_configs_path
,
"v1-inpainting-inference.yaml"
)
config_instruct_pix2pix
=
os
.
path
.
join
(
sd_configs_path
,
"instruct-pix2pix.yaml"
)
config_alt_diffusion
=
os
.
path
.
join
(
sd_configs_path
,
"alt-diffusion-inference.yaml"
)
re_parametrization_v
=
re
.
compile
(
r'-v\b'
)
def
guess_model_config_from_state_dict
(
sd
,
filename
):
fn
=
os
.
path
.
basename
(
filename
)
sd2_cond_proj_weight
=
sd
.
get
(
'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
,
None
)
diffusion_model_input
=
sd
.
get
(
'model.diffusion_model.input_blocks.0.0.weight'
,
None
)
roberta_weight
=
sd
.
get
(
'cond_stage_model.roberta.embeddings.word_embeddings.weight'
,
None
)
if
sd2_cond_proj_weight
is
not
None
and
sd2_cond_proj_weight
.
shape
[
1
]
==
1024
:
if
re
.
search
(
re_parametrization_v
,
fn
)
or
"v2-1_768"
in
fn
:
return
config_sd2v
else
:
return
config_sd2
if
diffusion_model_input
is
not
None
:
if
diffusion_model_input
.
shape
[
1
]
==
9
:
return
config_inpainting
if
diffusion_model_input
.
shape
[
1
]
==
8
:
return
config_instruct_pix2pix
if
roberta_weight
is
not
None
:
return
config_alt_diffusion
return
config_default
def
find_checkpoint_config
(
state_dict
,
info
):
if
info
is
None
:
return
guess_model_config_from_state_dict
(
state_dict
,
""
)
config
=
find_checkpoint_config_near_filename
(
info
)
if
config
is
not
None
:
return
config
return
guess_model_config_from_state_dict
(
state_dict
,
info
.
filename
)
def
find_checkpoint_config_near_filename
(
info
):
if
info
is
None
:
return
None
config
=
os
.
path
.
splitext
(
info
.
filename
)[
0
]
+
".yaml"
if
os
.
path
.
exists
(
config
):
return
config
return
None
modules/shared.py
View file @
d2ac95fa
...
...
@@ -13,13 +13,14 @@ import modules.interrogate
import
modules.memmon
import
modules.styles
import
modules.devices
as
devices
from
modules
import
localization
,
sd_vae
,
extensions
,
script_loading
,
errors
,
ui_components
,
shared_items
from
modules
import
localization
,
extensions
,
script_loading
,
errors
,
ui_components
,
shared_items
from
modules.paths
import
models_path
,
script_path
demo
=
None
sd_default_config
=
os
.
path
.
join
(
script_path
,
"configs/v1-inference.yaml"
)
sd_configs_path
=
os
.
path
.
join
(
script_path
,
"configs"
)
sd_default_config
=
os
.
path
.
join
(
sd_configs_path
,
"v1-inference.yaml"
)
sd_model_file
=
os
.
path
.
join
(
script_path
,
'model.ckpt'
)
default_sd_model_file
=
sd_model_file
...
...
@@ -391,7 +392,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint"
:
OptionInfo
(
None
,
"Stable Diffusion checkpoint"
,
gr
.
Dropdown
,
lambda
:
{
"choices"
:
list_checkpoint_tiles
()},
refresh
=
refresh_checkpoints
),
"sd_checkpoint_cache"
:
OptionInfo
(
0
,
"Checkpoints to cache in RAM"
,
gr
.
Slider
,
{
"minimum"
:
0
,
"maximum"
:
10
,
"step"
:
1
}),
"sd_vae_checkpoint_cache"
:
OptionInfo
(
0
,
"VAE Checkpoints to cache in RAM"
,
gr
.
Slider
,
{
"minimum"
:
0
,
"maximum"
:
10
,
"step"
:
1
}),
"sd_vae"
:
OptionInfo
(
"Automatic"
,
"SD VAE"
,
gr
.
Dropdown
,
lambda
:
{
"choices"
:
[
"Automatic"
,
"None"
]
+
list
(
sd_vae
.
vae_dict
)},
refresh
=
sd_vae
.
refresh_vae_list
),
"sd_vae"
:
OptionInfo
(
"Automatic"
,
"SD VAE"
,
gr
.
Dropdown
,
lambda
:
{
"choices"
:
shared_items
.
sd_vae_items
()},
refresh
=
shared_items
.
refresh_vae_list
),
"sd_vae_as_default"
:
OptionInfo
(
True
,
"Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"
),
"inpainting_mask_weight"
:
OptionInfo
(
1.0
,
"Inpainting conditioning mask strength"
,
gr
.
Slider
,
{
"minimum"
:
0.0
,
"maximum"
:
1.0
,
"step"
:
0.01
}),
"initial_noise_multiplier"
:
OptionInfo
(
1.0
,
"Noise multiplier for img2img"
,
gr
.
Slider
,
{
"minimum"
:
0.5
,
"maximum"
:
1.5
,
"step"
:
0.01
}),
...
...
modules/shared_items.py
View file @
d2ac95fa
...
...
@@ -4,7 +4,20 @@ def realesrgan_models_names():
import
modules.realesrgan_model
return
[
x
.
name
for
x
in
modules
.
realesrgan_model
.
get_realesrgan_models
(
None
)]
def
postprocessing_scripts
():
import
modules.scripts
return
modules
.
scripts
.
scripts_postproc
.
scripts
\ No newline at end of file
return
modules
.
scripts
.
scripts_postproc
.
scripts
def
sd_vae_items
():
import
modules.sd_vae
return
[
"Automatic"
,
"None"
]
+
list
(
modules
.
sd_vae
.
vae_dict
)
def
refresh_vae_list
():
import
modules.sd_vae
return
modules
.
sd_vae
.
refresh_vae_list
modules/timer.py
0 → 100644
View file @
d2ac95fa
import
time
class
Timer
:
def
__init__
(
self
):
self
.
start
=
time
.
time
()
self
.
records
=
{}
self
.
total
=
0
def
elapsed
(
self
):
end
=
time
.
time
()
res
=
end
-
self
.
start
self
.
start
=
end
return
res
def
record
(
self
,
category
,
extra_time
=
0
):
e
=
self
.
elapsed
()
if
category
not
in
self
.
records
:
self
.
records
[
category
]
=
0
self
.
records
[
category
]
+=
e
+
extra_time
self
.
total
+=
e
+
extra_time
def
summary
(
self
):
res
=
f
"{self.total:.1f}s"
additions
=
[
x
for
x
in
self
.
records
.
items
()
if
x
[
1
]
>=
0.1
]
if
not
additions
:
return
res
res
+=
" ("
res
+=
", "
.
join
([
f
"{category}: {time_taken:.1f}s"
for
category
,
time_taken
in
additions
])
res
+=
")"
return
res
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment