Skip to content

[PyTorch] Fix cross_entropy_forward stride guard for non-contiguous input#2746

Merged
yaox12 merged 3 commits intoNVIDIA:mainfrom
Bias92:fix/cross-entropy-forward-stride
Mar 10, 2026
Merged

[PyTorch] Fix cross_entropy_forward stride guard for non-contiguous input#2746
yaox12 merged 3 commits intoNVIDIA:mainfrom
Bias92:fix/cross-entropy-forward-stride

Conversation

@Bias92
Copy link
Contributor

@Bias92 Bias92 commented Mar 9, 2026

The stride guard in cross_entropy_forward only checks stride(-1) != 1,
which misses transposed tensors where stride(-1) == 1 but stride(-2) != shape[-1].
The Triton kernel then uses the wrong row stride and produces silently incorrect results.

Added stride(-2) check, same approach as the backward fix in #2402.

Fixes #2734

Before fix:
Non-contiguous: [2.0794, 2.0794, 2.0794, 2.0794] ← wrong (same values)
Contiguous: [4.1277, 3.7957, 2.1120, 2.5712] ← correct
before error

After fix:
Both return [
4.1277, 3.7957, 2.1120, 2.5712] ✓
after

Signed-off-by: Bias92 <pewpewplay315@gmail.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 9, 2026

Greptile Summary

This PR fixes a silent correctness bug in cross_entropy_forward where non-contiguous tensors with stride(-1) == 1 but stride(-2) != shape[-1] (e.g., a transposed tensor) bypassed the existing contiguity guard and caused the Triton kernel to use the wrong row stride, producing silently incorrect loss values.

Key changes:

  • transformer_engine/pytorch/triton/cross_entropy.py: Extends the stride guard from stride(-1) != 1 to also check stride(-2) != shape[-1], matching the pattern used in PR Make grad_output contiguous in cross_entropy.py #2402's backward fix. When triggered, the input is made contiguous before passing to the Triton kernels, which also implicitly fixes the backward pass.
  • tests/pytorch/test_parallel_cross_entropy.py: Adds regression test test_non_contiguous_transposed_input that creates a transposed (B, SQ, V) tensor, verifies the preconditions (stride(-1)==1, stride(-2) != shape[-1]), and validates that forward loss matches the result from an explicitly contiguous input.

Confidence Score: 5/5

  • This PR is safe to merge — it is a minimal, targeted fix applying a well-established guard pattern already present in the backward pass.
  • The one-line fix is correct, consistent with existing patterns in the codebase (PR Make grad_output contiguous in cross_entropy.py #2402), and the regression test properly exercises the bug path. The fix correctly catches the missed stride pattern and prevents silent correctness bugs. No safety concerns identified.
  • No files require special attention.

Last reviewed commit: c03ceb0


# ensure _input and target are contiguous in the last dimension
if _input.stride(-1) != 1:
if _input.stride(-1) != 1 or _input.stride(-2) != _input.shape[-1]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No regression test for the fixed scenario

The fix correctly catches tensors where stride(-1) == 1 but stride(-2) != shape[-1] (e.g., a tensor of shape (SQ, B, V) transposed to (B, SQ, V) via .transpose(0, 1)). However, no test is added to cover this specific case.

The existing test_swapped_input in tests/pytorch/test_parallel_cross_entropy.py creates a naturally-shaped (SQ, batch, vocab) tensor (contiguous, stride(-2) == V), so it does not exercise the bug path described in the PR.

A minimal regression test would look like:

def test_non_contiguous_transposed_input(self):
    """Regression test for stride(-2) != shape[-1] on non-contiguous input."""
    self.generate_iters(3)
    self.generate_infra(True, 0)
    for _ in range(self.iters):
        batch, SQ, vocab = 2, 64, 1024
        # shape (SQ, batch, vocab) transposed → (batch, SQ, vocab)
        # stride(-1)==1 but stride(-2)==batch*vocab != vocab  ← the old guard missed this
        x = torch.rand((SQ, batch, vocab), dtype=torch.float32, device="cuda").transpose(0, 1)
        assert x.stride(-1) == 1 and x.stride(-2) != x.shape[-1]  # confirm the guard is needed
        self.input_test = x
        ...  # drive through one_iteration_test

Without a test, this class of regression can silently reappear.

@Bias92
Copy link
Contributor Author

Bias92 commented Mar 9, 2026

Tested locally on RTX 4060 Ti (WSL2):

  • Reproduced the bug with transposed input (Match: False)
  • Verified fix produces correct results (Match: True)
  • Checked no other incomplete stride guards in the file

Full test suite needs CI — local editable build had C++ extension issues on WSL2.

Bias92 and others added 2 commits March 10, 2026 00:20
@ptrendx
Copy link
Member

ptrendx commented Mar 9, 2026

/te-ci pytorch

@ptrendx
Copy link
Member

ptrendx commented Mar 9, 2026

Thank you for your contribution @Bias92. I will merge once the CI passes.

@yaox12 yaox12 merged commit e6d97ff into NVIDIA:main Mar 10, 2026
20 of 24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[PyTorch] cross_entropy_forward: stride guard misses non-contiguous transposed input

3 participants