Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
79e39fae
Commit
79e39fae
authored
Jan 06, 2023
by
AUTOMATIC
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
CLIP hijack rework
parent
3246a2d6
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
256 additions
and
182 deletions
+256
-182
sd_hijack.py
modules/sd_hijack.py
+3
-3
sd_hijack_clip.py
modules/sd_hijack_clip.py
+171
-177
sd_hijack_clip_old.py
modules/sd_hijack_clip_old.py
+81
-0
textual_inversion.py
modules/textual_inversion/textual_inversion.py
+0
-1
ui.py
modules/ui.py
+1
-1
No files found.
modules/sd_hijack.py
View file @
79e39fae
...
@@ -150,10 +150,10 @@ class StableDiffusionModelHijack:
...
@@ -150,10 +150,10 @@ class StableDiffusionModelHijack:
def
clear_comments
(
self
):
def
clear_comments
(
self
):
self
.
comments
=
[]
self
.
comments
=
[]
def
tokenize
(
self
,
text
):
def
get_prompt_lengths
(
self
,
text
):
_
,
remade_batch_tokens
,
_
,
_
,
_
,
token_count
=
self
.
clip
.
process_text
([
text
])
_
,
token_count
=
self
.
clip
.
process_texts
([
text
])
return
remade_batch_tokens
[
0
],
token_count
,
sd_hijack_
clip
.
get_target_prompt_token_count
(
token_count
)
return
token_count
,
self
.
clip
.
get_target_prompt_token_count
(
token_count
)
class
EmbeddingsWithFixes
(
torch
.
nn
.
Module
):
class
EmbeddingsWithFixes
(
torch
.
nn
.
Module
):
...
...
modules/sd_hijack_clip.py
View file @
79e39fae
import
math
import
math
from
collections
import
namedtuple
import
torch
import
torch
from
modules
import
prompt_parser
,
devices
from
modules
import
prompt_parser
,
devices
from
modules.shared
import
opts
from
modules.shared
import
opts
def
get_target_prompt_token_count
(
token_count
):
return
math
.
ceil
(
max
(
token_count
,
1
)
/
75
)
*
75
class
PromptChunk
:
"""
This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt.
If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary.
Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token,
so just 75 tokens from prompt.
"""
def
__init__
(
self
):
self
.
tokens
=
[]
self
.
multipliers
=
[]
self
.
fixes
=
[]
PromptChunkFix
=
namedtuple
(
'PromptChunkFix'
,
[
'offset'
,
'embedding'
])
"""This is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt chunk"""
class
FrozenCLIPEmbedderWithCustomWordsBase
(
torch
.
nn
.
Module
):
class
FrozenCLIPEmbedderWithCustomWordsBase
(
torch
.
nn
.
Module
):
...
@@ -14,17 +30,49 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
...
@@ -14,17 +30,49 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
super
()
.
__init__
()
super
()
.
__init__
()
self
.
wrapped
=
wrapped
self
.
wrapped
=
wrapped
self
.
hijack
=
hijack
self
.
hijack
=
hijack
self
.
chunk_length
=
75
def
empty_chunk
(
self
):
"""creates an empty PromptChunk and returns it"""
chunk
=
PromptChunk
()
chunk
.
tokens
=
[
self
.
id_start
]
+
[
self
.
id_end
]
*
(
self
.
chunk_length
+
1
)
chunk
.
multipliers
=
[
1.0
]
*
(
self
.
chunk_length
+
2
)
return
chunk
def
get_target_prompt_token_count
(
self
,
token_count
):
"""returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented"""
return
math
.
ceil
(
max
(
token_count
,
1
)
/
self
.
chunk_length
)
*
self
.
chunk_length
def
tokenize
(
self
,
texts
):
def
tokenize
(
self
,
texts
):
"""Converts a batch of texts into a batch of token ids"""
raise
NotImplementedError
raise
NotImplementedError
def
encode_with_transformers
(
self
,
tokens
):
def
encode_with_transformers
(
self
,
tokens
):
"""
converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens;
All python lists with tokens are assumed to have same length, usually 77.
if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on
model - can be 768 and 1024
"""
raise
NotImplementedError
raise
NotImplementedError
def
encode_embedding_init_text
(
self
,
init_text
,
nvpt
):
def
encode_embedding_init_text
(
self
,
init_text
,
nvpt
):
"""Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through
transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned."""
raise
NotImplementedError
raise
NotImplementedError
def
tokenize_line
(
self
,
line
,
used_custom_terms
,
hijack_comments
):
def
tokenize_line
(
self
,
line
):
"""
this transforms a single prompt into a list of PromptChunk objects - as many as needed to
represent the prompt.
Returns the list and the total number of tokens in the prompt.
"""
if
opts
.
enable_emphasis
:
if
opts
.
enable_emphasis
:
parsed
=
prompt_parser
.
parse_prompt_attention
(
line
)
parsed
=
prompt_parser
.
parse_prompt_attention
(
line
)
else
:
else
:
...
@@ -32,205 +80,152 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
...
@@ -32,205 +80,152 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
tokenized
=
self
.
tokenize
([
text
for
text
,
_
in
parsed
])
tokenized
=
self
.
tokenize
([
text
for
text
,
_
in
parsed
])
fixe
s
=
[]
chunk
s
=
[]
remade_tokens
=
[]
chunk
=
PromptChunk
()
multipliers
=
[]
token_count
=
0
last_comma
=
-
1
last_comma
=
-
1
for
tokens
,
(
text
,
weight
)
in
zip
(
tokenized
,
parsed
):
def
next_chunk
():
i
=
0
"""puts current chunk into the list of results and produces the next one - empty"""
while
i
<
len
(
tokens
):
nonlocal
token_count
token
=
tokens
[
i
]
nonlocal
last_comma
nonlocal
chunk
token_count
+=
len
(
chunk
.
tokens
)
to_add
=
self
.
chunk_length
-
len
(
chunk
.
tokens
)
if
to_add
>
0
:
chunk
.
tokens
+=
[
self
.
id_end
]
*
to_add
chunk
.
multipliers
+=
[
1.0
]
*
to_add
chunk
.
tokens
=
[
self
.
id_start
]
+
chunk
.
tokens
+
[
self
.
id_end
]
chunk
.
multipliers
=
[
1.0
]
+
chunk
.
multipliers
+
[
1.0
]
embedding
,
embedding_length_in_tokens
=
self
.
hijack
.
embedding_db
.
find_embedding_at_position
(
tokens
,
i
)
last_comma
=
-
1
chunks
.
append
(
chunk
)
chunk
=
PromptChunk
()
for
tokens
,
(
text
,
weight
)
in
zip
(
tokenized
,
parsed
):
position
=
0
while
position
<
len
(
tokens
):
token
=
tokens
[
position
]
if
token
==
self
.
comma_token
:
if
token
==
self
.
comma_token
:
last_comma
=
len
(
remade_tokens
)
last_comma
=
len
(
chunk
.
tokens
)
elif
opts
.
comma_padding_backtrack
!=
0
and
max
(
len
(
remade_tokens
),
1
)
%
75
==
0
and
last_comma
!=
-
1
and
len
(
remade_tokens
)
-
last_comma
<=
opts
.
comma_padding_backtrack
:
last_comma
+=
1
# this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
reloc_tokens
=
remade_tokens
[
last_comma
:]
# is a setting that specifies that is there is a comma nearby, the text after comma should be moved out of this chunk and into the next.
reloc_mults
=
multipliers
[
last_comma
:]
elif
opts
.
comma_padding_backtrack
!=
0
and
len
(
chunk
.
tokens
)
==
self
.
chunk_length
and
last_comma
!=
-
1
and
len
(
chunk
.
tokens
)
-
last_comma
<=
opts
.
comma_padding_backtrack
:
break_location
=
last_comma
+
1
reloc_tokens
=
chunk
.
tokens
[
break_location
:]
reloc_mults
=
chunk
.
multipliers
[
break_location
:]
remade_tokens
=
remade_tokens
[:
last_comma
]
chunk
.
tokens
=
chunk
.
tokens
[:
break_location
]
length
=
len
(
remade_tokens
)
chunk
.
multipliers
=
chunk
.
multipliers
[:
break_location
]
rem
=
int
(
math
.
ceil
(
length
/
75
))
*
75
-
length
next_chunk
()
remade_tokens
+=
[
self
.
id_end
]
*
rem
+
reloc_tokens
chunk
.
tokens
=
reloc_tokens
multipliers
=
multipliers
[:
last_comma
]
+
[
1.0
]
*
rem
+
reloc_mults
chunk
.
multipliers
=
reloc_mults
if
len
(
chunk
.
tokens
)
==
self
.
chunk_length
:
next_chunk
()
embedding
,
embedding_length_in_tokens
=
self
.
hijack
.
embedding_db
.
find_embedding_at_position
(
tokens
,
position
)
if
embedding
is
None
:
if
embedding
is
None
:
remade_tokens
.
append
(
token
)
chunk
.
tokens
.
append
(
token
)
multipliers
.
append
(
weight
)
chunk
.
multipliers
.
append
(
weight
)
i
+=
1
position
+=
1
else
:
continue
emb_len
=
int
(
embedding
.
vec
.
shape
[
0
])
emb_len
=
int
(
embedding
.
vec
.
shape
[
0
])
iteration
=
len
(
remade_tokens
)
//
75
if
len
(
chunk
.
tokens
)
+
emb_len
>
self
.
chunk_length
:
if
(
len
(
remade_tokens
)
+
emb_len
)
//
75
!=
iteration
:
next_chunk
()
rem
=
(
75
*
(
iteration
+
1
)
-
len
(
remade_tokens
))
remade_tokens
+=
[
self
.
id_end
]
*
rem
chunk
.
fixes
.
append
(
PromptChunkFix
(
len
(
chunk
.
tokens
),
embedding
))
multipliers
+=
[
1.0
]
*
rem
iteration
+=
1
chunk
.
tokens
+=
[
0
]
*
emb_len
fixes
.
append
((
iteration
,
(
len
(
remade_tokens
)
%
75
,
embedding
)))
chunk
.
multipliers
+=
[
weight
]
*
emb_len
remade_tokens
+=
[
0
]
*
emb_len
position
+=
embedding_length_in_tokens
multipliers
+=
[
weight
]
*
emb_len
used_custom_terms
.
append
((
embedding
.
name
,
embedding
.
checksum
()))
if
len
(
chunk
.
tokens
)
>
0
:
i
+=
embedding_length_in_tokens
next_chunk
()
token_count
=
len
(
remade_tokens
)
return
chunks
,
token_count
prompt_target_length
=
get_target_prompt_token_count
(
token_count
)
tokens_to_add
=
prompt_target_length
-
len
(
remade_tokens
)
def
process_texts
(
self
,
texts
):
"""
remade_tokens
=
remade_tokens
+
[
self
.
id_end
]
*
tokens_to_add
Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum
multipliers
=
multipliers
+
[
1.0
]
*
tokens_to_add
length, in tokens, of all texts.
"""
return
remade_tokens
,
fixes
,
multipliers
,
token_count
def
process_text
(
self
,
texts
):
used_custom_terms
=
[]
remade_batch_tokens
=
[]
hijack_comments
=
[]
hijack_fixes
=
[]
token_count
=
0
token_count
=
0
cache
=
{}
cache
=
{}
batch_
multiplier
s
=
[]
batch_
chunk
s
=
[]
for
line
in
texts
:
for
line
in
texts
:
if
line
in
cache
:
if
line
in
cache
:
remade_tokens
,
fixes
,
multiplier
s
=
cache
[
line
]
chunk
s
=
cache
[
line
]
else
:
else
:
remade_tokens
,
fixes
,
multipliers
,
current_token_count
=
self
.
tokenize_line
(
line
,
used_custom_terms
,
hijack_comments
)
chunks
,
current_token_count
=
self
.
tokenize_line
(
line
)
token_count
=
max
(
current_token_count
,
token_count
)
token_count
=
max
(
current_token_count
,
token_count
)
cache
[
line
]
=
(
remade_tokens
,
fixes
,
multipliers
)
cache
[
line
]
=
chunks
remade_batch_tokens
.
append
(
remade_tokens
)
batch_chunks
.
append
(
chunks
)
hijack_fixes
.
append
(
fixes
)
batch_multipliers
.
append
(
multipliers
)
return
batch_
multipliers
,
remade_batch_tokens
,
used_custom_terms
,
hijack_comments
,
hijack_fixe
s
,
token_count
return
batch_
chunk
s
,
token_count
def
process_text_ol
d
(
self
,
texts
):
def
forwar
d
(
self
,
texts
):
id_start
=
self
.
id_start
"""
id_end
=
self
.
id_end
Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
maxlen
=
self
.
wrapped
.
max_length
# you get to stay at 77
Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
used_custom_terms
=
[]
be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
remade_batch_tokens
=
[]
An example shape returned by this function can be: (2, 77, 768).
hijack_comments
=
[]
Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
hijack_fixes
=
[]
is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
token_count
=
0
"""
cache
=
{}
if
opts
.
use_old_emphasis_implementation
:
batch_tokens
=
self
.
tokenize
(
texts
)
import
modules.sd_hijack_clip_old
batch_multipliers
=
[]
return
modules
.
sd_hijack_clip_old
.
forward_old
(
self
,
texts
)
for
tokens
in
batch_tokens
:
tuple_tokens
=
tuple
(
tokens
)
if
tuple_tokens
in
cache
:
batch_chunks
,
token_count
=
self
.
process_texts
(
texts
)
remade_tokens
,
fixes
,
multipliers
=
cache
[
tuple_tokens
]
else
:
fixes
=
[]
remade_tokens
=
[]
multipliers
=
[]
mult
=
1.0
i
=
0
used_embeddings
=
{}
while
i
<
len
(
tokens
):
chunk_count
=
max
([
len
(
x
)
for
x
in
batch_chunks
])
token
=
tokens
[
i
]
embedding
,
embedding_length_in_tokens
=
self
.
hijack
.
embedding_db
.
find_embedding_at_position
(
tokens
,
i
)
zs
=
[]
for
i
in
range
(
chunk_count
):
batch_chunk
=
[
chunks
[
i
]
if
i
<
len
(
chunks
)
else
self
.
empty_chunk
()
for
chunks
in
batch_chunks
]
mult_change
=
self
.
token_mults
.
get
(
token
)
if
opts
.
enable_emphasis
else
None
tokens
=
[
x
.
tokens
for
x
in
batch_chunk
]
if
mult_change
is
not
None
:
multipliers
=
[
x
.
multipliers
for
x
in
batch_chunk
]
mult
*=
mult_change
self
.
hijack
.
fixes
=
[
x
.
fixes
for
x
in
batch_chunk
]
i
+=
1
elif
embedding
is
None
:
remade_tokens
.
append
(
token
)
multipliers
.
append
(
mult
)
i
+=
1
else
:
emb_len
=
int
(
embedding
.
vec
.
shape
[
0
])
fixes
.
append
((
len
(
remade_tokens
),
embedding
))
remade_tokens
+=
[
0
]
*
emb_len
multipliers
+=
[
mult
]
*
emb_len
used_custom_terms
.
append
((
embedding
.
name
,
embedding
.
checksum
()))
i
+=
embedding_length_in_tokens
if
len
(
remade_tokens
)
>
maxlen
-
2
:
vocab
=
{
v
:
k
for
k
,
v
in
self
.
wrapped
.
tokenizer
.
get_vocab
()
.
items
()}
ovf
=
remade_tokens
[
maxlen
-
2
:]
overflowing_words
=
[
vocab
.
get
(
int
(
x
),
""
)
for
x
in
ovf
]
overflowing_text
=
self
.
wrapped
.
tokenizer
.
convert_tokens_to_string
(
''
.
join
(
overflowing_words
))
hijack_comments
.
append
(
f
"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:
\n
{overflowing_text}
\n
"
)
token_count
=
len
(
remade_tokens
)
remade_tokens
=
remade_tokens
+
[
id_end
]
*
(
maxlen
-
2
-
len
(
remade_tokens
))
remade_tokens
=
[
id_start
]
+
remade_tokens
[
0
:
maxlen
-
2
]
+
[
id_end
]
cache
[
tuple_tokens
]
=
(
remade_tokens
,
fixes
,
multipliers
)
multipliers
=
multipliers
+
[
1.0
]
*
(
maxlen
-
2
-
len
(
multipliers
))
multipliers
=
[
1.0
]
+
multipliers
[
0
:
maxlen
-
2
]
+
[
1.0
]
remade_batch_tokens
.
append
(
remade_tokens
)
hijack_fixes
.
append
(
fixes
)
batch_multipliers
.
append
(
multipliers
)
return
batch_multipliers
,
remade_batch_tokens
,
used_custom_terms
,
hijack_comments
,
hijack_fixes
,
token_count
def
forward
(
self
,
text
):
use_old
=
opts
.
use_old_emphasis_implementation
if
use_old
:
batch_multipliers
,
remade_batch_tokens
,
used_custom_terms
,
hijack_comments
,
hijack_fixes
,
token_count
=
self
.
process_text_old
(
text
)
else
:
batch_multipliers
,
remade_batch_tokens
,
used_custom_terms
,
hijack_comments
,
hijack_fixes
,
token_count
=
self
.
process_text
(
text
)
self
.
hijack
.
comments
+=
hijack_comments
if
len
(
used_custom_terms
)
>
0
:
self
.
hijack
.
comments
.
append
(
"Used embeddings: "
+
", "
.
join
([
f
'{word} [{checksum}]'
for
word
,
checksum
in
used_custom_terms
]))
if
use_old
:
self
.
hijack
.
fixes
=
hijack_fixes
return
self
.
process_tokens
(
remade_batch_tokens
,
batch_multipliers
)
z
=
None
i
=
0
while
max
(
map
(
len
,
remade_batch_tokens
))
!=
0
:
rem_tokens
=
[
x
[
75
:]
for
x
in
remade_batch_tokens
]
rem_multipliers
=
[
x
[
75
:]
for
x
in
batch_multipliers
]
self
.
hijack
.
fixes
=
[]
for
unfiltered
in
hijack_fixes
:
fixes
=
[]
for
fix
in
unfiltered
:
if
fix
[
0
]
==
i
:
fixes
.
append
(
fix
[
1
])
self
.
hijack
.
fixes
.
append
(
fixes
)
tokens
=
[]
multipliers
=
[]
for
j
in
range
(
len
(
remade_batch_tokens
)):
if
len
(
remade_batch_tokens
[
j
])
>
0
:
tokens
.
append
(
remade_batch_tokens
[
j
][:
75
])
multipliers
.
append
(
batch_multipliers
[
j
][:
75
])
else
:
tokens
.
append
([
self
.
id_end
]
*
75
)
multipliers
.
append
([
1.0
]
*
75
)
z1
=
self
.
process_tokens
(
tokens
,
multipliers
)
for
fixes
in
self
.
hijack
.
fixes
:
z
=
z1
if
z
is
None
else
torch
.
cat
((
z
,
z1
),
axis
=-
2
)
for
position
,
embedding
in
fixes
:
used_embeddings
[
embedding
.
name
]
=
embedding
remade_batch_tokens
=
rem_tokens
z
=
self
.
process_tokens
(
tokens
,
multipliers
)
batch_multipliers
=
rem_multipliers
zs
.
append
(
z
)
i
+=
1
return
z
if
len
(
used_embeddings
)
>
0
:
embeddings_list
=
", "
.
join
([
f
'{name} [{embedding.checksum()}]'
for
name
,
embedding
in
used_embeddings
.
items
()])
self
.
hijack
.
comments
.
append
(
f
"Used embeddings: {embeddings_list}"
)
def
process_tokens
(
self
,
remade_batch_tokens
,
batch_multipliers
):
return
torch
.
hstack
(
zs
)
if
not
opts
.
use_old_emphasis_implementation
:
remade_batch_tokens
=
[[
self
.
id_start
]
+
x
[:
75
]
+
[
self
.
id_end
]
for
x
in
remade_batch_tokens
]
batch_multipliers
=
[[
1.0
]
+
x
[:
75
]
+
[
1.0
]
for
x
in
batch_multipliers
]
def
process_tokens
(
self
,
remade_batch_tokens
,
batch_multipliers
):
"""
sends one single prompt chunk to be encoded by transformers neural network.
remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually
there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens.
Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier
corresponds to one token.
"""
tokens
=
torch
.
asarray
(
remade_batch_tokens
)
.
to
(
devices
.
device
)
tokens
=
torch
.
asarray
(
remade_batch_tokens
)
.
to
(
devices
.
device
)
# this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones.
if
self
.
id_end
!=
self
.
id_pad
:
if
self
.
id_end
!=
self
.
id_pad
:
for
batch_pos
in
range
(
len
(
remade_batch_tokens
)):
for
batch_pos
in
range
(
len
(
remade_batch_tokens
)):
index
=
remade_batch_tokens
[
batch_pos
]
.
index
(
self
.
id_end
)
index
=
remade_batch_tokens
[
batch_pos
]
.
index
(
self
.
id_end
)
...
@@ -239,8 +234,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
...
@@ -239,8 +234,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
z
=
self
.
encode_with_transformers
(
tokens
)
z
=
self
.
encode_with_transformers
(
tokens
)
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers_of_same_length
=
[
x
+
[
1.0
]
*
(
75
-
len
(
x
))
for
x
in
batch_multipliers
]
batch_multipliers
=
torch
.
asarray
(
batch_multipliers
)
.
to
(
devices
.
device
)
batch_multipliers
=
torch
.
asarray
(
batch_multipliers_of_same_length
)
.
to
(
devices
.
device
)
original_mean
=
z
.
mean
()
original_mean
=
z
.
mean
()
z
*=
batch_multipliers
.
reshape
(
batch_multipliers
.
shape
+
(
1
,))
.
expand
(
z
.
shape
)
z
*=
batch_multipliers
.
reshape
(
batch_multipliers
.
shape
+
(
1
,))
.
expand
(
z
.
shape
)
new_mean
=
z
.
mean
()
new_mean
=
z
.
mean
()
...
...
modules/sd_hijack_clip_old.py
0 → 100644
View file @
79e39fae
from
modules
import
sd_hijack_clip
from
modules
import
shared
def
process_text_old
(
self
:
sd_hijack_clip
.
FrozenCLIPEmbedderWithCustomWordsBase
,
texts
):
id_start
=
self
.
id_start
id_end
=
self
.
id_end
maxlen
=
self
.
wrapped
.
max_length
# you get to stay at 77
used_custom_terms
=
[]
remade_batch_tokens
=
[]
hijack_comments
=
[]
hijack_fixes
=
[]
token_count
=
0
cache
=
{}
batch_tokens
=
self
.
tokenize
(
texts
)
batch_multipliers
=
[]
for
tokens
in
batch_tokens
:
tuple_tokens
=
tuple
(
tokens
)
if
tuple_tokens
in
cache
:
remade_tokens
,
fixes
,
multipliers
=
cache
[
tuple_tokens
]
else
:
fixes
=
[]
remade_tokens
=
[]
multipliers
=
[]
mult
=
1.0
i
=
0
while
i
<
len
(
tokens
):
token
=
tokens
[
i
]
embedding
,
embedding_length_in_tokens
=
self
.
hijack
.
embedding_db
.
find_embedding_at_position
(
tokens
,
i
)
mult_change
=
self
.
token_mults
.
get
(
token
)
if
shared
.
opts
.
enable_emphasis
else
None
if
mult_change
is
not
None
:
mult
*=
mult_change
i
+=
1
elif
embedding
is
None
:
remade_tokens
.
append
(
token
)
multipliers
.
append
(
mult
)
i
+=
1
else
:
emb_len
=
int
(
embedding
.
vec
.
shape
[
0
])
fixes
.
append
((
len
(
remade_tokens
),
embedding
))
remade_tokens
+=
[
0
]
*
emb_len
multipliers
+=
[
mult
]
*
emb_len
used_custom_terms
.
append
((
embedding
.
name
,
embedding
.
checksum
()))
i
+=
embedding_length_in_tokens
if
len
(
remade_tokens
)
>
maxlen
-
2
:
vocab
=
{
v
:
k
for
k
,
v
in
self
.
wrapped
.
tokenizer
.
get_vocab
()
.
items
()}
ovf
=
remade_tokens
[
maxlen
-
2
:]
overflowing_words
=
[
vocab
.
get
(
int
(
x
),
""
)
for
x
in
ovf
]
overflowing_text
=
self
.
wrapped
.
tokenizer
.
convert_tokens_to_string
(
''
.
join
(
overflowing_words
))
hijack_comments
.
append
(
f
"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:
\n
{overflowing_text}
\n
"
)
token_count
=
len
(
remade_tokens
)
remade_tokens
=
remade_tokens
+
[
id_end
]
*
(
maxlen
-
2
-
len
(
remade_tokens
))
remade_tokens
=
[
id_start
]
+
remade_tokens
[
0
:
maxlen
-
2
]
+
[
id_end
]
cache
[
tuple_tokens
]
=
(
remade_tokens
,
fixes
,
multipliers
)
multipliers
=
multipliers
+
[
1.0
]
*
(
maxlen
-
2
-
len
(
multipliers
))
multipliers
=
[
1.0
]
+
multipliers
[
0
:
maxlen
-
2
]
+
[
1.0
]
remade_batch_tokens
.
append
(
remade_tokens
)
hijack_fixes
.
append
(
fixes
)
batch_multipliers
.
append
(
multipliers
)
return
batch_multipliers
,
remade_batch_tokens
,
used_custom_terms
,
hijack_comments
,
hijack_fixes
,
token_count
def
forward_old
(
self
:
sd_hijack_clip
.
FrozenCLIPEmbedderWithCustomWordsBase
,
texts
):
batch_multipliers
,
remade_batch_tokens
,
used_custom_terms
,
hijack_comments
,
hijack_fixes
,
token_count
=
process_text_old
(
self
,
texts
)
self
.
hijack
.
comments
+=
hijack_comments
if
len
(
used_custom_terms
)
>
0
:
self
.
hijack
.
comments
.
append
(
"Used embeddings: "
+
", "
.
join
([
f
'{word} [{checksum}]'
for
word
,
checksum
in
used_custom_terms
]))
self
.
hijack
.
fixes
=
hijack_fixes
return
self
.
process_tokens
(
remade_batch_tokens
,
batch_multipliers
)
modules/textual_inversion/textual_inversion.py
View file @
79e39fae
...
@@ -79,7 +79,6 @@ class EmbeddingDatabase:
...
@@ -79,7 +79,6 @@ class EmbeddingDatabase:
self
.
word_embeddings
[
embedding
.
name
]
=
embedding
self
.
word_embeddings
[
embedding
.
name
]
=
embedding
# TODO changing between clip and open clip changes tokenization, which will cause embeddings to stop working
ids
=
model
.
cond_stage_model
.
tokenize
([
embedding
.
name
])[
0
]
ids
=
model
.
cond_stage_model
.
tokenize
([
embedding
.
name
])[
0
]
first_id
=
ids
[
0
]
first_id
=
ids
[
0
]
...
...
modules/ui.py
View file @
79e39fae
...
@@ -368,7 +368,7 @@ def update_token_counter(text, steps):
...
@@ -368,7 +368,7 @@ def update_token_counter(text, steps):
flat_prompts
=
reduce
(
lambda
list1
,
list2
:
list1
+
list2
,
prompt_schedules
)
flat_prompts
=
reduce
(
lambda
list1
,
list2
:
list1
+
list2
,
prompt_schedules
)
prompts
=
[
prompt_text
for
step
,
prompt_text
in
flat_prompts
]
prompts
=
[
prompt_text
for
step
,
prompt_text
in
flat_prompts
]
token
s
,
token_count
,
max_length
=
max
([
model_hijack
.
tokenize
(
prompt
)
for
prompt
in
prompts
],
key
=
lambda
args
:
args
[
1
])
token
_count
,
max_length
=
max
([
model_hijack
.
get_prompt_lengths
(
prompt
)
for
prompt
in
prompts
],
key
=
lambda
args
:
args
[
0
])
style_class
=
' class="red"'
if
(
token_count
>
max_length
)
else
""
style_class
=
' class="red"'
if
(
token_count
>
max_length
)
else
""
return
f
"<span {style_class}>{token_count}/{max_length}</span>"
return
f
"<span {style_class}>{token_count}/{max_length}</span>"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment