From cde3c9163116f769c688378959931661f780750c Mon Sep 17 00:00:00 2001 From: Bias92 Date: Tue, 10 Mar 2026 00:10:10 +0900 Subject: [PATCH 1/3] Fix cross_entropy_forward stride guard for non-contiguous input Signed-off-by: Bias92 --- transformer_engine/pytorch/triton/cross_entropy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/triton/cross_entropy.py b/transformer_engine/pytorch/triton/cross_entropy.py index b574d69e0f..1401383c8f 100644 --- a/transformer_engine/pytorch/triton/cross_entropy.py +++ b/transformer_engine/pytorch/triton/cross_entropy.py @@ -49,7 +49,7 @@ def cross_entropy_forward( n_non_ignore = torch.zeros(1, dtype=torch.int64, device=_input.device) # ensure _input and target are contiguous in the last dimension - if _input.stride(-1) != 1: + if _input.stride(-1) != 1 or _input.stride(-2) != _input.shape[-1]: _input = _input.contiguous() if target.stride(-1) != 1: target = target.contiguous() From bbe845b1316913d75e8a55ab3b7745a2493773b9 Mon Sep 17 00:00:00 2001 From: Bias92 Date: Tue, 10 Mar 2026 00:20:59 +0900 Subject: [PATCH 2/3] Add regression test for non-contiguous transposed input Signed-off-by: Bias92 --- tests/pytorch/test_parallel_cross_entropy.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/pytorch/test_parallel_cross_entropy.py b/tests/pytorch/test_parallel_cross_entropy.py index 7b92672af7..5396328da3 100644 --- a/tests/pytorch/test_parallel_cross_entropy.py +++ b/tests/pytorch/test_parallel_cross_entropy.py @@ -167,3 +167,22 @@ def test_ignore_idx_reduced_loss(self): reduce_loss=True, ignore_idx=True, ) +def test_non_contiguous_transposed_input(): + """Regression test: stride(-2) != shape[-1] should not produce wrong results.""" + s, b, v = 4, 2, 8 + torch.manual_seed(42) + logits = torch.randn(s, b, v, device="cuda") + target = torch.randint(0, v, (b, s), device="cuda") + + logits_transposed = logits.transpose(0, 1) # stride(-2) != shape[-1] + logits_contiguous = logits_transposed.contiguous() + + assert logits_transposed.stride(-1) == 1 + assert logits_transposed.stride(-2) != logits_transposed.shape[-1] + + loss_t = parallel_cross_entropy(logits_transposed, target, 0.0, False, None) + loss_c = parallel_cross_entropy(logits_contiguous, target, 0.0, False, None) + + assert torch.allclose(loss_t, loss_c), ( + f"Non-contiguous transposed input gave wrong results: {loss_t} vs {loss_c}" + ) \ No newline at end of file From c03ceb0b91871032b9ad4940f1acbfac1d0d2b5f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Mar 2026 15:21:51 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_parallel_cross_entropy.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/pytorch/test_parallel_cross_entropy.py b/tests/pytorch/test_parallel_cross_entropy.py index 5396328da3..b4ea193f06 100644 --- a/tests/pytorch/test_parallel_cross_entropy.py +++ b/tests/pytorch/test_parallel_cross_entropy.py @@ -167,6 +167,8 @@ def test_ignore_idx_reduced_loss(self): reduce_loss=True, ignore_idx=True, ) + + def test_non_contiguous_transposed_input(): """Regression test: stride(-2) != shape[-1] should not produce wrong results.""" s, b, v = 4, 2, 8 @@ -183,6 +185,6 @@ def test_non_contiguous_transposed_input(): loss_t = parallel_cross_entropy(logits_transposed, target, 0.0, False, None) loss_c = parallel_cross_entropy(logits_contiguous, target, 0.0, False, None) - assert torch.allclose(loss_t, loss_c), ( - f"Non-contiguous transposed input gave wrong results: {loss_t} vs {loss_c}" - ) \ No newline at end of file + assert torch.allclose( + loss_t, loss_c + ), f"Non-contiguous transposed input gave wrong results: {loss_t} vs {loss_c}"