Skip to content

feat: native bf16 GPU elementwise + AdamW kernels (ADR 075 L4)#152

Merged
dndungu merged 9 commits into
mainfrom
feat/bf16-gpu-elementwise-adamw
Jun 16, 2026
Merged

feat: native bf16 GPU elementwise + AdamW kernels (ADR 075 L4)#152
dndungu merged 9 commits into
mainfrom
feat/bf16-gpu-elementwise-adamw

Conversation

@dndungu

@dndungu dndungu commented Jun 16, 2026

Copy link
Copy Markdown
Contributor

Summary

Adds the native bfloat16 GPU compute path so a full bf16 graph (matmul + elementwise + optimizer) runs on-device instead of falling back to CPU for every !isFloat32[T]() op. bf16 batched GEMM already existed (BFloat16Gemm); this fills in elementwise + AdamW.

Kernels (internal/cuda/kernels/)

  • elementwise_bf16.cu: Add/Sub/Mul/Div (__nv_bfloat162 SIMD), Tanh/Sqrt/Exp/Log (FP32 transcendental, no-fast-math convention), scaled-softmax (FP32 max/sum accumulation), F32<->BF16 conversion.
  • fused_adamw_bf16.cu: on-device AdamW on bf16 params with f32 first-moment + f64 second-moment sidecars — full-precision update, bf16 write-back (standard bf16-weights/high-precision-state recipe). Same bit-pattern scalar ABI as fused_adamw_f32.
  • purego bindings, symbol table, optionalSyms (missing-symbol non-fatal), Makefile SRCS.

Plumbing

  • gpuapi.KernelRunner: bf16 methods (CUDA real impl; FPGA/Metal/ROCm/OpenCL/SYCL not-implemented stubs).
  • compute: isBFloat16 guard + gpuBinaryOpBF16/gpuUnaryOpBF16/gpuSoftmaxBF16 (2-byte elements; getDevicePtr/makeGPUResult are already element-size-generic). bf16 branches in gpuAdd/Sub/Mul/Div/Tanh/Sqrt/Exp/Log/Softmax and GPUFusedAdamW (shares the f32/f64 moment state).

Universal quality gate (gpu_bf16_parity_test.go, CUDA-gated)

  • bf16-vs-f32 parity for every op (1–4 bf16 ulps), full AdamW-step parity vs f64 reference + in-place grad-zeroing, and a CPU-path-unchanged invariant.
  • bf16 shares f32's exponent range → does not reopen the ADR-072 forward-conditioning cliff; only the 7-bit mantissa differs (reductions/transcendentals accumulate in FP32).

Notes

  • Broadcast bf16 binary ops still fall back to CPU (same-shape on GPU). Generic mechanism, no Wolf-specific code.
  • GB10 (sm_121) GPU parity-test run pending the rebuilt libkernels.so.

dndungu added 4 commits June 16, 2026 12:55
Add native bfloat16 CUDA kernels mirroring the FP16 path:
- elementwise_bf16.cu: Add/Sub/Mul/Div (__nv_bfloat162 SIMD), Tanh/Sqrt/
  Exp/Log (FP32 transcendental), scaled_softmax (FP32 accum), F32<->BF16.
- fused_adamw_bf16.cu: on-device AdamW on bf16 params with f32 first-moment
  and f64 second-moment sidecars (full-precision update, bf16 write-back).
- purego bindings + symbol table + optionalSyms entries + Makefile SRCS.

bf16 shares f32's 8-bit exponent so this does not reopen the ADR-072
forward-conditioning cliff; reductions/transcendentals accumulate in FP32.
Add AddBF16/SubBF16/MulBF16/DivBF16, TanhBF16/SqrtBF16/ExpBF16/LogBF16,
ScaledSoftmaxBF16, F32ToBF16/BF16ToF32, and FusedAdamWBF16 to the
KernelRunner interface. CUDAKernels wires them to the kernels package;
FPGA/Metal/ROCm/OpenCL/SYCL backends return not-implemented stubs.
Route T=float16.BFloat16 to GPU bf16 kernels instead of CPU fallback:
- gpu_bf16.go: isBFloat16 guard + gpuBinaryOpBF16/gpuUnaryOpBF16/
  gpuSoftmaxBF16 helpers (2-byte element size; getDevicePtr/makeGPUResult
  are already element-size-generic).
- gpuAdd/Sub/Mul/Div/Tanh/Sqrt/Exp/Log/Softmax: bf16 branch (binary ops
  same-shape only; broadcast still CPU).
- GPUFusedAdamW: bf16 branch -> FusedAdamWBF16 (shared f32/f64 moment state).
- gpu_bf16_parity_test.go: CUDA-gated bf16-vs-f32 parity gate for every op
  + full AdamW step parity + CPU-path-unchanged invariant.
One-pod GB10 verification: clone bf16 branch, build libkernels.so for
sm_121 with host CUDA, run the CUDA-gated TestGPUBF16_* parity suite.
Token injected at submit time, never persisted. memory limit per L-0005.
@dndungu dndungu force-pushed the feat/bf16-gpu-elementwise-adamw branch from ac8e2af to 67d05fa Compare June 16, 2026 19:55
dndungu added 5 commits June 16, 2026 15:18
Add bf16 variants of the three forward-only fused normalization kernels,
mirroring the f32 originals (fused_add_rmsnorm, fused_norm_add,
fused_qk_norm_rope):

  - fused_add_rmsnorm_bf16: sum = input+residual, normed = rmsnorm(sum)*weight
  - fused_norm_add_bf16:     output = rmsnorm(input)*weight + residual
  - fused_qk_norm_rope_bf16: per-head RMSNorm + RoPE on combined Q+K heads

Params/weights/outputs are __nv_bfloat16 (2 bytes); all sum-of-squares
reductions and the normalization/RoPE arithmetic accumulate in FP32 and
round to bf16 only on store (no fast-math), so each result equals
round_to_bf16 of the FP32-accurate value -- the bf16-vs-f32 parity oracle.
bf16 shares f32's 8-bit exponent so this does not reopen the ADR-072 cliff.

These are forward-only: the f32 originals have no backward kernel, so there
is no bf16 backward to mirror. The PatchTST fused encoder fwd/bwd is NOT
ported (large multi-kernel orchestrator; deferred pending GB10 verification).

Wired: extern "C" launchers, purego symbol fields + loader registrations +
optionalSyms entries, and the .cu added to the Makefile SRCS.

GPU verification PENDING -- not yet run on a GB10.
Add FusedAddRMSNormBF16, FusedNormAddBF16, and FusedQKNormRoPEBF16 to the
KernelRunner interface. CUDAKernels implements them by delegating to the new
bf16 launchers; all other backends (FPGA, Metal, ROCm, OpenCL, SYCL) and the
stubKernelRunner in the test get not-implemented stubs so the compile-time
KernelRunner assertions on every implementer keep passing.

GPU verification PENDING.
…DR 075 L4)

Route T=float16.BFloat16 through the new bf16 fused norm kernels:
GPUFusedAddRMSNorm, GPUFusedNormAdd, and GPUFusedQKNormRoPE each gain an
isBFloat16[T]() branch (mirroring the existing GPUFusedAdamW bf16 dispatch)
that calls a native bf16 helper in gpu_bf16.go. The helpers allocate 2-byte
buffers and reuse the element-size-generic getDevicePtr/makeGPUResult, so
the f32 path is untouched. RoPE cos/sin tensors are bf16 to match the
engine's generic tensor type T (upconverted to FP32 inside the kernel).

Add CUDA-gated parity tests (compute/gpu_bf16_fused_norm_parity_test.go)
checking each bf16 GPU op against an f64 reference rounded to bf16 (FP32
reductions -> a few bf16 ulps of slack), mirroring gpu_bf16_parity_test.go:
TestGPUBF16_FusedAddRMSNormParity, TestGPUBF16_FusedNormAddParity,
TestGPUBF16_FusedQKNormRoPEParity. They skip without CUDA and will run on
the GB10 verify pod.

GPU verification PENDING -- parity not yet confirmed on a GB10.

Note on scope: an audit of the Wolf CrossAsset GPU hot path found it uses
LayerNorm (not RMSNorm), GELU (not SwiGLU), no RoPE, and discrete SDPA, so
the only fused kernels it actually exercises on GPU are scaled_softmax and
AdamW (both already bf16). These three norm kernels are general framework
additions, not on the CrossAsset path.
… GB10-exposed)

The bf16 MatMul parity tests (gpu_bf16_matmul_test.go, unchanged from main)
asserted max relative error < 1e-3, but bf16's ~8-bit mantissa gives ~3.9e-3
relative precision; on GB10 these measured 1.2-2.4e-3 and failed. CI never
caught it because CUDA-gated tests skip on the no-GPU runner. 5e-3 is the
bf16-appropriate bound. Exposed while GB10-verifying the bf16 elementwise/
AdamW kernels (all of which passed).
@dndungu

dndungu commented Jun 16, 2026

Copy link
Copy Markdown
Contributor Author

GB10 verified (sm_121, 2026-06-16T22:43Z): rebuilt libkernels.so from this branch via Spark and ran the CUDA-gated compute suite on the NVIDIA GB10 — all green: TestGPUBF16_BinaryParity (Add/Sub/Mul/Div), UnaryParity (Tanh/Sqrt/Exp/Log), SoftmaxParity, AdamWParity, FusedAddRMSNorm/FusedNormAdd/FusedQKNormRoPE parity, MatMulBF16* (tolerance corrected to bf16-appropriate 5e-3 — pre-existing 1e-3 was too tight and CI never caught it since it skips GPU), and BF16CPUPathUnchanged. ok github.com/zerfoo/ztensor/compute 1.748s.

@dndungu dndungu merged commit a7477a5 into main Jun 16, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant