Skip to content

Fix triton cross-entropy for large vocab sizes, support tensor-parallel#466

Open
jlamypoirier wants to merge 11 commits intojlp_entropy_loss_tweaksfrom
jlp_triton_loss
Open

Fix triton cross-entropy for large vocab sizes, support tensor-parallel#466
jlamypoirier wants to merge 11 commits intojlp_entropy_loss_tweaksfrom
jlp_triton_loss

Conversation

@jlamypoirier
Copy link
Collaborator

@jlamypoirier jlamypoirier commented Jan 31, 2026

✨ Description

Greatly expand the triton implementation of loss kernels:

  • Add cross-entropy with target logits or probabilities (distillation)
  • Add forward and reverse KL
  • Add Z loss
  • Add looped versions to remove the 64 vocab size limitation. Turns out the 64K vocab limitation is gone, but going higher makes the kernels way slower, so looped is still better. (Above 32K actually)
  • Add tensor-parallel support for all losses

There shouldn't be any reason for using fused implementations anymore, and to facilitate the transition I removed the parameter altogether so existing configs switch to triton automatically. I added the use_triton parameter in case we do need to disable triton at some point.

For a single loss, the implementation should be optimal WRT read/writes of logits and their gradients, which is the bottleneck for these losses. However, it's still sub-optimal with multiple losses because of redundant computations and sub-optimal gradient accumulation. Some of this could be addressed in a follow-up PR.

As an example, a here is a benchmark I ran for cross-entropy from labels (8K tokens, cuda time + est. memory usage):

# Single GPU, vocab 10K
fused 0.348 ms 492.078 MB
triton 0.169 ms 163.873 MB

# Single GPU, vocab 100K
fused 4.241 ms 4915.233 MB
triton 1.709 ms 1638.433 MB

# 2 GPUs, vocab 10K
fused 1.108 ms 655.606 MB
triton 0.198 ms 82.084 MB

# 2 GPUs, vocab 100K
fused 9.569 ms 6553.846 MB
triton 0.996 ms 819.364 MB

Also found out about triton's interpreter mode, which allows testing triton kernels on CPU. I adjusted the triton tests to support it, so almost every test can now be run without GPU access (Only distributed and megatron model tests remain.) They still won't run on github though because installing triton would be a bit difficult.

@jlamypoirier jlamypoirier changed the title Fix triton cross-entropy for large vocab sizes Fix triton cross-entropy for large vocab sizes, support tensor-parallel Feb 2, 2026
@jlamypoirier jlamypoirier marked this pull request as ready for review February 6, 2026 10:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant