Commit 02d7abf5 authored by AUTOMATIC's avatar AUTOMATIC

helpful error message when trying to load 2.0 without config

failing to load model weights from settings won't break generation for currently loaded model anymore
parent 7e549468
...@@ -2,9 +2,30 @@ import sys ...@@ -2,9 +2,30 @@ import sys
import traceback import traceback
def print_error_explanation(message):
lines = message.strip().split("\n")
max_len = max([len(x) for x in lines])
print('=' * max_len, file=sys.stderr)
for line in lines:
print(line, file=sys.stderr)
print('=' * max_len, file=sys.stderr)
def display(e: Exception, task):
print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
message = str(e)
if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
print_error_explanation("""
The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its connfig file.
See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
""")
def run(code, task): def run(code, task):
try: try:
code() code()
except Exception as e: except Exception as e:
print(f"{task}: {type(e).__name__}", file=sys.stderr) display(task, e)
print(traceback.format_exc(), file=sys.stderr)
...@@ -278,6 +278,7 @@ def enable_midas_autodownload(): ...@@ -278,6 +278,7 @@ def enable_midas_autodownload():
midas.api.load_model = load_model_wrapper midas.api.load_model = load_model_wrapper
def load_model(checkpoint_info=None): def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint() checkpoint_info = checkpoint_info or select_checkpoint()
...@@ -312,6 +313,7 @@ def load_model(checkpoint_info=None): ...@@ -312,6 +313,7 @@ def load_model(checkpoint_info=None):
sd_config.model.params.unet_config.params.use_fp16 = False sd_config.model.params.unet_config.params.use_fp16 = False
sd_model = instantiate_from_config(sd_config.model) sd_model = instantiate_from_config(sd_config.model)
load_model_weights(sd_model, checkpoint_info) load_model_weights(sd_model, checkpoint_info)
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
...@@ -340,6 +342,8 @@ def reload_model_weights(sd_model=None, info=None): ...@@ -340,6 +342,8 @@ def reload_model_weights(sd_model=None, info=None):
if not sd_model: if not sd_model:
sd_model = shared.sd_model sd_model = shared.sd_model
current_checkpoint_info = sd_model.sd_checkpoint_info
if sd_model.sd_model_checkpoint == checkpoint_info.filename: if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return return
...@@ -356,8 +360,13 @@ def reload_model_weights(sd_model=None, info=None): ...@@ -356,8 +360,13 @@ def reload_model_weights(sd_model=None, info=None):
sd_hijack.model_hijack.undo_hijack(sd_model) sd_hijack.model_hijack.undo_hijack(sd_model)
try:
load_model_weights(sd_model, checkpoint_info) load_model_weights(sd_model, checkpoint_info)
except Exception as e:
print("Failed to load checkpoint, restoring previous")
load_model_weights(sd_model, current_checkpoint_info)
raise
finally:
sd_hijack.model_hijack.hijack(sd_model) sd_hijack.model_hijack.hijack(sd_model)
script_callbacks.model_loaded_callback(sd_model) script_callbacks.model_loaded_callback(sd_model)
...@@ -365,4 +374,5 @@ def reload_model_weights(sd_model=None, info=None): ...@@ -365,4 +374,5 @@ def reload_model_weights(sd_model=None, info=None):
sd_model.to(devices.device) sd_model.to(devices.device)
print("Weights loaded.") print("Weights loaded.")
return sd_model return sd_model
...@@ -14,7 +14,7 @@ import modules.interrogate ...@@ -14,7 +14,7 @@ import modules.interrogate
import modules.memmon import modules.memmon
import modules.styles import modules.styles
import modules.devices as devices import modules.devices as devices
from modules import localization, sd_vae, extensions, script_loading from modules import localization, sd_vae, extensions, script_loading, errors
from modules.paths import models_path, script_path, sd_path from modules.paths import models_path, script_path, sd_path
...@@ -494,7 +494,12 @@ class Options: ...@@ -494,7 +494,12 @@ class Options:
return False return False
if self.data_labels[key].onchange is not None: if self.data_labels[key].onchange is not None:
try:
self.data_labels[key].onchange() self.data_labels[key].onchange()
except Exception as e:
errors.display(e, f"changing setting {key} to {value}")
setattr(self, key, oldval)
return False
return True return True
......
...@@ -9,7 +9,7 @@ from fastapi import FastAPI ...@@ -9,7 +9,7 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware from fastapi.middleware.gzip import GZipMiddleware
from modules import import_hook from modules import import_hook, errors
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
from modules.paths import script_path from modules.paths import script_path
...@@ -61,7 +61,15 @@ def initialize(): ...@@ -61,7 +61,15 @@ def initialize():
modelloader.load_upscalers() modelloader.load_upscalers()
modules.sd_vae.refresh_vae_list() modules.sd_vae.refresh_vae_list()
try:
modules.sd_models.load_model() modules.sd_models.load_model()
except Exception as e:
errors.display(e, "loading stable diffusion model")
print("", file=sys.stderr)
print("Stable diffusion model failed to load, exiting", file=sys.stderr)
exit(1)
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment