Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
12f4f476
Unverified
Commit
12f4f476
authored
Oct 11, 2022
by
AUTOMATIC1111
Committed by
GitHub
Oct 11, 2022
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #1795 from MarkovInequality/learnschedule
Added learning_rate scheduling for TI
parents
d7474a51
419e539f
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
51 additions
and
5 deletions
+51
-5
textual_inversion.py
modules/textual_inversion/textual_inversion.py
+50
-4
ui.py
modules/ui.py
+1
-1
No files found.
modules/textual_inversion/textual_inversion.py
View file @
12f4f476
...
...
@@ -189,8 +189,6 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
embedding
=
hijack
.
embedding_db
.
word_embeddings
[
embedding_name
]
embedding
.
vec
.
requires_grad
=
True
optimizer
=
torch
.
optim
.
AdamW
([
embedding
.
vec
],
lr
=
learn_rate
)
losses
=
torch
.
zeros
((
32
,))
last_saved_file
=
"<none>"
...
...
@@ -200,12 +198,27 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if
ititial_step
>
steps
:
return
embedding
,
filename
tr_img_len
=
len
([
os
.
path
.
join
(
data_root
,
file_path
)
for
file_path
in
os
.
listdir
(
data_root
)])
epoch_len
=
(
tr_img_len
*
num_repeats
)
+
tr_img_len
scheduleIter
=
iter
(
LearnSchedule
(
learn_rate
,
steps
,
ititial_step
))
(
learn_rate
,
end_step
)
=
next
(
scheduleIter
)
print
(
f
'Training at rate of {learn_rate} until step {end_step}'
)
optimizer
=
torch
.
optim
.
AdamW
([
embedding
.
vec
],
lr
=
learn_rate
)
pbar
=
tqdm
.
tqdm
(
enumerate
(
ds
),
total
=
steps
-
ititial_step
)
for
i
,
(
x
,
text
,
_
)
in
pbar
:
embedding
.
step
=
i
+
ititial_step
if
embedding
.
step
>
steps
:
break
if
embedding
.
step
>
end_step
:
try
:
(
learn_rate
,
end_step
)
=
next
(
scheduleIter
)
except
:
break
tqdm
.
tqdm
.
write
(
f
'Training at rate of {learn_rate} until step {end_step}'
)
for
pg
in
optimizer
.
param_groups
:
pg
[
'lr'
]
=
learn_rate
if
shared
.
state
.
interrupted
:
break
...
...
@@ -276,3 +289,36 @@ Last saved image: {html.escape(last_saved_image)}<br/>
return
embedding
,
filename
class
LearnSchedule
:
def
__init__
(
self
,
learn_rate
,
max_steps
,
cur_step
=
0
):
pairs
=
learn_rate
.
split
(
','
)
self
.
rates
=
[]
self
.
it
=
0
self
.
maxit
=
0
for
i
,
pair
in
enumerate
(
pairs
):
tmp
=
pair
.
split
(
':'
)
if
len
(
tmp
)
==
2
:
step
=
int
(
tmp
[
1
])
if
step
>
cur_step
:
self
.
rates
.
append
((
float
(
tmp
[
0
]),
min
(
step
,
max_steps
)))
self
.
maxit
+=
1
if
step
>
max_steps
:
return
elif
step
==
-
1
:
self
.
rates
.
append
((
float
(
tmp
[
0
]),
max_steps
))
self
.
maxit
+=
1
return
else
:
self
.
rates
.
append
((
float
(
tmp
[
0
]),
max_steps
))
self
.
maxit
+=
1
return
def
__iter__
(
self
):
return
self
def
__next__
(
self
):
if
self
.
it
<
self
.
maxit
:
self
.
it
+=
1
return
self
.
rates
[
self
.
it
-
1
]
else
:
raise
StopIteration
modules/ui.py
View file @
12f4f476
...
...
@@ -1070,7 +1070,7 @@ def create_ui(wrap_gradio_gpu_call):
gr
.
HTML
(
value
=
"<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>"
)
train_embedding_name
=
gr
.
Dropdown
(
label
=
'Embedding'
,
choices
=
sorted
(
sd_hijack
.
model_hijack
.
embedding_db
.
word_embeddings
.
keys
()))
train_hypernetwork_name
=
gr
.
Dropdown
(
label
=
'Hypernetwork'
,
choices
=
[
x
for
x
in
shared
.
hypernetworks
.
keys
()])
learn_rate
=
gr
.
Number
(
label
=
'Learning rate'
,
value
=
5.0e-03
)
learn_rate
=
gr
.
Textbox
(
label
=
'Learning rate'
,
placeholder
=
"Learning rate"
,
value
=
"5.0e-03"
)
dataset_directory
=
gr
.
Textbox
(
label
=
'Dataset directory'
,
placeholder
=
"Path to directory with input images"
)
log_directory
=
gr
.
Textbox
(
label
=
'Log directory'
,
placeholder
=
"Path to directory where to write outputs"
,
value
=
"textual_inversion"
)
template_file
=
gr
.
Textbox
(
label
=
'Prompt template file'
,
value
=
os
.
path
.
join
(
script_path
,
"textual_inversion_templates"
,
"style_filewords.txt"
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment