Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/src/darray.md
Original file line number Diff line number Diff line change
Expand Up @@ -693,12 +693,13 @@ From `Statistics`:
- `std`

From `LinearAlgebra`:
- `norm`
- `transpose`/`adjoint` (Out-of-place transpose)
- `*` (Out-of-place Matrix-(Matrix/Vector) multiply)
- `mul!` (In-place Matrix-Matrix multiply)
- `mul!` (In-place Matrix-Matrix and Matrix-Vector multiply)
- `cholesky`/`cholesky!` (In-place/Out-of-place Cholesky factorization)
- `lu`/`lu!` (In-place/Out-of-place LU factorization (`NoPivot` only))

From `AbstractFFTs`:
- `fft`/`fft!`
- `ifft`/`ifft!`
- `ifft`/`ifft!`
4 changes: 2 additions & 2 deletions src/array/linalg.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
function LinearAlgebra.norm2(A::DArray{T,2}) where T
function LinearAlgebra.norm2(A::DArray{T,N}) where {T,N}
Ac = A.chunks
norms = [Dagger.@spawn mapreduce(LinearAlgebra.norm_sqr, +, chunk) for chunk in Ac]::Matrix{DTask}
norms = [Dagger.@spawn mapreduce(LinearAlgebra.norm_sqr, +, chunk) for chunk in Ac]::Array{DTask,N}
zeroRT = zero(real(T))
return sqrt(sum(map(norm->fetch(norm)::real(T), norms); init=zeroRT))
end
Expand Down
92 changes: 92 additions & 0 deletions src/array/mul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -405,3 +405,95 @@ end
A[i, j] = C[i, j]
end
end
function LinearAlgebra.generic_matvecmul!(
C::DVector{T},
transA::Char,
A::DMatrix{T},
B::DVector{T},
_add::LinearAlgebra.MulAddMul,
) where {T}
partC, partA, partB = _repartition_matvecmul(C, A, B, transA)
return maybe_copy_buffered(C=>partC, A=>partA, B=>partB) do C, A, B
return gemv_dagger!(C, transA, A, B, _add)
end
end
function _repartition_matvecmul(C, A, B, transA::Char)
partA = A.partitioning.blocksize
partB = B.partitioning.blocksize
istransA = transA == 'T' || transA == 'C'
dimA = !istransA ? partA[1] : partA[2]
dimA_other = !istransA ? partA[2] : partA[1]
dimB = partB[1]

# If A and B rows/cols don't match, fix them
# Uses the smallest blocking of all dimensions
sz = minimum((partA[1], partA[2], partB[1]))
if dimA_other != dimB
dimA_other = dimB = sz
if !istransA
partA = (partA[1], sz)
else
partA = (sz, partA[2])
end
end
partC = (dimA,)
return Blocks(partC...), Blocks(partA...), Blocks(partB...)
end
function gemv_dagger!(
C::DVector{T},
transA::Char,
A::DMatrix{T},
B::DVector{T},
_add::LinearAlgebra.MulAddMul,
) where {T}
Ac = A.chunks
Bc = B.chunks
Cc = C.chunks
Amt, Ant = size(Ac)
Bmt = size(Bc)[1]
Cmt = size(Cc)[1]

alpha = T(_add.alpha)
beta = T(_add.beta)

if Ant != Bmt
throw(DimensionMismatch(lazy"A has number of blocks ($Amt,$Ant) but B has number of blocks ($Bmt)"))
end
if Amt != Cmt
throw(DimensionMismatch(lazy"A has number of blocks ($Amt,$Ant) but C has number of blocks ($Cmt)"))
end

