Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
5f1dfbbc
Commit
5f1dfbbc
authored
Dec 24, 2022
by
Vladimir Mandic
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
implement train api
parent
c5bdba20
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
132 additions
and
28 deletions
+132
-28
api.py
modules/api/api.py
+93
-1
models.py
modules/api/models.py
+9
-0
hypernetwork.py
modules/hypernetworks/hypernetwork.py
+26
-0
ui.py
modules/hypernetworks/ui.py
+4
-27
No files found.
modules/api/api.py
View file @
5f1dfbbc
...
@@ -10,13 +10,17 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials
...
@@ -10,13 +10,17 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials
from
secrets
import
compare_digest
from
secrets
import
compare_digest
import
modules.shared
as
shared
import
modules.shared
as
shared
from
modules
import
sd_samplers
,
deepbooru
from
modules
import
sd_samplers
,
deepbooru
,
sd_hijack
from
modules.api.models
import
*
from
modules.api.models
import
*
from
modules.processing
import
StableDiffusionProcessingTxt2Img
,
StableDiffusionProcessingImg2Img
,
process_images
from
modules.processing
import
StableDiffusionProcessingTxt2Img
,
StableDiffusionProcessingImg2Img
,
process_images
from
modules.extras
import
run_extras
,
run_pnginfo
from
modules.extras
import
run_extras
,
run_pnginfo
from
modules.textual_inversion.textual_inversion
import
create_embedding
,
train_embedding
from
modules.textual_inversion.preprocess
import
preprocess
from
modules.hypernetworks.hypernetwork
import
create_hypernetwork
,
train_hypernetwork
from
PIL
import
PngImagePlugin
,
Image
from
PIL
import
PngImagePlugin
,
Image
from
modules.sd_models
import
checkpoints_list
from
modules.sd_models
import
checkpoints_list
from
modules.realesrgan_model
import
get_realesrgan_models
from
modules.realesrgan_model
import
get_realesrgan_models
from
modules
import
devices
from
typing
import
List
from
typing
import
List
def
upscaler_to_index
(
name
:
str
):
def
upscaler_to_index
(
name
:
str
):
...
@@ -97,6 +101,11 @@ class Api:
...
@@ -97,6 +101,11 @@ class Api:
self
.
add_api_route
(
"/sdapi/v1/artist-categories"
,
self
.
get_artists_categories
,
methods
=
[
"GET"
],
response_model
=
List
[
str
])
self
.
add_api_route
(
"/sdapi/v1/artist-categories"
,
self
.
get_artists_categories
,
methods
=
[
"GET"
],
response_model
=
List
[
str
])
self
.
add_api_route
(
"/sdapi/v1/artists"
,
self
.
get_artists
,
methods
=
[
"GET"
],
response_model
=
List
[
ArtistItem
])
self
.
add_api_route
(
"/sdapi/v1/artists"
,
self
.
get_artists
,
methods
=
[
"GET"
],
response_model
=
List
[
ArtistItem
])
self
.
add_api_route
(
"/sdapi/v1/refresh-checkpoints"
,
self
.
refresh_checkpoints
,
methods
=
[
"POST"
])
self
.
add_api_route
(
"/sdapi/v1/refresh-checkpoints"
,
self
.
refresh_checkpoints
,
methods
=
[
"POST"
])
self
.
add_api_route
(
"/sdapi/v1/create/embedding"
,
self
.
create_embedding
,
methods
=
[
"POST"
],
response_model
=
CreateResponse
)
self
.
add_api_route
(
"/sdapi/v1/create/hypernetwork"
,
self
.
create_hypernetwork
,
methods
=
[
"POST"
],
response_model
=
CreateResponse
)
self
.
add_api_route
(
"/sdapi/v1/preprocess"
,
self
.
preprocess
,
methods
=
[
"POST"
],
response_model
=
PreprocessResponse
)
self
.
add_api_route
(
"/sdapi/v1/train/embedding"
,
self
.
train_embedding
,
methods
=
[
"POST"
],
response_model
=
TrainResponse
)
self
.
add_api_route
(
"/sdapi/v1/train/hypernetwork"
,
self
.
train_hypernetwork
,
methods
=
[
"POST"
],
response_model
=
TrainResponse
)
def
add_api_route
(
self
,
path
:
str
,
endpoint
,
**
kwargs
):
def
add_api_route
(
self
,
path
:
str
,
endpoint
,
**
kwargs
):
if
shared
.
cmd_opts
.
api_auth
:
if
shared
.
cmd_opts
.
api_auth
:
...
@@ -326,6 +335,89 @@ class Api:
...
@@ -326,6 +335,89 @@ class Api:
def
refresh_checkpoints
(
self
):
def
refresh_checkpoints
(
self
):
shared
.
refresh_checkpoints
()
shared
.
refresh_checkpoints
()
def
create_embedding
(
self
,
args
:
dict
):
try
:
shared
.
state
.
begin
()
filename
=
create_embedding
(
**
args
)
# create empty embedding
sd_hijack
.
model_hijack
.
embedding_db
.
load_textual_inversion_embeddings
()
# reload embeddings so new one can be immediately used
shared
.
state
.
end
()
return
CreateResponse
(
info
=
"create embedding filename: {filename}"
.
format
(
filename
=
filename
))
except
AssertionError
as
e
:
shared
.
state
.
end
()
return
TrainResponse
(
info
=
"create embedding error: {error}"
.
format
(
error
=
e
))
def
create_hypernetwork
(
self
,
args
:
dict
):
try
:
shared
.
state
.
begin
()
filename
=
create_hypernetwork
(
**
args
)
# create empty embedding
shared
.
state
.
end
()
return
CreateResponse
(
info
=
"create hypernetwork filename: {filename}"
.
format
(
filename
=
filename
))
except
AssertionError
as
e
:
shared
.
state
.
end
()
return
TrainResponse
(
info
=
"create hypernetwork error: {error}"
.
format
(
error
=
e
))
def
preprocess
(
self
,
args
:
dict
):
try
:
shared
.
state
.
begin
()
preprocess
(
**
args
)
# quick operation unless blip/booru interrogation is enabled
shared
.
state
.
end
()
return
PreprocessResponse
(
info
=
'preprocess complete'
)
except
KeyError
as
e
:
shared
.
state
.
end
()
return
PreprocessResponse
(
info
=
"preprocess error: invalid token: {error}"
.
format
(
error
=
e
))
except
AssertionError
as
e
:
shared
.
state
.
end
()
return
PreprocessResponse
(
info
=
"preprocess error: {error}"
.
format
(
error
=
e
))
except
FileNotFoundError
as
e
:
shared
.
state
.
end
()
return
PreprocessResponse
(
info
=
'preprocess error: {error}'
.
format
(
error
=
e
))
def
train_embedding
(
self
,
args
:
dict
):
try
:
shared
.
state
.
begin
()
apply_optimizations
=
shared
.
opts
.
training_xattention_optimizations
error
=
None
filename
=
''
if
not
apply_optimizations
:
sd_hijack
.
undo_optimizations
()
try
:
embedding
,
filename
=
train_embedding
(
**
args
)
# can take a long time to complete
except
Exception
as
e
:
error
=
e
finally
:
if
not
apply_optimizations
:
sd_hijack
.
apply_optimizations
()
shared
.
state
.
end
()
return
TrainResponse
(
info
=
"train embedding complete: filename: {filename} error: {error}"
.
format
(
filename
=
filename
,
error
=
error
))
except
AssertionError
as
msg
:
shared
.
state
.
end
()
return
TrainResponse
(
info
=
"train embedding error: {msg}"
.
format
(
msg
=
msg
))
def
train_hypernetwork
(
self
,
args
:
dict
):
try
:
shared
.
state
.
begin
()
initial_hypernetwork
=
shared
.
loaded_hypernetwork
apply_optimizations
=
shared
.
opts
.
training_xattention_optimizations
error
=
None
filename
=
''
if
not
apply_optimizations
:
sd_hijack
.
undo_optimizations
()
try
:
hypernetwork
,
filename
=
train_hypernetwork
(
*
args
)
except
Exception
as
e
:
error
=
e
finally
:
shared
.
loaded_hypernetwork
=
initial_hypernetwork
shared
.
sd_model
.
cond_stage_model
.
to
(
devices
.
device
)
shared
.
sd_model
.
first_stage_model
.
to
(
devices
.
device
)
if
not
apply_optimizations
:
sd_hijack
.
apply_optimizations
()
shared
.
state
.
end
()
return
TrainResponse
(
info
=
"train embedding complete: filename: {filename} error: {error}"
.
format
(
filename
=
filename
,
error
=
error
))
except
AssertionError
as
msg
:
shared
.
state
.
end
()
return
TrainResponse
(
info
=
"train embedding error: {error}"
.
format
(
error
=
error
))
def
launch
(
self
,
server_name
,
port
):
def
launch
(
self
,
server_name
,
port
):
self
.
app
.
include_router
(
self
.
router
)
self
.
app
.
include_router
(
self
.
router
)
uvicorn
.
run
(
self
.
app
,
host
=
server_name
,
port
=
port
)
uvicorn
.
run
(
self
.
app
,
host
=
server_name
,
port
=
port
)
modules/api/models.py
View file @
5f1dfbbc
...
@@ -175,6 +175,15 @@ class InterrogateRequest(BaseModel):
...
@@ -175,6 +175,15 @@ class InterrogateRequest(BaseModel):
class
InterrogateResponse
(
BaseModel
):
class
InterrogateResponse
(
BaseModel
):
caption
:
str
=
Field
(
default
=
None
,
title
=
"Caption"
,
description
=
"The generated caption for the image."
)
caption
:
str
=
Field
(
default
=
None
,
title
=
"Caption"
,
description
=
"The generated caption for the image."
)
class
TrainResponse
(
BaseModel
):
info
:
str
=
Field
(
title
=
"Train info"
,
description
=
"Response string from train embedding or hypernetwork task."
)
class
CreateResponse
(
BaseModel
):
info
:
str
=
Field
(
title
=
"Create info"
,
description
=
"Response string from create embedding or hypernetwork task."
)
class
PreprocessResponse
(
BaseModel
):
info
:
str
=
Field
(
title
=
"Preprocess info"
,
description
=
"Response string from preprocessing task."
)
fields
=
{}
fields
=
{}
for
key
,
metadata
in
opts
.
data_labels
.
items
():
for
key
,
metadata
in
opts
.
data_labels
.
items
():
value
=
opts
.
data
.
get
(
key
)
value
=
opts
.
data
.
get
(
key
)
...
...
modules/hypernetworks/hypernetwork.py
View file @
5f1dfbbc
...
@@ -378,6 +378,32 @@ def report_statistics(loss_info:dict):
...
@@ -378,6 +378,32 @@ def report_statistics(loss_info:dict):
print
(
e
)
print
(
e
)
def
create_hypernetwork
(
name
,
enable_sizes
,
overwrite_old
,
layer_structure
=
None
,
activation_func
=
None
,
weight_init
=
None
,
add_layer_norm
=
False
,
use_dropout
=
False
):
# Remove illegal characters from name.
name
=
""
.
join
(
x
for
x
in
name
if
(
x
.
isalnum
()
or
x
in
"._- "
))
fn
=
os
.
path
.
join
(
shared
.
cmd_opts
.
hypernetwork_dir
,
f
"{name}.pt"
)
if
not
overwrite_old
:
assert
not
os
.
path
.
exists
(
fn
),
f
"file {fn} already exists"
if
type
(
layer_structure
)
==
str
:
layer_structure
=
[
float
(
x
.
strip
())
for
x
in
layer_structure
.
split
(
","
)]
hypernet
=
modules
.
hypernetworks
.
hypernetwork
.
Hypernetwork
(
name
=
name
,
enable_sizes
=
[
int
(
x
)
for
x
in
enable_sizes
],
layer_structure
=
layer_structure
,
activation_func
=
activation_func
,
weight_init
=
weight_init
,
add_layer_norm
=
add_layer_norm
,
use_dropout
=
use_dropout
,
)
hypernet
.
save
(
fn
)
shared
.
reload_hypernetworks
()
return
fn
def
train_hypernetwork
(
hypernetwork_name
,
learn_rate
,
batch_size
,
gradient_step
,
data_root
,
log_directory
,
training_width
,
training_height
,
steps
,
shuffle_tags
,
tag_drop_out
,
latent_sampling_method
,
create_image_every
,
save_hypernetwork_every
,
template_file
,
preview_from_txt2img
,
preview_prompt
,
preview_negative_prompt
,
preview_steps
,
preview_sampler_index
,
preview_cfg_scale
,
preview_seed
,
preview_width
,
preview_height
):
def
train_hypernetwork
(
hypernetwork_name
,
learn_rate
,
batch_size
,
gradient_step
,
data_root
,
log_directory
,
training_width
,
training_height
,
steps
,
shuffle_tags
,
tag_drop_out
,
latent_sampling_method
,
create_image_every
,
save_hypernetwork_every
,
template_file
,
preview_from_txt2img
,
preview_prompt
,
preview_negative_prompt
,
preview_steps
,
preview_sampler_index
,
preview_cfg_scale
,
preview_seed
,
preview_width
,
preview_height
):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
...
...
modules/hypernetworks/ui.py
View file @
5f1dfbbc
...
@@ -3,39 +3,16 @@ import os
...
@@ -3,39 +3,16 @@ import os
import
re
import
re
import
gradio
as
gr
import
gradio
as
gr
import
modules.textual_inversion.preprocess
import
modules.hypernetworks.hypernetwork
import
modules.textual_inversion.textual_inversion
from
modules
import
devices
,
sd_hijack
,
shared
from
modules
import
devices
,
sd_hijack
,
shared
from
modules.hypernetworks
import
hypernetwork
not_available
=
[
"hardswish"
,
"multiheadattention"
]
not_available
=
[
"hardswish"
,
"multiheadattention"
]
keys
=
list
(
x
for
x
in
hypernetwork
.
HypernetworkModule
.
activation_dict
.
keys
()
if
x
not
in
not_available
)
keys
=
list
(
x
for
x
in
modules
.
hypernetworks
.
hypernetwork
.
HypernetworkModule
.
activation_dict
.
keys
()
if
x
not
in
not_available
)
def
create_hypernetwork
(
name
,
enable_sizes
,
overwrite_old
,
layer_structure
=
None
,
activation_func
=
None
,
weight_init
=
None
,
add_layer_norm
=
False
,
use_dropout
=
False
):
def
create_hypernetwork
(
name
,
enable_sizes
,
overwrite_old
,
layer_structure
=
None
,
activation_func
=
None
,
weight_init
=
None
,
add_layer_norm
=
False
,
use_dropout
=
False
):
# Remove illegal characters from name.
filename
=
modules
.
hypernetworks
.
hypernetwork
.
create_hypernetwork
(
name
,
enable_sizes
,
overwrite_old
,
layer_structure
,
activation_func
,
weight_init
,
add_layer_norm
,
use_dropout
)
name
=
""
.
join
(
x
for
x
in
name
if
(
x
.
isalnum
()
or
x
in
"._- "
))
fn
=
os
.
path
.
join
(
shared
.
cmd_opts
.
hypernetwork_dir
,
f
"{name}.pt"
)
return
gr
.
Dropdown
.
update
(
choices
=
sorted
([
x
for
x
in
shared
.
hypernetworks
.
keys
()])),
f
"Created: {filename}"
,
""
if
not
overwrite_old
:
assert
not
os
.
path
.
exists
(
fn
),
f
"file {fn} already exists"
if
type
(
layer_structure
)
==
str
:
layer_structure
=
[
float
(
x
.
strip
())
for
x
in
layer_structure
.
split
(
","
)]
hypernet
=
modules
.
hypernetworks
.
hypernetwork
.
Hypernetwork
(
name
=
name
,
enable_sizes
=
[
int
(
x
)
for
x
in
enable_sizes
],
layer_structure
=
layer_structure
,
activation_func
=
activation_func
,
weight_init
=
weight_init
,
add_layer_norm
=
add_layer_norm
,
use_dropout
=
use_dropout
,
)
hypernet
.
save
(
fn
)
shared
.
reload_hypernetworks
()
return
gr
.
Dropdown
.
update
(
choices
=
sorted
([
x
for
x
in
shared
.
hypernetworks
.
keys
()])),
f
"Created: {fn}"
,
""
def
train_hypernetwork
(
*
args
):
def
train_hypernetwork
(
*
args
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment