[PyTorch] Fix cross_entropy_forward stride guard for non-contiguous input#2746
Conversation
Signed-off-by: Bias92 <pewpewplay315@gmail.com>
Greptile SummaryThis PR fixes a silent correctness bug in Key changes:
Confidence Score: 5/5
Last reviewed commit: c03ceb0 |
|
|
||
| # ensure _input and target are contiguous in the last dimension | ||
| if _input.stride(-1) != 1: | ||
| if _input.stride(-1) != 1 or _input.stride(-2) != _input.shape[-1]: |
There was a problem hiding this comment.
No regression test for the fixed scenario
The fix correctly catches tensors where stride(-1) == 1 but stride(-2) != shape[-1] (e.g., a tensor of shape (SQ, B, V) transposed to (B, SQ, V) via .transpose(0, 1)). However, no test is added to cover this specific case.
The existing test_swapped_input in tests/pytorch/test_parallel_cross_entropy.py creates a naturally-shaped (SQ, batch, vocab) tensor (contiguous, stride(-2) == V), so it does not exercise the bug path described in the PR.
A minimal regression test would look like:
def test_non_contiguous_transposed_input(self):
"""Regression test for stride(-2) != shape[-1] on non-contiguous input."""
self.generate_iters(3)
self.generate_infra(True, 0)
for _ in range(self.iters):
batch, SQ, vocab = 2, 64, 1024
# shape (SQ, batch, vocab) transposed → (batch, SQ, vocab)
# stride(-1)==1 but stride(-2)==batch*vocab != vocab ← the old guard missed this
x = torch.rand((SQ, batch, vocab), dtype=torch.float32, device="cuda").transpose(0, 1)
assert x.stride(-1) == 1 and x.stride(-2) != x.shape[-1] # confirm the guard is needed
self.input_test = x
... # drive through one_iteration_testWithout a test, this class of regression can silently reappear.
|
Tested locally on RTX 4060 Ti (WSL2):
Full test suite needs CI — local editable build had C++ extension issues on WSL2. |
Signed-off-by: Bias92 <pewpewplay315@gmail.com>
for more information, see https://pre-commit.ci
|
/te-ci pytorch |
|
Thank you for your contribution @Bias92. I will merge once the CI passes. |
The stride guard in cross_entropy_forward only checks stride(-1) != 1,
which misses transposed tensors where stride(-1) == 1 but stride(-2) != shape[-1].
The Triton kernel then uses the wrong row stride and produces silently incorrect results.
Added stride(-2) check, same approach as the backward fix in #2402.
Fixes #2734
Before fix:

Non-contiguous: [2.0794, 2.0794, 2.0794, 2.0794] ← wrong (same values)
Contiguous: [4.1277, 3.7957, 2.1120, 2.5712] ← correct
After fix:

Both return [
4.1277, 3.7957, 2.1120, 2.5712] ✓