Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
50fb20ce
Commit
50fb20ce
authored
Jan 10, 2023
by
AUTOMATIC
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'disable_initialization'
parents
a0ef416a
0f8603a5
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
133 additions
and
7 deletions
+133
-7
modelloader.py
modules/modelloader.py
+3
-1
sd_disable_initialization.py
modules/sd_disable_initialization.py
+95
-0
sd_models.py
modules/sd_models.py
+35
-6
No files found.
modules/modelloader.py
View file @
50fb20ce
...
@@ -10,7 +10,7 @@ from modules.upscaler import Upscaler
...
@@ -10,7 +10,7 @@ from modules.upscaler import Upscaler
from
modules.paths
import
script_path
,
models_path
from
modules.paths
import
script_path
,
models_path
def
load_models
(
model_path
:
str
,
model_url
:
str
=
None
,
command_path
:
str
=
None
,
ext_filter
=
None
,
download_name
=
None
)
->
list
:
def
load_models
(
model_path
:
str
,
model_url
:
str
=
None
,
command_path
:
str
=
None
,
ext_filter
=
None
,
download_name
=
None
,
ext_blacklist
=
None
)
->
list
:
"""
"""
A one-and done loader to try finding the desired models in specified directories.
A one-and done loader to try finding the desired models in specified directories.
...
@@ -45,6 +45,8 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
...
@@ -45,6 +45,8 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
full_path
=
file
full_path
=
file
if
os
.
path
.
isdir
(
full_path
):
if
os
.
path
.
isdir
(
full_path
):
continue
continue
if
ext_blacklist
is
not
None
and
any
([
full_path
.
endswith
(
x
)
for
x
in
ext_blacklist
]):
continue
if
len
(
ext_filter
)
!=
0
:
if
len
(
ext_filter
)
!=
0
:
model_name
,
extension
=
os
.
path
.
splitext
(
file
)
model_name
,
extension
=
os
.
path
.
splitext
(
file
)
if
extension
not
in
ext_filter
:
if
extension
not
in
ext_filter
:
...
...
modules/sd_disable_initialization.py
0 → 100644
View file @
50fb20ce
import
ldm.modules.encoders.modules
import
open_clip
import
torch
import
transformers.utils.hub
class
DisableInitialization
:
"""
When an object of this class enters a `with` block, it starts:
- preventing torch's layer initialization functions from working
- changes CLIP and OpenCLIP to not download model weights
- changes CLIP to not make requests to check if there is a new version of a file you already have
When it leaves the block, it reverts everything to how it was before.
Use it like this:
```
with DisableInitialization():
do_things()
```
"""
def
__enter__
(
self
):
def
do_nothing
(
*
args
,
**
kwargs
):
pass
def
create_model_and_transforms_without_pretrained
(
*
args
,
pretrained
=
None
,
**
kwargs
):
return
self
.
create_model_and_transforms
(
*
args
,
pretrained
=
None
,
**
kwargs
)
def
CLIPTextModel_from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
return
self
.
CLIPTextModel_from_pretrained
(
None
,
*
model_args
,
config
=
pretrained_model_name_or_path
,
state_dict
=
{},
**
kwargs
)
def
transformers_modeling_utils_load_pretrained_model
(
*
args
,
**
kwargs
):
args
=
args
[
0
:
3
]
+
(
'/'
,
)
+
args
[
4
:]
# resolved_archive_file; must set it to something to prevent what seems to be a bug
return
self
.
transformers_modeling_utils_load_pretrained_model
(
*
args
,
**
kwargs
)
def
transformers_utils_hub_get_file_from_cache
(
original
,
url
,
*
args
,
**
kwargs
):
# this file is always 404, prevent making request
if
url
==
'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json'
:
raise
transformers
.
utils
.
hub
.
EntryNotFoundError
try
:
return
original
(
url
,
*
args
,
local_files_only
=
True
,
**
kwargs
)
except
Exception
as
e
:
return
original
(
url
,
*
args
,
local_files_only
=
False
,
**
kwargs
)
def
transformers_utils_hub_get_from_cache
(
url
,
*
args
,
local_files_only
=
False
,
**
kwargs
):
return
transformers_utils_hub_get_file_from_cache
(
self
.
transformers_utils_hub_get_from_cache
,
url
,
*
args
,
**
kwargs
)
def
transformers_tokenization_utils_base_cached_file
(
url
,
*
args
,
local_files_only
=
False
,
**
kwargs
):
return
transformers_utils_hub_get_file_from_cache
(
self
.
transformers_tokenization_utils_base_cached_file
,
url
,
*
args
,
**
kwargs
)
def
transformers_configuration_utils_cached_file
(
url
,
*
args
,
local_files_only
=
False
,
**
kwargs
):
return
transformers_utils_hub_get_file_from_cache
(
self
.
transformers_configuration_utils_cached_file
,
url
,
*
args
,
**
kwargs
)
self
.
init_kaiming_uniform
=
torch
.
nn
.
init
.
kaiming_uniform_
self
.
init_no_grad_normal
=
torch
.
nn
.
init
.
_no_grad_normal_
self
.
init_no_grad_uniform_
=
torch
.
nn
.
init
.
_no_grad_uniform_
self
.
create_model_and_transforms
=
open_clip
.
create_model_and_transforms
self
.
CLIPTextModel_from_pretrained
=
ldm
.
modules
.
encoders
.
modules
.
CLIPTextModel
.
from_pretrained
self
.
transformers_modeling_utils_load_pretrained_model
=
getattr
(
transformers
.
modeling_utils
.
PreTrainedModel
,
'_load_pretrained_model'
,
None
)
self
.
transformers_tokenization_utils_base_cached_file
=
getattr
(
transformers
.
tokenization_utils_base
,
'cached_file'
,
None
)
self
.
transformers_configuration_utils_cached_file
=
getattr
(
transformers
.
configuration_utils
,
'cached_file'
,
None
)
self
.
transformers_utils_hub_get_from_cache
=
getattr
(
transformers
.
utils
.
hub
,
'get_from_cache'
,
None
)
torch
.
nn
.
init
.
kaiming_uniform_
=
do_nothing
torch
.
nn
.
init
.
_no_grad_normal_
=
do_nothing
torch
.
nn
.
init
.
_no_grad_uniform_
=
do_nothing
open_clip
.
create_model_and_transforms
=
create_model_and_transforms_without_pretrained
ldm
.
modules
.
encoders
.
modules
.
CLIPTextModel
.
from_pretrained
=
CLIPTextModel_from_pretrained
if
self
.
transformers_modeling_utils_load_pretrained_model
is
not
None
:
transformers
.
modeling_utils
.
PreTrainedModel
.
_load_pretrained_model
=
transformers_modeling_utils_load_pretrained_model
if
self
.
transformers_tokenization_utils_base_cached_file
is
not
None
:
transformers
.
tokenization_utils_base
.
cached_file
=
transformers_tokenization_utils_base_cached_file
if
self
.
transformers_configuration_utils_cached_file
is
not
None
:
transformers
.
configuration_utils
.
cached_file
=
transformers_configuration_utils_cached_file
if
self
.
transformers_utils_hub_get_from_cache
is
not
None
:
transformers
.
utils
.
hub
.
get_from_cache
=
transformers_utils_hub_get_from_cache
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
torch
.
nn
.
init
.
kaiming_uniform_
=
self
.
init_kaiming_uniform
torch
.
nn
.
init
.
_no_grad_normal_
=
self
.
init_no_grad_normal
torch
.
nn
.
init
.
_no_grad_uniform_
=
self
.
init_no_grad_uniform_
open_clip
.
create_model_and_transforms
=
self
.
create_model_and_transforms
ldm
.
modules
.
encoders
.
modules
.
CLIPTextModel
.
from_pretrained
=
self
.
CLIPTextModel_from_pretrained
if
self
.
transformers_modeling_utils_load_pretrained_model
is
not
None
:
transformers
.
modeling_utils
.
PreTrainedModel
.
_load_pretrained_model
=
self
.
transformers_modeling_utils_load_pretrained_model
if
self
.
transformers_tokenization_utils_base_cached_file
is
not
None
:
transformers
.
utils
.
hub
.
cached_file
=
self
.
transformers_tokenization_utils_base_cached_file
if
self
.
transformers_configuration_utils_cached_file
is
not
None
:
transformers
.
utils
.
hub
.
cached_file
=
self
.
transformers_configuration_utils_cached_file
if
self
.
transformers_utils_hub_get_from_cache
is
not
None
:
transformers
.
utils
.
hub
.
get_from_cache
=
self
.
transformers_utils_hub_get_from_cache
modules/sd_models.py
View file @
50fb20ce
...
@@ -2,6 +2,7 @@ import collections
...
@@ -2,6 +2,7 @@ import collections
import
os.path
import
os.path
import
sys
import
sys
import
gc
import
gc
import
time
from
collections
import
namedtuple
from
collections
import
namedtuple
import
torch
import
torch
import
re
import
re
...
@@ -13,7 +14,7 @@ import ldm.modules.midas as midas
...
@@ -13,7 +14,7 @@ import ldm.modules.midas as midas
from
ldm.util
import
instantiate_from_config
from
ldm.util
import
instantiate_from_config
from
modules
import
shared
,
modelloader
,
devices
,
script_callbacks
,
sd_vae
from
modules
import
shared
,
modelloader
,
devices
,
script_callbacks
,
sd_vae
,
sd_disable_initialization
,
errors
from
modules.paths
import
models_path
from
modules.paths
import
models_path
from
modules.sd_hijack_inpainting
import
do_inpainting_hijack
,
should_hijack_inpainting
from
modules.sd_hijack_inpainting
import
do_inpainting_hijack
,
should_hijack_inpainting
...
@@ -61,7 +62,7 @@ def find_checkpoint_config(info):
...
@@ -61,7 +62,7 @@ def find_checkpoint_config(info):
def
list_models
():
def
list_models
():
checkpoints_list
.
clear
()
checkpoints_list
.
clear
()
model_list
=
modelloader
.
load_models
(
model_path
=
model_path
,
command_path
=
shared
.
cmd_opts
.
ckpt_dir
,
ext_filter
=
[
".ckpt"
,
".safetensors"
])
model_list
=
modelloader
.
load_models
(
model_path
=
model_path
,
command_path
=
shared
.
cmd_opts
.
ckpt_dir
,
ext_filter
=
[
".ckpt"
,
".safetensors"
]
,
ext_blacklist
=
[
".vae.safetensors"
]
)
def
modeltitle
(
path
,
shorthash
):
def
modeltitle
(
path
,
shorthash
):
abspath
=
os
.
path
.
abspath
(
path
)
abspath
=
os
.
path
.
abspath
(
path
)
...
@@ -288,6 +289,17 @@ def enable_midas_autodownload():
...
@@ -288,6 +289,17 @@ def enable_midas_autodownload():
midas
.
api
.
load_model
=
load_model_wrapper
midas
.
api
.
load_model
=
load_model_wrapper
class
Timer
:
def
__init__
(
self
):
self
.
start
=
time
.
time
()
def
elapsed
(
self
):
end
=
time
.
time
()
res
=
end
-
self
.
start
self
.
start
=
end
return
res
def
load_model
(
checkpoint_info
=
None
):
def
load_model
(
checkpoint_info
=
None
):
from
modules
import
lowvram
,
sd_hijack
from
modules
import
lowvram
,
sd_hijack
checkpoint_info
=
checkpoint_info
or
select_checkpoint
()
checkpoint_info
=
checkpoint_info
or
select_checkpoint
()
...
@@ -319,10 +331,21 @@ def load_model(checkpoint_info=None):
...
@@ -319,10 +331,21 @@ def load_model(checkpoint_info=None):
if
shared
.
cmd_opts
.
no_half
:
if
shared
.
cmd_opts
.
no_half
:
sd_config
.
model
.
params
.
unet_config
.
params
.
use_fp16
=
False
sd_config
.
model
.
params
.
unet_config
.
params
.
use_fp16
=
False
sd_model
=
instantiate_from_config
(
sd_config
.
model
)
timer
=
Timer
()
try
:
with
sd_disable_initialization
.
DisableInitialization
():
sd_model
=
instantiate_from_config
(
sd_config
.
model
)
except
Exception
as
e
:
print
(
'Failed to create model quickly; will retry using slow method.'
,
file
=
sys
.
stderr
)
sd_model
=
instantiate_from_config
(
sd_config
.
model
)
elapsed_create
=
timer
.
elapsed
()
load_model_weights
(
sd_model
,
checkpoint_info
)
load_model_weights
(
sd_model
,
checkpoint_info
)
elapsed_load_weights
=
timer
.
elapsed
()
if
shared
.
cmd_opts
.
lowvram
or
shared
.
cmd_opts
.
medvram
:
if
shared
.
cmd_opts
.
lowvram
or
shared
.
cmd_opts
.
medvram
:
lowvram
.
setup_for_low_vram
(
sd_model
,
shared
.
cmd_opts
.
medvram
)
lowvram
.
setup_for_low_vram
(
sd_model
,
shared
.
cmd_opts
.
medvram
)
else
:
else
:
...
@@ -337,7 +360,9 @@ def load_model(checkpoint_info=None):
...
@@ -337,7 +360,9 @@ def load_model(checkpoint_info=None):
script_callbacks
.
model_loaded_callback
(
sd_model
)
script_callbacks
.
model_loaded_callback
(
sd_model
)
print
(
"Model loaded."
)
elapsed_the_rest
=
timer
.
elapsed
()
print
(
f
"Model loaded in {elapsed_create + elapsed_load_weights + elapsed_the_rest:.1f}s ({elapsed_create:.1f}s create model, {elapsed_load_weights:.1f}s load weights)."
)
return
sd_model
return
sd_model
...
@@ -348,7 +373,7 @@ def reload_model_weights(sd_model=None, info=None):
...
@@ -348,7 +373,7 @@ def reload_model_weights(sd_model=None, info=None):
if
not
sd_model
:
if
not
sd_model
:
sd_model
=
shared
.
sd_model
sd_model
=
shared
.
sd_model
if
sd_model
is
None
:
# previous model load failed
if
sd_model
is
None
:
# previous model load failed
current_checkpoint_info
=
None
current_checkpoint_info
=
None
else
:
else
:
current_checkpoint_info
=
sd_model
.
sd_checkpoint_info
current_checkpoint_info
=
sd_model
.
sd_checkpoint_info
...
@@ -370,6 +395,8 @@ def reload_model_weights(sd_model=None, info=None):
...
@@ -370,6 +395,8 @@ def reload_model_weights(sd_model=None, info=None):
sd_hijack
.
model_hijack
.
undo_hijack
(
sd_model
)
sd_hijack
.
model_hijack
.
undo_hijack
(
sd_model
)
timer
=
Timer
()
try
:
try
:
load_model_weights
(
sd_model
,
checkpoint_info
)
load_model_weights
(
sd_model
,
checkpoint_info
)
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -383,6 +410,8 @@ def reload_model_weights(sd_model=None, info=None):
...
@@ -383,6 +410,8 @@ def reload_model_weights(sd_model=None, info=None):
if
not
shared
.
cmd_opts
.
lowvram
and
not
shared
.
cmd_opts
.
medvram
:
if
not
shared
.
cmd_opts
.
lowvram
and
not
shared
.
cmd_opts
.
medvram
:
sd_model
.
to
(
devices
.
device
)
sd_model
.
to
(
devices
.
device
)
print
(
"Weights loaded."
)
elapsed
=
timer
.
elapsed
()
print
(
f
"Weights loaded in {elapsed:.1f}s."
)
return
sd_model
return
sd_model
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment