Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
278e7c71
Unverified
Commit
278e7c71
authored
Sep 28, 2022
by
AUTOMATIC1111
Committed by
GitHub
Sep 28, 2022
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #1194 from liamkerr/token_count
Token count
parents
1deac2b6
7ca9858c
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
53 additions
and
9 deletions
+53
-9
ui.js
javascript/ui.js
+19
-0
sd_hijack.py
modules/sd_hijack.py
+21
-8
ui.py
modules/ui.py
+9
-1
style.css
style.css
+4
-0
No files found.
javascript/ui.js
View file @
278e7c71
...
...
@@ -182,4 +182,23 @@ onUiUpdate(function(){
});
json_elem
.
parentElement
.
style
.
display
=
"none"
if
(
!
txt2img_textarea
)
{
txt2img_textarea
=
gradioApp
().
querySelector
(
"#txt2img_prompt > label > textarea"
);
txt2img_textarea
?.
addEventListener
(
"input"
,
()
=>
update_token_counter
(
"txt2img_token_button"
));
}
if
(
!
img2img_textarea
)
{
img2img_textarea
=
gradioApp
().
querySelector
(
"#img2img_prompt > label > textarea"
);
img2img_textarea
?.
addEventListener
(
"input"
,
()
=>
update_token_counter
(
"img2img_token_button"
));
}
})
let
txt2img_textarea
,
img2img_textarea
=
undefined
;
let
wait_time
=
800
let
token_timeout
;
function
update_token_counter
(
button_id
)
{
if
(
token_timeout
)
clearTimeout
(
token_timeout
);
token_timeout
=
setTimeout
(()
=>
gradioApp
().
getElementById
(
button_id
)?.
click
(),
wait_time
);
}
modules/sd_hijack.py
View file @
278e7c71
...
...
@@ -180,6 +180,7 @@ class StableDiffusionModelHijack:
dir_mtime
=
None
layers
=
None
circular_enabled
=
False
clip
=
None
def
load_textual_inversion_embeddings
(
self
,
dirname
,
model
):
mt
=
os
.
path
.
getmtime
(
dirname
)
...
...
@@ -242,6 +243,7 @@ class StableDiffusionModelHijack:
model_embeddings
.
token_embedding
=
EmbeddingsWithFixes
(
model_embeddings
.
token_embedding
,
self
)
m
.
cond_stage_model
=
FrozenCLIPEmbedderWithCustomWords
(
m
.
cond_stage_model
,
self
)
self
.
clip
=
m
.
cond_stage_model
if
cmd_opts
.
opt_split_attention_v1
:
ldm
.
modules
.
attention
.
CrossAttention
.
forward
=
split_cross_attention_forward_v1
...
...
@@ -268,6 +270,10 @@ class StableDiffusionModelHijack:
for
layer
in
[
layer
for
layer
in
self
.
layers
if
type
(
layer
)
==
torch
.
nn
.
Conv2d
]:
layer
.
padding_mode
=
'circular'
if
enable
else
'zeros'
def
tokenize
(
self
,
text
):
max_length
=
self
.
clip
.
max_length
-
2
_
,
remade_batch_tokens
,
_
,
_
,
_
,
token_count
=
self
.
clip
.
process_text
([
text
])
return
remade_batch_tokens
[
0
],
token_count
,
max_length
class
FrozenCLIPEmbedderWithCustomWords
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
wrapped
,
hijack
):
...
...
@@ -294,14 +300,16 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if
mult
!=
1.0
:
self
.
token_mults
[
ident
]
=
mult
def
forward
(
self
,
text
):
self
.
hijack
.
fixes
=
[]
self
.
hijack
.
comments
=
[]
remade_batch_tokens
=
[]
def
process_text
(
self
,
text
):
id_start
=
self
.
wrapped
.
tokenizer
.
bos_token_id
id_end
=
self
.
wrapped
.
tokenizer
.
eos_token_id
maxlen
=
self
.
wrapped
.
max_length
used_custom_terms
=
[]
remade_batch_tokens
=
[]
overflowing_words
=
[]
hijack_comments
=
[]
hijack_fixes
=
[]
token_count
=
0
cache
=
{}
batch_tokens
=
self
.
wrapped
.
tokenizer
(
text
,
truncation
=
False
,
add_special_tokens
=
False
)[
"input_ids"
]
...
...
@@ -353,9 +361,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
ovf
=
remade_tokens
[
maxlen
-
2
:]
overflowing_words
=
[
vocab
.
get
(
int
(
x
),
""
)
for
x
in
ovf
]
overflowing_text
=
self
.
wrapped
.
tokenizer
.
convert_tokens_to_string
(
''
.
join
(
overflowing_words
))
self
.
hijack
.
comments
.
append
(
f
"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:
\n
{overflowing_text}
\n
"
)
hijack_comments
.
append
(
f
"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:
\n
{overflowing_text}
\n
"
)
token_count
=
len
(
remade_tokens
)
remade_tokens
=
remade_tokens
+
[
id_end
]
*
(
maxlen
-
2
-
len
(
remade_tokens
))
remade_tokens
=
[
id_start
]
+
remade_tokens
[
0
:
maxlen
-
2
]
+
[
id_end
]
cache
[
tuple_tokens
]
=
(
remade_tokens
,
fixes
,
multipliers
)
...
...
@@ -364,8 +371,14 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
multipliers
=
[
1.0
]
+
multipliers
[
0
:
maxlen
-
2
]
+
[
1.0
]
remade_batch_tokens
.
append
(
remade_tokens
)
self
.
hijack
.
fixes
.
append
(
fixes
)
hijack_
fixes
.
append
(
fixes
)
batch_multipliers
.
append
(
multipliers
)
return
batch_multipliers
,
remade_batch_tokens
,
used_custom_terms
,
hijack_comments
,
hijack_fixes
,
token_count
def
forward
(
self
,
text
):
batch_multipliers
,
remade_batch_tokens
,
used_custom_terms
,
hijack_comments
,
hijack_fixes
,
token_count
=
self
.
process_text
(
text
)
self
.
hijack
.
fixes
=
hijack_fixes
self
.
hijack
.
comments
=
hijack_comments
if
len
(
used_custom_terms
)
>
0
:
self
.
hijack
.
comments
.
append
(
"Used custom terms: "
+
", "
.
join
([
f
'{word} [{checksum}]'
for
word
,
checksum
in
used_custom_terms
]))
...
...
modules/ui.py
View file @
278e7c71
...
...
@@ -22,6 +22,7 @@ from modules.paths import script_path
from
modules.shared
import
opts
,
cmd_opts
import
modules.shared
as
shared
from
modules.sd_samplers
import
samplers
,
samplers_for_img2img
from
modules.sd_hijack
import
model_hijack
import
modules.ldsr_model
import
modules.scripts
import
modules.gfpgan_model
...
...
@@ -333,6 +334,10 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
outputs
=
[
seed
,
dummy_component
]
)
def
update_token_counter
(
text
):
tokens
,
token_count
,
max_length
=
model_hijack
.
tokenize
(
text
)
style_class
=
' class="red"'
if
(
token_count
>
max_length
)
else
""
return
f
"<span {style_class}>{token_count}/{max_length}</span>"
def
create_toprow
(
is_img2img
):
id_part
=
"img2img"
if
is_img2img
else
"txt2img"
...
...
@@ -342,11 +347,14 @@ def create_toprow(is_img2img):
with
gr
.
Row
():
with
gr
.
Column
(
scale
=
80
):
with
gr
.
Row
():
prompt
=
gr
.
Textbox
(
label
=
"Prompt"
,
elem_id
=
"
prompt"
,
show_label
=
False
,
placeholder
=
"Prompt"
,
lines
=
2
)
prompt
=
gr
.
Textbox
(
label
=
"Prompt"
,
elem_id
=
f
"{id_part}_
prompt"
,
show_label
=
False
,
placeholder
=
"Prompt"
,
lines
=
2
)
with
gr
.
Column
(
scale
=
1
,
elem_id
=
"roll_col"
):
roll
=
gr
.
Button
(
value
=
art_symbol
,
elem_id
=
"roll"
,
visible
=
len
(
shared
.
artist_db
.
artists
)
>
0
)
paste
=
gr
.
Button
(
value
=
paste_symbol
,
elem_id
=
"paste"
)
token_counter
=
gr
.
HTML
(
value
=
"<span></span>"
,
elem_id
=
f
"{id_part}_token_counter"
)
hidden_button
=
gr
.
Button
(
visible
=
False
,
elem_id
=
f
"{id_part}_token_button"
)
hidden_button
.
click
(
fn
=
update_token_counter
,
inputs
=
[
prompt
],
outputs
=
[
token_counter
])
with
gr
.
Column
(
scale
=
10
,
elem_id
=
"style_pos_col"
):
prompt_style
=
gr
.
Dropdown
(
label
=
"Style 1"
,
elem_id
=
f
"{id_part}_style_index"
,
choices
=
[
k
for
k
,
v
in
shared
.
prompt_styles
.
styles
.
items
()],
value
=
next
(
iter
(
shared
.
prompt_styles
.
styles
.
keys
())),
visible
=
len
(
shared
.
prompt_styles
.
styles
)
>
1
)
...
...
style.css
View file @
278e7c71
...
...
@@ -389,3 +389,7 @@ input[type="range"]{
border-radius
:
8px
;
display
:
none
;
}
.red
{
color
:
red
;
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment