Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
c24a314c
Unverified
Commit
c24a314c
authored
Dec 31, 2022
by
AUTOMATIC1111
Committed by
GitHub
Dec 31, 2022
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #6149 from vladmandic/validate-embeddings
validate textual inversion embeddings
parents
f378b8d5
f55ac33d
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
7 deletions
+41
-7
sd_models.py
modules/sd_models.py
+3
-0
textual_inversion.py
modules/textual_inversion/textual_inversion.py
+38
-5
ui.py
modules/ui.py
+0
-2
No files found.
modules/sd_models.py
View file @
c24a314c
...
...
@@ -325,6 +325,9 @@ def load_model(checkpoint_info=None):
script_callbacks
.
model_loaded_callback
(
sd_model
)
print
(
"Model loaded."
)
sd_hijack
.
model_hijack
.
embedding_db
.
load_textual_inversion_embeddings
(
force_reload
=
True
)
# Reload embeddings after model load as they may or may not fit the model
return
sd_model
...
...
modules/textual_inversion/textual_inversion.py
View file @
c24a314c
...
...
@@ -23,6 +23,8 @@ class Embedding:
self
.
vec
=
vec
self
.
name
=
name
self
.
step
=
step
self
.
shape
=
None
self
.
vectors
=
0
self
.
cached_checksum
=
None
self
.
sd_checkpoint
=
None
self
.
sd_checkpoint_name
=
None
...
...
@@ -57,8 +59,10 @@ class EmbeddingDatabase:
def
__init__
(
self
,
embeddings_dir
):
self
.
ids_lookup
=
{}
self
.
word_embeddings
=
{}
self
.
skipped_embeddings
=
[]
self
.
dir_mtime
=
None
self
.
embeddings_dir
=
embeddings_dir
self
.
expected_shape
=
-
1
def
register_embedding
(
self
,
embedding
,
model
):
...
...
@@ -75,14 +79,35 @@ class EmbeddingDatabase:
return
embedding
def
load_textual_inversion_embeddings
(
self
):
def
get_expected_shape
(
self
):
expected_shape
=
-
1
# initialize with unknown
idx
=
torch
.
tensor
(
0
)
.
to
(
shared
.
device
)
if
expected_shape
==
-
1
:
try
:
# matches sd15 signature
first_embedding
=
shared
.
sd_model
.
cond_stage_model
.
wrapped
.
transformer
.
text_model
.
embeddings
.
token_embedding
.
wrapped
(
idx
)
expected_shape
=
first_embedding
.
shape
[
0
]
except
:
pass
if
expected_shape
==
-
1
:
try
:
# matches sd20 signature
first_embedding
=
shared
.
sd_model
.
cond_stage_model
.
wrapped
.
model
.
token_embedding
.
wrapped
(
idx
)
expected_shape
=
first_embedding
.
shape
[
0
]
except
:
pass
if
expected_shape
==
-
1
:
print
(
'Could not determine expected embeddings shape from model'
)
return
expected_shape
def
load_textual_inversion_embeddings
(
self
,
force_reload
=
False
):
mt
=
os
.
path
.
getmtime
(
self
.
embeddings_dir
)
if
self
.
dir_mtime
is
not
None
and
mt
<=
self
.
dir_mtime
:
if
not
force_reload
and
self
.
dir_mtime
is
not
None
and
mt
<=
self
.
dir_mtime
:
return
self
.
dir_mtime
=
mt
self
.
ids_lookup
.
clear
()
self
.
word_embeddings
.
clear
()
self
.
skipped_embeddings
=
[]
self
.
expected_shape
=
self
.
get_expected_shape
()
def
process_file
(
path
,
filename
):
name
=
os
.
path
.
splitext
(
filename
)[
0
]
...
...
@@ -122,7 +147,14 @@ class EmbeddingDatabase:
embedding
.
step
=
data
.
get
(
'step'
,
None
)
embedding
.
sd_checkpoint
=
data
.
get
(
'sd_checkpoint'
,
None
)
embedding
.
sd_checkpoint_name
=
data
.
get
(
'sd_checkpoint_name'
,
None
)
self
.
register_embedding
(
embedding
,
shared
.
sd_model
)
embedding
.
vectors
=
vec
.
shape
[
0
]
embedding
.
shape
=
vec
.
shape
[
-
1
]
if
(
self
.
expected_shape
==
-
1
)
or
(
self
.
expected_shape
==
embedding
.
shape
):
self
.
register_embedding
(
embedding
,
shared
.
sd_model
)
else
:
self
.
skipped_embeddings
.
append
(
name
)
# print('Skipping embedding {name}: shape was {shape} expected {expected}'.format(name = name, shape = embedding.shape, expected = self.expected_shape))
for
fn
in
os
.
listdir
(
self
.
embeddings_dir
):
try
:
...
...
@@ -137,8 +169,9 @@ class EmbeddingDatabase:
print
(
traceback
.
format_exc
(),
file
=
sys
.
stderr
)
continue
print
(
f
"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings."
)
print
(
"Embeddings:"
,
', '
.
join
(
self
.
word_embeddings
.
keys
()))
print
(
"Textual inversion embeddings {num} loaded: {val}"
.
format
(
num
=
len
(
self
.
word_embeddings
),
val
=
', '
.
join
(
self
.
word_embeddings
.
keys
())))
if
(
len
(
self
.
skipped_embeddings
)
>
0
):
print
(
"Textual inversion embeddings {num} skipped: {val}"
.
format
(
num
=
len
(
self
.
skipped_embeddings
),
val
=
', '
.
join
(
self
.
skipped_embeddings
)))
def
find_embedding_at_position
(
self
,
tokens
,
offset
):
token
=
tokens
[
offset
]
...
...
modules/ui.py
View file @
c24a314c
...
...
@@ -1157,8 +1157,6 @@ def create_ui():
with
gr
.
Column
(
variant
=
'panel'
):
submit_result
=
gr
.
Textbox
(
elem_id
=
"modelmerger_result"
,
show_label
=
False
)
sd_hijack
.
model_hijack
.
embedding_db
.
load_textual_inversion_embeddings
()
with
gr
.
Blocks
(
analytics_enabled
=
False
)
as
train_interface
:
with
gr
.
Row
()
.
style
(
equal_height
=
False
):
gr
.
HTML
(
value
=
"<p style='margin-bottom: 0.7em'>See <b><a href=
\"
https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion
\"
>wiki</a></b> for detailed explanation.</p>"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment