Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
fddb4883
Commit
fddb4883
authored
Oct 26, 2022
by
evshiron
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
prototype progress api
parent
99d728b5
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
88 additions
and
14 deletions
+88
-14
api.py
modules/api/api.py
+75
-14
shared.py
modules/shared.py
+13
-0
No files found.
modules/api/api.py
View file @
fddb4883
import
time
from
modules.api.models
import
StableDiffusionTxt2ImgProcessingAPI
,
StableDiffusionImg2ImgProcessingAPI
from
modules.api.models
import
StableDiffusionTxt2ImgProcessingAPI
,
StableDiffusionImg2ImgProcessingAPI
from
modules.processing
import
StableDiffusionProcessingTxt2Img
,
StableDiffusionProcessingImg2Img
,
process_images
from
modules.processing
import
StableDiffusionProcessingTxt2Img
,
StableDiffusionProcessingImg2Img
,
process_images
from
modules.sd_samplers
import
all_samplers
from
modules.sd_samplers
import
all_samplers
from
modules.extras
import
run_pnginfo
from
modules.extras
import
run_pnginfo
import
modules.shared
as
shared
import
modules.shared
as
shared
from
modules
import
devices
import
uvicorn
import
uvicorn
from
fastapi
import
Body
,
APIRouter
,
HTTPException
from
fastapi
import
Body
,
APIRouter
,
HTTPException
from
fastapi.responses
import
JSONResponse
from
fastapi.responses
import
JSONResponse
...
@@ -25,6 +28,37 @@ class ImageToImageResponse(BaseModel):
...
@@ -25,6 +28,37 @@ class ImageToImageResponse(BaseModel):
parameters
:
Json
parameters
:
Json
info
:
Json
info
:
Json
class
ProgressResponse
(
BaseModel
):
progress
:
float
eta_relative
:
float
state
:
Json
# copy from wrap_gradio_gpu_call of webui.py
# because queue lock will be acquired in api handlers
# and time start needs to be set
# the function has been modified into two parts
def
before_gpu_call
():
devices
.
torch_gc
()
shared
.
state
.
sampling_step
=
0
shared
.
state
.
job_count
=
-
1
shared
.
state
.
job_no
=
0
shared
.
state
.
job_timestamp
=
shared
.
state
.
get_job_timestamp
()
shared
.
state
.
current_latent
=
None
shared
.
state
.
current_image
=
None
shared
.
state
.
current_image_sampling_step
=
0
shared
.
state
.
skipped
=
False
shared
.
state
.
interrupted
=
False
shared
.
state
.
textinfo
=
None
shared
.
state
.
time_start
=
time
.
time
()
def
after_gpu_call
():
shared
.
state
.
job
=
""
shared
.
state
.
job_count
=
0
devices
.
torch_gc
()
class
Api
:
class
Api
:
def
__init__
(
self
,
app
,
queue_lock
):
def
__init__
(
self
,
app
,
queue_lock
):
...
@@ -33,6 +67,7 @@ class Api:
...
@@ -33,6 +67,7 @@ class Api:
self
.
queue_lock
=
queue_lock
self
.
queue_lock
=
queue_lock
self
.
app
.
add_api_route
(
"/sdapi/v1/txt2img"
,
self
.
text2imgapi
,
methods
=
[
"POST"
])
self
.
app
.
add_api_route
(
"/sdapi/v1/txt2img"
,
self
.
text2imgapi
,
methods
=
[
"POST"
])
self
.
app
.
add_api_route
(
"/sdapi/v1/img2img"
,
self
.
img2imgapi
,
methods
=
[
"POST"
])
self
.
app
.
add_api_route
(
"/sdapi/v1/img2img"
,
self
.
img2imgapi
,
methods
=
[
"POST"
])
self
.
app
.
add_api_route
(
"/sdapi/v1/progress"
,
self
.
progressapi
,
methods
=
[
"GET"
])
def
__base64_to_image
(
self
,
base64_string
):
def
__base64_to_image
(
self
,
base64_string
):
# if has a comma, deal with prefix
# if has a comma, deal with prefix
...
@@ -44,12 +79,12 @@ class Api:
...
@@ -44,12 +79,12 @@ class Api:
def
text2imgapi
(
self
,
txt2imgreq
:
StableDiffusionTxt2ImgProcessingAPI
):
def
text2imgapi
(
self
,
txt2imgreq
:
StableDiffusionTxt2ImgProcessingAPI
):
sampler_index
=
sampler_to_index
(
txt2imgreq
.
sampler_index
)
sampler_index
=
sampler_to_index
(
txt2imgreq
.
sampler_index
)
if
sampler_index
is
None
:
if
sampler_index
is
None
:
raise
HTTPException
(
status_code
=
404
,
detail
=
"Sampler not found"
)
raise
HTTPException
(
status_code
=
404
,
detail
=
"Sampler not found"
)
populate
=
txt2imgreq
.
copy
(
update
=
{
# Override __init__ params
populate
=
txt2imgreq
.
copy
(
update
=
{
# Override __init__ params
"sd_model"
:
shared
.
sd_model
,
"sd_model"
:
shared
.
sd_model
,
"sampler_index"
:
sampler_index
[
0
],
"sampler_index"
:
sampler_index
[
0
],
"do_not_save_samples"
:
True
,
"do_not_save_samples"
:
True
,
"do_not_save_grid"
:
True
"do_not_save_grid"
:
True
...
@@ -57,9 +92,11 @@ class Api:
...
@@ -57,9 +92,11 @@ class Api:
)
)
p
=
StableDiffusionProcessingTxt2Img
(
**
vars
(
populate
))
p
=
StableDiffusionProcessingTxt2Img
(
**
vars
(
populate
))
# Override object param
# Override object param
before_gpu_call
()
with
self
.
queue_lock
:
with
self
.
queue_lock
:
processed
=
process_images
(
p
)
processed
=
process_images
(
p
)
after_gpu_call
()
b64images
=
[]
b64images
=
[]
for
i
in
processed
.
images
:
for
i
in
processed
.
images
:
buffer
=
io
.
BytesIO
()
buffer
=
io
.
BytesIO
()
...
@@ -67,30 +104,30 @@ class Api:
...
@@ -67,30 +104,30 @@ class Api:
b64images
.
append
(
base64
.
b64encode
(
buffer
.
getvalue
()))
b64images
.
append
(
base64
.
b64encode
(
buffer
.
getvalue
()))
return
TextToImageResponse
(
images
=
b64images
,
parameters
=
json
.
dumps
(
vars
(
txt2imgreq
)),
info
=
processed
.
js
())
return
TextToImageResponse
(
images
=
b64images
,
parameters
=
json
.
dumps
(
vars
(
txt2imgreq
)),
info
=
processed
.
js
())
def
img2imgapi
(
self
,
img2imgreq
:
StableDiffusionImg2ImgProcessingAPI
):
def
img2imgapi
(
self
,
img2imgreq
:
StableDiffusionImg2ImgProcessingAPI
):
sampler_index
=
sampler_to_index
(
img2imgreq
.
sampler_index
)
sampler_index
=
sampler_to_index
(
img2imgreq
.
sampler_index
)
if
sampler_index
is
None
:
if
sampler_index
is
None
:
raise
HTTPException
(
status_code
=
404
,
detail
=
"Sampler not found"
)
raise
HTTPException
(
status_code
=
404
,
detail
=
"Sampler not found"
)
init_images
=
img2imgreq
.
init_images
init_images
=
img2imgreq
.
init_images
if
init_images
is
None
:
if
init_images
is
None
:
raise
HTTPException
(
status_code
=
404
,
detail
=
"Init image not found"
)
raise
HTTPException
(
status_code
=
404
,
detail
=
"Init image not found"
)
mask
=
img2imgreq
.
mask
mask
=
img2imgreq
.
mask
if
mask
:
if
mask
:
mask
=
self
.
__base64_to_image
(
mask
)
mask
=
self
.
__base64_to_image
(
mask
)
populate
=
img2imgreq
.
copy
(
update
=
{
# Override __init__ params
populate
=
img2imgreq
.
copy
(
update
=
{
# Override __init__ params
"sd_model"
:
shared
.
sd_model
,
"sd_model"
:
shared
.
sd_model
,
"sampler_index"
:
sampler_index
[
0
],
"sampler_index"
:
sampler_index
[
0
],
"do_not_save_samples"
:
True
,
"do_not_save_samples"
:
True
,
"do_not_save_grid"
:
True
,
"do_not_save_grid"
:
True
,
"mask"
:
mask
"mask"
:
mask
}
}
)
)
...
@@ -103,9 +140,11 @@ class Api:
...
@@ -103,9 +140,11 @@ class Api:
p
.
init_images
=
imgs
p
.
init_images
=
imgs
# Override object param
# Override object param
before_gpu_call
()
with
self
.
queue_lock
:
with
self
.
queue_lock
:
processed
=
process_images
(
p
)
processed
=
process_images
(
p
)
after_gpu_call
()
b64images
=
[]
b64images
=
[]
for
i
in
processed
.
images
:
for
i
in
processed
.
images
:
buffer
=
io
.
BytesIO
()
buffer
=
io
.
BytesIO
()
...
@@ -118,6 +157,28 @@ class Api:
...
@@ -118,6 +157,28 @@ class Api:
return
ImageToImageResponse
(
images
=
b64images
,
parameters
=
json
.
dumps
(
vars
(
img2imgreq
)),
info
=
processed
.
js
())
return
ImageToImageResponse
(
images
=
b64images
,
parameters
=
json
.
dumps
(
vars
(
img2imgreq
)),
info
=
processed
.
js
())
def
progressapi
(
self
):
# copy from check_progress_call of ui.py
if
shared
.
state
.
job_count
==
0
:
return
ProgressResponse
(
progress
=
0
,
eta_relative
=
0
,
state
=
shared
.
state
.
js
())
# avoid dividing zero
progress
=
0.01
if
shared
.
state
.
job_count
>
0
:
progress
+=
shared
.
state
.
job_no
/
shared
.
state
.
job_count
if
shared
.
state
.
sampling_steps
>
0
:
progress
+=
1
/
shared
.
state
.
job_count
*
shared
.
state
.
sampling_step
/
shared
.
state
.
sampling_steps
time_since_start
=
time
.
time
()
-
shared
.
state
.
time_start
eta
=
(
time_since_start
/
progress
)
eta_relative
=
eta
-
time_since_start
progress
=
min
(
progress
,
1
)
return
ProgressResponse
(
progress
=
progress
,
eta_relative
=
eta_relative
,
state
=
shared
.
state
.
js
())
def
extrasapi
(
self
):
def
extrasapi
(
self
):
raise
NotImplementedError
raise
NotImplementedError
...
...
modules/shared.py
View file @
fddb4883
...
@@ -146,6 +146,19 @@ class State:
...
@@ -146,6 +146,19 @@ class State:
def
get_job_timestamp
(
self
):
def
get_job_timestamp
(
self
):
return
datetime
.
datetime
.
now
()
.
strftime
(
"
%
Y
%
m
%
d
%
H
%
M
%
S"
)
# shouldn't this return job_timestamp?
return
datetime
.
datetime
.
now
()
.
strftime
(
"
%
Y
%
m
%
d
%
H
%
M
%
S"
)
# shouldn't this return job_timestamp?
def
js
(
self
):
obj
=
{
"skipped"
:
self
.
skipped
,
"interrupted"
:
self
.
skipped
,
"job"
:
self
.
job
,
"job_count"
:
self
.
job_count
,
"job_no"
:
self
.
job_no
,
"sampling_step"
:
self
.
sampling_step
,
"sampling_steps"
:
self
.
sampling_steps
,
}
return
json
.
dumps
(
obj
)
state
=
State
()
state
=
State
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment