Commit fddb4883 authored by evshiron's avatar evshiron

prototype progress api

parent 99d728b5
import time
from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers from modules.sd_samplers import all_samplers
from modules.extras import run_pnginfo from modules.extras import run_pnginfo
import modules.shared as shared import modules.shared as shared
from modules import devices
import uvicorn import uvicorn
from fastapi import Body, APIRouter, HTTPException from fastapi import Body, APIRouter, HTTPException
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
...@@ -25,6 +28,37 @@ class ImageToImageResponse(BaseModel): ...@@ -25,6 +28,37 @@ class ImageToImageResponse(BaseModel):
parameters: Json parameters: Json
info: Json info: Json
class ProgressResponse(BaseModel):
progress: float
eta_relative: float
state: Json
# copy from wrap_gradio_gpu_call of webui.py
# because queue lock will be acquired in api handlers
# and time start needs to be set
# the function has been modified into two parts
def before_gpu_call():
devices.torch_gc()
shared.state.sampling_step = 0
shared.state.job_count = -1
shared.state.job_no = 0
shared.state.job_timestamp = shared.state.get_job_timestamp()
shared.state.current_latent = None
shared.state.current_image = None
shared.state.current_image_sampling_step = 0
shared.state.skipped = False
shared.state.interrupted = False
shared.state.textinfo = None
shared.state.time_start = time.time()
def after_gpu_call():
shared.state.job = ""
shared.state.job_count = 0
devices.torch_gc()
class Api: class Api:
def __init__(self, app, queue_lock): def __init__(self, app, queue_lock):
...@@ -33,6 +67,7 @@ class Api: ...@@ -33,6 +67,7 @@ class Api:
self.queue_lock = queue_lock self.queue_lock = queue_lock
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"]) self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"]) self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"])
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"])
def __base64_to_image(self, base64_string): def __base64_to_image(self, base64_string):
# if has a comma, deal with prefix # if has a comma, deal with prefix
...@@ -44,12 +79,12 @@ class Api: ...@@ -44,12 +79,12 @@ class Api:
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index) sampler_index = sampler_to_index(txt2imgreq.sampler_index)
if sampler_index is None: if sampler_index is None:
raise HTTPException(status_code=404, detail="Sampler not found") raise HTTPException(status_code=404, detail="Sampler not found")
populate = txt2imgreq.copy(update={ # Override __init__ params populate = txt2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model, "sd_model": shared.sd_model,
"sampler_index": sampler_index[0], "sampler_index": sampler_index[0],
"do_not_save_samples": True, "do_not_save_samples": True,
"do_not_save_grid": True "do_not_save_grid": True
...@@ -57,9 +92,11 @@ class Api: ...@@ -57,9 +92,11 @@ class Api:
) )
p = StableDiffusionProcessingTxt2Img(**vars(populate)) p = StableDiffusionProcessingTxt2Img(**vars(populate))
# Override object param # Override object param
before_gpu_call()
with self.queue_lock: with self.queue_lock:
processed = process_images(p) processed = process_images(p)
after_gpu_call()
b64images = [] b64images = []
for i in processed.images: for i in processed.images:
buffer = io.BytesIO() buffer = io.BytesIO()
...@@ -67,30 +104,30 @@ class Api: ...@@ -67,30 +104,30 @@ class Api:
b64images.append(base64.b64encode(buffer.getvalue())) b64images.append(base64.b64encode(buffer.getvalue()))
return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.js()) return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.js())
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI): def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
sampler_index = sampler_to_index(img2imgreq.sampler_index) sampler_index = sampler_to_index(img2imgreq.sampler_index)
if sampler_index is None: if sampler_index is None:
raise HTTPException(status_code=404, detail="Sampler not found") raise HTTPException(status_code=404, detail="Sampler not found")
init_images = img2imgreq.init_images init_images = img2imgreq.init_images
if init_images is None: if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found") raise HTTPException(status_code=404, detail="Init image not found")
mask = img2imgreq.mask mask = img2imgreq.mask
if mask: if mask:
mask = self.__base64_to_image(mask) mask = self.__base64_to_image(mask)
populate = img2imgreq.copy(update={ # Override __init__ params populate = img2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model, "sd_model": shared.sd_model,
"sampler_index": sampler_index[0], "sampler_index": sampler_index[0],
"do_not_save_samples": True, "do_not_save_samples": True,
"do_not_save_grid": True, "do_not_save_grid": True,
"mask": mask "mask": mask
} }
) )
...@@ -103,9 +140,11 @@ class Api: ...@@ -103,9 +140,11 @@ class Api:
p.init_images = imgs p.init_images = imgs
# Override object param # Override object param
before_gpu_call()
with self.queue_lock: with self.queue_lock:
processed = process_images(p) processed = process_images(p)
after_gpu_call()
b64images = [] b64images = []
for i in processed.images: for i in processed.images:
buffer = io.BytesIO() buffer = io.BytesIO()
...@@ -118,6 +157,28 @@ class Api: ...@@ -118,6 +157,28 @@ class Api:
return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=processed.js()) return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=processed.js())
def progressapi(self):
# copy from check_progress_call of ui.py
if shared.state.job_count == 0:
return ProgressResponse(progress=0, eta_relative=0, state=shared.state.js())
# avoid dividing zero
progress = 0.01
if shared.state.job_count > 0:
progress += shared.state.job_no / shared.state.job_count
if shared.state.sampling_steps > 0:
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
time_since_start = time.time() - shared.state.time_start
eta = (time_since_start/progress)
eta_relative = eta-time_since_start
progress = min(progress, 1)
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.js())
def extrasapi(self): def extrasapi(self):
raise NotImplementedError raise NotImplementedError
......
...@@ -146,6 +146,19 @@ class State: ...@@ -146,6 +146,19 @@ class State:
def get_job_timestamp(self): def get_job_timestamp(self):
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp? return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
def js(self):
obj = {
"skipped": self.skipped,
"interrupted": self.skipped,
"job": self.job,
"job_count": self.job_count,
"job_no": self.job_no,
"sampling_step": self.sampling_step,
"sampling_steps": self.sampling_steps,
}
return json.dumps(obj)
state = State() state = State()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment