Commit c0ee1488 authored by AUTOMATIC's avatar AUTOMATIC

add support for running with gradio 3.9 installed

parent fda1ed18
...@@ -7,7 +7,7 @@ from pathlib import Path ...@@ -7,7 +7,7 @@ from pathlib import Path
import gradio as gr import gradio as gr
from modules.shared import script_path from modules.shared import script_path
from modules import shared from modules import shared, ui_tempdir
import tempfile import tempfile
from PIL import Image from PIL import Image
...@@ -39,7 +39,7 @@ def quote(text): ...@@ -39,7 +39,7 @@ def quote(text):
def image_from_url_text(filedata): def image_from_url_text(filedata):
if type(filedata) == dict and filedata["is_file"]: if type(filedata) == dict and filedata["is_file"]:
filename = filedata["name"] filename = filedata["name"]
is_in_right_dir = any([filename in fileset for fileset in shared.demo.temp_file_sets]) is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
assert is_in_right_dir, 'trying to open image file outside of allowed directories' assert is_in_right_dir, 'trying to open image file outside of allowed directories'
return Image.open(filename) return Image.open(filename)
......
import os import os
import tempfile import tempfile
from collections import namedtuple from collections import namedtuple
from pathlib import Path
import gradio as gr import gradio as gr
...@@ -12,10 +13,28 @@ from modules import shared ...@@ -12,10 +13,28 @@ from modules import shared
Savedfile = namedtuple("Savedfile", ["name"]) Savedfile = namedtuple("Savedfile", ["name"])
def register_tmp_file(gradio, filename):
if hasattr(gradio, 'temp_file_sets'): # gradio 3.15
gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)}
if hasattr(gradio, 'temp_dirs'): # gradio 3.9
gradio.temp_dirs = gradio.temp_dirs | {os.path.abspath(os.path.dirname(filename))}
def check_tmp_file(gradio, filename):
if hasattr(gradio, 'temp_file_sets'):
return any([filename in fileset for fileset in gradio.temp_file_sets])
if hasattr(gradio, 'temp_dirs'):
return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs)
return False
def save_pil_to_file(pil_image, dir=None): def save_pil_to_file(pil_image, dir=None):
already_saved_as = getattr(pil_image, 'already_saved_as', None) already_saved_as = getattr(pil_image, 'already_saved_as', None)
if already_saved_as and os.path.isfile(already_saved_as): if already_saved_as and os.path.isfile(already_saved_as):
shared.demo.temp_file_sets[0] = shared.demo.temp_file_sets[0] | {os.path.abspath(already_saved_as)} register_tmp_file(shared.demo, already_saved_as)
file_obj = Savedfile(already_saved_as) file_obj = Savedfile(already_saved_as)
return file_obj return file_obj
...@@ -45,7 +64,7 @@ def on_tmpdir_changed(): ...@@ -45,7 +64,7 @@ def on_tmpdir_changed():
os.makedirs(shared.opts.temp_dir, exist_ok=True) os.makedirs(shared.opts.temp_dir, exist_ok=True)
shared.demo.temp_file_sets[0] = shared.demo.temp_file_sets[0] | {os.path.abspath(shared.opts.temp_dir)} register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x"))
def cleanup_tmpdr(): def cleanup_tmpdr():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment