Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
17a2076f
Unverified
Commit
17a2076f
authored
Oct 30, 2022
by
AUTOMATIC1111
Committed by
GitHub
Oct 30, 2022
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #3928 from R-N/validate-before-load
Optimize training a little
parents
3dc9a43f
3d58510f
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
107 additions
and
51 deletions
+107
-51
hypernetwork.py
modules/hypernetworks/hypernetwork.py
+43
-26
dataset.py
modules/textual_inversion/dataset.py
+2
-0
textual_inversion.py
modules/textual_inversion/textual_inversion.py
+62
-25
No files found.
modules/hypernetworks/hypernetwork.py
View file @
17a2076f
...
...
@@ -335,7 +335,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from
modules
import
images
assert
hypernetwork_name
,
'hypernetwork not selected'
save_hypernetwork_every
=
save_hypernetwork_every
or
0
create_image_every
=
create_image_every
or
0
textual_inversion
.
validate_train_inputs
(
hypernetwork_name
,
learn_rate
,
batch_size
,
data_root
,
template_file
,
steps
,
save_hypernetwork_every
,
create_image_every
,
log_directory
,
name
=
"hypernetwork"
)
path
=
shared
.
hypernetworks
.
get
(
hypernetwork_name
,
None
)
shared
.
loaded_hypernetwork
=
Hypernetwork
()
...
...
@@ -361,39 +363,44 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
else
:
images_dir
=
None
hypernetwork
=
shared
.
loaded_hypernetwork
checkpoint
=
sd_models
.
select_checkpoint
()
ititial_step
=
hypernetwork
.
step
or
0
if
ititial_step
>=
steps
:
shared
.
state
.
textinfo
=
f
"Model has already been trained beyond specified max steps"
return
hypernetwork
,
filename
scheduler
=
LearnRateScheduler
(
learn_rate
,
steps
,
ititial_step
)
# dataset loading may take a while, so input validations and early returns should be done before this
shared
.
state
.
textinfo
=
f
"Preparing dataset from {html.escape(data_root)}..."
with
torch
.
autocast
(
"cuda"
):
ds
=
modules
.
textual_inversion
.
dataset
.
PersonalizedBase
(
data_root
=
data_root
,
width
=
training_width
,
height
=
training_height
,
repeats
=
shared
.
opts
.
training_image_repeats_per_epoch
,
placeholder_token
=
hypernetwork_name
,
model
=
shared
.
sd_model
,
device
=
devices
.
device
,
template_file
=
template_file
,
include_cond
=
True
,
batch_size
=
batch_size
)
if
unload
:
shared
.
sd_model
.
cond_stage_model
.
to
(
devices
.
cpu
)
shared
.
sd_model
.
first_stage_model
.
to
(
devices
.
cpu
)
hypernetwork
=
shared
.
loaded_hypernetwork
weights
=
hypernetwork
.
weights
()
for
weight
in
weights
:
weight
.
requires_grad
=
True
size
=
len
(
ds
.
indexes
)
loss_dict
=
defaultdict
(
lambda
:
deque
(
maxlen
=
1024
))
losses
=
torch
.
zeros
((
size
,))
previous_mean_losses
=
[
0
]
previous_mean_loss
=
0
print
(
"Mean loss of {} elements"
.
format
(
size
))
last_saved_file
=
"<none>"
last_saved_image
=
"<none>"
forced_filename
=
"<none>"
ititial_step
=
hypernetwork
.
step
or
0
if
ititial_step
>
steps
:
return
hypernetwork
,
filename
scheduler
=
LearnRateScheduler
(
learn_rate
,
steps
,
ititial_step
)
weights
=
hypernetwork
.
weights
()
for
weight
in
weights
:
weight
.
requires_grad
=
True
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
optimizer
=
torch
.
optim
.
AdamW
(
weights
,
lr
=
scheduler
.
learn_rate
)
steps_without_grad
=
0
last_saved_file
=
"<none>"
last_saved_image
=
"<none>"
forced_filename
=
"<none>"
pbar
=
tqdm
.
tqdm
(
enumerate
(
ds
),
total
=
steps
-
ititial_step
)
for
i
,
entries
in
pbar
:
hypernetwork
.
step
=
i
+
ititial_step
...
...
@@ -446,9 +453,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
if
hypernetwork_dir
is
not
None
and
steps_done
%
save_hypernetwork_every
==
0
:
# Before saving, change name to match current checkpoint.
hypernetwork
.
name
=
f
'{hypernetwork_name}-{steps_done}'
last_saved_file
=
os
.
path
.
join
(
hypernetwork_dir
,
f
'{hypernetwork
.name
}.pt'
)
hypernetwork
.
save
(
last_saved_file
)
hypernetwork
_name_every
=
f
'{hypernetwork_name}-{steps_done}'
last_saved_file
=
os
.
path
.
join
(
hypernetwork_dir
,
f
'{hypernetwork
_name_every
}.pt'
)
save_hypernetwork
(
hypernetwork
,
checkpoint
,
hypernetwork_name
,
last_saved_file
)
textual_inversion
.
write_loss
(
log_directory
,
"hypernetwork_loss.csv"
,
hypernetwork
.
step
,
len
(
ds
),
{
"loss"
:
f
"{previous_mean_loss:.7f}"
,
...
...
@@ -509,13 +516,23 @@ Last saved image: {html.escape(last_saved_image)}<br/>
"""
report_statistics
(
loss_dict
)
checkpoint
=
sd_models
.
select_checkpoint
()
hypernetwork
.
sd_checkpoint
=
checkpoint
.
hash
hypernetwork
.
sd_checkpoint_name
=
checkpoint
.
model_name
# Before saving for the last time, change name back to the base name (as opposed to the save_hypernetwork_every step-suffixed naming convention).
hypernetwork
.
name
=
hypernetwork_name
filename
=
os
.
path
.
join
(
shared
.
cmd_opts
.
hypernetwork_dir
,
f
'{hypernetwork.name}.pt'
)
hypernetwork
.
save
(
filename
)
filename
=
os
.
path
.
join
(
shared
.
cmd_opts
.
hypernetwork_dir
,
f
'{hypernetwork_name}.pt'
)
save_hypernetwork
(
hypernetwork
,
checkpoint
,
hypernetwork_name
,
filename
)
return
hypernetwork
,
filename
def
save_hypernetwork
(
hypernetwork
,
checkpoint
,
hypernetwork_name
,
filename
):
old_hypernetwork_name
=
hypernetwork
.
name
old_sd_checkpoint
=
hypernetwork
.
sd_checkpoint
if
hasattr
(
hypernetwork
,
"sd_checkpoint"
)
else
None
old_sd_checkpoint_name
=
hypernetwork
.
sd_checkpoint_name
if
hasattr
(
hypernetwork
,
"sd_checkpoint_name"
)
else
None
try
:
hypernetwork
.
sd_checkpoint
=
checkpoint
.
hash
hypernetwork
.
sd_checkpoint_name
=
checkpoint
.
model_name
hypernetwork
.
name
=
hypernetwork_name
hypernetwork
.
save
(
filename
)
except
:
hypernetwork
.
sd_checkpoint
=
old_sd_checkpoint
hypernetwork
.
sd_checkpoint_name
=
old_sd_checkpoint_name
hypernetwork
.
name
=
old_hypernetwork_name
raise
modules/textual_inversion/dataset.py
View file @
17a2076f
...
...
@@ -42,6 +42,8 @@ class PersonalizedBase(Dataset):
self
.
lines
=
lines
assert
data_root
,
'dataset directory not specified'
assert
os
.
path
.
isdir
(
data_root
),
"Dataset directory doesn't exist"
assert
os
.
listdir
(
data_root
),
"Dataset directory is empty"
cond_model
=
shared
.
sd_model
.
cond_stage_model
...
...
modules/textual_inversion/textual_inversion.py
View file @
17a2076f
...
...
@@ -119,7 +119,7 @@ class EmbeddingDatabase:
vec
=
emb
.
detach
()
.
to
(
devices
.
device
,
dtype
=
torch
.
float32
)
embedding
=
Embedding
(
vec
,
name
)
embedding
.
step
=
data
.
get
(
'step'
,
None
)
embedding
.
sd_checkpoint
=
data
.
get
(
'
hash
'
,
None
)
embedding
.
sd_checkpoint
=
data
.
get
(
'
sd_checkpoint
'
,
None
)
embedding
.
sd_checkpoint_name
=
data
.
get
(
'sd_checkpoint_name'
,
None
)
self
.
register_embedding
(
embedding
,
shared
.
sd_model
)
...
...
@@ -204,9 +204,30 @@ def write_loss(log_directory, filename, step, epoch_len, values):
**
values
,
})
def
validate_train_inputs
(
model_name
,
learn_rate
,
batch_size
,
data_root
,
template_file
,
steps
,
save_model_every
,
create_image_every
,
log_directory
,
name
=
"embedding"
):
assert
model_name
,
f
"{name} not selected"
assert
learn_rate
,
"Learning rate is empty or 0"
assert
isinstance
(
batch_size
,
int
),
"Batch size must be integer"
assert
batch_size
>
0
,
"Batch size must be positive"
assert
data_root
,
"Dataset directory is empty"
assert
os
.
path
.
isdir
(
data_root
),
"Dataset directory doesn't exist"
assert
os
.
listdir
(
data_root
),
"Dataset directory is empty"
assert
template_file
,
"Prompt template file is empty"
assert
os
.
path
.
isfile
(
template_file
),
"Prompt template file doesn't exist"
assert
steps
,
"Max steps is empty or 0"
assert
isinstance
(
steps
,
int
),
"Max steps must be integer"
assert
steps
>
0
,
"Max steps must be positive"
assert
isinstance
(
save_model_every
,
int
),
"Save {name} must be integer"
assert
save_model_every
>=
0
,
"Save {name} must be positive or 0"
assert
isinstance
(
create_image_every
,
int
),
"Create image must be integer"
assert
create_image_every
>=
0
,
"Create image must be positive or 0"
if
save_model_every
or
create_image_every
:
assert
log_directory
,
"Log directory is empty"
def
train_embedding
(
embedding_name
,
learn_rate
,
batch_size
,
data_root
,
log_directory
,
training_width
,
training_height
,
steps
,
create_image_every
,
save_embedding_every
,
template_file
,
save_image_with_stored_embedding
,
preview_from_txt2img
,
preview_prompt
,
preview_negative_prompt
,
preview_steps
,
preview_sampler_index
,
preview_cfg_scale
,
preview_seed
,
preview_width
,
preview_height
):
assert
embedding_name
,
'embedding not selected'
save_embedding_every
=
save_embedding_every
or
0
create_image_every
=
create_image_every
or
0
validate_train_inputs
(
embedding_name
,
learn_rate
,
batch_size
,
data_root
,
template_file
,
steps
,
save_embedding_every
,
create_image_every
,
log_directory
,
name
=
"embedding"
)
shared
.
state
.
textinfo
=
"Initializing textual inversion training..."
shared
.
state
.
job_count
=
steps
...
...
@@ -232,17 +253,28 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
os
.
makedirs
(
images_embeds_dir
,
exist_ok
=
True
)
else
:
images_embeds_dir
=
None
cond_model
=
shared
.
sd_model
.
cond_stage_model
hijack
=
sd_hijack
.
model_hijack
embedding
=
hijack
.
embedding_db
.
word_embeddings
[
embedding_name
]
checkpoint
=
sd_models
.
select_checkpoint
()
ititial_step
=
embedding
.
step
or
0
if
ititial_step
>=
steps
:
shared
.
state
.
textinfo
=
f
"Model has already been trained beyond specified max steps"
return
embedding
,
filename
scheduler
=
LearnRateScheduler
(
learn_rate
,
steps
,
ititial_step
)
# dataset loading may take a while, so input validations and early returns should be done before this
shared
.
state
.
textinfo
=
f
"Preparing dataset from {html.escape(data_root)}..."
with
torch
.
autocast
(
"cuda"
):
ds
=
modules
.
textual_inversion
.
dataset
.
PersonalizedBase
(
data_root
=
data_root
,
width
=
training_width
,
height
=
training_height
,
repeats
=
shared
.
opts
.
training_image_repeats_per_epoch
,
placeholder_token
=
embedding_name
,
model
=
shared
.
sd_model
,
device
=
devices
.
device
,
template_file
=
template_file
,
batch_size
=
batch_size
)
hijack
=
sd_hijack
.
model_hijack
embedding
=
hijack
.
embedding_db
.
word_embeddings
[
embedding_name
]
embedding
.
vec
.
requires_grad
=
True
optimizer
=
torch
.
optim
.
AdamW
([
embedding
.
vec
],
lr
=
scheduler
.
learn_rate
)
losses
=
torch
.
zeros
((
32
,))
...
...
@@ -251,13 +283,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
forced_filename
=
"<none>"
embedding_yet_to_be_embedded
=
False
ititial_step
=
embedding
.
step
or
0
if
ititial_step
>
steps
:
return
embedding
,
filename
scheduler
=
LearnRateScheduler
(
learn_rate
,
steps
,
ititial_step
)
optimizer
=
torch
.
optim
.
AdamW
([
embedding
.
vec
],
lr
=
scheduler
.
learn_rate
)
pbar
=
tqdm
.
tqdm
(
enumerate
(
ds
),
total
=
steps
-
ititial_step
)
for
i
,
entries
in
pbar
:
embedding
.
step
=
i
+
ititial_step
...
...
@@ -290,9 +315,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
if
embedding_dir
is
not
None
and
steps_done
%
save_embedding_every
==
0
:
# Before saving, change name to match current checkpoint.
embedding
.
name
=
f
'{embedding_name}-{steps_done}'
last_saved_file
=
os
.
path
.
join
(
embedding_dir
,
f
'{embedding
.name
}.pt'
)
embedding
.
save
(
last_saved_fil
e
)
embedding
_name_every
=
f
'{embedding_name}-{steps_done}'
last_saved_file
=
os
.
path
.
join
(
embedding_dir
,
f
'{embedding
_name_every
}.pt'
)
save_embedding
(
embedding
,
checkpoint
,
embedding_name_every
,
last_saved_file
,
remove_cached_checksum
=
Tru
e
)
embedding_yet_to_be_embedded
=
True
write_loss
(
log_directory
,
"textual_inversion_loss.csv"
,
embedding
.
step
,
len
(
ds
),
{
...
...
@@ -373,14 +398,26 @@ Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
checkpoint
=
sd_models
.
select_checkpoint
()
embedding
.
sd_checkpoint
=
checkpoint
.
hash
embedding
.
sd_checkpoint_name
=
checkpoint
.
model_name
embedding
.
cached_checksum
=
None
# Before saving for the last time, change name back to base name (as opposed to the save_embedding_every step-suffixed naming convention).
embedding
.
name
=
embedding_name
filename
=
os
.
path
.
join
(
shared
.
cmd_opts
.
embeddings_dir
,
f
'{embedding.name}.pt'
)
embedding
.
save
(
filename
)
filename
=
os
.
path
.
join
(
shared
.
cmd_opts
.
embeddings_dir
,
f
'{embedding_name}.pt'
)
save_embedding
(
embedding
,
checkpoint
,
embedding_name
,
filename
,
remove_cached_checksum
=
True
)
return
embedding
,
filename
def
save_embedding
(
embedding
,
checkpoint
,
embedding_name
,
filename
,
remove_cached_checksum
=
True
):
old_embedding_name
=
embedding
.
name
old_sd_checkpoint
=
embedding
.
sd_checkpoint
if
hasattr
(
embedding
,
"sd_checkpoint"
)
else
None
old_sd_checkpoint_name
=
embedding
.
sd_checkpoint_name
if
hasattr
(
embedding
,
"sd_checkpoint_name"
)
else
None
old_cached_checksum
=
embedding
.
cached_checksum
if
hasattr
(
embedding
,
"cached_checksum"
)
else
None
try
:
embedding
.
sd_checkpoint
=
checkpoint
.
hash
embedding
.
sd_checkpoint_name
=
checkpoint
.
model_name
if
remove_cached_checksum
:
embedding
.
cached_checksum
=
None
embedding
.
name
=
embedding_name
embedding
.
save
(
filename
)
except
:
embedding
.
sd_checkpoint
=
old_sd_checkpoint
embedding
.
sd_checkpoint_name
=
old_sd_checkpoint_name
embedding
.
name
=
old_embedding_name
embedding
.
cached_checksum
=
old_cached_checksum
raise
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment