diff --git a/docs/src/darray.md b/docs/src/darray.md index 34dadb333..3b30289a0 100644 --- a/docs/src/darray.md +++ b/docs/src/darray.md @@ -693,12 +693,13 @@ From `Statistics`: - `std` From `LinearAlgebra`: +- `norm` - `transpose`/`adjoint` (Out-of-place transpose) - `*` (Out-of-place Matrix-(Matrix/Vector) multiply) -- `mul!` (In-place Matrix-Matrix multiply) +- `mul!` (In-place Matrix-Matrix and Matrix-Vector multiply) - `cholesky`/`cholesky!` (In-place/Out-of-place Cholesky factorization) - `lu`/`lu!` (In-place/Out-of-place LU factorization (`NoPivot` only)) From `AbstractFFTs`: - `fft`/`fft!` -- `ifft`/`ifft!` \ No newline at end of file +- `ifft`/`ifft!` diff --git a/src/array/linalg.jl b/src/array/linalg.jl index c0ed52fae..bed0848d2 100644 --- a/src/array/linalg.jl +++ b/src/array/linalg.jl @@ -1,6 +1,6 @@ -function LinearAlgebra.norm2(A::DArray{T,2}) where T +function LinearAlgebra.norm2(A::DArray{T,N}) where {T,N} Ac = A.chunks - norms = [Dagger.@spawn mapreduce(LinearAlgebra.norm_sqr, +, chunk) for chunk in Ac]::Matrix{DTask} + norms = [Dagger.@spawn mapreduce(LinearAlgebra.norm_sqr, +, chunk) for chunk in Ac]::Array{DTask,N} zeroRT = zero(real(T)) return sqrt(sum(map(norm->fetch(norm)::real(T), norms); init=zeroRT)) end diff --git a/src/array/mul.jl b/src/array/mul.jl index 44e5ab948..18ce58bd1 100644 --- a/src/array/mul.jl +++ b/src/array/mul.jl @@ -405,3 +405,95 @@ end A[i, j] = C[i, j] end end +function LinearAlgebra.generic_matvecmul!( + C::DVector{T}, + transA::Char, + A::DMatrix{T}, + B::DVector{T}, + _add::LinearAlgebra.MulAddMul, +) where {T} + partC, partA, partB = _repartition_matvecmul(C, A, B, transA) + return maybe_copy_buffered(C=>partC, A=>partA, B=>partB) do C, A, B + return gemv_dagger!(C, transA, A, B, _add) + end +end +function _repartition_matvecmul(C, A, B, transA::Char) + partA = A.partitioning.blocksize + partB = B.partitioning.blocksize + istransA = transA == 'T' || transA == 'C' + dimA = !istransA ? partA[1] : partA[2] + dimA_other = !istransA ? partA[2] : partA[1] + dimB = partB[1] + + # If A and B rows/cols don't match, fix them + # Uses the smallest blocking of all dimensions + sz = minimum((partA[1], partA[2], partB[1])) + if dimA_other != dimB + dimA_other = dimB = sz + if !istransA + partA = (partA[1], sz) + else + partA = (sz, partA[2]) + end + end + partC = (dimA,) + return Blocks(partC...), Blocks(partA...), Blocks(partB...) +end +function gemv_dagger!( + C::DVector{T}, + transA::Char, + A::DMatrix{T}, + B::DVector{T}, + _add::LinearAlgebra.MulAddMul, +) where {T} + Ac = A.chunks + Bc = B.chunks + Cc = C.chunks + Amt, Ant = size(Ac) + Bmt = size(Bc)[1] + Cmt = size(Cc)[1] + + alpha = T(_add.alpha) + beta = T(_add.beta) + + if Ant != Bmt + throw(DimensionMismatch(lazy"A has number of blocks ($Amt,$Ant) but B has number of blocks ($Bmt)")) + end + if Amt != Cmt + throw(DimensionMismatch(lazy"A has number of blocks ($Amt,$Ant) but C has number of blocks ($Cmt)")) + end + + Dagger.spawn_datadeps() do + for m in range(1, Cmt) + if transA == 'N' + # A: NoTrans + for k in range(1, Ant) + mzone = k == 1 ? beta : T(1.0) + Dagger.@spawn BLAS.gemv!( + transA, + alpha, + In(Ac[m, k]), + In(Bc[k]), + mzone, + InOut(Cc[m]), + ) + end + else + # A: [Conj]Trans + for k in range(1, Amt) + mzone = k == 1 ? beta : T(1.0) + Dagger.@spawn BLAS.gemv!( + transA, + alpha, + In(Ac[k, m]), + In(Bc[k]), + mzone, + InOut(Cc[m]), + ) + end + end + end + end + + return C +end diff --git a/test/array/linalg/core.jl b/test/array/linalg/core.jl index 87c3d3d71..4c66ac34a 100644 --- a/test/array/linalg/core.jl +++ b/test/array/linalg/core.jl @@ -9,3 +9,17 @@ L2 = LowerTriangular(DArray(A, Blocks(16, 16))) @test isapprox(L1, L2) end + +@testset "norm" begin + A = rand(16, 16) + DA = DArray(A) + @test isapprox(norm(A), norm(DA)) + + A = rand(16) + DA = DArray(A) + @test isapprox(norm(A), norm(DA)) + + A = rand(16, 16, 16) + DA = DArray(A) + @test isapprox(norm(A), norm(DA)) +end diff --git a/test/array/linalg/matmul.jl b/test/array/linalg/matmul.jl index e15f74329..ebc9d8ddc 100644 --- a/test/array/linalg/matmul.jl +++ b/test/array/linalg/matmul.jl @@ -136,10 +136,87 @@ part_sets_to_test = map(_sizes_to_test) do sz ] end parts_to_test = vcat(part_sets_to_test...) -@testset "Size=$szA*$szB" for (szA, szB) in sizes_to_test - @testset "Partitioning=$partA*$partB" for (partA,partB) in parts_to_test - @testset "T=$T" for T in (Float32, Float64, ComplexF32, ComplexF64) - test_gemm!(T, szA, szB, partA, partB) +@testset "GEMM" begin + @testset "Size=$szA*$szB" for (szA, szB) in sizes_to_test + @testset "Partitioning=$partA*$partB" for (partA,partB) in parts_to_test + @testset "T=$T" for T in (Float32, Float64, ComplexF32, ComplexF64) + test_gemm!(T, szA, szB, partA, partB) + end + end + end +end + +function test_gemv!(T, szA, szB, partA, partB) + @assert szA[2] == szB[1] + szC = (szA[1],) + @assert partA.blocksize[2] == partB.blocksize[1] + partC = Blocks(partA.blocksize[1],) + + A = rand(T, szA...) + B = rand(T, szB...) + + DA = distribute(A, partA) + DB = distribute(B, partB) + + ## Out-of-place gemm + # No transA + DC = DA * DB + C = A * B + @test collect(DC) ≈ C + + if szA[1] == szB[1] + # transA + DC = DA' * DB + C = A' * B + @test collect(DC) ≈ C + end + + ## In-place gemm + # No transA + C = zeros(T, szC...) + DC = distribute(C, partC) + mul!(C, A, B) + mul!(DC, DA, DB) + @test collect(DC) ≈ C + + if szA[1] == szB[1] + # transA + C = zeros(T, szC...) + DC = distribute(C, partC) + mul!(C, A', B) + mul!(DC, DA', DB) + @test collect(DC) ≈ C + end +end + +_sizes_to_test = [ + (4, 4), + (7, 7), + (12, 12), + (16, 16), +] +size_sets_to_test = map(_sizes_to_test) do sz + rows, cols = sz + return [ + (rows, cols) => (cols,), + (rows, cols ÷ 2) => (cols ÷ 2,), + ] +end +sizes_to_test = vcat(size_sets_to_test...) +part_sets_to_test = map(_sizes_to_test) do sz + rows, cols = sz + return [ + Blocks(rows, cols) => Blocks(cols,), + Blocks(rows, cols ÷ 2) => Blocks(cols ÷ 2,), + ] +end +parts_to_test = vcat(part_sets_to_test...) +@testset "GEMV" begin + @testset "Size=$szA*$szB" for (szA, szB) in sizes_to_test + @testset "Partitioning=$partA*$partB" for (partA,partB) in parts_to_test + @testset "T=$T" for T in (Float32, Float64, ComplexF32, ComplexF64) + test_gemv!(T, szA, szB, partA, partB) + end end end end