From 5d9ab823e8c29dfeb196a724109d82b6771293d9 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Thu, 1 Jan 2026 19:00:54 -0400 Subject: [PATCH] Faster matmul --- src/host/linalg.jl | 81 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/src/host/linalg.jl b/src/host/linalg.jl index bc599968..c5f3472d 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -436,11 +436,92 @@ function LinearAlgebra.ldiv!(B::AbstractGPUVecOrMat, B end +# XXX: figure out how to do dynamically +MAX_TILE_DIM = 16 ## matrix multiplication # legacy method generic_matmatmul!(C::AbstractArray, A::AbstractArray, B::AbstractArray, a::Number, b::Number) = generic_matmatmul!(C, A, B, MulAddMul(a, b)) +function generic_matmatmul!(C::AbstractGPUMatrix{R}, A::AbstractGPUMatrix{T}, B::AbstractGPUMatrix{S}, add::MulAddMul) where {T<:Number,S<:Number,R<:Number} + N = size(A,1) + Q = size(A,2) + M = size(B,2) + if Q != size(B,1) + throw(DimensionMismatch("matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))")) + end + if size(C,1) != N || size(C,2) != M + throw(DimensionMismatch("result C has dimensions $(size(C)), needs $((N,M))")) + end + if isempty(A) || isempty(B) + return fill!(C, zero(R)) + end + + @kernel unsafe_indices=true function coalesced_matmul_kernel!( + output, @Const(input1), @Const(input2), N, Q, M, + ::Val{BANK} = Val(1), + ) where {BANK} + grow, gcol = @index(Group, NTuple) + tile_row, tile_col = @index(Local, NTuple) + + TILE_DIM = @uniform @groupsize()[1] + + # +1 to avoid bank conflicts on shared memory + tile1 = @localmem(R, (TILE_DIM + BANK, TILE_DIM)) + tile2 = @localmem(R, (TILE_DIM + BANK, TILE_DIM)) + + # private variable for tile output + outval = @private R 1 + @inbounds outval[1] = -zero(R) + + # number of tiles depends on inner dimension + @uniform NUM_TILES = div(Q + TILE_DIM - 1, TILE_DIM) + + # loop over all tiles needed for this calculation + for t in 0:(NUM_TILES - 1) + I = (grow - 1) * TILE_DIM + tile_row + J = (gcol - 1) * TILE_DIM + tile_col + + # load inputs into tiles, with bounds checking for non-square matrices + if I <= N && t * TILE_DIM + tile_col <= Q + @inbounds tile1[tile_row, tile_col] = input1[I, t * TILE_DIM + tile_col] + else + @inbounds tile1[tile_row, tile_col] = zero(R) + end + if J <= M && t * TILE_DIM + tile_row <= Q + @inbounds tile2[tile_row, tile_col] = input2[t * TILE_DIM + tile_row, J] + else + @inbounds tile2[tile_row, tile_col] = zero(R) + end + + # wait for all tiles to be loaded + @synchronize + + I = (grow - 1) * TILE_DIM + tile_row + J = (gcol - 1) * TILE_DIM + tile_col + + # calculate value of spot in output, use temporary value to allow for vectorization + out = zero(R) + @simd for k in 1:TILE_DIM + @inbounds out += tile1[tile_row, k] * tile2[k, tile_col] + end + outval[1] += out + + @synchronize + end + + I = (grow - 1) * TILE_DIM + tile_row + J = (gcol - 1) * TILE_DIM + tile_col + + # save if inbounds + if I <= N && J <= M + @inbounds output[I, J] = add(outval[1], output[I, J]) + end + end + + coalesced_matmul_kernel!(get_backend(C), (MAX_TILE_DIM, MAX_TILE_DIM))(C, A, B, N, Q, M;ndrange=map(x -> ceil(Int,x/MAX_TILE_DIM)*MAX_TILE_DIM, size(C))) + C +end function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::AbstractArray{S}, add::MulAddMul) where {T,S,R} if size(A,2) != size(B,1) throw(DimensionMismatch("matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))"))