Skip to content

Conversation

@jonahsamost
Copy link

We optimize this kernel with a sort of checkpointing (see CHECKPOINT_INTERVAL in the code). It reduces intermediate buffer writes/reads by 4x. The backward recomputes values between checkpoints instead of reading them. The forward only writes every CHECKPOINT_INTERVAL amount of times. Also use fast math intrinsics where we can. End up with 2.3x forward and 1.03x backward speedups (about ~1.3x overall).

Had to fix some build issues i had as well.

Tested on 5090

fused_scan (N=4194304, 512x64x128, combined=512x64x384)
  	forward             74.0 us  442.54 M elem/s
  	backward           106.1 us  308.80 M elem/s
  	forward (checkpointed)   31.6 us  1038.33 M elem/s
  	backward (checkpointed)  103.4 us  316.86 M elem/s
  checkpointed forward correctness: out=ok(3.29e-05) next_state=ok(3.58e-05)
  checkpointed backward correctness: grad_combined=ok(5.43e-05) grad_state=ok(4.77e-07)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant