⚡ Thunderbolt: softmax — Single FMA Range Reduction#40
Conversation
…ingle FMA Added `softmax_v6` utilizing `exp256_ps_v3` which combines the `r = x - n * ln(2)` range reduction step into a single `_mm256_fnmadd_ps` instruction. Precision loss from avoiding the high/low split is acceptable within `1e-4` precision tolerance due to softmax's shift-invariance. This reduces instruction latency and dependency chain size on the critical path. Measured an improvement from 3.73 GFLOP/s to 3.97 GFLOP/s for N=1048576 arrays. Co-authored-by: bugparty <1510776+bugparty@users.noreply.github.com>
|
👋 Jules, reporting for duty! I'm here to lend a hand with this pull request. When you start a review, I'll add a 👀 emoji to each comment to let you know I've read it. I'll focus on feedback directed at me and will do my best to stay out of conversations between you and other bots or reviewers to keep the noise down. I'll push a commit with your requested changes shortly after. Please note there might be a delay between these steps, but rest assured I'm on the job! For more direct control, you can switch me to Reactive Mode. When this mode is on, I will only act on comments where you specifically mention me with New to Jules? Learn more at jules.google/docs. For security, I will only act on instructions from the user who triggered this task. |
📝 WalkthroughWalkthroughA new AVX2 softmax kernel ( ChangesSoftmax v6 optimization
Sequence DiagramsequenceDiagram
participant Input as Input buffer
participant VectorMax as Vector max<br/>(4 accumulators)
participant Reduce_Max as reduce_max
participant VectorExp as exp256_ps_v3<br/>(4 accumulators)
participant Reduce_Sum as reduce_sum
participant ScalarExp as std::exp<br/>(tail)
participant Normalize as Normalize<br/>(reciprocal · output)
participant Output as Output buffer
Input->>VectorMax: load 32 elements per iteration
VectorMax->>Reduce_Max: reduce four accumulators
Input->>VectorExp: compute exp for 32 elements per iteration
VectorExp->>Output: store exponent results
VectorExp->>Reduce_Sum: accumulate four sums
Reduce_Sum->>ScalarExp: reduce to scalar, process tail
ScalarExp->>Normalize: final sum
Normalize->>Output: multiply by 1/sum
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~22 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
ml_kernels/src/test_naive_ops.cpp (1)
184-211: ⚡ Quick winConsider adding edge-case coverage for scalar tail and 8-element remainder paths.
The test input has exactly 32 elements, which exercises only the main 32-element unrolled loop. The 8-element loop (lines 235-240 in softmax.h) and scalar tail (lines 243-247) are not exercised. Adding inputs with sizes like 33 or 41 elements would improve coverage.
This matches the existing test pattern for v3–v5, so it's optional to address now.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@ml_kernels/src/test_naive_ops.cpp` around lines 184 - 211, The test test_softmax_v6 only covers a 32-element input and misses the 8-element remainder and scalar tail paths in softmax_v6; add additional cases in test_softmax_v6 (or new tests) that call ml_kernels::softmax_v6 with input sizes that exercise the 8-element remainder and scalar tail (e.g., 33 and 41 elements) and validate output equality with ml_kernels::softmax_naive and that the result sums to 1.0f, similar to the existing assertions.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@ml_kernels/src/test_naive_ops.cpp`:
- Around line 184-211: The test test_softmax_v6 only covers a 32-element input
and misses the 8-element remainder and scalar tail paths in softmax_v6; add
additional cases in test_softmax_v6 (or new tests) that call
ml_kernels::softmax_v6 with input sizes that exercise the 8-element remainder
and scalar tail (e.g., 33 and 41 elements) and validate output equality with
ml_kernels::softmax_naive and that the result sums to 1.0f, similar to the
existing assertions.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 92d9ca68-3b58-4119-a8a6-36ad72541248
📒 Files selected for processing (4)
.jules/thunderbolt.mdml_kernels/include/ml_kernels/softmax.hml_kernels/src/kernel_bench.cppml_kernels/src/test_naive_ops.cpp
💡 What: Added
softmax_v6containing a newexp256_ps_v3evaluation function. This eliminates the two-partr = x - n * ln2_hi - n * ln2_lorange reduction step and instead folds it into a single_mm256_fnmadd_ps.🎯 Why: Breaking the precision of the range reduction down into high and low parts costs execution cycles and registers. Since the softmax kernel inherently evaluates relative shifted constants (by previously subtracting the maximum array element from everything), absolute numerical floating point perfection in
exp256evaluates to negligible shifts at the softmax output.🏗️ How:
exp256_ps_v3was introduced alongsidesoftmax_v6usingr = _mm256_fnmadd_ps(n, _mm256_set1_ps(0.6931471805599453f), x);replacing the sequence of two independentfnmadd_psoperations inexp256_ps_v2.📊 Impact: Benchmarks indicate throughput on 1MB structures (N=1048576) scaled from 3.73 GFLOP/s up to 3.97 GFLOP/s, a roughly ~6.4% jump in execution throughput. Tests mapping deviation from absolute correctness show no values drift past
1e-4tolerance limits.🖥️ Tested on: Linux Sandbox, AVX2 compatible Haswell microarchitecture runtime.
🔬 How to reproduce:
DISABLE_CPU_BINDING=1 ./build/ml_kernels/ml_kernel_bench --filter 'softmax' --iters 500 --warmup 50 --sizes 1048576PR created automatically by Jules for task 17482068372732358033 started by @bugparty
Summary by CodeRabbit
New Features
Tests
Documentation