Commit db9ab1a4 authored by Stephen's avatar Stephen Committed by AUTOMATIC1111

[Bugfix][API] - Fix API response for colab users

parent cbb857b6
...@@ -7,6 +7,7 @@ import uvicorn ...@@ -7,6 +7,7 @@ import uvicorn
from fastapi import Body, APIRouter, HTTPException from fastapi import Body, APIRouter, HTTPException
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, Json from pydantic import BaseModel, Field, Json
from typing import List
import json import json
import io import io
import base64 import base64
...@@ -15,12 +16,12 @@ from PIL import Image ...@@ -15,12 +16,12 @@ from PIL import Image
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
class TextToImageResponse(BaseModel): class TextToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: Json parameters: Json
info: Json info: Json
class ImageToImageResponse(BaseModel): class ImageToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: Json parameters: Json
info: Json info: Json
...@@ -41,6 +42,9 @@ class Api: ...@@ -41,6 +42,9 @@ class Api:
# convert base64 to PIL image # convert base64 to PIL image
return Image.open(io.BytesIO(imgdata)) return Image.open(io.BytesIO(imgdata))
def __processed_info_to_json(self, processed):
return json.dumps(processed.info)
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index) sampler_index = sampler_to_index(txt2imgreq.sampler_index)
...@@ -65,7 +69,7 @@ class Api: ...@@ -65,7 +69,7 @@ class Api:
i.save(buffer, format="png") i.save(buffer, format="png")
b64images.append(base64.b64encode(buffer.getvalue())) b64images.append(base64.b64encode(buffer.getvalue()))
return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info)) return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.js())
...@@ -111,7 +115,12 @@ class Api: ...@@ -111,7 +115,12 @@ class Api:
i.save(buffer, format="png") i.save(buffer, format="png")
b64images.append(base64.b64encode(buffer.getvalue())) b64images.append(base64.b64encode(buffer.getvalue()))
return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=json.dumps(processed.info)) if (not img2imgreq.include_init_images):
# remove img2imgreq.init_images and img2imgreq.mask
img2imgreq.init_images = None
img2imgreq.mask = None
return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=processed.js())
def extrasapi(self): def extrasapi(self):
raise NotImplementedError raise NotImplementedError
......
...@@ -31,6 +31,7 @@ class ModelDef(BaseModel): ...@@ -31,6 +31,7 @@ class ModelDef(BaseModel):
field_alias: str field_alias: str
field_type: Any field_type: Any
field_value: Any field_value: Any
field_exclude: bool = False
class PydanticModelGenerator: class PydanticModelGenerator:
...@@ -68,7 +69,7 @@ class PydanticModelGenerator: ...@@ -68,7 +69,7 @@ class PydanticModelGenerator:
field=underscore(k), field=underscore(k),
field_alias=k, field_alias=k,
field_type=field_type_generator(k, v), field_type=field_type_generator(k, v),
field_value=v.default field_value=v.default,
) )
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
] ]
...@@ -78,7 +79,8 @@ class PydanticModelGenerator: ...@@ -78,7 +79,8 @@ class PydanticModelGenerator:
field=underscore(fields["key"]), field=underscore(fields["key"]),
field_alias=fields["key"], field_alias=fields["key"],
field_type=fields["type"], field_type=fields["type"],
field_value=fields["default"])) field_value=fields["default"],
field_exclude=fields["exclude"] if "exclude" in fields else False))
def generate_model(self): def generate_model(self):
""" """
...@@ -86,7 +88,7 @@ class PydanticModelGenerator: ...@@ -86,7 +88,7 @@ class PydanticModelGenerator:
from the json and overrides provided at initialization from the json and overrides provided at initialization
""" """
fields = { fields = {
d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def
} }
DynamicModel = create_model(self._model_name, **fields) DynamicModel = create_model(self._model_name, **fields)
DynamicModel.__config__.allow_population_by_field_name = True DynamicModel.__config__.allow_population_by_field_name = True
...@@ -102,5 +104,5 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( ...@@ -102,5 +104,5 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingImg2Img", "StableDiffusionProcessingImg2Img",
StableDiffusionProcessingImg2Img, StableDiffusionProcessingImg2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}] [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
).generate_model() ).generate_model()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment