Unverified Commit 994136b9 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge pull request #4294 from evshiron/feat/allow-origins

add --cors-allow-origins cmd opt
parents c9b2eef6 37ba0070
...@@ -86,6 +86,7 @@ parser.add_argument("--nowebui", action='store_true', help="use api=True to laun ...@@ -86,6 +86,7 @@ parser.add_argument("--nowebui", action='store_true', help="use api=True to laun
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False) parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origins", default=None)
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None) parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None) parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
......
...@@ -5,6 +5,7 @@ import importlib ...@@ -5,6 +5,7 @@ import importlib
import signal import signal
import threading import threading
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware from fastapi.middleware.gzip import GZipMiddleware
from modules.paths import script_path from modules.paths import script_path
...@@ -107,6 +108,11 @@ def initialize(): ...@@ -107,6 +108,11 @@ def initialize():
signal.signal(signal.SIGINT, sigint_handler) signal.signal(signal.SIGINT, sigint_handler)
def setup_cors(app):
if cmd_opts.cors_allow_origins:
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'])
def create_api(app): def create_api(app):
from modules.api.api import Api from modules.api.api import Api
api = Api(app, queue_lock) api = Api(app, queue_lock)
...@@ -128,6 +134,7 @@ def api_only(): ...@@ -128,6 +134,7 @@ def api_only():
initialize() initialize()
app = FastAPI() app = FastAPI()
setup_cors(app)
app.add_middleware(GZipMiddleware, minimum_size=1000) app.add_middleware(GZipMiddleware, minimum_size=1000)
api = create_api(app) api = create_api(app)
...@@ -163,6 +170,8 @@ def webui(): ...@@ -163,6 +170,8 @@ def webui():
# runnnig its code. We disable this here. Suggested by RyotaK. # runnnig its code. We disable this here. Suggested by RyotaK.
app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware'] app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
setup_cors(app)
app.add_middleware(GZipMiddleware, minimum_size=1000) app.add_middleware(GZipMiddleware, minimum_size=1000)
if launch_api: if launch_api:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment