|
6 | 6 | logger = logging.get_logger(__name__) |
7 | 7 |
|
8 | 8 |
|
9 | | -# 无损 |
10 | | -FLASH_ATTN_4_AVAILABLE = importlib.util.find_spec("flash_attn.cute.interface") is not None |
11 | | -if FLASH_ATTN_4_AVAILABLE: |
12 | | - logger.info("Flash attention 4 is available") |
13 | | -else: |
14 | | - logger.info("Flash attention 4 is not available") |
| 9 | +def check_module_available(module_path: str, module_name: str = None) -> bool: |
| 10 | + try: |
| 11 | + available = importlib.util.find_spec(module_path) is not None |
| 12 | + except (ModuleNotFoundError, AttributeError, ValueError): |
| 13 | + available = False |
15 | 14 |
|
16 | | -FLASH_ATTN_3_AVAILABLE = importlib.util.find_spec("flash_attn_interface") is not None |
17 | | -if FLASH_ATTN_3_AVAILABLE: |
18 | | - logger.info("Flash attention 3 is available") |
19 | | -else: |
20 | | - logger.info("Flash attention 3 is not available") |
| 15 | + if module_name: |
| 16 | + if available: |
| 17 | + logger.info(f"{module_name} is available") |
| 18 | + else: |
| 19 | + logger.info(f"{module_name} is not available") |
21 | 20 |
|
22 | | -FLASH_ATTN_2_AVAILABLE = importlib.util.find_spec("flash_attn") is not None |
23 | | -if FLASH_ATTN_2_AVAILABLE: |
24 | | - logger.info("Flash attention 2 is available") |
25 | | -else: |
26 | | - logger.info("Flash attention 2 is not available") |
| 21 | + return available |
27 | 22 |
|
28 | | -XFORMERS_AVAILABLE = importlib.util.find_spec("xformers") is not None |
29 | | -if XFORMERS_AVAILABLE: |
30 | | - logger.info("XFormers is available") |
31 | | -else: |
32 | | - logger.info("XFormers is not available") |
| 23 | + |
| 24 | +# 无损 |
| 25 | +FLASH_ATTN_4_AVAILABLE = check_module_available("flash_attn.cute.interface", "Flash attention 4") |
| 26 | +FLASH_ATTN_3_AVAILABLE = check_module_available("flash_attn_interface", "Flash attention 3") |
| 27 | +FLASH_ATTN_2_AVAILABLE = check_module_available("flash_attn", "Flash attention 2") |
| 28 | +XFORMERS_AVAILABLE = check_module_available("xformers", "XFormers") |
| 29 | +AITER_AVAILABLE = check_module_available("aiter", "Aiter") |
33 | 30 |
|
34 | 31 | SDPA_AVAILABLE = hasattr(torch.nn.functional, "scaled_dot_product_attention") |
35 | 32 | if SDPA_AVAILABLE: |
36 | 33 | logger.info("Torch SDPA is available") |
37 | 34 | else: |
38 | 35 | logger.info("Torch SDPA is not available") |
39 | 36 |
|
40 | | -AITER_AVAILABLE = importlib.util.find_spec("aiter") is not None |
41 | | -if AITER_AVAILABLE: |
42 | | - logger.info("Aiter is available") |
43 | | -else: |
44 | | - logger.info("Aiter is not available") |
45 | 37 |
|
46 | 38 | # 有损 |
47 | | -SAGE_ATTN_AVAILABLE = importlib.util.find_spec("sageattention") is not None |
48 | | -if SAGE_ATTN_AVAILABLE: |
49 | | - logger.info("Sage attention is available") |
50 | | -else: |
51 | | - logger.info("Sage attention is not available") |
52 | | - |
53 | | -SPARGE_ATTN_AVAILABLE = importlib.util.find_spec("spas_sage_attn") is not None |
54 | | -if SPARGE_ATTN_AVAILABLE: |
55 | | - logger.info("Sparge attention is available") |
56 | | -else: |
57 | | - logger.info("Sparge attention is not available") |
| 39 | +SAGE_ATTN_AVAILABLE = check_module_available("sageattention", "Sage attention") |
| 40 | +SPARGE_ATTN_AVAILABLE = check_module_available("spas_sage_attn", "Sparge attention") |
| 41 | +VIDEO_SPARSE_ATTN_AVAILABLE = check_module_available("vsa", "Video sparse attention") |
58 | 42 |
|
59 | | -VIDEO_SPARSE_ATTN_AVAILABLE = importlib.util.find_spec("vsa") is not None |
60 | | -if VIDEO_SPARSE_ATTN_AVAILABLE: |
61 | | - logger.info("Video sparse attention is available") |
62 | | -else: |
63 | | - logger.info("Video sparse attention is not available") |
64 | | - |
65 | | -NUNCHAKU_AVAILABLE = importlib.util.find_spec("nunchaku") is not None |
| 43 | +NUNCHAKU_AVAILABLE = check_module_available("nunchaku", "Nunchaku") |
66 | 44 | NUNCHAKU_IMPORT_ERROR = None |
67 | | -if NUNCHAKU_AVAILABLE: |
68 | | - logger.info("Nunchaku is available") |
69 | | -else: |
70 | | - logger.info("Nunchaku is not available") |
| 45 | +if not NUNCHAKU_AVAILABLE: |
71 | 46 | import sys |
72 | 47 | torch_version = getattr(torch, "__version__", "unknown") |
73 | 48 | python_version = f"{sys.version_info.major}.{sys.version_info.minor}" |
|
0 commit comments