Commit a1a37633 authored by AUTOMATIC's avatar AUTOMATIC

make existing script loading and new preload code use same code for loading modules

limit extension preload scripts to just one file named preload.py
parent e5690d0b
import os import os
import sys import sys
import traceback import traceback
from importlib.machinery import SourceFileLoader
import git import git
...@@ -85,23 +84,3 @@ def list_extensions(): ...@@ -85,23 +84,3 @@ def list_extensions():
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions) extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions)
extensions.append(extension) extensions.append(extension)
def preload_extensions(parser):
if not os.path.isdir(extensions_dir):
return
for dirname in sorted(os.listdir(extensions_dir)):
path = os.path.join(extensions_dir, dirname)
if not os.path.isdir(path):
continue
for file in os.listdir(path):
if "preload.py" in file:
full_file = os.path.join(path, file)
print(f"Got preload file: {full_file}")
try:
ext = SourceFileLoader("preload", full_file).load_module()
parser = ext.preload(parser)
except Exception as e:
print(f"Exception preloading script: {e}")
return parser
\ No newline at end of file
import os
import sys
import traceback
from types import ModuleType
def load_module(path):
with open(path, "r", encoding="utf8") as file:
text = file.read()
compiled = compile(text, path, 'exec')
module = ModuleType(os.path.basename(path))
exec(compiled, module.__dict__)
return module
def preload_extensions(extensions_dir, parser):
if not os.path.isdir(extensions_dir):
return
for dirname in sorted(os.listdir(extensions_dir)):
preload_script = os.path.join(extensions_dir, dirname, "preload.py")
if not os.path.isfile(preload_script):
continue
try:
module = load_module(preload_script)
if hasattr(module, 'preload'):
module.preload(parser)
except Exception:
print(f"Error running preload() for {preload_script}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
...@@ -6,7 +6,7 @@ from collections import namedtuple ...@@ -6,7 +6,7 @@ from collections import namedtuple
import gradio as gr import gradio as gr
from modules.processing import StableDiffusionProcessing from modules.processing import StableDiffusionProcessing
from modules import shared, paths, script_callbacks, extensions from modules import shared, paths, script_callbacks, extensions, script_loading
AlwaysVisible = object() AlwaysVisible = object()
...@@ -161,13 +161,7 @@ def load_scripts(): ...@@ -161,13 +161,7 @@ def load_scripts():
sys.path = [scriptfile.basedir] + sys.path sys.path = [scriptfile.basedir] + sys.path
current_basedir = scriptfile.basedir current_basedir = scriptfile.basedir
with open(scriptfile.path, "r", encoding="utf8") as file: module = script_loading.load_module(scriptfile.path)
text = file.read()
from types import ModuleType
compiled = compile(text, scriptfile.path, 'exec')
module = ModuleType(scriptfile.filename)
exec(compiled, module.__dict__)
for key, script_class in module.__dict__.items(): for key, script_class in module.__dict__.items():
if type(script_class) == type and issubclass(script_class, Script): if type(script_class) == type and issubclass(script_class, Script):
...@@ -328,19 +322,13 @@ class ScriptRunner: ...@@ -328,19 +322,13 @@ class ScriptRunner:
def reload_sources(self, cache): def reload_sources(self, cache):
for si, script in list(enumerate(self.scripts)): for si, script in list(enumerate(self.scripts)):
with open(script.filename, "r", encoding="utf8") as file:
args_from = script.args_from args_from = script.args_from
args_to = script.args_to args_to = script.args_to
filename = script.filename filename = script.filename
text = file.read()
from types import ModuleType
module = cache.get(filename, None) module = cache.get(filename, None)
if module is None: if module is None:
compiled = compile(text, filename, 'exec') module = script_loading.load_module(script.filename)
module = ModuleType(script.filename)
exec(compiled, module.__dict__)
cache[filename] = module cache[filename] = module
for key, script_class in module.__dict__.items(): for key, script_class in module.__dict__.items():
......
...@@ -3,7 +3,6 @@ import datetime ...@@ -3,7 +3,6 @@ import datetime
import json import json
import os import os
import sys import sys
from collections import OrderedDict
import time import time
import gradio as gr import gradio as gr
...@@ -15,7 +14,7 @@ import modules.memmon ...@@ -15,7 +14,7 @@ import modules.memmon
import modules.sd_models import modules.sd_models
import modules.styles import modules.styles
import modules.devices as devices import modules.devices as devices
from modules import sd_samplers, sd_models, localization, sd_vae, extensions from modules import sd_samplers, sd_models, localization, sd_vae, extensions, script_loading
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
from modules.paths import models_path, script_path, sd_path from modules.paths import models_path, script_path, sd_path
...@@ -91,7 +90,7 @@ parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requ ...@@ -91,7 +90,7 @@ parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requ
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None) parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
extensions.preload_extensions(parser) script_loading.preload_extensions(extensions.extensions_dir, parser)
cmd_opts = parser.parse_args() cmd_opts = parser.parse_args()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment