Unverified Commit a407c9f0 authored by DaniAndTheWeb's avatar DaniAndTheWeb Committed by GitHub

Automatic torch install for amd on linux

This commit allows the launch script to automatically download rocm's torch version for AMD GPUs using an external GPU detection script. It also prints the operative system and GPU in use.
parent eaebcf63
...@@ -7,6 +7,7 @@ import shlex ...@@ -7,6 +7,7 @@ import shlex
import platform import platform
import argparse import argparse
import json import json
import detection
dir_repos = "repositories" dir_repos = "repositories"
dir_extensions = "extensions" dir_extensions = "extensions"
...@@ -15,6 +16,12 @@ git = os.environ.get('GIT', "git") ...@@ -15,6 +16,12 @@ git = os.environ.get('GIT', "git")
index_url = os.environ.get('INDEX_URL', "") index_url = os.environ.get('INDEX_URL', "")
stored_commit_hash = None stored_commit_hash = None
# Get the GPU vendor and the operating system
gpu = detection.check_gpu()
if os.name == "posix":
os_name = platform.uname().system
else:
os_name = os.name
def commit_hash(): def commit_hash():
global stored_commit_hash global stored_commit_hash
...@@ -173,7 +180,11 @@ def run_extensions_installers(settings_file): ...@@ -173,7 +180,11 @@ def run_extensions_installers(settings_file):
def prepare_environment(): def prepare_environment():
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113") if gpu == "AMD" and os_name !="nt":
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2")
else:
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
commandline_args = os.environ.get('COMMANDLINE_ARGS', "") commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
...@@ -295,6 +306,8 @@ def tests(test_dir): ...@@ -295,6 +306,8 @@ def tests(test_dir):
def start(): def start():
print(f"Operating System: {os_name}")
print(f"GPU: {gpu}")
print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}") print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}")
import webui import webui
if '--nowebui' in sys.argv: if '--nowebui' in sys.argv:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment