Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
c1512ef9
Unverified
Commit
c1512ef9
authored
Dec 25, 2022
by
AUTOMATIC1111
Committed by
GitHub
Dec 25, 2022
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #5999 from vladmandic/trainapi
implement train api
parents
8eef9d8e
5f1dfbbc
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
132 additions
and
28 deletions
+132
-28
api.py
modules/api/api.py
+93
-1
models.py
modules/api/models.py
+9
-0
hypernetwork.py
modules/hypernetworks/hypernetwork.py
+26
-0
ui.py
modules/hypernetworks/ui.py
+4
-27
No files found.
modules/api/api.py
View file @
c1512ef9
...
...
@@ -10,13 +10,17 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials
from
secrets
import
compare_digest
import
modules.shared
as
shared
from
modules
import
sd_samplers
,
deepbooru
from
modules
import
sd_samplers
,
deepbooru
,
sd_hijack
from
modules.api.models
import
*
from
modules.processing
import
StableDiffusionProcessingTxt2Img
,
StableDiffusionProcessingImg2Img
,
process_images
from
modules.extras
import
run_extras
,
run_pnginfo
from
modules.textual_inversion.textual_inversion
import
create_embedding
,
train_embedding
from
modules.textual_inversion.preprocess
import
preprocess
from
modules.hypernetworks.hypernetwork
import
create_hypernetwork
,
train_hypernetwork
from
PIL
import
PngImagePlugin
,
Image
from
modules.sd_models
import
checkpoints_list
from
modules.realesrgan_model
import
get_realesrgan_models
from
modules
import
devices
from
typing
import
List
def
upscaler_to_index
(
name
:
str
):
...
...
@@ -97,6 +101,11 @@ class Api:
self
.
add_api_route
(
"/sdapi/v1/artist-categories"
,
self
.
get_artists_categories
,
methods
=
[
"GET"
],
response_model
=
List
[
str
])
self
.
add_api_route
(
"/sdapi/v1/artists"
,
self
.
get_artists
,
methods
=
[
"GET"
],
response_model
=
List
[
ArtistItem
])
self
.
add_api_route
(
"/sdapi/v1/refresh-checkpoints"
,
self
.
refresh_checkpoints
,
methods
=
[
"POST"
])
self
.
add_api_route
(
"/sdapi/v1/create/embedding"
,
self
.
create_embedding
,
methods
=
[
"POST"
],
response_model
=
CreateResponse
)
self
.
add_api_route
(
"/sdapi/v1/create/hypernetwork"
,
self
.
create_hypernetwork
,
methods
=
[
"POST"
],
response_model
=
CreateResponse
)
self
.
add_api_route
(
"/sdapi/v1/preprocess"
,
self
.
preprocess
,
methods
=
[
"POST"
],
response_model
=
PreprocessResponse
)
self
.
add_api_route
(
"/sdapi/v1/train/embedding"
,
self
.
train_embedding
,
methods
=
[
"POST"
],
response_model
=
TrainResponse
)
self
.
add_api_route
(
"/sdapi/v1/train/hypernetwork"
,
self
.
train_hypernetwork
,
methods
=
[
"POST"
],
response_model
=
TrainResponse
)
def
add_api_route
(
self
,
path
:
str
,
endpoint
,
**
kwargs
):
if
shared
.
cmd_opts
.
api_auth
:
...
...
@@ -326,6 +335,89 @@ class Api:
def
refresh_checkpoints
(
self
):
shared
.
refresh_checkpoints
()
def
create_embedding
(
self
,
args
:
dict
):
try
:
shared
.
state
.
begin
()
filename
=
create_embedding
(
**
args
)
# create empty embedding
sd_hijack
.
model_hijack
.
embedding_db
.
load_textual_inversion_embeddings
()
# reload embeddings so new one can be immediately used
shared
.
state
.
end
()
return
CreateResponse
(
info
=
"create embedding filename: {filename}"
.
format
(
filename
=
filename
))
except
AssertionError
as
e
:
shared
.
state
.
end
()
return
TrainResponse
(
info
=
"create embedding error: {error}"
.
format
(
error
=
e
))
def
create_hypernetwork
(
self
,
args
:
dict
):
try
:
shared
.
state
.
begin
()
filename
=
create_hypernetwork
(
**
args
)
# create empty embedding
shared
.
state
.
end
()
return
CreateResponse
(
info
=
"create hypernetwork filename: {filename}"
.
format
(
filename
=
filename
))
except
AssertionError
as
e
:
shared
.
state
.
end
()
return
TrainResponse
(
info
=
"create hypernetwork error: {error}"
.
format
(
error
=
e
))
def
preprocess
(
self
,
args
:
dict
):
try
:
shared
.
state
.
begin
()
preprocess
(
**
args
)
# quick operation unless blip/booru interrogation is enabled
shared
.
state
.
end
()
return
PreprocessResponse
(
info
=
'preprocess complete'
)
except
KeyError
as
e
:
shared
.
state
.
end
()
return
PreprocessResponse
(
info
=
"preprocess error: invalid token: {error}"
.
format
(
error
=
e
))
except
AssertionError
as
e
:
shared
.
state
.
end
()
return
PreprocessResponse
(
info
=
"preprocess error: {error}"
.
format
(
error
=
e
))
except
FileNotFoundError
as
e
:
shared
.
state
.
end
()
return
PreprocessResponse
(
info
=
'preprocess error: {error}'
.
format
(
error
=
e
))
def
train_embedding
(
self
,
args
:
dict
):
try
:
shared
.
state
.
begin
()
apply_optimizations
=
shared
.
opts
.
training_xattention_optimizations
error
=
None
filename
=
''
if
not
apply_optimizations
:
sd_hijack
.
undo_optimizations
()
try
:
embedding
,
filename
=
train_embedding
(
**
args
)
# can take a long time to complete
except
Exception
as
e
:
error
=
e
finally
:
if
not
apply_optimizations
:
sd_hijack
.
apply_optimizations
()
shared
.
state
.
end
()
return
TrainResponse
(
info
=
"train embedding complete: filename: {filename} error: {error}"
.
format
(
filename
=
filename
,
error
=
error
))
except
AssertionError
as
msg
:
shared
.
state
.
end
()
return
TrainResponse
(
info
=
"train embedding error: {msg}"
.
format
(
msg
=
msg
))
def
train_hypernetwork
(
self
,
args
:
dict
):
try
:
shared
.
state
.
begin
()
initial_hypernetwork
=
shared
.
loaded_hypernetwork
apply_optimizations
=
shared
.
opts
.
training_xattention_optimizations
error
=
None
filename
=
''
if
not
apply_optimizations
:
sd_hijack
.
undo_optimizations
()
try
:
hypernetwork
,
filename
=
train_hypernetwork
(
*
args
)
except
Exception
as
e
:
error
=
e
finally
:
shared
.
loaded_hypernetwork
=
initial_hypernetwork
shared
.
sd_model
.
cond_stage_model
.
to
(
devices
.
device
)
shared
.
sd_model
.
first_stage_model
.
to
(
devices
.
device
)
if
not
apply_optimizations
:
sd_hijack
.
apply_optimizations
()
shared
.
state
.
end
()
return
TrainResponse
(
info
=
"train embedding complete: filename: {filename} error: {error}"
.
format
(
filename
=
filename
,
error
=
error
))
except
AssertionError
as
msg
:
shared
.
state
.
end
()
return
TrainResponse
(
info
=
"train embedding error: {error}"
.
format
(
error
=
error
))
def
launch
(
self
,
server_name
,
port
):
self
.
app
.
include_router
(
self
.
router
)
uvicorn
.
run
(
self
.
app
,
host
=
server_name
,
port
=
port
)
modules/api/models.py
View file @
c1512ef9
...
...
@@ -175,6 +175,15 @@ class InterrogateRequest(BaseModel):
class
InterrogateResponse
(
BaseModel
):
caption
:
str
=
Field
(
default
=
None
,
title
=
"Caption"
,
description
=
"The generated caption for the image."
)
class
TrainResponse
(
BaseModel
):
info
:
str
=
Field
(
title
=
"Train info"
,
description
=
"Response string from train embedding or hypernetwork task."
)
class
CreateResponse
(
BaseModel
):
info
:
str
=
Field
(
title
=
"Create info"
,
description
=
"Response string from create embedding or hypernetwork task."
)
class
PreprocessResponse
(
BaseModel
):
info
:
str
=
Field
(
title
=
"Preprocess info"
,
description
=
"Response string from preprocessing task."
)
fields
=
{}
for
key
,
metadata
in
opts
.
data_labels
.
items
():
value
=
opts
.
data
.
get
(
key
)
...
...
modules/hypernetworks/hypernetwork.py
View file @
c1512ef9
...
...
@@ -378,6 +378,32 @@ def report_statistics(loss_info:dict):
print
(
e
)
def
create_hypernetwork
(
name
,
enable_sizes
,
overwrite_old
,
layer_structure
=
None
,
activation_func
=
None
,
weight_init
=
None
,
add_layer_norm
=
False
,
use_dropout
=
False
):
# Remove illegal characters from name.
name
=
""
.
join
(
x
for
x
in
name
if
(
x
.
isalnum
()
or
x
in
"._- "
))
fn
=
os
.
path
.
join
(
shared
.
cmd_opts
.
hypernetwork_dir
,
f
"{name}.pt"
)
if
not
overwrite_old
:
assert
not
os
.
path
.
exists
(
fn
),
f
"file {fn} already exists"
if
type
(
layer_structure
)
==
str
:
layer_structure
=
[
float
(
x
.
strip
())
for
x
in
layer_structure
.
split
(
","
)]
hypernet
=
modules
.
hypernetworks
.
hypernetwork
.
Hypernetwork
(
name
=
name
,
enable_sizes
=
[
int
(
x
)
for
x
in
enable_sizes
],
layer_structure
=
layer_structure
,
activation_func
=
activation_func
,
weight_init
=
weight_init
,
add_layer_norm
=
add_layer_norm
,
use_dropout
=
use_dropout
,
)
hypernet
.
save
(
fn
)
shared
.
reload_hypernetworks
()
return
fn
def
train_hypernetwork
(
hypernetwork_name
,
learn_rate
,
batch_size
,
gradient_step
,
data_root
,
log_directory
,
training_width
,
training_height
,
steps
,
shuffle_tags
,
tag_drop_out
,
latent_sampling_method
,
create_image_every
,
save_hypernetwork_every
,
template_file
,
preview_from_txt2img
,
preview_prompt
,
preview_negative_prompt
,
preview_steps
,
preview_sampler_index
,
preview_cfg_scale
,
preview_seed
,
preview_width
,
preview_height
):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
...
...
modules/hypernetworks/ui.py
View file @
c1512ef9
...
...
@@ -3,39 +3,16 @@ import os
import
re
import
gradio
as
gr
import
modules.textual_inversion.preprocess
import
modules.textual_inversion.textual_inversion
import
modules.hypernetworks.hypernetwork
from
modules
import
devices
,
sd_hijack
,
shared
from
modules.hypernetworks
import
hypernetwork
not_available
=
[
"hardswish"
,
"multiheadattention"
]
keys
=
list
(
x
for
x
in
hypernetwork
.
HypernetworkModule
.
activation_dict
.
keys
()
if
x
not
in
not_available
)
keys
=
list
(
x
for
x
in
modules
.
hypernetworks
.
hypernetwork
.
HypernetworkModule
.
activation_dict
.
keys
()
if
x
not
in
not_available
)
def
create_hypernetwork
(
name
,
enable_sizes
,
overwrite_old
,
layer_structure
=
None
,
activation_func
=
None
,
weight_init
=
None
,
add_layer_norm
=
False
,
use_dropout
=
False
):
# Remove illegal characters from name.
name
=
""
.
join
(
x
for
x
in
name
if
(
x
.
isalnum
()
or
x
in
"._- "
))
filename
=
modules
.
hypernetworks
.
hypernetwork
.
create_hypernetwork
(
name
,
enable_sizes
,
overwrite_old
,
layer_structure
,
activation_func
,
weight_init
,
add_layer_norm
,
use_dropout
)
fn
=
os
.
path
.
join
(
shared
.
cmd_opts
.
hypernetwork_dir
,
f
"{name}.pt"
)
if
not
overwrite_old
:
assert
not
os
.
path
.
exists
(
fn
),
f
"file {fn} already exists"
if
type
(
layer_structure
)
==
str
:
layer_structure
=
[
float
(
x
.
strip
())
for
x
in
layer_structure
.
split
(
","
)]
hypernet
=
modules
.
hypernetworks
.
hypernetwork
.
Hypernetwork
(
name
=
name
,
enable_sizes
=
[
int
(
x
)
for
x
in
enable_sizes
],
layer_structure
=
layer_structure
,
activation_func
=
activation_func
,
weight_init
=
weight_init
,
add_layer_norm
=
add_layer_norm
,
use_dropout
=
use_dropout
,
)
hypernet
.
save
(
fn
)
shared
.
reload_hypernetworks
()
return
gr
.
Dropdown
.
update
(
choices
=
sorted
([
x
for
x
in
shared
.
hypernetworks
.
keys
()])),
f
"Created: {fn}"
,
""
return
gr
.
Dropdown
.
update
(
choices
=
sorted
([
x
for
x
in
shared
.
hypernetworks
.
keys
()])),
f
"Created: {filename}"
,
""
def
train_hypernetwork
(
*
args
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment