feat: native bf16 GPU elementwise + AdamW kernels (ADR 075 L4)#152
Merged
Conversation
Add native bfloat16 CUDA kernels mirroring the FP16 path: - elementwise_bf16.cu: Add/Sub/Mul/Div (__nv_bfloat162 SIMD), Tanh/Sqrt/ Exp/Log (FP32 transcendental), scaled_softmax (FP32 accum), F32<->BF16. - fused_adamw_bf16.cu: on-device AdamW on bf16 params with f32 first-moment and f64 second-moment sidecars (full-precision update, bf16 write-back). - purego bindings + symbol table + optionalSyms entries + Makefile SRCS. bf16 shares f32's 8-bit exponent so this does not reopen the ADR-072 forward-conditioning cliff; reductions/transcendentals accumulate in FP32.
Add AddBF16/SubBF16/MulBF16/DivBF16, TanhBF16/SqrtBF16/ExpBF16/LogBF16, ScaledSoftmaxBF16, F32ToBF16/BF16ToF32, and FusedAdamWBF16 to the KernelRunner interface. CUDAKernels wires them to the kernels package; FPGA/Metal/ROCm/OpenCL/SYCL backends return not-implemented stubs.
Route T=float16.BFloat16 to GPU bf16 kernels instead of CPU fallback: - gpu_bf16.go: isBFloat16 guard + gpuBinaryOpBF16/gpuUnaryOpBF16/ gpuSoftmaxBF16 helpers (2-byte element size; getDevicePtr/makeGPUResult are already element-size-generic). - gpuAdd/Sub/Mul/Div/Tanh/Sqrt/Exp/Log/Softmax: bf16 branch (binary ops same-shape only; broadcast still CPU). - GPUFusedAdamW: bf16 branch -> FusedAdamWBF16 (shared f32/f64 moment state). - gpu_bf16_parity_test.go: CUDA-gated bf16-vs-f32 parity gate for every op + full AdamW step parity + CPU-path-unchanged invariant.
One-pod GB10 verification: clone bf16 branch, build libkernels.so for sm_121 with host CUDA, run the CUDA-gated TestGPUBF16_* parity suite. Token injected at submit time, never persisted. memory limit per L-0005.
ac8e2af to
67d05fa
Compare
Add bf16 variants of the three forward-only fused normalization kernels, mirroring the f32 originals (fused_add_rmsnorm, fused_norm_add, fused_qk_norm_rope): - fused_add_rmsnorm_bf16: sum = input+residual, normed = rmsnorm(sum)*weight - fused_norm_add_bf16: output = rmsnorm(input)*weight + residual - fused_qk_norm_rope_bf16: per-head RMSNorm + RoPE on combined Q+K heads Params/weights/outputs are __nv_bfloat16 (2 bytes); all sum-of-squares reductions and the normalization/RoPE arithmetic accumulate in FP32 and round to bf16 only on store (no fast-math), so each result equals round_to_bf16 of the FP32-accurate value -- the bf16-vs-f32 parity oracle. bf16 shares f32's 8-bit exponent so this does not reopen the ADR-072 cliff. These are forward-only: the f32 originals have no backward kernel, so there is no bf16 backward to mirror. The PatchTST fused encoder fwd/bwd is NOT ported (large multi-kernel orchestrator; deferred pending GB10 verification). Wired: extern "C" launchers, purego symbol fields + loader registrations + optionalSyms entries, and the .cu added to the Makefile SRCS. GPU verification PENDING -- not yet run on a GB10.
Add FusedAddRMSNormBF16, FusedNormAddBF16, and FusedQKNormRoPEBF16 to the KernelRunner interface. CUDAKernels implements them by delegating to the new bf16 launchers; all other backends (FPGA, Metal, ROCm, OpenCL, SYCL) and the stubKernelRunner in the test get not-implemented stubs so the compile-time KernelRunner assertions on every implementer keep passing. GPU verification PENDING.
…DR 075 L4) Route T=float16.BFloat16 through the new bf16 fused norm kernels: GPUFusedAddRMSNorm, GPUFusedNormAdd, and GPUFusedQKNormRoPE each gain an isBFloat16[T]() branch (mirroring the existing GPUFusedAdamW bf16 dispatch) that calls a native bf16 helper in gpu_bf16.go. The helpers allocate 2-byte buffers and reuse the element-size-generic getDevicePtr/makeGPUResult, so the f32 path is untouched. RoPE cos/sin tensors are bf16 to match the engine's generic tensor type T (upconverted to FP32 inside the kernel). Add CUDA-gated parity tests (compute/gpu_bf16_fused_norm_parity_test.go) checking each bf16 GPU op against an f64 reference rounded to bf16 (FP32 reductions -> a few bf16 ulps of slack), mirroring gpu_bf16_parity_test.go: TestGPUBF16_FusedAddRMSNormParity, TestGPUBF16_FusedNormAddParity, TestGPUBF16_FusedQKNormRoPEParity. They skip without CUDA and will run on the GB10 verify pod. GPU verification PENDING -- parity not yet confirmed on a GB10. Note on scope: an audit of the Wolf CrossAsset GPU hot path found it uses LayerNorm (not RMSNorm), GELU (not SwiGLU), no RoPE, and discrete SDPA, so the only fused kernels it actually exercises on GPU are scaled_softmax and AdamW (both already bf16). These three norm kernels are general framework additions, not on the CrossAsset path.
… GB10-exposed) The bf16 MatMul parity tests (gpu_bf16_matmul_test.go, unchanged from main) asserted max relative error < 1e-3, but bf16's ~8-bit mantissa gives ~3.9e-3 relative precision; on GB10 these measured 1.2-2.4e-3 and failed. CI never caught it because CUDA-gated tests skip on the no-GPU runner. 5e-3 is the bf16-appropriate bound. Exposed while GB10-verifying the bf16 elementwise/ AdamW kernels (all of which passed).
Contributor
Author
|
GB10 verified (sm_121, 2026-06-16T22:43Z): rebuilt libkernels.so from this branch via Spark and ran the CUDA-gated compute suite on the NVIDIA GB10 — all green: TestGPUBF16_BinaryParity (Add/Sub/Mul/Div), UnaryParity (Tanh/Sqrt/Exp/Log), SoftmaxParity, AdamWParity, FusedAddRMSNorm/FusedNormAdd/FusedQKNormRoPE parity, MatMulBF16* (tolerance corrected to bf16-appropriate 5e-3 — pre-existing 1e-3 was too tight and CI never caught it since it skips GPU), and BF16CPUPathUnchanged. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds the native bfloat16 GPU compute path so a full bf16 graph (matmul + elementwise + optimizer) runs on-device instead of falling back to CPU for every
!isFloat32[T]()op. bf16 batched GEMM already existed (BFloat16Gemm); this fills in elementwise + AdamW.Kernels (
internal/cuda/kernels/)elementwise_bf16.cu: Add/Sub/Mul/Div (__nv_bfloat162SIMD), Tanh/Sqrt/Exp/Log (FP32 transcendental, no-fast-math convention), scaled-softmax (FP32 max/sum accumulation), F32<->BF16 conversion.fused_adamw_bf16.cu: on-device AdamW on bf16 params with f32 first-moment + f64 second-moment sidecars — full-precision update, bf16 write-back (standard bf16-weights/high-precision-state recipe). Same bit-pattern scalar ABI asfused_adamw_f32.optionalSyms(missing-symbol non-fatal), Makefile SRCS.Plumbing
gpuapi.KernelRunner: bf16 methods (CUDA real impl; FPGA/Metal/ROCm/OpenCL/SYCL not-implemented stubs).compute:isBFloat16guard +gpuBinaryOpBF16/gpuUnaryOpBF16/gpuSoftmaxBF16(2-byte elements;getDevicePtr/makeGPUResultare already element-size-generic). bf16 branches in gpuAdd/Sub/Mul/Div/Tanh/Sqrt/Exp/Log/Softmax andGPUFusedAdamW(shares the f32/f64 moment state).Universal quality gate (
gpu_bf16_parity_test.go, CUDA-gated)Notes