Skip to content

Commit f314655

Browse files
pre-commit-ci[bot]phu0ngng
authored andcommitted
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent e0905bd commit f314655

3 files changed

Lines changed: 36 additions & 12 deletions

File tree

examples/jax/collective_gemm/test_dense_grad.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@ def _get_operand_sharding(mesh, collective_op):
6363
return x_sharding, weight_sharding, bias_sharding
6464

6565

66-
def _mean_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set, quantizer_set):
66+
def _mean_dense(
67+
x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set, quantizer_set
68+
):
6769
output = dense(
6870
x,
6971
weight,
@@ -78,7 +80,9 @@ def _mean_dense(x, weight, bias, input_axes, weight_axes, output_axes, collectiv
7880
return jnp.mean(output.astype(jnp.float32))
7981

8082

81-
def _value_and_grad_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set, quantizer_set):
83+
def _value_and_grad_dense(
84+
x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set, quantizer_set
85+
):
8286
return jax.jit(jax.value_and_grad(_mean_dense, (0, 1, 2)), static_argnums=(3, 4, 5, 6))(
8387
x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set, quantizer_set
8488
)
@@ -205,7 +209,9 @@ def test_te_bf16_reduce_scatter(self):
205209
def test_te_delayed_scaling_fp8_all_gather(self):
206210
"""Test Collective Dense Gradient with FP8 DelayedScaling + AllGather"""
207211
self.args.quantize_recipe = "DelayedScaling"
208-
is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe))
212+
is_supported, reason = is_scaling_mode_supported(
213+
get_scaling_mode_from_recipe_name(self.args.quantize_recipe)
214+
)
209215
if not is_supported:
210216
self.skipTest(reason)
211217
self.args.use_fp8 = True
@@ -215,7 +221,9 @@ def test_te_delayed_scaling_fp8_all_gather(self):
215221
def test_te_delayed_scaling_fp8_reduce_scatter(self):
216222
"""Test Collective Dense Gradient with FP8 DelayedScaling + ReduceScatter"""
217223
self.args.quantize_recipe = "DelayedScaling"
218-
is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe))
224+
is_supported, reason = is_scaling_mode_supported(
225+
get_scaling_mode_from_recipe_name(self.args.quantize_recipe)
226+
)
219227
if not is_supported:
220228
self.skipTest(reason)
221229
self.args.use_fp8 = True
@@ -225,7 +233,9 @@ def test_te_delayed_scaling_fp8_reduce_scatter(self):
225233
def test_te_current_scaling_fp8_all_gather(self):
226234
"""Test Collective Dense Gradient with FP8 Float8CurrentScaling + AllGather"""
227235
self.args.quantize_recipe = "Float8CurrentScaling"
228-
is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe))
236+
is_supported, reason = is_scaling_mode_supported(
237+
get_scaling_mode_from_recipe_name(self.args.quantize_recipe)
238+
)
229239
if not is_supported:
230240
self.skipTest(reason)
231241
self.args.use_fp8 = True
@@ -235,7 +245,9 @@ def test_te_current_scaling_fp8_all_gather(self):
235245
def test_te_current_scaling_fp8_reduce_scatter(self):
236246
"""Test Collective Dense Gradient with FP8 Float8CurrentScaling + ReduceScatter"""
237247
self.args.quantize_recipe = "Float8CurrentScaling"
238-
is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe))
248+
is_supported, reason = is_scaling_mode_supported(
249+
get_scaling_mode_from_recipe_name(self.args.quantize_recipe)
250+
)
239251
if not is_supported:
240252
self.skipTest(reason)
241253
self.args.use_fp8 = True

examples/jax/collective_gemm/test_gemm.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,9 @@ def test_te_bf16_reduce_scatter_with_dp(self):
208208
def test_te_delayed_scaling_fp8_all_gather_with_dp(self):
209209
"""Test Collective GEMM with FP8 DelayedScaling + AllGather"""
210210
self.args.quantize_recipe = "DelayedScaling"
211-
is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe))
211+
is_supported, reason = is_scaling_mode_supported(
212+
get_scaling_mode_from_recipe_name(self.args.quantize_recipe)
213+
)
212214
if not is_supported:
213215
self.skipTest(reason)
214216
self.args.use_fp8 = True
@@ -218,7 +220,9 @@ def test_te_delayed_scaling_fp8_all_gather_with_dp(self):
218220
def test_te_delayed_scaling_fp8_reduce_scatter_with_dp(self):
219221
"""Test Collective GEMM with FP8 DelayedScaling + ReduceScatter"""
220222
self.args.quantize_recipe = "DelayedScaling"
221-
is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe))
223+
is_supported, reason = is_scaling_mode_supported(
224+
get_scaling_mode_from_recipe_name(self.args.quantize_recipe)
225+
)
222226
if not is_supported:
223227
self.skipTest(reason)
224228
self.args.use_fp8 = True
@@ -228,7 +232,9 @@ def test_te_delayed_scaling_fp8_reduce_scatter_with_dp(self):
228232
def test_te_current_scaling_fp8_all_gather_with_dp(self):
229233
"""Test Collective GEMM with FP8 Float8CurrentScaling + AllGather"""
230234
self.args.quantize_recipe = "Float8CurrentScaling"
231-
is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe))
235+
is_supported, reason = is_scaling_mode_supported(
236+
get_scaling_mode_from_recipe_name(self.args.quantize_recipe)
237+
)
232238
if not is_supported:
233239
self.skipTest(reason)
234240
self.args.use_fp8 = True
@@ -238,7 +244,9 @@ def test_te_current_scaling_fp8_all_gather_with_dp(self):
238244
def test_te_current_scaling_fp8_reduce_scatter_with_dp(self):
239245
"""Test Collective GEMM with FP8 Float8CurrentScaling + ReduceScatter"""
240246
self.args.quantize_recipe = "Float8CurrentScaling"
241-
is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe))
247+
is_supported, reason = is_scaling_mode_supported(
248+
get_scaling_mode_from_recipe_name(self.args.quantize_recipe)
249+
)
242250
if not is_supported:
243251
self.skipTest(reason)
244252
self.args.use_fp8 = True

examples/jax/collective_gemm/test_layernorm_mlp_grad.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,9 @@ def test_te_bf16_layernorm_mlp_grad(self):
265265
def test_te_delayed_scaling_fp8_layernorm_mlp_grad(self):
266266
"""Test Collective LayerNorm MLP Gradient with FP8 DelayedScaling"""
267267
self.args.quantize_recipe = "DelayedScaling"
268-
is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe))
268+
is_supported, reason = is_scaling_mode_supported(
269+
get_scaling_mode_from_recipe_name(self.args.quantize_recipe)
270+
)
269271
if not is_supported:
270272
self.skipTest(reason)
271273
self.args.use_fp8 = True
@@ -274,7 +276,9 @@ def test_te_delayed_scaling_fp8_layernorm_mlp_grad(self):
274276
def test_te_current_scaling_fp8_layernorm_mlp_grad(self):
275277
"""Test Collective LayerNorm MLP Gradient with FP8 Float8CurrentScaling"""
276278
self.args.quantize_recipe = "Float8CurrentScaling"
277-
is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe))
279+
is_supported, reason = is_scaling_mode_supported(
280+
get_scaling_mode_from_recipe_name(self.args.quantize_recipe)
281+
)
278282
if not is_supported:
279283
self.skipTest(reason)
280284
self.args.use_fp8 = True

0 commit comments

Comments
 (0)