Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions tests/pytorch/test_parallel_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,24 @@ def test_ignore_idx_reduced_loss(self):
reduce_loss=True,
ignore_idx=True,
)


def test_non_contiguous_transposed_input():
"""Regression test: stride(-2) != shape[-1] should not produce wrong results."""
s, b, v = 4, 2, 8
torch.manual_seed(42)
logits = torch.randn(s, b, v, device="cuda")
target = torch.randint(0, v, (b, s), device="cuda")

logits_transposed = logits.transpose(0, 1) # stride(-2) != shape[-1]
logits_contiguous = logits_transposed.contiguous()

assert logits_transposed.stride(-1) == 1
assert logits_transposed.stride(-2) != logits_transposed.shape[-1]

loss_t = parallel_cross_entropy(logits_transposed, target, 0.0, False, None)
loss_c = parallel_cross_entropy(logits_contiguous, target, 0.0, False, None)

assert torch.allclose(
loss_t, loss_c
), f"Non-contiguous transposed input gave wrong results: {loss_t} vs {loss_c}"
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/triton/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def cross_entropy_forward(
n_non_ignore = torch.zeros(1, dtype=torch.int64, device=_input.device)

# ensure _input and target are contiguous in the last dimension
if _input.stride(-1) != 1:
if _input.stride(-1) != 1 or _input.stride(-2) != _input.shape[-1]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No regression test for the fixed scenario

The fix correctly catches tensors where stride(-1) == 1 but stride(-2) != shape[-1] (e.g., a tensor of shape (SQ, B, V) transposed to (B, SQ, V) via .transpose(0, 1)). However, no test is added to cover this specific case.

The existing test_swapped_input in tests/pytorch/test_parallel_cross_entropy.py creates a naturally-shaped (SQ, batch, vocab) tensor (contiguous, stride(-2) == V), so it does not exercise the bug path described in the PR.

A minimal regression test would look like:

def test_non_contiguous_transposed_input(self):
    """Regression test for stride(-2) != shape[-1] on non-contiguous input."""
    self.generate_iters(3)
    self.generate_infra(True, 0)
    for _ in range(self.iters):
        batch, SQ, vocab = 2, 64, 1024
        # shape (SQ, batch, vocab) transposed → (batch, SQ, vocab)
        # stride(-1)==1 but stride(-2)==batch*vocab != vocab  ← the old guard missed this
        x = torch.rand((SQ, batch, vocab), dtype=torch.float32, device="cuda").transpose(0, 1)
        assert x.stride(-1) == 1 and x.stride(-2) != x.shape[-1]  # confirm the guard is needed
        self.input_test = x
        ...  # drive through one_iteration_test

Without a test, this class of regression can silently reappear.

_input = _input.contiguous()
if target.stride(-1) != 1:
target = target.contiguous()
Expand Down
Loading