diff --git a/setup.py b/setup.py index e2a4b77f..c30a1078 100644 --- a/setup.py +++ b/setup.py @@ -22,6 +22,7 @@ # SOFTWARE. # +import atexit import os import subprocess import sys @@ -34,6 +35,37 @@ ENABLE_SPARSE = os.getenv("ENABLE_SPARSE") +_warning_printed = False + + +def print_platform_warning(): + global _warning_printed + if not PLATFORM and not _warning_printed: + _warning_printed = True + RED = "\033[91m" + YELLOW = "\033[93m" + BOLD = "\033[1m" + RESET = "\033[0m" + + warning_msg = f""" +{RED}{'=' * 80} +{BOLD}⚠️ WARNING: PLATFORM environment variable is not set! ⚠️{RESET} +{RED}{'=' * 80}{RESET} +{YELLOW}Please set PLATFORM to one of: cuda, ascend, musa, maca{RESET} +Example: + {BOLD}export PLATFORM=cuda{RESET} # For CUDA platform +{YELLOW}In CI scenarios only, you don't need to specify PLATFORM. If it's not a CI scenario, please uninstall and then reinstall with PLATFORM specified.{RESET} +{RED}{'=' * 80}{RESET} +""" + # Use write and flush to ensure output even without -v flag + sys.stderr.write(warning_msg) + sys.stderr.flush() + + +if not PLATFORM: + atexit.register(print_platform_warning) + + def enable_sparse() -> bool: return ENABLE_SPARSE is not None and ENABLE_SPARSE.lower() == "true"