Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
50fb20ce
Commit
50fb20ce
authored
Jan 10, 2023
by
AUTOMATIC
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'disable_initialization'
parents
a0ef416a
0f8603a5
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
133 additions
and
7 deletions
+133
-7
modelloader.py
modules/modelloader.py
+3
-1
sd_disable_initialization.py
modules/sd_disable_initialization.py
+95
-0
sd_models.py
modules/sd_models.py
+35
-6
No files found.
modules/modelloader.py
View file @
50fb20ce
...
...
@@ -10,7 +10,7 @@ from modules.upscaler import Upscaler
from
modules.paths
import
script_path
,
models_path
def
load_models
(
model_path
:
str
,
model_url
:
str
=
None
,
command_path
:
str
=
None
,
ext_filter
=
None
,
download_name
=
None
)
->
list
:
def
load_models
(
model_path
:
str
,
model_url
:
str
=
None
,
command_path
:
str
=
None
,
ext_filter
=
None
,
download_name
=
None
,
ext_blacklist
=
None
)
->
list
:
"""
A one-and done loader to try finding the desired models in specified directories.
...
...
@@ -45,6 +45,8 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
full_path
=
file
if
os
.
path
.
isdir
(
full_path
):
continue
if
ext_blacklist
is
not
None
and
any
([
full_path
.
endswith
(
x
)
for
x
in
ext_blacklist
]):
continue
if
len
(
ext_filter
)
!=
0
:
model_name
,
extension
=
os
.
path
.
splitext
(
file
)
if
extension
not
in
ext_filter
:
...
...
modules/sd_disable_initialization.py
0 → 100644
View file @
50fb20ce
import
ldm.modules.encoders.modules
import
open_clip
import
torch
import
transformers.utils.hub
class
DisableInitialization
:
"""
When an object of this class enters a `with` block, it starts:
- preventing torch's layer initialization functions from working
- changes CLIP and OpenCLIP to not download model weights
- changes CLIP to not make requests to check if there is a new version of a file you already have
When it leaves the block, it reverts everything to how it was before.
Use it like this:
```
with DisableInitialization():
do_things()
```
"""
def
__enter__
(
self
):
def
do_nothing
(
*
args
,
**
kwargs
):
pass
def
create_model_and_transforms_without_pretrained
(
*
args
,
pretrained
=
None
,
**
kwargs
):
return
self
.
create_model_and_transforms
(
*
args
,
pretrained
=
None
,
**
kwargs
)
def
CLIPTextModel_from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
return
self
.
CLIPTextModel_from_pretrained
(
None
,
*
model_args
,
config
=
pretrained_model_name_or_path
,
state_dict
=
{},
**
kwargs
)
def
transformers_modeling_utils_load_pretrained_model
(
*
args
,
**
kwargs
):
args
=
args
[
0
:
3
]
+
(
'/'
,
)
+
args
[
4
:]
# resolved_archive_file; must set it to something to prevent what seems to be a bug
return
self
.
transformers_modeling_utils_load_pretrained_model
(
*
args
,
**
kwargs
)
def
transformers_utils_hub_get_file_from_cache
(
original
,
url
,
*
args
,
**
kwargs
):
# this file is always 404, prevent making request
if
url
==
'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json'
:
raise
transformers
.
utils
.
hub
.
EntryNotFoundError
try
:
return
original
(
url
,
*
args
,
local_files_only
=
True
,
**
kwargs
)
except
Exception
as
e
:
return
original
(
url
,
*
args
,
local_files_only
=
False
,
**
kwargs
)
def
transformers_utils_hub_get_from_cache
(
url
,
*
args
,
local_files_only
=
False
,
**
kwargs
):
return
transformers_utils_hub_get_file_from_cache
(
self
.
transformers_utils_hub_get_from_cache
,
url
,
*
args
,
**
kwargs
)
def
transformers_tokenization_utils_base_cached_file
(
url
,
*
args
,
local_files_only
=
False
,
**
kwargs
):
return
transformers_utils_hub_get_file_from_cache
(
self
.
transformers_tokenization_utils_base_cached_file
,
url
,
*
args
,
**
kwargs
)
def
transformers_configuration_utils_cached_file
(
url
,
*
args
,
local_files_only
=
False
,
**
kwargs
):
return
transformers_utils_hub_get_file_from_cache
(
self
.
transformers_configuration_utils_cached_file
,
url
,
*
args
,
**
kwargs
)
self
.
init_kaiming_uniform
=
torch
.
nn
.
init
.
kaiming_uniform_
self
.
init_no_grad_normal
=
torch
.
nn
.
init
.
_no_grad_normal_
self
.
init_no_grad_uniform_
=
torch
.
nn
.
init
.
_no_grad_uniform_
self
.
create_model_and_transforms
=
open_clip
.
create_model_and_transforms
self
.
CLIPTextModel_from_pretrained
=
ldm
.
modules
.
encoders
.
modules
.
CLIPTextModel
.
from_pretrained
self
.
transformers_modeling_utils_load_pretrained_model
=
getattr
(
transformers
.
modeling_utils
.
PreTrainedModel
,
'_load_pretrained_model'
,
None
)
self
.
transformers_tokenization_utils_base_cached_file
=
getattr
(
transformers
.
tokenization_utils_base
,
'cached_file'
,
None
)
self
.
transformers_configuration_utils_cached_file
=
getattr
(
transformers
.
configuration_utils
,
'cached_file'
,
None
)
self
.
transformers_utils_hub_get_from_cache
=
getattr
(
transformers
.
utils
.
hub
,
'get_from_cache'
,
None
)
torch
.
nn
.
init
.
kaiming_uniform_
=
do_nothing
torch
.
nn
.
init
.
_no_grad_normal_
=
do_nothing
torch
.
nn
.
init
.
_no_grad_uniform_
=
do_nothing
open_clip
.
create_model_and_transforms
=
create_model_and_transforms_without_pretrained
ldm
.
modules
.
encoders
.
modules
.
CLIPTextModel
.
from_pretrained
=
CLIPTextModel_from_pretrained
if
self
.
transformers_modeling_utils_load_pretrained_model
is
not
None
:
transformers
.
modeling_utils
.
PreTrainedModel
.
_load_pretrained_model
=
transformers_modeling_utils_load_pretrained_model
if
self
.
transformers_tokenization_utils_base_cached_file
is
not
None
:
transformers
.
tokenization_utils_base
.
cached_file
=
transformers_tokenization_utils_base_cached_file
if
self
.
transformers_configuration_utils_cached_file
is
not
None
:
transformers
.
configuration_utils
.
cached_file
=
transformers_configuration_utils_cached_file
if
self
.
transformers_utils_hub_get_from_cache
is
not
None
:
transformers
.
utils
.
hub
.
get_from_cache
=
transformers_utils_hub_get_from_cache
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
torch
.
nn
.
init
.
kaiming_uniform_
=
self
.
init_kaiming_uniform
torch
.
nn
.
init
.
_no_grad_normal_
=
self
.
init_no_grad_normal
torch
.
nn
.
init
.
_no_grad_uniform_
=
self
.
init_no_grad_uniform_
open_clip
.
create_model_and_transforms
=
self
.
create_model_and_transforms
ldm
.
modules
.
encoders
.
modules
.
CLIPTextModel
.
from_pretrained
=
self
.
CLIPTextModel_from_pretrained
if
self
.
transformers_modeling_utils_load_pretrained_model
is
not
None
:
transformers
.
modeling_utils
.
PreTrainedModel
.
_load_pretrained_model
=
self
.
transformers_modeling_utils_load_pretrained_model
if
self
.
transformers_tokenization_utils_base_cached_file
is
not
None
:
transformers
.
utils
.
hub
.
cached_file
=
self
.
transformers_tokenization_utils_base_cached_file
if
self
.
transformers_configuration_utils_cached_file
is
not
None
:
transformers
.
utils
.
hub
.
cached_file
=
self
.
transformers_configuration_utils_cached_file
if
self
.
transformers_utils_hub_get_from_cache
is
not
None
:
transformers
.
utils
.
hub
.
get_from_cache
=
self
.
transformers_utils_hub_get_from_cache
modules/sd_models.py
View file @
50fb20ce
...
...
@@ -2,6 +2,7 @@ import collections
import
os.path
import
sys
import
gc
import
time
from
collections
import
namedtuple
import
torch
import
re
...
...
@@ -13,7 +14,7 @@ import ldm.modules.midas as midas
from
ldm.util
import
instantiate_from_config
from
modules
import
shared
,
modelloader
,
devices
,
script_callbacks
,
sd_vae
from
modules
import
shared
,
modelloader
,
devices
,
script_callbacks
,
sd_vae
,
sd_disable_initialization
,
errors
from
modules.paths
import
models_path
from
modules.sd_hijack_inpainting
import
do_inpainting_hijack
,
should_hijack_inpainting
...
...
@@ -61,7 +62,7 @@ def find_checkpoint_config(info):
def
list_models
():
checkpoints_list
.
clear
()
model_list
=
modelloader
.
load_models
(
model_path
=
model_path
,
command_path
=
shared
.
cmd_opts
.
ckpt_dir
,
ext_filter
=
[
".ckpt"
,
".safetensors"
])
model_list
=
modelloader
.
load_models
(
model_path
=
model_path
,
command_path
=
shared
.
cmd_opts
.
ckpt_dir
,
ext_filter
=
[
".ckpt"
,
".safetensors"
]
,
ext_blacklist
=
[
".vae.safetensors"
]
)
def
modeltitle
(
path
,
shorthash
):
abspath
=
os
.
path
.
abspath
(
path
)
...
...
@@ -288,6 +289,17 @@ def enable_midas_autodownload():
midas
.
api
.
load_model
=
load_model_wrapper
class
Timer
:
def
__init__
(
self
):
self
.
start
=
time
.
time
()
def
elapsed
(
self
):
end
=
time
.
time
()
res
=
end
-
self
.
start
self
.
start
=
end
return
res
def
load_model
(
checkpoint_info
=
None
):
from
modules
import
lowvram
,
sd_hijack
checkpoint_info
=
checkpoint_info
or
select_checkpoint
()
...
...
@@ -319,10 +331,21 @@ def load_model(checkpoint_info=None):
if
shared
.
cmd_opts
.
no_half
:
sd_config
.
model
.
params
.
unet_config
.
params
.
use_fp16
=
False
sd_model
=
instantiate_from_config
(
sd_config
.
model
)
timer
=
Timer
()
try
:
with
sd_disable_initialization
.
DisableInitialization
():
sd_model
=
instantiate_from_config
(
sd_config
.
model
)
except
Exception
as
e
:
print
(
'Failed to create model quickly; will retry using slow method.'
,
file
=
sys
.
stderr
)
sd_model
=
instantiate_from_config
(
sd_config
.
model
)
elapsed_create
=
timer
.
elapsed
()
load_model_weights
(
sd_model
,
checkpoint_info
)
elapsed_load_weights
=
timer
.
elapsed
()
if
shared
.
cmd_opts
.
lowvram
or
shared
.
cmd_opts
.
medvram
:
lowvram
.
setup_for_low_vram
(
sd_model
,
shared
.
cmd_opts
.
medvram
)
else
:
...
...
@@ -337,7 +360,9 @@ def load_model(checkpoint_info=None):
script_callbacks
.
model_loaded_callback
(
sd_model
)
print
(
"Model loaded."
)
elapsed_the_rest
=
timer
.
elapsed
()
print
(
f
"Model loaded in {elapsed_create + elapsed_load_weights + elapsed_the_rest:.1f}s ({elapsed_create:.1f}s create model, {elapsed_load_weights:.1f}s load weights)."
)
return
sd_model
...
...
@@ -348,7 +373,7 @@ def reload_model_weights(sd_model=None, info=None):
if
not
sd_model
:
sd_model
=
shared
.
sd_model
if
sd_model
is
None
:
# previous model load failed
if
sd_model
is
None
:
# previous model load failed
current_checkpoint_info
=
None
else
:
current_checkpoint_info
=
sd_model
.
sd_checkpoint_info
...
...
@@ -370,6 +395,8 @@ def reload_model_weights(sd_model=None, info=None):
sd_hijack
.
model_hijack
.
undo_hijack
(
sd_model
)
timer
=
Timer
()
try
:
load_model_weights
(
sd_model
,
checkpoint_info
)
except
Exception
as
e
:
...
...
@@ -383,6 +410,8 @@ def reload_model_weights(sd_model=None, info=None):
if
not
shared
.
cmd_opts
.
lowvram
and
not
shared
.
cmd_opts
.
medvram
:
sd_model
.
to
(
devices
.
device
)
print
(
"Weights loaded."
)
elapsed
=
timer
.
elapsed
()
print
(
f
"Weights loaded in {elapsed:.1f}s."
)
return
sd_model
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment