33# See LICENSE for license information.
44
55import os
6- import pytest
76import subprocess
87from pathlib import Path
9- import transformer_engine .pytorch as te
108
9+ import pytest
1110import torch
1211
12+ import transformer_engine .pytorch as te
13+ from transformer_engine .pytorch import fp8
1314
14- fp8_available , reason_for_no_fp8 = te .is_fp8_available (return_reason = True )
15- mxfp8_available , reason_for_no_mxfp8 = te .is_mxfp8_available (return_reason = True )
1615NUM_PROCS : int = torch .cuda .device_count ()
1716
17+ # Each entry: (recipe_class_name, hydra_overrides, check_fn)
18+ _FP8_RECIPE_CONFIGS = [
19+ ("DelayedScaling" , fp8 .check_fp8_support ),
20+ ("Float8CurrentScaling" , fp8 .check_fp8_support ),
21+ ("Float8BlockScaling" , fp8 .check_fp8_block_scaling_support ),
22+ ("MXFP8BlockScaling" , fp8 .check_mxfp8_support ),
23+ ["NVFP4BlockScaling" , fp8 .check_nvfp4_support ],
24+ ]
25+
26+
27+ def _parametrize_fp8_recipes ():
28+ """Generate pytest.param objects with xfail marks for unsupported FP8 recipes."""
29+ params = []
30+ for name , check_fn in _FP8_RECIPE_CONFIGS :
31+ supported , reason = check_fn ()
32+ params .append (
33+ pytest .param (
34+ name ,
35+ id = name ,
36+ marks = pytest .mark .xfail (condition = not supported , reason = reason ),
37+ )
38+ )
39+ return params
40+
41+
42+ @pytest .fixture (params = _parametrize_fp8_recipes ())
43+ def fp_recipe (request ):
44+ """Parametrized fixture providing FP8 recipe Hydra overrides for each supported TE recipe."""
45+ return request .param
46+
1847
1948def _run_test (fp_init , sharding_dims , recipe , layer_type ):
2049 test_path = Path (__file__ ).parent .resolve () / "run_fsdp2_model.py"
@@ -32,28 +61,17 @@ def _run_test(fp_init, sharding_dims, recipe, layer_type):
3261 test_cmd += ["--recipe" , recipe ]
3362 test_cmd += ["--layer-type" , layer_type ]
3463
35- result = subprocess .run (test_cmd , env = os .environ , check = True )
64+ subprocess .run (test_cmd , env = os .environ , check = True )
3665
3766
3867@pytest .mark .skipif (NUM_PROCS < 4 , reason = "Requires 4+ GPUs" )
3968@pytest .mark .skipif (NUM_PROCS % 2 != 0 , reason = "Requires even number of GPUs" )
4069@pytest .mark .skipif (not te .torch_version () >= (2 , 4 , 0 ), reason = "Requires PyTorch 2.4.0+" )
4170@pytest .mark .parametrize ("sharding_dims" , ([NUM_PROCS ], [2 , NUM_PROCS // 2 ]))
4271@pytest .mark .parametrize ("fp8_init" , (False , True ))
43- @pytest .mark .parametrize ("recipe" , ("delayed_scaling" , "current_scaling" , "mx_fp8_block_scaling" ))
4472@pytest .mark .parametrize ("layer_type" , ("LayerNormLinear" , "TransformerLayer" ))
45- def test_distributed (fp8_init , sharding_dims , recipe , layer_type ):
46-
47- # Skip invalid configurations
48- if torch .cuda .device_count () < 4 :
49- pytest .skip ("FSDP2 test requires at least 4 GPUs" )
50-
51- if recipe == "mx_fp8_block_scaling" and not mxfp8_available :
52- pytest .skip (reason_for_no_mxfp8 )
53- elif not fp8_available :
54- pytest .skip (reason_for_no_fp8 )
55-
56- _run_test (fp8_init , sharding_dims , recipe , layer_type )
73+ def test_distributed (fp8_init , sharding_dims , fp_recipe , layer_type ):
74+ _run_test (fp8_init , sharding_dims , fp_recipe , layer_type )
5775
5876
5977## ── FusedAdam + FSDP2 tests ─────────────────────────────────────────
@@ -77,80 +95,48 @@ def _run_fused_adam_test(test_name, recipe="delayed_scaling"):
7795
7896
7997@pytest .mark .skipif (NUM_PROCS < 2 , reason = "Requires 2+ GPUs" )
80- @pytest .mark .skipif (not fp8_available , reason = reason_for_no_fp8 )
81- @pytest .mark .parametrize ("recipe" , ("delayed_scaling" , "current_scaling" , "mx_fp8_block_scaling" ))
82- def test_fsdp2_fused_adam_fp8_master_weights (recipe ):
98+ def test_fsdp2_fused_adam_fp8_master_weights (fp_recipe ):
8399 """FusedAdam(master_weights=True) + FSDP2 + quantized_model_init."""
84- if recipe == "mx_fp8_block_scaling" and not mxfp8_available :
85- pytest .skip (reason_for_no_mxfp8 )
86- _run_fused_adam_test ("fused_adam_fp8_master_weights" , recipe )
100+ _run_fused_adam_test ("fused_adam_fp8_master_weights" , fp_recipe )
87101
88102
89103@pytest .mark .skipif (NUM_PROCS < 2 , reason = "Requires 2+ GPUs" )
90- @pytest .mark .parametrize ("recipe" , ("delayed_scaling" , "current_scaling" , "mx_fp8_block_scaling" ))
91- def test_fsdp2_fused_adam_bf16 (recipe ):
104+ def test_fsdp2_fused_adam_bf16 (fp_recipe ):
92105 """FusedAdam(master_weights=True) + FSDP2 + bf16 params (no FP8)."""
93- if recipe == "mx_fp8_block_scaling" and not mxfp8_available :
94- pytest .skip (reason_for_no_mxfp8 )
95- elif not fp8_available :
96- pytest .skip (reason_for_no_fp8 )
97- _run_fused_adam_test ("fused_adam_bf16" , recipe )
106+ _run_fused_adam_test ("fused_adam_bf16" , fp_recipe )
98107
99108
100109@pytest .mark .skipif (NUM_PROCS < 2 , reason = "Requires 2+ GPUs" )
101- @pytest .mark .skipif (not fp8_available , reason = reason_for_no_fp8 )
102- @pytest .mark .parametrize ("recipe" , ("delayed_scaling" , "current_scaling" , "mx_fp8_block_scaling" ))
103- def test_fsdp2_fused_adam_fp8_no_master (recipe ):
110+ def test_fsdp2_fused_adam_fp8_no_master (fp_recipe ):
104111 """FusedAdam(master_weights=False) + FSDP2 + FP8 params."""
105- if recipe == "mx_fp8_block_scaling" and not mxfp8_available :
106- pytest .skip (reason_for_no_mxfp8 )
107- _run_fused_adam_test ("fused_adam_fp8_no_master" , recipe )
112+ _run_fused_adam_test ("fused_adam_fp8_no_master" , fp_recipe )
108113
109114
110115@pytest .mark .skipif (NUM_PROCS < 2 , reason = "Requires 2+ GPUs" )
111- @pytest .mark .parametrize ("recipe" , ("delayed_scaling" , "current_scaling" , "mx_fp8_block_scaling" ))
112- def test_fsdp2_fused_adam_bf16_store_param_remainders (recipe ):
116+ def test_fsdp2_fused_adam_bf16_store_param_remainders (fp_recipe ):
113117 """FusedAdam(master_weights=True, store_param_remainders=True) + FSDP2 + bf16."""
114- if recipe == "mx_fp8_block_scaling" and not mxfp8_available :
115- pytest .skip (reason_for_no_mxfp8 )
116- elif not fp8_available :
117- pytest .skip (reason_for_no_fp8 )
118- _run_fused_adam_test ("fused_adam_bf16_store_param_remainders" , recipe )
118+ _run_fused_adam_test ("fused_adam_bf16_store_param_remainders" , fp_recipe )
119119
120120
121121@pytest .mark .skipif (NUM_PROCS < 2 , reason = "Requires 2+ GPUs" )
122- @pytest .mark .skipif (not fp8_available , reason = reason_for_no_fp8 )
123- @pytest .mark .parametrize ("recipe" , ("delayed_scaling" , "current_scaling" , "mx_fp8_block_scaling" ))
124- def test_fsdp2_dcp_save_load (recipe ):
122+ def test_fsdp2_dcp_save_load (fp_recipe ):
125123 """Distributed checkpoint save/load with FSDP2 + FP8 + FusedAdam."""
126- if recipe == "mx_fp8_block_scaling" and not mxfp8_available :
127- pytest .skip (reason_for_no_mxfp8 )
128- _run_fused_adam_test ("dcp_save_load" , recipe )
124+ _run_fused_adam_test ("dcp_save_load" , fp_recipe )
129125
130126
131127@pytest .mark .skipif (NUM_PROCS < 2 , reason = "Requires 2+ GPUs" )
132- @pytest .mark .skipif (not fp8_available , reason = reason_for_no_fp8 )
133- @pytest .mark .parametrize ("recipe" , ("delayed_scaling" , "current_scaling" , "mx_fp8_block_scaling" ))
134- def test_fsdp2_dcp_output_parity (recipe ):
128+ def test_fsdp2_dcp_output_parity (fp_recipe ):
135129 """DCP save/load round-trip into a fresh model produces identical outputs."""
136- if recipe == "mx_fp8_block_scaling" and not mxfp8_available :
137- pytest .skip (reason_for_no_mxfp8 )
138- _run_fused_adam_test ("dcp_output_parity" , recipe )
130+ _run_fused_adam_test ("dcp_output_parity" , fp_recipe )
139131
140132
141133@pytest .mark .skipif (NUM_PROCS < 2 , reason = "Requires 2+ GPUs" )
142- @pytest .mark .skipif (not fp8_available , reason = reason_for_no_fp8 )
143- @pytest .mark .parametrize ("recipe" , ("delayed_scaling" , "current_scaling" , "mx_fp8_block_scaling" ))
144- def test_fsdp2_safetensors_fp32_export (recipe ):
134+ def test_fsdp2_safetensors_fp32_export (fp_recipe ):
145135 """Export FP32 model from optimizer master weights to safetensors."""
146- if recipe == "mx_fp8_block_scaling" and not mxfp8_available :
147- pytest .skip (reason_for_no_mxfp8 )
148- _run_fused_adam_test ("safetensors_fp32_export" , recipe )
136+ _run_fused_adam_test ("safetensors_fp32_export" , fp_recipe )
149137
150138
151139@pytest .mark .skipif (NUM_PROCS < 2 , reason = "Requires 2+ GPUs" )
152- @pytest .mark .skipif (not fp8_available , reason = reason_for_no_fp8 )
153- @pytest .mark .parametrize ("recipe" , ("delayed_scaling" , "current_scaling" , "mx_fp8_block_scaling" ))
154140@pytest .mark .xfail (
155141 reason = (
156142 "fuse_wgrad_accumulation is incompatible with vanilla FSDP2: "
@@ -161,11 +147,9 @@ def test_fsdp2_safetensors_fp32_export(recipe):
161147 raises = subprocess .CalledProcessError ,
162148 strict = True ,
163149)
164- def test_fsdp2_fuse_wgrad_accumulation (recipe ):
150+ def test_fsdp2_fuse_wgrad_accumulation (fp_recipe ):
165151 """fuse_wgrad_accumulation=True + FSDP2 -- expected to fail."""
166- if recipe == "mx_fp8_block_scaling" and not mxfp8_available :
167- pytest .skip (reason_for_no_mxfp8 )
168- _run_fused_adam_test ("fuse_wgrad_accumulation" , recipe )
152+ _run_fused_adam_test ("fuse_wgrad_accumulation" , fp_recipe )
169153
170154
171155def test_dummy () -> None :
@@ -175,3 +159,10 @@ def test_dummy() -> None:
175159
176160 """
177161 pass
162+
163+
164+ """
165+ TODO:
166+ - async DCP tests
167+
168+ """
0 commit comments