3131
3232
3333_LOGGED_ACTIVATION_SHARDINGS = set ()
34+ _LOGGED_LOGICAL_AXES = set ()
3435
3536
3637def get_input_data_sharding (config , mesh ):
3738 """Get the input data sharding for the model"""
3839 return create_sharding (mesh , config .input_data_sharding_logical_axes , rules = config .logical_axis_rules )
3940
4041
41- def maybe_shard_with_name (
42- inputs , named_sharding , shard_mode , debug_sharding = False , extra_stack_level = 0 , logical_axes = None
43- ):
42+ def maybe_shard_with_name (inputs , named_sharding , shard_mode , debug_sharding = False , extra_stack_level = 0 ):
4443 """
4544 In auto shardmode, this function hints inputs follow given named_sharding.
4645 In explicit shardmode, this function enforces inputs following named_sharding.
4746 """
4847 if inputs is None :
4948 return None
50- if debug_sharding and isinstance (inputs , Tracer ) and isinstance (named_sharding , NamedSharding ):
49+ if (
50+ debug_sharding and isinstance (inputs , Tracer ) and isinstance (named_sharding , NamedSharding )
51+ ): # only print pspec for JitTracer
5152 pspec = remove_size_one_mesh_axis (getattr (named_sharding , "spec" ), getattr (named_sharding , "mesh" ))
52- if logical_axes is not None :
53- logical_str = str (logical_axes )
54- else :
55- logical_str = "None"
56- shape_str = str (jax .typeof (inputs ))
57- log_key = (shape_str , tuple (pspec ), extra_stack_level , logical_str )
58-
53+ log_key = (str (jax .typeof (inputs )), tuple (pspec ), extra_stack_level )
5954 if log_key not in _LOGGED_ACTIVATION_SHARDINGS :
60- max_logging .info (
61- f"Activation: { logical_str :<40} -> { str (tuple (pspec )):<30} { shape_str } " , stacklevel = 3 + extra_stack_level
62- )
55+ max_logging .info (f"{ log_key [0 ]:.<80} { log_key [1 ]} ." , stacklevel = 3 + extra_stack_level )
6356 _LOGGED_ACTIVATION_SHARDINGS .add (log_key )
6457 if shard_mode == ShardMode .EXPLICIT :
6558 return reshard (inputs , named_sharding )
@@ -75,14 +68,26 @@ def maybe_shard_with_logical(
7568 """
7669 if inputs is None :
7770 return None
71+
7872 named_sharding = create_sharding (mesh , logical_axes , rules = rules )
73+
74+ if debug_sharding and isinstance (inputs , Tracer ):
75+ log_key = (str (jax .typeof (inputs )), logical_axes , extra_stack_level )
76+
77+ if log_key not in _LOGGED_LOGICAL_AXES :
78+ pspec = remove_size_one_mesh_axis (getattr (named_sharding , "spec" ), getattr (named_sharding , "mesh" ))
79+ pspec_str = str (tuple (pspec )) if pspec else "None"
80+
81+ max_logging .info (f"Logical: { log_key [0 ]:.<60} { log_key [1 ]} " , stacklevel = 3 + extra_stack_level )
82+ max_logging .info (f"{ log_key [0 ]:.<80} { pspec_str } ." , stacklevel = 3 + extra_stack_level )
83+ _LOGGED_LOGICAL_AXES .add (log_key )
84+
7985 return maybe_shard_with_name (
8086 inputs ,
8187 named_sharding ,
8288 shard_mode ,
83- debug_sharding = debug_sharding ,
89+ debug_sharding = False ,
8490 extra_stack_level = extra_stack_level + 1 ,
85- logical_axes = logical_axes ,
8691 )
8792
8893
0 commit comments