Skip to content

Commit dcb87fc

Browse files
author
Sharon Yu
committed
fix comments
1 parent 16c96d9 commit dcb87fc

3 files changed

Lines changed: 33 additions & 53 deletions

File tree

src/MaxText/maxtext_utils.py

Lines changed: 11 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1244,48 +1244,23 @@ def print_shardings_params(params, params_sharding, mesh, logical_annotations=No
12441244
params = {"params": params}
12451245
if not hasattr(params_sharding, "params"):
12461246
params_sharding = {"params": params_sharding}
1247+
if logical_annotations and not hasattr(logical_annotations, "params"):
1248+
logical_annotations = {"params": logical_annotations}
12471249

12481250
leaves_params, _ = jax.tree_util.tree_flatten_with_path(params)
12491251
leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(params_sharding)
1252+
leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations.params)
12501253

1251-
leaves_logical = []
1252-
has_logical = False
1253-
if logical_annotations and hasattr(logical_annotations, "params"):
1254-
try:
1255-
leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations.params)
1256-
if len(leaves_params) == len(leaves_logical):
1257-
has_logical = True
1258-
else:
1259-
max_logging.warning("Warning: Logical annotations tree structure mismatch. Skipping logical info.")
1260-
except Exception as e: # pylint: disable=broad-exception-caught
1261-
max_logging.warning(f"Warning: Failed to process logical annotations: {e}. Skipping logical info.")
1262-
1263-
if not has_logical:
1264-
leaves_logical = [(None, None)] * len(leaves_params)
1265-
1266-
if len(leaves_params) != len(leaves_sharding):
1267-
max_logging.warning("Warning: Params and Sharding tree mismatch.")
1268-
return
1269-
1270-
for i, (path, leaf_val) in enumerate(leaves_params):
1271-
_, leaf_sharding = leaves_sharding[i]
1272-
leaf_logical_val = leaves_logical[i][1] if has_logical else None
1254+
for i, ((path, leaf_val), (_, leaf_sharding)) in enumerate(zip(leaves_params, leaves_sharding)):
1255+
path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path)
1256+
shape = jax.typeof(leaf_val)
1257+
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
1258+
pspec_str = str(tuple(pspec))
12731259

1274-
path_str = "/".join(str(p.key if hasattr(p, "key") else getattr(p, "name", "?")) for p in path)
1275-
1276-
shape = str(jax.typeof(leaf_val))
1277-
1278-
pspec_str = "N/A"
1279-
if hasattr(leaf_sharding, "spec"):
1280-
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
1281-
pspec_str = str(tuple(pspec))
1282-
elif leaf_sharding is not None:
1283-
pspec_str = str(leaf_sharding)
1284-
1285-
if has_logical and leaf_logical_val is not None:
1260+
logical_str = "N/A"
1261+
if leaves_logical:
1262+
_, leaf_logical_val = leaves_logical[i]
12861263
logical_str = str(leaf_logical_val)
1287-
else:
1288-
logical_str = "N/A"
12891264

12901265
message = f" {path_str}\n" f" Shape: {shape}\n" f" Logical: {logical_str}\n" f" Physical: {pspec_str}"
12911266
max_logging.info(message)

src/MaxText/model_creation_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,8 @@ def create_sharded_state():
160160
maxtext_utils.print_shardings_params(
161161
params=sharded_state,
162162
params_sharding=out_shardings,
163-
logical_annotations=specs,
164163
mesh=model.mesh,
164+
logical_annotations=specs,
165165
)
166166
if config.load_parameters_path:
167167
try:

src/MaxText/sharding.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,35 +31,28 @@
3131

3232

3333
_LOGGED_ACTIVATION_SHARDINGS = set()
34+
_LOGGED_LOGICAL_AXES = set()
3435

3536

3637
def get_input_data_sharding(config, mesh):
3738
"""Get the input data sharding for the model"""
3839
return create_sharding(mesh, config.input_data_sharding_logical_axes, rules=config.logical_axis_rules)
3940

4041

41-
def maybe_shard_with_name(
42-
inputs, named_sharding, shard_mode, debug_sharding=False, extra_stack_level=0, logical_axes=None
43-
):
42+
def maybe_shard_with_name(inputs, named_sharding, shard_mode, debug_sharding=False, extra_stack_level=0):
4443
"""
4544
In auto shardmode, this function hints inputs follow given named_sharding.
4645
In explicit shardmode, this function enforces inputs following named_sharding.
4746
"""
4847
if inputs is None:
4948
return None
50-
if debug_sharding and isinstance(inputs, Tracer) and isinstance(named_sharding, NamedSharding):
49+
if (
50+
debug_sharding and isinstance(inputs, Tracer) and isinstance(named_sharding, NamedSharding)
51+
): # only print pspec for JitTracer
5152
pspec = remove_size_one_mesh_axis(getattr(named_sharding, "spec"), getattr(named_sharding, "mesh"))
52-
if logical_axes is not None:
53-
logical_str = str(logical_axes)
54-
else:
55-
logical_str = "None"
56-
shape_str = str(jax.typeof(inputs))
57-
log_key = (shape_str, tuple(pspec), extra_stack_level, logical_str)
58-
53+
log_key = (str(jax.typeof(inputs)), tuple(pspec), extra_stack_level)
5954
if log_key not in _LOGGED_ACTIVATION_SHARDINGS:
60-
max_logging.info(
61-
f"Activation: {logical_str:<40} -> {str(tuple(pspec)):<30} {shape_str}", stacklevel=3 + extra_stack_level
62-
)
55+
max_logging.info(f"{log_key[0]:.<80} {log_key[1]}.", stacklevel=3 + extra_stack_level)
6356
_LOGGED_ACTIVATION_SHARDINGS.add(log_key)
6457
if shard_mode == ShardMode.EXPLICIT:
6558
return reshard(inputs, named_sharding)
@@ -75,14 +68,26 @@ def maybe_shard_with_logical(
7568
"""
7669
if inputs is None:
7770
return None
71+
7872
named_sharding = create_sharding(mesh, logical_axes, rules=rules)
73+
74+
if debug_sharding and isinstance(inputs, Tracer):
75+
log_key = (str(jax.typeof(inputs)), logical_axes, extra_stack_level)
76+
77+
if log_key not in _LOGGED_LOGICAL_AXES:
78+
pspec = remove_size_one_mesh_axis(getattr(named_sharding, "spec"), getattr(named_sharding, "mesh"))
79+
pspec_str = str(tuple(pspec)) if pspec else "None"
80+
81+
max_logging.info(f"Logical: {log_key[0]:.<60} {log_key[1]}", stacklevel=3 + extra_stack_level)
82+
max_logging.info(f"{log_key[0]:.<80} {pspec_str}.", stacklevel=3 + extra_stack_level)
83+
_LOGGED_LOGICAL_AXES.add(log_key)
84+
7985
return maybe_shard_with_name(
8086
inputs,
8187
named_sharding,
8288
shard_mode,
83-
debug_sharding=debug_sharding,
89+
debug_sharding=False,
8490
extra_stack_level=extra_stack_level + 1,
85-
logical_axes=logical_axes,
8691
)
8792

8893

0 commit comments

Comments
 (0)