@@ -63,7 +63,9 @@ def _get_operand_sharding(mesh, collective_op):
6363 return x_sharding , weight_sharding , bias_sharding
6464
6565
66- def _mean_dense (x , weight , bias , input_axes , weight_axes , output_axes , collective_op_set , quantizer_set ):
66+ def _mean_dense (
67+ x , weight , bias , input_axes , weight_axes , output_axes , collective_op_set , quantizer_set
68+ ):
6769 output = dense (
6870 x ,
6971 weight ,
@@ -78,7 +80,9 @@ def _mean_dense(x, weight, bias, input_axes, weight_axes, output_axes, collectiv
7880 return jnp .mean (output .astype (jnp .float32 ))
7981
8082
81- def _value_and_grad_dense (x , weight , bias , input_axes , weight_axes , output_axes , collective_op_set , quantizer_set ):
83+ def _value_and_grad_dense (
84+ x , weight , bias , input_axes , weight_axes , output_axes , collective_op_set , quantizer_set
85+ ):
8286 return jax .jit (jax .value_and_grad (_mean_dense , (0 , 1 , 2 )), static_argnums = (3 , 4 , 5 , 6 ))(
8387 x , weight , bias , input_axes , weight_axes , output_axes , collective_op_set , quantizer_set
8488 )
@@ -205,7 +209,9 @@ def test_te_bf16_reduce_scatter(self):
205209 def test_te_delayed_scaling_fp8_all_gather (self ):
206210 """Test Collective Dense Gradient with FP8 DelayedScaling + AllGather"""
207211 self .args .quantize_recipe = "DelayedScaling"
208- is_supported , reason = is_scaling_mode_supported (get_scaling_mode_from_recipe_name (self .args .quantize_recipe ))
212+ is_supported , reason = is_scaling_mode_supported (
213+ get_scaling_mode_from_recipe_name (self .args .quantize_recipe )
214+ )
209215 if not is_supported :
210216 self .skipTest (reason )
211217 self .args .use_fp8 = True
@@ -215,7 +221,9 @@ def test_te_delayed_scaling_fp8_all_gather(self):
215221 def test_te_delayed_scaling_fp8_reduce_scatter (self ):
216222 """Test Collective Dense Gradient with FP8 DelayedScaling + ReduceScatter"""
217223 self .args .quantize_recipe = "DelayedScaling"
218- is_supported , reason = is_scaling_mode_supported (get_scaling_mode_from_recipe_name (self .args .quantize_recipe ))
224+ is_supported , reason = is_scaling_mode_supported (
225+ get_scaling_mode_from_recipe_name (self .args .quantize_recipe )
226+ )
219227 if not is_supported :
220228 self .skipTest (reason )
221229 self .args .use_fp8 = True
@@ -225,7 +233,9 @@ def test_te_delayed_scaling_fp8_reduce_scatter(self):
225233 def test_te_current_scaling_fp8_all_gather (self ):
226234 """Test Collective Dense Gradient with FP8 Float8CurrentScaling + AllGather"""
227235 self .args .quantize_recipe = "Float8CurrentScaling"
228- is_supported , reason = is_scaling_mode_supported (get_scaling_mode_from_recipe_name (self .args .quantize_recipe ))
236+ is_supported , reason = is_scaling_mode_supported (
237+ get_scaling_mode_from_recipe_name (self .args .quantize_recipe )
238+ )
229239 if not is_supported :
230240 self .skipTest (reason )
231241 self .args .use_fp8 = True
@@ -235,7 +245,9 @@ def test_te_current_scaling_fp8_all_gather(self):
235245 def test_te_current_scaling_fp8_reduce_scatter (self ):
236246 """Test Collective Dense Gradient with FP8 Float8CurrentScaling + ReduceScatter"""
237247 self .args .quantize_recipe = "Float8CurrentScaling"
238- is_supported , reason = is_scaling_mode_supported (get_scaling_mode_from_recipe_name (self .args .quantize_recipe ))
248+ is_supported , reason = is_scaling_mode_supported (
249+ get_scaling_mode_from_recipe_name (self .args .quantize_recipe )
250+ )
239251 if not is_supported :
240252 self .skipTest (reason )
241253 self .args .use_fp8 = True
0 commit comments