Skip to content

fix: skip embedding[padding_idx] = 0 with TP#1675

Merged
akoumpa merged 3 commits intomainfrom
akoumparouli/fix_zeroing_padd_idx
Apr 10, 2026
Merged

fix: skip embedding[padding_idx] = 0 with TP#1675
akoumpa merged 3 commits intomainfrom
akoumparouli/fix_zeroing_padd_idx

Conversation

@akoumpa
Copy link
Copy Markdown
Contributor

@akoumpa akoumpa commented Apr 3, 2026

What does this PR do ?

Context:
HF may initialize models and set the padding_idx to zero. When the embedding layer is row-wise sharded this can cause the following error:

  File "/opt/Automodel/nemo_automodel/components/checkpoint/checkpointing.py", line 538, in initialize_model_weights
    model.initialize_weights()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 124, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/venv/lib/python3.12/site-packages/transformers/modeling_utils.py", line 2410, in initialize_weights
    self.smart_apply(self._initialize_weights, self.is_remote_code())
  File "/opt/venv/lib/python3.12/site-packages/transformers/modeling_utils.py", line 2401, in smart_apply
    module.smart_apply(module._initialize_weights, is_remote_code)
  File "/opt/venv/lib/python3.12/site-packages/transformers/modeling_utils.py", line 2403, in smart_apply
    module.smart_apply(fn, is_remote_code)
  File "/opt/venv/lib/python3.12/site-packages/transformers/modeling_utils.py", line 2404, in smart_apply
    fn(self, is_remote_code)
  File "/opt/venv/lib/python3.12/site-packages/transformers/modeling_utils.py", line 2381, in _initialize_weights
    self._init_weights(module)
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 124, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/venv/lib/python3.12/site-packages/transformers/modeling_utils.py", line 2327, in _init_weights
    init.zeros_(module.weight[module.padding_idx])
                ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/tensor/_dispatch.py", line 261, in _dispatch_get_local_results_slow_path
    self.redistribute_local_args(
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/tensor/_dispatch.py", line 486, in redistribute_local_args
    resharded_local_tensor = redistribute_local_tensor(
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/tensor/_redistribute.py", line 864, in redistribute_local_tensor
    transform_infos = _gen_transform_infos(
                      ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/tensor/_redistribute.py", line 826, in _gen_transform_infos
    return _gen_transform_infos_non_cached(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/tensor/_redistribute.py", line 790, in _gen_transform_infos_non_cached
    assert src_shard_order is not None and dst_shard_order is not None
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

the [] operator fails in this case.

Alternatively, we could remap the padding_idx to the local shard and zero as follows, but I'm not sure that's better either.

def _zero_dtensor_embedding_padding_row(embedding: nn.Embedding) -> None:
    """Zero the ``padding_idx`` row of a TP-sharded embedding weight via local tensor ops.

    When the weight is a DTensor sharded along dim 0 (vocab-parallel), only the
    rank whose local shard contains the ``padding_idx`` row performs the zeroing.
    For replicated weights or weights sharded on other dims, every rank zeros
    the row in its local tensor.
    """
    padding_idx = embedding.padding_idx
    weight = embedding.weight
    if padding_idx is None:
        return
    if type(weight).__name__ != "DTensor":
        return

    local = weight._local_tensor
    spec = weight._spec

    for mesh_dim, placement in enumerate(spec.placements):
        if placement.is_shard() and placement.dim == 0:
            mesh = spec.mesh
            tp_size = mesh.size(mesh_dim)
            rank = mesh.get_local_rank(mesh_dim)
            vocab_size = weight.shape[0]

            chunk = vocab_size // tp_size
            rem = vocab_size % tp_size
            if rank < rem:
                local_off = rank * (chunk + 1)
                local_size = chunk + 1
            else:
                local_off = rem * (chunk + 1) + (rank - rem) * chunk
                local_size = chunk

            local_idx = padding_idx - local_off
            if 0 <= local_idx < local_size:
                with torch.no_grad():
                    local[local_idx].zero_()
            return

    with torch.no_grad():
        local[padding_idx].zero_()

Changelog

  • Add specific line by line info of high level changes in this PR.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

  • Related to # (issue)

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 3, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
@akoumpa
Copy link
Copy Markdown
Contributor Author

akoumpa commented Apr 3, 2026

/ok to test de12f90

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
@akoumpa
Copy link
Copy Markdown
Contributor Author

akoumpa commented Apr 3, 2026

/ok to test 5d20089

@akoumpa
Copy link
Copy Markdown
Contributor Author

akoumpa commented Apr 5, 2026

/claude review

Comment thread nemo_automodel/components/checkpoint/checkpointing.py
@akoumpa akoumpa marked this pull request as ready for review April 10, 2026 03:27
@akoumpa akoumpa added the r0.4.0 Auto-cherrypick to release branch. Apply before merge; cherrypick happens after merge. label Apr 10, 2026
@akoumpa akoumpa merged commit 1389350 into main Apr 10, 2026
50 of 52 checks passed
@akoumpa akoumpa deleted the akoumparouli/fix_zeroing_padd_idx branch April 10, 2026 18:29
svcnvidia-nemo-ci pushed a commit that referenced this pull request Apr 10, 2026
* skip embedding[padding_idx] = 0

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* fix

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* remove code

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

---------

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
akoumpa added a commit that referenced this pull request Apr 10, 2026
…0` (#1771)

fix: skip embedding[padding_idx] = 0 with TP (#1675)

* skip embedding[padding_idx] = 0



* fix



* remove code



---------

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
Co-authored-by: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com>
vgauraha62 pushed a commit to vgauraha62/Automodel that referenced this pull request Apr 11, 2026
* skip embedding[padding_idx] = 0

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* fix

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* remove code

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

---------

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
Signed-off-by: vgauraha62 <vaibhavgauraha62.com>
vgauraha62 pushed a commit to vgauraha62/Automodel that referenced this pull request Apr 11, 2026
* skip embedding[padding_idx] = 0

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* fix

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* remove code

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

---------

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
Signed-off-by: vgauraha62 <vaibhavgauraha62.com>
edjson pushed a commit to edjson/Automodel that referenced this pull request Apr 17, 2026
* skip embedding[padding_idx] = 0

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* fix

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* remove code

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

---------

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
edjson pushed a commit to edjson/Automodel that referenced this pull request Apr 18, 2026
* skip embedding[padding_idx] = 0

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* fix

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* remove code

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

---------

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
Signed-off-by: Edison <edisonggacc@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

r0.4.0 Auto-cherrypick to release branch. Apply before merge; cherrypick happens after merge.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants