Commit b8435e63 authored by evshiron's avatar evshiron

add --cors-allow-origins cmd opt

parent 89722fb5
...@@ -86,6 +86,7 @@ parser.add_argument("--nowebui", action='store_true', help="use api=True to laun ...@@ -86,6 +86,7 @@ parser.add_argument("--nowebui", action='store_true', help="use api=True to laun
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False) parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origins", default=None)
cmd_opts = parser.parse_args() cmd_opts = parser.parse_args()
restricted_opts = { restricted_opts = {
...@@ -147,9 +148,9 @@ class State: ...@@ -147,9 +148,9 @@ class State:
self.interrupted = True self.interrupted = True
def nextjob(self): def nextjob(self):
if opts.show_progress_every_n_steps == -1: if opts.show_progress_every_n_steps == -1:
self.do_set_current_image() self.do_set_current_image()
self.job_no += 1 self.job_no += 1
self.sampling_step = 0 self.sampling_step = 0
self.current_image_sampling_step = 0 self.current_image_sampling_step = 0
...@@ -198,7 +199,7 @@ class State: ...@@ -198,7 +199,7 @@ class State:
return return
if self.current_latent is None: if self.current_latent is None:
return return
if opts.show_progress_grid: if opts.show_progress_grid:
self.current_image = sd_samplers.samples_to_image_grid(self.current_latent) self.current_image = sd_samplers.samples_to_image_grid(self.current_latent)
else: else:
......
...@@ -5,6 +5,7 @@ import importlib ...@@ -5,6 +5,7 @@ import importlib
import signal import signal
import threading import threading
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware from fastapi.middleware.gzip import GZipMiddleware
from modules.paths import script_path from modules.paths import script_path
...@@ -93,6 +94,11 @@ def initialize(): ...@@ -93,6 +94,11 @@ def initialize():
signal.signal(signal.SIGINT, sigint_handler) signal.signal(signal.SIGINT, sigint_handler)
def setup_cors(app):
if cmd_opts.cors_allow_origins:
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'])
def create_api(app): def create_api(app):
from modules.api.api import Api from modules.api.api import Api
api = Api(app, queue_lock) api = Api(app, queue_lock)
...@@ -114,6 +120,7 @@ def api_only(): ...@@ -114,6 +120,7 @@ def api_only():
initialize() initialize()
app = FastAPI() app = FastAPI()
setup_cors(app)
app.add_middleware(GZipMiddleware, minimum_size=1000) app.add_middleware(GZipMiddleware, minimum_size=1000)
api = create_api(app) api = create_api(app)
...@@ -147,6 +154,8 @@ def webui(): ...@@ -147,6 +154,8 @@ def webui():
# runnnig its code. We disable this here. Suggested by RyotaK. # runnnig its code. We disable this here. Suggested by RyotaK.
app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware'] app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
setup_cors(app)
app.add_middleware(GZipMiddleware, minimum_size=1000) app.add_middleware(GZipMiddleware, minimum_size=1000)
if launch_api: if launch_api:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment