Skip to content

Commit a4d691f

Browse files
committed
refactoring fsdp2 tests
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 96c123e commit a4d691f

2 files changed

Lines changed: 69 additions & 86 deletions

File tree

tests/pytorch/distributed/run_fsdp2_model.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,8 @@
99
import argparse
1010

1111
import transformer_engine.pytorch as te
12-
from transformer_engine.common.recipe import (
13-
Format,
14-
DelayedScaling,
15-
Float8CurrentScaling,
16-
MXFP8BlockScaling,
17-
)
12+
import transformer_engine.common.recipe
13+
from transformer_engine.common.recipe import Format
1814

1915
import torch
2016
import torch.distributed as dist
@@ -43,7 +39,10 @@ def _parse_args(argv=None, namespace=None):
4339
parser.add_argument("--seq-length", type=int, default=128, help="Sequence length of input")
4440
parser.add_argument("--params-dtype", type=str, default="float32", help="Parameter dtype.")
4541
parser.add_argument(
46-
"--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8."
42+
"--fp8-init",
43+
action="store_true",
44+
default=False,
45+
help="Initialize primary weights in FP8.",
4746
)
4847
parser.add_argument(
4948
"--recipe",
@@ -111,14 +110,7 @@ def get_te_layer_from_string(layer_name):
111110

112111

113112
def get_recipe_from_string(recipe, fp8_format=Format.HYBRID):
114-
if recipe == "delayed_scaling":
115-
return DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
116-
elif recipe == "current_scaling":
117-
return Float8CurrentScaling(fp8_format=fp8_format)
118-
elif recipe == "mx_fp8_block_scaling":
119-
return MXFP8BlockScaling(fp8_format=fp8_format)
120-
else:
121-
raise ValueError(f"Unknown quantizer type: {recipe}")
113+
return getattr(transformer_engine.common.recipe, recipe)(fp8_format=fp8_format)
122114

123115

124116
def init_te_model(config):
@@ -292,13 +284,13 @@ def _train(args):
292284
build_model_context_args["enabled"] = True
293285
build_model_context_args["recipe"] = fp8_recipe
294286

295-
dist_print(f"Memory before model init: {torch.cuda.memory_allocated(device)/1e6} MB")
287+
dist_print(f"Memory before model init: {torch.cuda.memory_allocated(device) / 1e6} MB")
296288
# Create the model on the meta/cuda device as per args
297289
with build_model_context(**build_model_context_args):
298290
model, inp_shape, out_shape = init_te_model(args)
299291
dist_print(
300292
f"Memory after model init on device {args.device}:"
301-
f" {torch.cuda.memory_allocated(device)/1e6} MB"
293+
f" {torch.cuda.memory_allocated(device) / 1e6} MB"
302294
)
303295

304296
# Creating a DeviceMesh for fully_shard
@@ -319,7 +311,7 @@ def _train(args):
319311
dist_print(f" Sharded parameters materialized and initialized on cuda device.")
320312

321313
dist_print(
322-
f"FSDP2 model in cuda, memory allocated: {torch.cuda.memory_allocated(device)/1e6} MB"
314+
f"FSDP2 model in cuda, memory allocated: {torch.cuda.memory_allocated(device) / 1e6} MB"
323315
)
324316

325317
optimizer = optim.Adam(model.parameters(), lr=1e-3)

tests/pytorch/distributed/test_torch_fsdp2.py

Lines changed: 59 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,47 @@
33
# See LICENSE for license information.
44

55
import os
6-
import pytest
76
import subprocess
87
from pathlib import Path
9-
import transformer_engine.pytorch as te
108

9+
import pytest
1110
import torch
1211

12+
import transformer_engine.pytorch as te
13+
from transformer_engine.pytorch import fp8
1314

14-
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
15-
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
1615
NUM_PROCS: int = torch.cuda.device_count()
1716

17+
# Each entry: (recipe_class_name, hydra_overrides, check_fn)
18+
_FP8_RECIPE_CONFIGS = [
19+
("DelayedScaling", fp8.check_fp8_support),
20+
("Float8CurrentScaling", fp8.check_fp8_support),
21+
("Float8BlockScaling", fp8.check_fp8_block_scaling_support),
22+
("MXFP8BlockScaling", fp8.check_mxfp8_support),
23+
["NVFP4BlockScaling", fp8.check_nvfp4_support],
24+
]
25+
26+
27+
def _parametrize_fp8_recipes():
28+
"""Generate pytest.param objects with xfail marks for unsupported FP8 recipes."""
29+
params = []
30+
for name, check_fn in _FP8_RECIPE_CONFIGS:
31+
supported, reason = check_fn()
32+
params.append(
33+
pytest.param(
34+
name,
35+
id=name,
36+
marks=pytest.mark.xfail(condition=not supported, reason=reason),
37+
)
38+
)
39+
return params
40+
41+
42+
@pytest.fixture(params=_parametrize_fp8_recipes())
43+
def fp_recipe(request):
44+
"""Parametrized fixture providing FP8 recipe Hydra overrides for each supported TE recipe."""
45+
return request.param
46+
1847

1948
def _run_test(fp_init, sharding_dims, recipe, layer_type):
2049
test_path = Path(__file__).parent.resolve() / "run_fsdp2_model.py"
@@ -32,28 +61,17 @@ def _run_test(fp_init, sharding_dims, recipe, layer_type):
3261
test_cmd += ["--recipe", recipe]
3362
test_cmd += ["--layer-type", layer_type]
3463

35-
result = subprocess.run(test_cmd, env=os.environ, check=True)
64+
subprocess.run(test_cmd, env=os.environ, check=True)
3665

3766

3867
@pytest.mark.skipif(NUM_PROCS < 4, reason="Requires 4+ GPUs")
3968
@pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs")
4069
@pytest.mark.skipif(not te.torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+")
4170
@pytest.mark.parametrize("sharding_dims", ([NUM_PROCS], [2, NUM_PROCS // 2]))
4271
@pytest.mark.parametrize("fp8_init", (False, True))
43-
@pytest.mark.parametrize("recipe", ("delayed_scaling", "current_scaling", "mx_fp8_block_scaling"))
4472
@pytest.mark.parametrize("layer_type", ("LayerNormLinear", "TransformerLayer"))
45-
def test_distributed(fp8_init, sharding_dims, recipe, layer_type):
46-
47-
# Skip invalid configurations
48-
if torch.cuda.device_count() < 4:
49-
pytest.skip("FSDP2 test requires at least 4 GPUs")
50-
51-
if recipe == "mx_fp8_block_scaling" and not mxfp8_available:
52-
pytest.skip(reason_for_no_mxfp8)
53-
elif not fp8_available:
54-
pytest.skip(reason_for_no_fp8)
55-
56-
_run_test(fp8_init, sharding_dims, recipe, layer_type)
73+
def test_distributed(fp8_init, sharding_dims, fp_recipe, layer_type):
74+
_run_test(fp8_init, sharding_dims, fp_recipe, layer_type)
5775

5876

5977
## ── FusedAdam + FSDP2 tests ─────────────────────────────────────────
@@ -77,80 +95,48 @@ def _run_fused_adam_test(test_name, recipe="delayed_scaling"):
7795

7896

7997
@pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs")
80-
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
81-
@pytest.mark.parametrize("recipe", ("delayed_scaling", "current_scaling", "mx_fp8_block_scaling"))
82-
def test_fsdp2_fused_adam_fp8_master_weights(recipe):
98+
def test_fsdp2_fused_adam_fp8_master_weights(fp_recipe):
8399
"""FusedAdam(master_weights=True) + FSDP2 + quantized_model_init."""
84-
if recipe == "mx_fp8_block_scaling" and not mxfp8_available:
85-
pytest.skip(reason_for_no_mxfp8)
86-
_run_fused_adam_test("fused_adam_fp8_master_weights", recipe)
100+
_run_fused_adam_test("fused_adam_fp8_master_weights", fp_recipe)
87101

88102

89103
@pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs")
90-
@pytest.mark.parametrize("recipe", ("delayed_scaling", "current_scaling", "mx_fp8_block_scaling"))
91-
def test_fsdp2_fused_adam_bf16(recipe):
104+
def test_fsdp2_fused_adam_bf16(fp_recipe):
92105
"""FusedAdam(master_weights=True) + FSDP2 + bf16 params (no FP8)."""
93-
if recipe == "mx_fp8_block_scaling" and not mxfp8_available:
94-
pytest.skip(reason_for_no_mxfp8)
95-
elif not fp8_available:
96-
pytest.skip(reason_for_no_fp8)
97-
_run_fused_adam_test("fused_adam_bf16", recipe)
106+
_run_fused_adam_test("fused_adam_bf16", fp_recipe)
98107

99108

100109
@pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs")
101-
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
102-
@pytest.mark.parametrize("recipe", ("delayed_scaling", "current_scaling", "mx_fp8_block_scaling"))
103-
def test_fsdp2_fused_adam_fp8_no_master(recipe):
110+
def test_fsdp2_fused_adam_fp8_no_master(fp_recipe):
104111
"""FusedAdam(master_weights=False) + FSDP2 + FP8 params."""
105-
if recipe == "mx_fp8_block_scaling" and not mxfp8_available:
106-
pytest.skip(reason_for_no_mxfp8)
107-
_run_fused_adam_test("fused_adam_fp8_no_master", recipe)
112+
_run_fused_adam_test("fused_adam_fp8_no_master", fp_recipe)
108113

109114

110115
@pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs")
111-
@pytest.mark.parametrize("recipe", ("delayed_scaling", "current_scaling", "mx_fp8_block_scaling"))
112-
def test_fsdp2_fused_adam_bf16_store_param_remainders(recipe):
116+
def test_fsdp2_fused_adam_bf16_store_param_remainders(fp_recipe):
113117
"""FusedAdam(master_weights=True, store_param_remainders=True) + FSDP2 + bf16."""
114-
if recipe == "mx_fp8_block_scaling" and not mxfp8_available:
115-
pytest.skip(reason_for_no_mxfp8)
116-
elif not fp8_available:
117-
pytest.skip(reason_for_no_fp8)
118-
_run_fused_adam_test("fused_adam_bf16_store_param_remainders", recipe)
118+
_run_fused_adam_test("fused_adam_bf16_store_param_remainders", fp_recipe)
119119

120120

121121
@pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs")
122-
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
123-
@pytest.mark.parametrize("recipe", ("delayed_scaling", "current_scaling", "mx_fp8_block_scaling"))
124-
def test_fsdp2_dcp_save_load(recipe):
122+
def test_fsdp2_dcp_save_load(fp_recipe):
125123
"""Distributed checkpoint save/load with FSDP2 + FP8 + FusedAdam."""
126-
if recipe == "mx_fp8_block_scaling" and not mxfp8_available:
127-
pytest.skip(reason_for_no_mxfp8)
128-
_run_fused_adam_test("dcp_save_load", recipe)
124+
_run_fused_adam_test("dcp_save_load", fp_recipe)
129125

130126

131127
@pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs")
132-
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
133-
@pytest.mark.parametrize("recipe", ("delayed_scaling", "current_scaling", "mx_fp8_block_scaling"))
134-
def test_fsdp2_dcp_output_parity(recipe):
128+
def test_fsdp2_dcp_output_parity(fp_recipe):
135129
"""DCP save/load round-trip into a fresh model produces identical outputs."""
136-
if recipe == "mx_fp8_block_scaling" and not mxfp8_available:
137-
pytest.skip(reason_for_no_mxfp8)
138-
_run_fused_adam_test("dcp_output_parity", recipe)
130+
_run_fused_adam_test("dcp_output_parity", fp_recipe)
139131

140132

141133
@pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs")
142-
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
143-
@pytest.mark.parametrize("recipe", ("delayed_scaling", "current_scaling", "mx_fp8_block_scaling"))
144-
def test_fsdp2_safetensors_fp32_export(recipe):
134+
def test_fsdp2_safetensors_fp32_export(fp_recipe):
145135
"""Export FP32 model from optimizer master weights to safetensors."""
146-
if recipe == "mx_fp8_block_scaling" and not mxfp8_available:
147-
pytest.skip(reason_for_no_mxfp8)
148-
_run_fused_adam_test("safetensors_fp32_export", recipe)
136+
_run_fused_adam_test("safetensors_fp32_export", fp_recipe)
149137

150138

151139
@pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs")
152-
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
153-
@pytest.mark.parametrize("recipe", ("delayed_scaling", "current_scaling", "mx_fp8_block_scaling"))
154140
@pytest.mark.xfail(
155141
reason=(
156142
"fuse_wgrad_accumulation is incompatible with vanilla FSDP2: "
@@ -161,11 +147,9 @@ def test_fsdp2_safetensors_fp32_export(recipe):
161147
raises=subprocess.CalledProcessError,
162148
strict=True,
163149
)
164-
def test_fsdp2_fuse_wgrad_accumulation(recipe):
150+
def test_fsdp2_fuse_wgrad_accumulation(fp_recipe):
165151
"""fuse_wgrad_accumulation=True + FSDP2 -- expected to fail."""
166-
if recipe == "mx_fp8_block_scaling" and not mxfp8_available:
167-
pytest.skip(reason_for_no_mxfp8)
168-
_run_fused_adam_test("fuse_wgrad_accumulation", recipe)
152+
_run_fused_adam_test("fuse_wgrad_accumulation", fp_recipe)
169153

170154

171155
def test_dummy() -> None:
@@ -175,3 +159,10 @@ def test_dummy() -> None:
175159
176160
"""
177161
pass
162+
163+
164+
"""
165+
TODO:
166+
- async DCP tests
167+
168+
"""

0 commit comments

Comments
 (0)