Commit 7ff1ef77 authored by AUTOMATIC's avatar AUTOMATIC

add a message about new torch/xformers version and a way to upgrade by...

add a message about new torch/xformers version and a way to upgrade by specifying a commandline flag
parent 56f63cd4
...@@ -208,6 +208,7 @@ def prepare_environment(): ...@@ -208,6 +208,7 @@ def prepare_environment():
sys.argv, _ = extract_arg(sys.argv, '-f') sys.argv, _ = extract_arg(sys.argv, '-f')
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test') sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers') sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
sys.argv, reinstall_torch = extract_arg(sys.argv, '--reinstall-torch')
sys.argv, update_check = extract_arg(sys.argv, '--update-check') sys.argv, update_check = extract_arg(sys.argv, '--update-check')
sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests') sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests')
sys.argv, skip_install = extract_arg(sys.argv, '--skip-install') sys.argv, skip_install = extract_arg(sys.argv, '--skip-install')
...@@ -219,7 +220,7 @@ def prepare_environment(): ...@@ -219,7 +220,7 @@ def prepare_environment():
print(f"Python {sys.version}") print(f"Python {sys.version}")
print(f"Commit hash: {commit}") print(f"Commit hash: {commit}")
if not is_installed("torch") or not is_installed("torchvision"): if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch") run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch")
if not skip_torch_cuda_test: if not skip_torch_cuda_test:
......
...@@ -8,6 +8,7 @@ import re ...@@ -8,6 +8,7 @@ import re
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware from fastapi.middleware.gzip import GZipMiddleware
from packaging import version
from modules import import_hook, errors, extra_networks from modules import import_hook, errors, extra_networks
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
...@@ -49,7 +50,32 @@ else: ...@@ -49,7 +50,32 @@ else:
server_name = "0.0.0.0" if cmd_opts.listen else None server_name = "0.0.0.0" if cmd_opts.listen else None
def check_versions():
expected_torch_version = "1.13.1"
if version.parse(torch.__version__) < version.parse(expected_torch_version):
errors.print_error_explanation(f"""
You are running torch {torch.__version__}.
The program is tested to work with torch {expected_torch_version}.
To reinstall the desired version, run with commandline flag --reinstall-torch.
Beware that this will cause a lot of large files to be downloaded.
""".strip())
expected_xformers_version = "0.0.16rc425"
if shared.xformers_available:
import xformers
if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
errors.print_error_explanation(f"""
You are running xformers {xformers.__version__}.
The program is tested to work with xformers {expected_xformers_version}.
To reinstall the desired version, run with commandline flag --reinstall-xformers.
""".strip())
def initialize(): def initialize():
check_versions()
extensions.list_extensions() extensions.list_extensions()
localization.list_localizations(cmd_opts.localizations_dir) localization.list_localizations(cmd_opts.localizations_dir)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment