diff --git a/ext/CUDAExt.jl b/ext/CUDAExt.jl index de4002a6b..025df0979 100644 --- a/ext/CUDAExt.jl +++ b/ext/CUDAExt.jl @@ -56,6 +56,11 @@ function Dagger.aliasing(x::CuArray{T}) where T return Dagger.ContiguousAliasing(Dagger.MemorySpan{S}(rptr, sizeof(T)*length(x))) end +function Dagger.unsafe_free!(x::CuArray) + CUDA.unsafe_free!(x) + return +end + Dagger.memory_spaces(proc::CuArrayDeviceProc) = Set([CUDAVRAMMemorySpace(proc.owner, proc.device, proc.device_uuid)]) Dagger.processors(space::CUDAVRAMMemorySpace) = Set([CuArrayDeviceProc(space.owner, space.device, space.device_uuid)]) @@ -284,11 +289,14 @@ end # Adapt BLAS/LAPACK functions import LinearAlgebra: BLAS, LAPACK +_keep_blas_functions = Set(["iamax"]) for lib in [BLAS, LAPACK] for name in names(lib; all=true) name == nameof(lib) && continue startswith(string(name), '#') && continue - endswith(string(name), '!') || continue + if !endswith(string(name), '!') && !any(endswith(string(name), func) for func in _keep_blas_functions) + continue + end for culib in [CUBLAS, CUSOLVER] if name in names(culib; all=true) @@ -300,6 +308,37 @@ for lib in [BLAS, LAPACK] end end +# Adapt RefValue +Dagger.move(from_proc::CPUProc, to_proc::CuArrayDeviceProc, x::Base.RefValue) = + Dagger.GPURef(Dagger.move(from_proc, to_proc, x[]), only(Dagger.memory_spaces(to_proc))) +Dagger.move(from_proc::CuArrayDeviceProc, to_proc::CPUProc, x::Dagger.GPURef{T,CUDAVRAMMemorySpace} where T) = + Ref(Dagger.move(from_proc, to_proc, x[])) +function Dagger.move!(dep_mod, to_space::CPURAMMemorySpace, from_space::CUDAVRAMMemorySpace, to::Base.RefValue, from::Dagger.GPURef) + if Dagger.type_may_alias(typeof(from[])) + Dagger.move!(dep_mod, to_space, from_space, to[], from[]) + else + to[] = dep_mod(from[]) + end + return +end +function Dagger.move!(dep_mod, to_space::CUDAVRAMMemorySpace, from_space::CPURAMMemorySpace, to::Dagger.GPURef, from::Base.RefValue) + if Dagger.type_may_alias(typeof(from[])) + Dagger.move!(dep_mod, to_space, from_space, to[], from[]) + else + to[] = dep_mod(from[]) + end + return +end +function Dagger.move!(dep_mod, to_space::CUDAVRAMMemorySpace, from_space::CUDAVRAMMemorySpace, to::Dagger.GPURef, from::Dagger.GPURef) + if Dagger.type_may_alias(typeof(from[])) + Dagger.move!(dep_mod, to_space, from_space, to[], from[]) + else + to[] = dep_mod(from[]) + end + return +end + +# Adapt HaloArray CuArray(H::Dagger.HaloArray) = convert(CuArray, H) Base.convert(::Type{C}, H::Dagger.HaloArray) where {C<:CuArray} = Dagger.HaloArray(C(H.center), diff --git a/ext/IntelExt.jl b/ext/IntelExt.jl index 6ef4fc45a..43471c3e8 100644 --- a/ext/IntelExt.jl +++ b/ext/IntelExt.jl @@ -55,6 +55,11 @@ function Dagger.aliasing(x::oneArray{T}) where T return Dagger.ContiguousAliasing(Dagger.MemorySpan{S}(rptr, sizeof(T)*length(x))) end +function Dagger.unsafe_free!(x::oneArray) + oneAPI.unsafe_free!(x) + return +end + Dagger.memory_spaces(proc::oneArrayDeviceProc) = Set([IntelVRAMMemorySpace(proc.owner, proc.device_id)]) Dagger.processors(space::IntelVRAMMemorySpace) = Set([oneArrayDeviceProc(space.owner, space.device_id)]) @@ -278,6 +283,37 @@ function Dagger.execute!(proc::oneArrayDeviceProc, f, args...; kwargs...) end end +# Adapt RefValue +Dagger.move(from_proc::CPUProc, to_proc::oneArrayDeviceProc, x::Base.RefValue) = + Dagger.GPURef(Dagger.move(from_proc, to_proc, x[]), only(Dagger.memory_spaces(to_proc))) +Dagger.move(from_proc::oneArrayDeviceProc, to_proc::CPUProc, x::Dagger.GPURef{T,IntelVRAMMemorySpace} where T) = + Ref(Dagger.move(from_proc, to_proc, x[])) +function Dagger.move!(dep_mod, to_space::CPURAMMemorySpace, from_space::IntelVRAMMemorySpace, to::Base.RefValue, from::Dagger.GPURef) + if Dagger.type_may_alias(typeof(from[])) + Dagger.move!(dep_mod, to_space, from_space, to[], from[]) + else + to[] = dep_mod(from[]) + end + return +end +function Dagger.move!(dep_mod, to_space::IntelVRAMMemorySpace, from_space::CPURAMMemorySpace, to::Dagger.GPURef, from::Base.RefValue) + if Dagger.type_may_alias(typeof(from[])) + Dagger.move!(dep_mod, to_space, from_space, to[], from[]) + else + to[] = dep_mod(from[]) + end + return +end +function Dagger.move!(dep_mod, to_space::IntelVRAMMemorySpace, from_space::IntelVRAMMemorySpace, to::Dagger.GPURef, from::Dagger.GPURef) + if Dagger.type_may_alias(typeof(from[])) + Dagger.move!(dep_mod, to_space, from_space, to[], from[]) + else + to[] = dep_mod(from[]) + end + return +end + +# Adapt HaloArray oneArray(H::Dagger.HaloArray) = convert(oneArray, H) Base.convert(::Type{C}, H::Dagger.HaloArray) where {C<:oneArray} = Dagger.HaloArray(C(H.center), diff --git a/ext/MetalExt.jl b/ext/MetalExt.jl index 43a15a38f..129ebdc98 100644 --- a/ext/MetalExt.jl +++ b/ext/MetalExt.jl @@ -52,6 +52,11 @@ function Dagger.aliasing(x::MtlArray{T}) where T return Dagger.ContiguousAliasing(Dagger.MemorySpan{S}(rptr, sizeof(T)*length(x))) end +function Dagger.unsafe_free!(x::MtlArray) + Metal.unsafe_free!(x) + return +end + Dagger.memory_spaces(proc::MtlArrayDeviceProc) = Set([MetalVRAMMemorySpace(proc.owner, proc.device_id)]) Dagger.processors(space::MetalVRAMMemorySpace) = Set([MtlArrayDeviceProc(space.owner, space.device_id)]) @@ -284,6 +289,37 @@ function Dagger.execute!(proc::MtlArrayDeviceProc, f, args...; kwargs...) end end +# Adapt RefValue +Dagger.move(from_proc::CPUProc, to_proc::MtlArrayDeviceProc, x::Base.RefValue) = + Dagger.GPURef(Dagger.move(from_proc, to_proc, x[]), only(Dagger.memory_spaces(to_proc))) +Dagger.move(from_proc::MtlArrayDeviceProc, to_proc::CPUProc, x::Dagger.GPURef{T,MetalVRAMMemorySpace} where T) = + Ref(Dagger.move(from_proc, to_proc, x[])) +function Dagger.move!(dep_mod, to_space::CPURAMMemorySpace, from_space::MetalVRAMMemorySpace, to::Base.RefValue, from::Dagger.GPURef) + if Dagger.type_may_alias(typeof(from[])) + Dagger.move!(dep_mod, to_space, from_space, to[], from[]) + else + to[] = dep_mod(from[]) + end + return +end +function Dagger.move!(dep_mod, to_space::MetalVRAMMemorySpace, from_space::CPURAMMemorySpace, to::Dagger.GPURef, from::Base.RefValue) + if Dagger.type_may_alias(typeof(from[])) + Dagger.move!(dep_mod, to_space, from_space, to[], from[]) + else + to[] = dep_mod(from[]) + end + return +end +function Dagger.move!(dep_mod, to_space::MetalVRAMMemorySpace, from_space::MetalVRAMMemorySpace, to::Dagger.GPURef, from::Dagger.GPURef) + if Dagger.type_may_alias(typeof(from[])) + Dagger.move!(dep_mod, to_space, from_space, to[], from[]) + else + to[] = dep_mod(from[]) + end + return +end + +# Adapt HaloArray MtlArray(H::Dagger.HaloArray) = convert(MtlArray, H) Base.convert(::Type{C}, H::Dagger.HaloArray) where {C<:MtlArray} = Dagger.HaloArray(C(H.center), diff --git a/ext/OpenCLExt.jl b/ext/OpenCLExt.jl index 144a1ba77..783e3521a 100644 --- a/ext/OpenCLExt.jl +++ b/ext/OpenCLExt.jl @@ -52,6 +52,11 @@ function Dagger.aliasing(x::CLArray{T}) where T return Dagger.ContiguousAliasing(Dagger.MemorySpan{S}(rptr, sizeof(T)*length(x))) end +function Dagger.unsafe_free!(x::CLArray) + cl.unsafe_free!(x) + return +end + Dagger.memory_spaces(proc::CLArrayDeviceProc) = Set([CLMemorySpace(proc.owner, proc.device)]) Dagger.processors(space::CLMemorySpace) = Set([CLArrayDeviceProc(space.owner, space.device)]) @@ -251,6 +256,37 @@ function Dagger.execute!(proc::CLArrayDeviceProc, f, args...; kwargs...) end end +# Adapt RefValue +Dagger.move(from_proc::CPUProc, to_proc::CLArrayDeviceProc, x::Base.RefValue) = + Dagger.GPURef(Dagger.move(from_proc, to_proc, x[]), only(Dagger.memory_spaces(to_proc))) +Dagger.move(from_proc::CLArrayDeviceProc, to_proc::CPUProc, x::Dagger.GPURef{T,CLMemorySpace} where T) = + Ref(Dagger.move(from_proc, to_proc, x[])) +function Dagger.move!(dep_mod, to_space::CPURAMMemorySpace, from_space::CLMemorySpace, to::Base.RefValue, from::Dagger.GPURef) + if Dagger.type_may_alias(typeof(from[])) + Dagger.move!(dep_mod, to_space, from_space, to[], from[]) + else + to[] = dep_mod(from[]) + end + return +end +function Dagger.move!(dep_mod, to_space::CLMemorySpace, from_space::CPURAMMemorySpace, to::Dagger.GPURef, from::Base.RefValue) + if Dagger.type_may_alias(typeof(from[])) + Dagger.move!(dep_mod, to_space, from_space, to[], from[]) + else + to[] = dep_mod(from[]) + end + return +end +function Dagger.move!(dep_mod, to_space::CLMemorySpace, from_space::CLMemorySpace, to::Dagger.GPURef, from::Dagger.GPURef) + if Dagger.type_may_alias(typeof(from[])) + Dagger.move!(dep_mod, to_space, from_space, to[], from[]) + else + to[] = dep_mod(from[]) + end + return +end + +# Adapt HaloArray CLArray(H::Dagger.HaloArray) = convert(CLArray, H) Base.convert(::Type{C}, H::Dagger.HaloArray) where {C<:CLArray} = Dagger.HaloArray(C(H.center), diff --git a/ext/ROCExt.jl b/ext/ROCExt.jl index 444ff919d..a8c7aa98d 100644 --- a/ext/ROCExt.jl +++ b/ext/ROCExt.jl @@ -47,6 +47,11 @@ function Dagger.aliasing(x::ROCArray{T}) where T return Dagger.ContiguousAliasing(Dagger.MemorySpan{S}(rptr, sizeof(T)*length(x))) end +function Dagger.unsafe_free!(x::ROCArray) + AMDGPU.unsafe_free!(x) + return +end + Dagger.memory_spaces(proc::ROCArrayDeviceProc) = Set([ROCVRAMMemorySpace(proc.owner, proc.device_id)]) Dagger.processors(space::ROCVRAMMemorySpace) = Set([ROCArrayDeviceProc(space.owner, space.device_id)]) @@ -231,24 +236,6 @@ Dagger.move(from_proc::CPUProc, to_proc::ROCArrayDeviceProc, x::Function) = x Dagger.move(from_proc::CPUProc, to_proc::ROCArrayDeviceProc, x::Chunk{T}) where {T<:Function} = Dagger.move(from_proc, to_proc, fetch(x)) -# Adapt BLAS/LAPACK functions -import LinearAlgebra: BLAS, LAPACK -for lib in [BLAS, LAPACK] - for name in names(lib; all=true) - name == nameof(lib) && continue - startswith(string(name), '#') && continue - endswith(string(name), '!') || continue - - for roclib in [rocBLAS, rocSOLVER] - if name in names(roclib; all=true) - fn = getproperty(lib, name) - rocfn = getproperty(roclib, name) - @eval Dagger.move(from_proc::CPUProc, to_proc::ROCArrayDeviceProc, ::$(typeof(fn))) = $rocfn - end - end - end -end - # Task execution function Dagger.execute!(proc::ROCArrayDeviceProc, f, args...; kwargs...) @nospecialize f args kwargs @@ -270,6 +257,58 @@ function Dagger.execute!(proc::ROCArrayDeviceProc, f, args...; kwargs...) end end +# Adapt BLAS/LAPACK functions +import LinearAlgebra: BLAS, LAPACK +_keep_blas_functions = Set(["iamax"]) +for lib in [BLAS, LAPACK] + for name in names(lib; all=true) + name == nameof(lib) && continue + startswith(string(name), '#') && continue + if !endswith(string(name), '!') && !any(endswith(string(name), func) for func in _keep_blas_functions) + continue + end + + for roclib in [rocBLAS, rocSOLVER] + if name in names(roclib; all=true) + fn = getproperty(lib, name) + rocfn = getproperty(roclib, name) + @eval Dagger.move(from_proc::CPUProc, to_proc::ROCArrayDeviceProc, ::$(typeof(fn))) = $rocfn + end + end + end +end + +# Adapt RefValue +Dagger.move(from_proc::CPUProc, to_proc::ROCArrayDeviceProc, x::Base.RefValue) = + Dagger.GPURef(Dagger.move(from_proc, to_proc, x[]), only(Dagger.memory_spaces(to_proc))) +Dagger.move(from_proc::ROCArrayDeviceProc, to_proc::CPUProc, x::Dagger.GPURef{T,ROCVRAMMemorySpace} where T) = + Ref(Dagger.move(from_proc, to_proc, x[])) +function Dagger.move!(dep_mod, to_space::CPURAMMemorySpace, from_space::ROCVRAMMemorySpace, to::Base.RefValue, from::Dagger.GPURef) + if Dagger.type_may_alias(typeof(from[])) + Dagger.move!(dep_mod, to_space, from_space, to[], from[]) + else + to[] = dep_mod(from[]) + end + return +end +function Dagger.move!(dep_mod, to_space::ROCVRAMMemorySpace, from_space::CPURAMMemorySpace, to::Dagger.GPURef, from::Base.RefValue) + if Dagger.type_may_alias(typeof(from[])) + Dagger.move!(dep_mod, to_space, from_space, to[], from[]) + else + to[] = dep_mod(from[]) + end + return +end +function Dagger.move!(dep_mod, to_space::ROCVRAMMemorySpace, from_space::ROCVRAMMemorySpace, to::Dagger.GPURef, from::Dagger.GPURef) + if Dagger.type_may_alias(typeof(from[])) + Dagger.move!(dep_mod, to_space, from_space, to[], from[]) + else + to[] = dep_mod(from[]) + end + return +end + +# Adapt HaloArray ROCArray(H::Dagger.HaloArray) = convert(ROCArray, H) Base.convert(::Type{C}, H::Dagger.HaloArray) where {C<:ROCArray} = Dagger.HaloArray(C(H.center), diff --git a/src/array/alloc.jl b/src/array/alloc.jl index fe92ae1e1..aa1050210 100644 --- a/src/array/alloc.jl +++ b/src/array/alloc.jl @@ -206,3 +206,12 @@ function Base.view(A::AbstractArray{T,N}, p::Blocks{N}) where {T,N} end Base.view(A::AbstractArray, ::AutoBlocks) = view(A, auto_blocks(size(A))) + +function unsafe_free!(A::DArray) + spawn_datadeps() do + for chunk in A.chunks + scope = UnionScope(map(ExactScope, collect(processors(memory_space(chunk))))) + Dagger.@spawn scope=scope unsafe_free!(chunk) + end + end +end diff --git a/src/array/copy.jl b/src/array/copy.jl index 7ed815daf..7d92566e6 100644 --- a/src/array/copy.jl +++ b/src/array/copy.jl @@ -18,6 +18,18 @@ function copy_buffered(f, args...) for (buf_arg, arg) in zip(buffered_args, real_args) copyto!(arg, buf_arg) end + + # Free the buffers + foreach(unsafe_free!, buffered_args) + + # If the result is one of the buffered args, return the corresponding + # original arg instead (since we've already copied data back to it, + # and the buffer has been freed) + result_idx = findfirst(buf_arg -> buf_arg === result, buffered_args) + if result_idx !== nothing + return real_args[result_idx] + end + return result end function allocate_copy_buffer(part::Blocks{N}, A::DArray{T,N}) where {T,N} diff --git a/src/array/linalg.jl b/src/array/linalg.jl index 3bf8e20e0..c244410d1 100644 --- a/src/array/linalg.jl +++ b/src/array/linalg.jl @@ -219,22 +219,43 @@ function LinearAlgebra.ldiv!(C::Cholesky{T,<:DMatrix}, B::DVecOrMat) where T parent_A = factors dB = B isa DVecOrMat ? B : (B isa AbstractMatrix ? view(B, factors.partitioning) : view(B, AutoBlocks())) min_bsa = min(parent_A.partitioning.blocksize...) - - if C.uplo == 'U' - # A = U'U → solve U'y = B, then Ux = y - maybe_copy_buffered(parent_A => Blocks(min_bsa, min_bsa), dB => Blocks(min_bsa, min_bsa)) do pA, pB - Dagger.trsm!('L', 'U', trans, 'N', alpha, pA, pB) - end - maybe_copy_buffered(parent_A => Blocks(min_bsa, min_bsa), dB => Blocks(min_bsa, min_bsa)) do pA, pB - Dagger.trsm!('L', 'U', 'N', 'N', alpha, pA, pB) + partB = Blocks(ntuple(_->min_bsa, ndims(B))...) + + if B isa DVector + if C.uplo == 'U' + # A = U'U → solve U'y = B, then Ux = y + maybe_copy_buffered(parent_A => Blocks(min_bsa, min_bsa), dB => partB) do pA, pB + Dagger.trsv!('U', trans, 'N', alpha, pA, pB) + end + maybe_copy_buffered(parent_A => Blocks(min_bsa, min_bsa), dB => partB) do pA, pB + Dagger.trsv!('U', 'N', 'N', alpha, pA, pB) + end + else + # A = LL' → solve Ly = B, then L'x = y + maybe_copy_buffered(parent_A => Blocks(min_bsa, min_bsa), dB => partB) do pA, pB + Dagger.trsv!('L', 'N', 'N', alpha, pA, pB) + end + maybe_copy_buffered(parent_A => Blocks(min_bsa, min_bsa), dB => partB) do pA, pB + Dagger.trsv!('L', trans, 'N', alpha, pA, pB) + end end else - # A = LL' → solve Ly = B, then L'x = y - maybe_copy_buffered(parent_A => Blocks(min_bsa, min_bsa), dB => Blocks(min_bsa, min_bsa)) do pA, pB - Dagger.trsm!('L', 'L', 'N', 'N', alpha, pA, pB) - end - maybe_copy_buffered(parent_A => Blocks(min_bsa, min_bsa), dB => Blocks(min_bsa, min_bsa)) do pA, pB - Dagger.trsm!('L', 'L', trans, 'N', alpha, pA, pB) + if C.uplo == 'U' + # A = U'U → solve U'y = B, then Ux = y + maybe_copy_buffered(parent_A => Blocks(min_bsa, min_bsa), dB => partB) do pA, pB + Dagger.trsm!('L', 'U', trans, 'N', alpha, pA, pB) + end + maybe_copy_buffered(parent_A => Blocks(min_bsa, min_bsa), dB => partB) do pA, pB + Dagger.trsm!('L', 'U', 'N', 'N', alpha, pA, pB) + end + else + # A = LL' → solve Ly = B, then L'x = y + maybe_copy_buffered(parent_A => Blocks(min_bsa, min_bsa), dB => partB) do pA, pB + Dagger.trsm!('L', 'L', 'N', 'N', alpha, pA, pB) + end + maybe_copy_buffered(parent_A => Blocks(min_bsa, min_bsa), dB => partB) do pA, pB + Dagger.trsm!('L', 'L', trans, 'N', alpha, pA, pB) + end end end diff --git a/src/array/lu.jl b/src/array/lu.jl index b4d1448cf..0b9279849 100644 --- a/src/array/lu.jl +++ b/src/array/lu.jl @@ -62,7 +62,7 @@ function search_and_update_ipiv!(ipiv_chunk::AbstractVector{Int}, info::Ref{Int} offdiag_blocks::Vararg{AbstractMatrix{T}}) where T # Search diagonal block column p (rows p:end) diag_col = view(diag_block, p:min(mb, m-(k-1)*mb), p:p) - max_idx = LinearAlgebra.BLAS.iamax(diag_col[:]) + max_idx = move(task_processor(), LinearAlgebra.BLAS.iamax)(diag_col[:]) best_piv_idx = (p - 1) + max_idx best_piv_val = diag_col[max_idx] best_block = 1 @@ -70,7 +70,7 @@ function search_and_update_ipiv!(ipiv_chunk::AbstractVector{Int}, info::Ref{Int} # Search off-diagonal block columns for (bi, blk) in enumerate(offdiag_blocks) col = view(blk, :, p:p) - idx = LinearAlgebra.BLAS.iamax(col[:]) + idx = move(task_processor(), LinearAlgebra.BLAS.iamax)(col[:]) val = col[idx] abs_best = best_piv_val isa Real ? abs(best_piv_val) : abs(real(best_piv_val)) + abs(imag(best_piv_val)) abs_val = val isa Real ? abs(val) : abs(real(val)) + abs(imag(val)) @@ -100,48 +100,29 @@ function swaprows_panel!(A::AbstractMatrix{T}, M::AbstractMatrix{T}, ipiv_chunk: end end -@static if !isdefined(LinearAlgebra.BLAS, :geru!) - for elty in (:Float64, :Float32) - @eval begin - geru!(α::$elty, x::AbstractVector{$elty}, y::AbstractVector{$elty}, A::AbstractMatrix{$elty}) = - LinearAlgebra.BLAS.ger!(α, x, y, A) - end - end - for (fname, elty) in ((:zgeru_,:ComplexF64), (:cgeru_,:ComplexF32)) - @eval begin - function geru!(α::$elty, x::AbstractVector{$elty}, y::AbstractVector{$elty}, A::AbstractMatrix{$elty}) - Base.require_one_based_indexing(A, x, y) - m, n = size(A) - if m != length(x) || n != length(y) - throw(DimensionMismatch(lazy"A has size ($m,$n), x has length $(length(x)), y has length $(length(y))")) - end - px, stx = LinearAlgebra.BLAS.vec_pointer_stride(x, ArgumentError("input vector with 0 stride is not allowed")) - py, sty = LinearAlgebra.BLAS.vec_pointer_stride(y, ArgumentError("input vector with 0 stride is not allowed")) - GC.@preserve x y ccall((LinearAlgebra.BLAS.@blasfunc($fname), LinearAlgebra.BLAS.libblas), Cvoid, - (Ref{Int64}, Ref{Int64}, Ref{$elty}, Ptr{$elty}, - Ref{Int64}, Ptr{$elty}, Ref{Int64}, Ptr{$elty}, - Ref{Int64}), - m, n, α, px, stx, py, sty, A, max(1,stride(A,2))) - A - end - end - end -else - geru! = LinearAlgebra.BLAS.geru! +@kernel function _geru_kernel!(alpha, x, y, A) + i, j = @index(Global, NTuple) + @inbounds A[i, j] = A[i, j] + alpha * x[i] * y[j] +end + +function geru!(α::T, x::AbstractVector{T}, y::AbstractVector{T}, A::AbstractMatrix{T}) where T + isempty(A) && return A + Kernel(_geru_kernel!)(α, x, y, A; ndrange=size(A)) + return A end # Update panel on the diagonal block (rows p+1:end). Receives the full block. function update_panel_diag!(A::AbstractMatrix{T}, p::Int, row_end::Int) where T M = view(A, p+1:row_end, :) Acinv = one(T) / A[p,p] - LinearAlgebra.BLAS.scal!(Acinv, view(M, :, p)) + view(M, :, p) .= Acinv .* view(M, :, p) geru!(-one(T), view(M, :, p), view(A, p, p+1:size(A,2)), view(M, :, p+1:size(M,2))) end # Update panel on an off-diagonal block. Receives full Chunks. function update_panel_offdiag!(M::AbstractMatrix{T}, A::AbstractMatrix{T}, p::Int) where T Acinv = one(T) / A[p,p] - LinearAlgebra.BLAS.scal!(Acinv, view(M, :, p)) + view(M, :, p) .= Acinv .* view(M, :, p) geru!(-one(T), view(M, :, p), view(A, p, p+1:size(A,2)), view(M, :, p+1:size(M,2))) end diff --git a/src/array/operators.jl b/src/array/operators.jl index 62db4400e..f93f09a75 100644 --- a/src/array/operators.jl +++ b/src/array/operators.jl @@ -134,6 +134,39 @@ end Base.first(A::DArray) = A[begin] Base.last(A::DArray) = A[end] +# Addition and subtraction + +function elementwise_op!(f, C, A, B) + @assert size(C) == size(A) == size(B) + C .= f.(A, B) + return +end +function elementwise_op(f, A::DArray, B::DArray) + if size(A) != size(B) + throw(DimensionMismatch("Sizes of A and B must match")) + end + A_part = A.partitioning + B_part = B.partitioning + if A.partitioning != B.partitioning + B_part = A_part + end + C = similar(A) + maybe_copy_buffered(B=>B_part, A=>A_part) do B, A + Ac = A.chunks + Bc = B.chunks + Cc = C.chunks + Dagger.spawn_datadeps() do + for idx in eachindex(Cc) + Dagger.@spawn elementwise_op!(f, Out(Cc[idx]), In(Ac[idx]), In(Bc[idx])) + end + end + return + end + return C +end +Base.:(+)(A::DArray, B::DArray) = elementwise_op(+, A, B) +Base.:(-)(A::DArray, B::DArray) = elementwise_op(-, A, B) + # In-place operations function imap!(f, A) diff --git a/src/array/trsm.jl b/src/array/trsm.jl index 4535a3cb6..65e87c5d5 100644 --- a/src/array/trsm.jl +++ b/src/array/trsm.jl @@ -28,7 +28,7 @@ function trsv!(uplo::Char, trans::Char, diag::Char, alpha::T, A::DMatrix{T}, B:: lalpha = (k == 1) ? alpha : zone Dagger.@spawn BLAS.trsv!('U', trans, diag, In(Ac[k, k]), InOut(Bc[k])) for i in k+1:Bnt - Dagger.@spawn BLAS.gemv!(trans, mzone, In(Ac[k, i]), In(Bc[i]), lalpha, InOut(Bc[k])) + Dagger.@spawn BLAS.gemv!(trans, mzone, In(Ac[k, i]), In(Bc[k]), lalpha, InOut(Bc[i])) end end end @@ -46,7 +46,7 @@ function trsv!(uplo::Char, trans::Char, diag::Char, alpha::T, A::DMatrix{T}, B:: lalpha = (k == Bnt) ? alpha : zone Dagger.@spawn BLAS.trsv!('L', trans, diag, In(Ac[k, k]), InOut(Bc[k])) for i in 1:k-1 - Dagger.@spawn BLAS.gemv!(trans, mzone, In(Ac[k, i]), In(Bc[i]), lalpha, InOut(Bc[k])) + Dagger.@spawn BLAS.gemv!(trans, mzone, In(Ac[k, i]), In(Bc[k]), lalpha, InOut(Bc[i])) end end end diff --git a/src/cancellation.jl b/src/cancellation.jl index ff9e19fcb..93a5aafee 100644 --- a/src/cancellation.jl +++ b/src/cancellation.jl @@ -99,8 +99,7 @@ function _cancel!(state, tid, force, graceful, halt_sch) tid !== nothing && task.id != tid && continue @dagdebug tid :cancel "Cancelling ready task" ex = DTaskFailedException(task, task, InterruptException()) - Sch.store_result!(state, task, ex; error=true) - Sch.finish_failed!(state, task, task) + Sch.set_failed!(state, task; ex) end if tid === nothing empty!(state.ready) @@ -114,8 +113,7 @@ function _cancel!(state, tid, force, graceful, halt_sch) tid !== nothing && task.id != tid && continue @dagdebug tid :cancel "Cancelling waiting task" ex = DTaskFailedException(task, task, InterruptException()) - Sch.store_result!(state, task, ex; error=true) - Sch.finish_failed!(state, task, task) + Sch.set_failed!(state, task; ex) end if tid === nothing empty!(state.waiting) diff --git a/src/datadeps/aliasing.jl b/src/datadeps/aliasing.jl index 64ce11be5..2c6655491 100644 --- a/src/datadeps/aliasing.jl +++ b/src/datadeps/aliasing.jl @@ -246,14 +246,19 @@ struct AliasedObjectCacheStore derived::Dict{AbstractAliasing,AbstractAliasing} stored::Dict{MemorySpace,Set{AbstractAliasing}} values::Dict{MemorySpace,Dict{AbstractAliasing,Chunk}} + originals::Set{AbstractAliasing} end AliasedObjectCacheStore() = AliasedObjectCacheStore(Vector{AbstractAliasing}(), Dict{AbstractAliasing,AbstractAliasing}(), Dict{MemorySpace,Set{AbstractAliasing}}(), - Dict{MemorySpace,Dict{AbstractAliasing,Chunk}}()) + Dict{MemorySpace,Dict{AbstractAliasing,Chunk}}(), + Set{AbstractAliasing}()) function is_stored(cache::AliasedObjectCacheStore, space::MemorySpace, ainfo::AbstractAliasing) + if !(ainfo in cache.originals) + push!(cache.originals, ainfo) + end if !haskey(cache.stored, space) return false end diff --git a/src/datadeps/queue.jl b/src/datadeps/queue.jl index 3e8d89d50..ad832b57c 100644 --- a/src/datadeps/queue.jl +++ b/src/datadeps/queue.jl @@ -147,6 +147,25 @@ function distribute_tasks!(queue::DataDepsTaskQueue) @maybelog ctx timespan_finish(ctx, :datadeps_copy_skip, (;id), (;thunk_id=0, from_space=origin_space, to_space=origin_space, arg_w, from_arg=arg, to_arg=arg)) end end + write_num += 1 + + # Free all allocated buffers + obj_cache = unwrap(state.ainfo_backing_chunk) + for remote_space in keys(obj_cache.values) + for (ainfo, remote_arg) in obj_cache.values[remote_space] + if !(ainfo in obj_cache.originals) + # We allocated this buffer, we can free it + remote_proc = first(processors(remote_space)) + free_scope = ExactScope(remote_proc) + free_syncdeps = Set{ThunkSyncdep}() + # FIXME: Send ainfo through aliasing! to calculate overlaps + #ainfo = AliasingWrapper(aliasing(arg, identity)) + @assert haskey(state.ainfo_arg, ainfo) "Ainfo not found in state.ainfo_arg: $ainfo" + get_write_deps!(state, remote_space, ainfo, write_num, free_syncdeps) + fetch(Dagger.@spawn scope=free_scope syncdeps=free_syncdeps Dagger.unsafe_free!(remote_arg); raw=true) + end + end + end end struct DataDepsTaskDependency arg_w::ArgumentWrapper diff --git a/src/gpu.jl b/src/gpu.jl index fa93f8076..fac44893d 100644 --- a/src/gpu.jl +++ b/src/gpu.jl @@ -103,4 +103,34 @@ end gpu_synchronize(::Val{:CPU}) = nothing with_context!(proc::Processor) = nothing -with_context!(space::MemorySpace) = nothing \ No newline at end of file +with_context!(space::MemorySpace) = nothing + +# Adapt RefValue +mutable struct GPURef{T,S<:MemorySpace} <: Ref{T} + value::T + space::S # This is ignored for aliasing +end +Base.getindex(x::GPURef) = x.value +Base.setindex!(x::GPURef, value) = x.value = value +# FIXME: Wire up with adapt +function aliasing(x::GPURef) + addr = UInt(Base.pointer_from_objref(x) + fieldoffset(typeof(x), 1)) + ptr = RemotePtr{Cvoid}(addr, x.space) + ainfo = ObjectAliasing(ptr, sizeof(x.value)) + return CombinedAliasing([ainfo]) +end +memory_space(x::GPURef) = x.space +function read_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, from::GPURef, from_ptr::UInt64, n::UInt64) + if from_ptr == UInt64(Base.pointer_from_objref(from) + fieldoffset(typeof(from), 1)) + unsafe_copyto!(pointer(copies, copies_offset), Ptr{UInt8}(from_ptr), n) + else + read_remainder!(copies, copies_offset, from[], from_ptr, n) + end +end +function write_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, to::GPURef, to_ptr::UInt64, n::UInt64) + if to_ptr == UInt64(Base.pointer_from_objref(to) + fieldoffset(typeof(to), 1)) + unsafe_copyto!(Ptr{UInt8}(to_ptr), pointer(copies, copies_offset), n) + else + write_remainder!(copies, copies_offset, to[], to_ptr, n) + end +end \ No newline at end of file diff --git a/src/memory-spaces.jl b/src/memory-spaces.jl index 1184f34dd..5c33dd743 100644 --- a/src/memory-spaces.jl +++ b/src/memory-spaces.jl @@ -602,3 +602,11 @@ function will_alias(x_span::MemorySpan, y_span::MemorySpan) y_end = y_span.ptr + y_span.len - 1 return x_span.ptr <= y_end && y_span.ptr <= x_end end + +### Unsafe Free + +unsafe_free!(x::Chunk) = remotecall_wait(root_worker_id(x), x) do x + unsafe_free!(unwrap(x)) +end +unsafe_free!(x::DTask) = unsafe_free!(fetch(x; raw=true)) +unsafe_free!(x) = nothing # Do nothing by default \ No newline at end of file diff --git a/src/sch/Sch.jl b/src/sch/Sch.jl index f0bed125e..a9a7f74c0 100644 --- a/src/sch/Sch.jl +++ b/src/sch/Sch.jl @@ -563,7 +563,7 @@ end if haskey(state.errored, task) # An error was eagerly propagated to this task @dagdebug task :schedule "Task received upstream error, finishing" - finish_failed!(state, task) + set_failed!(state, task) else # This shouldn't have happened @dagdebug task :schedule "Scheduling inconsistency: Task being scheduled is already cached!" @@ -598,8 +598,7 @@ end @something(options.result_scope, AnyScope())) if scope isa InvalidScope ex = SchedulingException("compute_scope and result_scope are not compatible: $(scope.x), $(scope.y)") - store_result!(state, task, ex; error=true) - finish_failed!(state, task) + set_failed!(state, task; ex) @goto pop_task end for arg in task.inputs @@ -615,8 +614,7 @@ end scope = constrain(scope, chunk.scope) if scope isa InvalidScope ex = SchedulingException("Current scope and argument Chunk scope are not compatible: $(scope.x), $(scope.y)") - store_result!(state, task, ex; error=true) - finish_failed!(state, task) + set_failed!(state, task; ex) @goto pop_task end end @@ -677,8 +675,7 @@ end end ex = SchedulingException("No processors available, try widening scope") - store_result!(state, task, ex; error=true) - finish_failed!(state, task) + set_failed!(state, task; ex) @dagdebug task :schedule "No processors available, skipping" sorted_procs_cleanup() costs_cleanup() @@ -750,7 +747,7 @@ function finish_task!(ctx, state, node, thunk_failed) pop!(state.running, node) delete!(state.running_on, node) if thunk_failed - set_failed!(state, node) + set_failed!(state, node; ex=load_result(state, node)) end schedule_dependents!(state, node, thunk_failed) fill_registered_futures!(state, node, thunk_failed) diff --git a/src/sch/util.jl b/src/sch/util.jl index d3b7a4804..dc8b62d30 100644 --- a/src/sch/util.jl +++ b/src/sch/util.jl @@ -232,39 +232,54 @@ end const RESCHEDULE_SYNCDEPS_SEEN_CACHE = TaskLocalValue{ReusableCache{Set{Thunk},Nothing}}(()->ReusableCache(Set{Thunk}, nothing, 1)) "Marks `thunk` and all dependent thunks as failed." -function set_failed!(state, origin, thunk=origin) +function set_failed!(state, origin::Thunk, thunk::Thunk=origin; ex=nothing) @assert islocked(state.lock) has_result(state, thunk) && return @dagdebug thunk :finish "Setting as failed" - filter!(x -> x !== thunk, state.ready) - # N.B. If origin === thunk, we assume that the caller has already set the error - if origin !== thunk && !has_result(state, thunk) - origin_ex = load_result(state, origin) - if origin_ex isa RemoteException - origin_ex = origin_ex.captured - end - ex = DTaskFailedException(thunk, origin, origin_ex) + + if origin === thunk && ex !== nothing store_result!(state, thunk, ex; error=true) end - finish_failed!(state, thunk, origin) -end -function finish_failed!(state, thunk, origin=nothing) - @assert islocked(state.lock) - fill_registered_futures!(state, thunk, true) - if haskey(state.waiting_data, thunk) - for dep in state.waiting_data[thunk] - haskey(state.waiting, dep) && - delete!(state.waiting, dep) - haskey(state.errored, dep) && - continue - origin !== nothing && set_failed!(state, origin, dep) + + seen = Set{Thunk}() + to_visit = Thunk[thunk] + ctr = 0 + while !isempty(to_visit) + ctr += 1 + if ctr > 10000 + error("set_failed! is stuck in a loop") + end + thunk = pop!(to_visit) + push!(seen, thunk) + + filter!(x -> x !== thunk, state.ready) + + if !has_result(state, thunk) && origin !== thunk + origin_ex = load_result(state, origin) + if origin_ex isa RemoteException + origin_ex = origin_ex.captured + end + if origin_ex isa DTaskFailedException + origin_ex = origin_ex.ex + end + ex = DTaskFailedException(thunk, origin, origin_ex) + store_result!(state, thunk, ex; error=true) + end + + fill_registered_futures!(state, thunk, true) + if haskey(state.waiting_data, thunk) + for dep in state.waiting_data[thunk] + haskey(state.errored, dep) && continue + dep in seen && continue + push!(to_visit, dep) + end + delete!(state.waiting_data, thunk) + thunk.sch_accessible = false + delete_unused_task!(state, thunk) + end + if haskey(state.waiting, thunk) + delete!(state.waiting, thunk) end - delete!(state.waiting_data, thunk) - thunk.sch_accessible = false - delete_unused_task!(state, thunk) - end - if haskey(state.waiting, thunk) - delete!(state.waiting, thunk) end end