Dagger.spawn_datadeps() do
for m in range(1, Cmt)
if transA == 'N'
# A: NoTrans
for k in range(1, Ant)
mzone = k == 1 ? beta : T(1.0)
Dagger.@spawn BLAS.gemv!(
transA,
alpha,
In(Ac[m, k]),
In(Bc[k]),
mzone,
InOut(Cc[m]),
)
end
else
# A: [Conj]Trans
for k in range(1, Amt)
mzone = k == 1 ? beta : T(1.0)
Dagger.@spawn BLAS.gemv!(
transA,
alpha,
In(Ac[k, m]),
In(Bc[k]),
mzone,
InOut(Cc[m]),
)
end
end
end
end

return C
end
14 changes: 14 additions & 0 deletions test/array/linalg/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,17 @@
L2 = LowerTriangular(DArray(A, Blocks(16, 16)))
@test isapprox(L1, L2)
end

@testset "norm" begin
A = rand(16, 16)
DA = DArray(A)
@test isapprox(norm(A), norm(DA))

A = rand(16)
DA = DArray(A)
@test isapprox(norm(A), norm(DA))

A = rand(16, 16, 16)
DA = DArray(A)
@test isapprox(norm(A), norm(DA))
end
85 changes: 81 additions & 4 deletions test/array/linalg/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,87 @@ part_sets_to_test = map(_sizes_to_test) do sz
]
end
parts_to_test = vcat(part_sets_to_test...)
@testset "Size=$szA*$szB" for (szA, szB) in sizes_to_test
@testset "Partitioning=$partA*$partB" for (partA,partB) in parts_to_test
@testset "T=$T" for T in (Float32, Float64, ComplexF32, ComplexF64)
test_gemm!(T, szA, szB, partA, partB)
@testset "GEMM" begin
@testset "Size=$szA*$szB" for (szA, szB) in sizes_to_test
@testset "Partitioning=$partA*$partB" for (partA,partB) in parts_to_test
@testset "T=$T" for T in (Float32, Float64, ComplexF32, ComplexF64)
test_gemm!(T, szA, szB, partA, partB)
end
end
end
end

function test_gemv!(T, szA, szB, partA, partB)
@assert szA[2] == szB[1]
szC = (szA[1],)
@assert partA.blocksize[2] == partB.blocksize[1]
partC = Blocks(partA.blocksize[1],)

A = rand(T, szA...)
B = rand(T, szB...)

DA = distribute(A, partA)
DB = distribute(B, partB)

## Out-of-place gemm
# No transA
DC = DA * DB
C = A * B
@test collect(DC) ≈ C

if szA[1] == szB[1]
# transA
DC = DA' * DB
C = A' * B
@test collect(DC) ≈ C
end

## In-place gemm
# No transA
C = zeros(T, szC...)
DC = distribute(C, partC)
mul!(C, A, B)
mul!(DC, DA, DB)
@test collect(DC) ≈ C

if szA[1] == szB[1]
# transA
C = zeros(T, szC...)
DC = distribute(C, partC)
mul!(C, A', B)
mul!(DC, DA', DB)
@test collect(DC) ≈ C
end
end

_sizes_to_test = [
(4, 4),
(7, 7),
(12, 12),
(16, 16),
]
size_sets_to_test = map(_sizes_to_test) do sz
rows, cols = sz
return [
(rows, cols) => (cols,),
(rows, cols ÷ 2) => (cols ÷ 2,),
]
end
sizes_to_test = vcat(size_sets_to_test...)
part_sets_to_test = map(_sizes_to_test) do sz
rows, cols = sz
return [
Blocks(rows, cols) => Blocks(cols,),
Blocks(rows, cols ÷ 2) => Blocks(cols ÷ 2,),
]
end
parts_to_test = vcat(part_sets_to_test...)
@testset "GEMV" begin
@testset "Size=$szA*$szB" for (szA, szB) in sizes_to_test
@testset "Partitioning=$partA*$partB" for (partA,partB) in parts_to_test
@testset "T=$T" for T in (Float32, Float64, ComplexF32, ComplexF64)
test_gemv!(T, szA, szB, partA, partB)
end
end
end
end