diff --git a/tests/pytorch/test_parallel_cross_entropy.py b/tests/pytorch/test_parallel_cross_entropy.py index 7b92672af7..b4ea193f06 100644 --- a/tests/pytorch/test_parallel_cross_entropy.py +++ b/tests/pytorch/test_parallel_cross_entropy.py @@ -167,3 +167,24 @@ def test_ignore_idx_reduced_loss(self): reduce_loss=True, ignore_idx=True, ) + + +def test_non_contiguous_transposed_input(): + """Regression test: stride(-2) != shape[-1] should not produce wrong results.""" + s, b, v = 4, 2, 8 + torch.manual_seed(42) + logits = torch.randn(s, b, v, device="cuda") + target = torch.randint(0, v, (b, s), device="cuda") + + logits_transposed = logits.transpose(0, 1) # stride(-2) != shape[-1] + logits_contiguous = logits_transposed.contiguous() + + assert logits_transposed.stride(-1) == 1 + assert logits_transposed.stride(-2) != logits_transposed.shape[-1] + + loss_t = parallel_cross_entropy(logits_transposed, target, 0.0, False, None) + loss_c = parallel_cross_entropy(logits_contiguous, target, 0.0, False, None) + + assert torch.allclose( + loss_t, loss_c + ), f"Non-contiguous transposed input gave wrong results: {loss_t} vs {loss_c}" diff --git a/transformer_engine/pytorch/triton/cross_entropy.py b/transformer_engine/pytorch/triton/cross_entropy.py index b574d69e0f..1401383c8f 100644 --- a/transformer_engine/pytorch/triton/cross_entropy.py +++ b/transformer_engine/pytorch/triton/cross_entropy.py @@ -49,7 +49,7 @@ def cross_entropy_forward( n_non_ignore = torch.zeros(1, dtype=torch.int64, device=_input.device) # ensure _input and target are contiguous in the last dimension - if _input.stride(-1) != 1: + if _input.stride(-1) != 1 or _input.stride(-2) != _input.shape[-1]: _input = _input.contiguous() if target.stride(-1) != 1: target = target.contiguous()