diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index e9177b9a9..ad94a7b6d 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -91,7 +91,7 @@ steps: codecov: true - label: Julia 1.11 (CUDA) - timeout_in_minutes: 20 + timeout_in_minutes: 30 <<: *gputest plugins: - JuliaCI/julia#v1: diff --git a/ext/CUDAExt.jl b/ext/CUDAExt.jl index 6b8c61f9a..9f9b8df4d 100644 --- a/ext/CUDAExt.jl +++ b/ext/CUDAExt.jl @@ -48,6 +48,13 @@ function Dagger.memory_space(x::CuArray) device_uuid = CUDA.uuid(dev) return CUDAVRAMMemorySpace(myid(), device_id, device_uuid) end +function Dagger.aliasing(x::CuArray{T}) where T + space = Dagger.memory_space(x) + S = typeof(space) + cuptr = pointer(x) + rptr = Dagger.RemotePtr{Cvoid}(UInt64(cuptr), space) + return Dagger.ContiguousAliasing(Dagger.MemorySpan{S}(rptr, sizeof(T)*length(x))) +end Dagger.memory_spaces(proc::CuArrayDeviceProc) = Set([CUDAVRAMMemorySpace(proc.owner, proc.device, proc.device_uuid)]) Dagger.processors(space::CUDAVRAMMemorySpace) = Set([CuArrayDeviceProc(space.owner, space.device, space.device_uuid)]) @@ -75,6 +82,8 @@ function with_context!(space::CUDAVRAMMemorySpace) @assert Dagger.root_worker_id(space) == myid() with_context!(space.device) end +Dagger.with_context!(proc::CuArrayDeviceProc) = with_context!(proc) +Dagger.with_context!(space::CUDAVRAMMemorySpace) = with_context!(space) function with_context(f, x) old_ctx = context() old_stream = stream() diff --git a/ext/IntelExt.jl b/ext/IntelExt.jl index 74253007d..08d54ee81 100644 --- a/ext/IntelExt.jl +++ b/ext/IntelExt.jl @@ -46,6 +46,13 @@ function Dagger.memory_space(x::oneArray) return IntelVRAMMemorySpace(myid(), device_id) end _device_id(dev::ZeDevice) = findfirst(other_dev->other_dev === dev, collect(oneAPI.devices())) +function Dagger.aliasing(x::oneArray{T}) where T + space = Dagger.memory_space(x) + S = typeof(space) + gpu_ptr = pointer(x) + rptr = Dagger.RemotePtr{Cvoid}(UInt64(gpu_ptr), space) + return Dagger.ContiguousAliasing(Dagger.MemorySpan{S}(rptr, sizeof(T)*length(x))) +end Dagger.memory_spaces(proc::oneArrayDeviceProc) = Set([IntelVRAMMemorySpace(proc.owner, proc.device_id)]) Dagger.processors(space::IntelVRAMMemorySpace) = Set([oneArrayDeviceProc(space.owner, space.device_id)]) @@ -68,6 +75,8 @@ function with_context!(space::IntelVRAMMemorySpace) @assert Dagger.root_worker_id(space) == myid() with_context!(space.device_id) end +Dagger.with_context!(proc::oneArrayDeviceProc) = with_context!(proc) +Dagger.with_context!(space::IntelVRAMMemorySpace) = with_context!(space) function with_context(f, x) old_drv = driver() old_dev = device() diff --git a/ext/MetalExt.jl b/ext/MetalExt.jl index 50cfc8905..21cea360a 100644 --- a/ext/MetalExt.jl +++ b/ext/MetalExt.jl @@ -43,6 +43,13 @@ function Dagger.memory_space(x::MtlArray) return MetalVRAMMemorySpace(myid(), device_id) end _device_id(dev::MtlDevice) = findfirst(other_dev->other_dev === dev, Metal.devices()) +function Dagger.aliasing(x::MtlArray{T}) where T + space = Dagger.memory_space(x) + S = typeof(space) + gpu_ptr = pointer(x) + rptr = Dagger.RemotePtr{Cvoid}(UInt64(gpu_ptr), space) + return Dagger.ContiguousAliasing(Dagger.MemorySpan{S}(rptr, sizeof(T)*length(x))) +end Dagger.memory_spaces(proc::MtlArrayDeviceProc) = Set([MetalVRAMMemorySpace(proc.owner, proc.device_id)]) Dagger.processors(space::MetalVRAMMemorySpace) = Set([MtlArrayDeviceProc(space.owner, space.device_id)]) @@ -66,6 +73,8 @@ end function with_context!(space::MetalVRAMMemorySpace) @assert Dagger.root_worker_id(space) == myid() end +Dagger.with_context!(proc::MtlArrayDeviceProc) = with_context!(proc) +Dagger.with_context!(space::MetalVRAMMemorySpace) = with_context!(space) function with_context(f, x) with_context!(x) return f() diff --git a/ext/OpenCLExt.jl b/ext/OpenCLExt.jl index fbf73de72..f8eac930c 100644 --- a/ext/OpenCLExt.jl +++ b/ext/OpenCLExt.jl @@ -44,6 +44,13 @@ function Dagger.memory_space(x::CLArray) idx = findfirst(==(queue), QUEUES) return CLMemorySpace(myid(), idx) end +function Dagger.aliasing(x::CLArray{T}) where T + space = Dagger.memory_space(x) + S = typeof(space) + gpu_ptr = pointer(x) + rptr = Dagger.RemotePtr{Cvoid}(UInt64(gpu_ptr), space) + return Dagger.ContiguousAliasing(Dagger.MemorySpan{S}(rptr, sizeof(T)*length(x))) +end Dagger.memory_spaces(proc::CLArrayDeviceProc) = Set([CLMemorySpace(proc.owner, proc.device)]) Dagger.processors(space::CLMemorySpace) = Set([CLArrayDeviceProc(space.owner, space.device)]) @@ -71,6 +78,8 @@ function with_context!(space::CLMemorySpace) @assert Dagger.root_worker_id(space) == myid() with_context!(space.device) end +Dagger.with_context!(proc::CLArrayDeviceProc) = with_context!(proc) +Dagger.with_context!(space::CLMemorySpace) = with_context!(space) function with_context(f, x) old_ctx = cl.context() old_queue = cl.queue() diff --git a/ext/ROCExt.jl b/ext/ROCExt.jl index 288c4744f..773c2bb95 100644 --- a/ext/ROCExt.jl +++ b/ext/ROCExt.jl @@ -39,6 +39,13 @@ end Dagger.root_worker_id(space::ROCVRAMMemorySpace) = space.owner Dagger.memory_space(x::ROCArray) = ROCVRAMMemorySpace(myid(), AMDGPU.device(x).device_id) +function Dagger.aliasing(x::ROCArray{T}) where T + space = Dagger.memory_space(x) + S = typeof(space) + gpu_ptr = pointer(x) + rptr = Dagger.RemotePtr{Cvoid}(UInt64(gpu_ptr), space) + return Dagger.ContiguousAliasing(Dagger.MemorySpan{S}(rptr, sizeof(T)*length(x))) +end Dagger.memory_spaces(proc::ROCArrayDeviceProc) = Set([ROCVRAMMemorySpace(proc.owner, proc.device_id)]) Dagger.processors(space::ROCVRAMMemorySpace) = Set([ROCArrayDeviceProc(space.owner, space.device_id)]) @@ -67,6 +74,8 @@ function with_context!(space::ROCVRAMMemorySpace) @assert Dagger.root_worker_id(space) == myid() with_context!(space.device_id) end +Dagger.with_context!(proc::ROCArrayDeviceProc) = with_context!(proc) +Dagger.with_context!(space::ROCVRAMMemorySpace) = with_context!(space) function with_context(f, x) old_ctx = context() old_device = AMDGPU.device() diff --git a/src/Dagger.jl b/src/Dagger.jl index fa30c7c1a..102a76149 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -73,6 +73,9 @@ include("utils/fetch.jl") include("utils/chunks.jl") include("utils/logging.jl") include("submission.jl") +abstract type MemorySpace end +include("utils/memory-span.jl") +include("utils/interval_tree.jl") include("memory-spaces.jl") # Task scheduling @@ -83,7 +86,12 @@ include("utils/caching.jl") include("sch/Sch.jl"); using .Sch # Data dependency task queue -include("datadeps.jl") +include("datadeps/aliasing.jl") +include("datadeps/chunkview.jl") +include("datadeps/remainders.jl") +include("datadeps/queue.jl") + +# Stencils include("utils/haloarray.jl") include("stencil.jl") diff --git a/src/argument.jl b/src/argument.jl index 94246a75e..849486e03 100644 --- a/src/argument.jl +++ b/src/argument.jl @@ -20,6 +20,7 @@ function pos_kw(pos::ArgPosition) @assert pos.kw != :NULL return pos.kw end + mutable struct Argument pos::ArgPosition value @@ -41,6 +42,35 @@ function Base.iterate(arg::Argument, state::Bool) return nothing end end - Base.copy(arg::Argument) = Argument(ArgPosition(arg.pos), arg.value) chunktype(arg::Argument) = chunktype(value(arg)) + +mutable struct TypedArgument{T} + pos::ArgPosition + value::T +end +TypedArgument(pos::Integer, value::T) where T = TypedArgument{T}(ArgPosition(true, pos, :NULL), value) +TypedArgument(kw::Symbol, value::T) where T = TypedArgument{T}(ArgPosition(false, 0, kw), value) +Base.setproperty!(arg::TypedArgument, name::Symbol, value::T) where T = + throw(ArgumentError("Cannot set properties of TypedArgument")) +ispositional(arg::TypedArgument) = ispositional(arg.pos) +iskw(arg::TypedArgument) = iskw(arg.pos) +pos_idx(arg::TypedArgument) = pos_idx(arg.pos) +pos_kw(arg::TypedArgument) = pos_kw(arg.pos) +raw_position(arg::TypedArgument) = raw_position(arg.pos) +value(arg::TypedArgument) = arg.value +valuetype(arg::TypedArgument{T}) where T = T +Base.iterate(arg::TypedArgument) = (arg.pos, true) +function Base.iterate(arg::TypedArgument, state::Bool) + if state + return (arg.value, false) + else + return nothing + end +end +Base.copy(arg::TypedArgument{T}) where T = TypedArgument{T}(ArgPosition(arg.pos), arg.value) +chunktype(arg::TypedArgument) = chunktype(value(arg)) + +Argument(arg::TypedArgument) = Argument(arg.pos, arg.value) + +const AnyArgument = Union{Argument, TypedArgument} \ No newline at end of file diff --git a/src/datadeps.jl b/src/datadeps.jl deleted file mode 100644 index d20bda647..000000000 --- a/src/datadeps.jl +++ /dev/null @@ -1,1082 +0,0 @@ -import Graphs: SimpleDiGraph, add_edge!, add_vertex!, inneighbors, outneighbors, nv - -export In, Out, InOut, Deps, spawn_datadeps - -"Specifies a read-only dependency." -struct In{T} - x::T -end -"Specifies a write-only dependency." -struct Out{T} - x::T -end -"Specifies a read-write dependency." -struct InOut{T} - x::T -end -"Specifies one or more dependencies." -struct Deps{T,DT<:Tuple} - x::T - deps::DT -end -Deps(x, deps...) = Deps(x, deps) - -struct DataDepsTaskQueue <: AbstractTaskQueue - # The queue above us - upper_queue::AbstractTaskQueue - # The set of tasks that have already been seen - seen_tasks::Union{Vector{Pair{DTaskSpec,DTask}},Nothing} - # The data-dependency graph of all tasks - g::Union{SimpleDiGraph{Int},Nothing} - # The mapping from task to graph ID - task_to_id::Union{Dict{DTask,Int},Nothing} - # How to traverse the dependency graph when launching tasks - traversal::Symbol - # Which scheduler to use to assign tasks to processors - scheduler::Symbol - - # Whether aliasing across arguments is possible - # The fields following only apply when aliasing==true - aliasing::Bool - - function DataDepsTaskQueue(upper_queue; - traversal::Symbol=:inorder, - scheduler::Symbol=:naive, - aliasing::Bool=true) - seen_tasks = Pair{DTaskSpec,DTask}[] - g = SimpleDiGraph() - task_to_id = Dict{DTask,Int}() - return new(upper_queue, seen_tasks, g, task_to_id, traversal, scheduler, - aliasing) - end -end - -function unwrap_inout(arg) - readdep = false - writedep = false - if arg isa In - readdep = true - arg = arg.x - elseif arg isa Out - writedep = true - arg = arg.x - elseif arg isa InOut - readdep = true - writedep = true - arg = arg.x - elseif arg isa Deps - alldeps = Tuple[] - for dep in arg.deps - dep_mod, inner_deps = unwrap_inout(dep) - for (_, readdep, writedep) in inner_deps - push!(alldeps, (dep_mod, readdep, writedep)) - end - end - arg = arg.x - return arg, alldeps - else - readdep = true - end - return arg, Tuple[(identity, readdep, writedep)] -end - -function enqueue!(queue::DataDepsTaskQueue, spec::Pair{DTaskSpec,DTask}) - push!(queue.seen_tasks, spec) -end -function enqueue!(queue::DataDepsTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}) - append!(queue.seen_tasks, specs) -end - -_identity_hash(arg, h::UInt=UInt(0)) = ismutable(arg) ? objectid(arg) : hash(arg, h) -_identity_hash(arg::SubArray, h::UInt=UInt(0)) = hash(arg.indices, hash(arg.offset1, hash(arg.stride1, _identity_hash(arg.parent, h)))) -_identity_hash(arg::CartesianIndices, h::UInt=UInt(0)) = hash(arg.indices, hash(typeof(arg), h)) - -struct ArgumentWrapper - arg - dep_mod - hash::UInt - - function ArgumentWrapper(arg, dep_mod) - h = hash(dep_mod) - h = _identity_hash(arg, h) - return new(arg, dep_mod, h) - end -end -Base.hash(aw::ArgumentWrapper) = hash(ArgumentWrapper, aw.hash) -Base.:(==)(aw1::ArgumentWrapper, aw2::ArgumentWrapper) = - aw1.hash == aw2.hash - -struct DataDepsAliasingState - # Track original and current data locations - # We track data => space - data_origin::Dict{AliasingWrapper,MemorySpace} - data_locality::Dict{AliasingWrapper,MemorySpace} - - # Track writers ("owners") and readers - ainfos_owner::Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}} - ainfos_readers::Dict{AliasingWrapper,Vector{Pair{DTask,Int}}} - ainfos_overlaps::Dict{AliasingWrapper,Set{AliasingWrapper}} - - # Cache ainfo lookups - ainfo_cache::Dict{ArgumentWrapper,AliasingWrapper} - - function DataDepsAliasingState() - data_origin = Dict{AliasingWrapper,MemorySpace}() - data_locality = Dict{AliasingWrapper,MemorySpace}() - - ainfos_owner = Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}}() - ainfos_readers = Dict{AliasingWrapper,Vector{Pair{DTask,Int}}}() - ainfos_overlaps = Dict{AliasingWrapper,Set{AliasingWrapper}}() - - ainfo_cache = Dict{ArgumentWrapper,AliasingWrapper}() - - return new(data_origin, data_locality, - ainfos_owner, ainfos_readers, ainfos_overlaps, - ainfo_cache) - end -end -struct DataDepsNonAliasingState - # Track original and current data locations - # We track data => space - data_origin::IdDict{Any,MemorySpace} - data_locality::IdDict{Any,MemorySpace} - - # Track writers ("owners") and readers - args_owner::IdDict{Any,Union{Pair{DTask,Int},Nothing}} - args_readers::IdDict{Any,Vector{Pair{DTask,Int}}} - - function DataDepsNonAliasingState() - data_origin = IdDict{Any,MemorySpace}() - data_locality = IdDict{Any,MemorySpace}() - - args_owner = IdDict{Any,Union{Pair{DTask,Int},Nothing}}() - args_readers = IdDict{Any,Vector{Pair{DTask,Int}}}() - - return new(data_origin, data_locality, - args_owner, args_readers) - end -end -struct DataDepsState{State<:Union{DataDepsAliasingState,DataDepsNonAliasingState}} - # Whether aliasing is being analyzed - aliasing::Bool - - # The ordered list of tasks and their read/write dependencies - dependencies::Vector{Pair{DTask,Vector{Tuple{Bool,Bool,AliasingWrapper,<:Any,<:Any}}}} - - # The mapping of memory space to remote argument copies - remote_args::Dict{MemorySpace,IdDict{Any,Any}} - - # Cache of whether arguments supports in-place move - supports_inplace_cache::IdDict{Any,Bool} - - # The aliasing analysis state - alias_state::State - - function DataDepsState(aliasing::Bool) - dependencies = Pair{DTask,Vector{Tuple{Bool,Bool,AliasingWrapper,<:Any,<:Any}}}[] - remote_args = Dict{MemorySpace,IdDict{Any,Any}}() - supports_inplace_cache = IdDict{Any,Bool}() - if aliasing - state = DataDepsAliasingState() - else - state = DataDepsNonAliasingState() - end - return new{typeof(state)}(aliasing, dependencies, remote_args, supports_inplace_cache, state) - end -end - -function aliasing(astate::DataDepsAliasingState, arg, dep_mod) - aw = ArgumentWrapper(arg, dep_mod) - get!(astate.ainfo_cache, aw) do - return AliasingWrapper(aliasing(arg, dep_mod)) - end -end - -function supports_inplace_move(state::DataDepsState, arg) - return get!(state.supports_inplace_cache, arg) do - return supports_inplace_move(arg) - end -end - -# Determine which arguments could be written to, and thus need tracking - -"Whether `arg` has any writedep in this datadeps region." -function has_writedep(state::DataDepsState{DataDepsNonAliasingState}, arg, deps) - # Check if we are writing to this memory - writedep = any(dep->dep[3], deps) - if writedep - arg_has_writedep[arg] = true - return true - end - - # Check if another task is writing to this memory - for (_, taskdeps) in state.dependencies - for (_, other_arg_writedep, _, _, other_arg) in taskdeps - other_arg_writedep || continue - if arg === other_arg - return true - end - end - end - - return false -end -""" -Whether `arg` has any writedep at or before executing `task` in this -datadeps region. -""" -function has_writedep(state::DataDepsState, arg, deps, task::DTask) - is_writedep(arg, deps, task) && return true - if state.aliasing - for (other_task, other_taskdeps) in state.dependencies - for (readdep, writedep, other_ainfo, _, _) in other_taskdeps - writedep || continue - for (dep_mod, _, _) in deps - ainfo = aliasing(state.alias_state, arg, dep_mod) - if will_alias(ainfo, other_ainfo) - return true - end - end - end - if task === other_task - return false - end - end - else - for (other_task, other_taskdeps) in state.dependencies - for (readdep, writedep, _, _, other_arg) in other_taskdeps - writedep || continue - if arg === other_arg - return true - end - end - if task === other_task - return false - end - end - end - error("Task isn't in argdeps set") -end -"Whether `arg` is written to by `task`." -function is_writedep(arg, deps, task::DTask) - return any(dep->dep[3], deps) -end - -# Aliasing state setup -function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask) - # Populate task dependencies - dependencies_to_add = Vector{Tuple{Bool,Bool,AliasingWrapper,<:Any,<:Any}}() - - # Track the task's arguments and access patterns - for (idx, _arg) in enumerate(spec.fargs) - # Unwrap In/InOut/Out wrappers and record dependencies - arg, deps = unwrap_inout(value(_arg)) - - # Unwrap the Chunk underlying any DTask arguments - arg = arg isa DTask ? fetch(arg; raw=true) : arg - - # Skip non-aliasing arguments - type_may_alias(typeof(arg)) || continue - - # Add all aliasing dependencies - for (dep_mod, readdep, writedep) in deps - if state.aliasing - ainfo = aliasing(state.alias_state, arg, dep_mod) - else - ainfo = AliasingWrapper(UnknownAliasing()) - end - push!(dependencies_to_add, (readdep, writedep, ainfo, dep_mod, arg)) - end - - # Populate argument write info - populate_argument_info!(state, arg, deps) - end - - # Track the task result too - # N.B. We state no readdep/writedep because, while we can't model the aliasing info for the task result yet, we don't want to synchronize because of this - push!(dependencies_to_add, (false, false, AliasingWrapper(UnknownAliasing()), identity, task)) - - # Record argument/result dependencies - push!(state.dependencies, task => dependencies_to_add) -end -function populate_argument_info!(state::DataDepsState{DataDepsAliasingState}, arg, deps) - astate = state.alias_state - for (dep_mod, readdep, writedep) in deps - ainfo = aliasing(astate, arg, dep_mod) - - # Initialize owner and readers - if !haskey(astate.ainfos_owner, ainfo) - overlaps = Set{AliasingWrapper}() - push!(overlaps, ainfo) - for other_ainfo in keys(astate.ainfos_owner) - ainfo == other_ainfo && continue - if will_alias(ainfo, other_ainfo) - push!(overlaps, other_ainfo) - push!(astate.ainfos_overlaps[other_ainfo], ainfo) - end - end - astate.ainfos_overlaps[ainfo] = overlaps - astate.ainfos_owner[ainfo] = nothing - astate.ainfos_readers[ainfo] = Pair{DTask,Int}[] - end - - # Assign data owner and locality - if !haskey(astate.data_locality, ainfo) - astate.data_locality[ainfo] = memory_space(arg) - astate.data_origin[ainfo] = memory_space(arg) - end - end -end -function populate_argument_info!(state::DataDepsState{DataDepsNonAliasingState}, arg, deps) - astate = state.alias_state - # Initialize owner and readers - if !haskey(astate.args_owner, arg) - astate.args_owner[arg] = nothing - astate.args_readers[arg] = DTask[] - end - - # Assign data owner and locality - if !haskey(astate.data_locality, arg) - astate.data_locality[arg] = memory_space(arg) - astate.data_origin[arg] = memory_space(arg) - end -end -function populate_return_info!(state::DataDepsState{DataDepsAliasingState}, task, space) - astate = state.alias_state - @assert !haskey(astate.data_locality, task) - # FIXME: We don't yet know about ainfos for this task -end -function populate_return_info!(state::DataDepsState{DataDepsNonAliasingState}, task, space) - astate = state.alias_state - @assert !haskey(astate.data_locality, task) - astate.data_locality[task] = space - astate.data_origin[task] = space -end - -""" - supports_inplace_move(x) -> Bool - -Returns `false` if `x` doesn't support being copied into from another object -like `x`, via `move!`. This is used in `spawn_datadeps` to prevent attempting -to copy between values which don't support mutation or otherwise don't have an -implemented `move!` and want to skip in-place copies. When this returns -`false`, datadeps will instead perform out-of-place copies for each non-local -use of `x`, and the data in `x` will not be updated when the `spawn_datadeps` -region returns. -""" -supports_inplace_move(x) = true -supports_inplace_move(t::DTask) = supports_inplace_move(fetch(t; raw=true)) -function supports_inplace_move(c::Chunk) - # FIXME: Use MemPool.access_ref - pid = root_worker_id(c.processor) - if pid == myid() - return supports_inplace_move(poolget(c.handle)) - else - return remotecall_fetch(supports_inplace_move, pid, c) - end -end -supports_inplace_move(::Function) = false - -# Read/write dependency management -function get_write_deps!(state::DataDepsState, ainfo_or_arg, task, write_num, syncdeps) - _get_write_deps!(state, ainfo_or_arg, task, write_num, syncdeps) - _get_read_deps!(state, ainfo_or_arg, task, write_num, syncdeps) -end -function get_read_deps!(state::DataDepsState, ainfo_or_arg, task, write_num, syncdeps) - _get_write_deps!(state, ainfo_or_arg, task, write_num, syncdeps) -end - -function _get_write_deps!(state::DataDepsState{DataDepsAliasingState}, ainfo::AbstractAliasing, task, write_num, syncdeps) - astate = state.alias_state - ainfo.inner isa NoAliasing && return - for other_ainfo in astate.ainfos_overlaps[ainfo] - other_task_write_num = astate.ainfos_owner[other_ainfo] - @dagdebug nothing :spawn_datadeps "Considering sync with writer via $ainfo -> $other_ainfo" - other_task_write_num === nothing && continue - other_task, other_write_num = other_task_write_num - write_num == other_write_num && continue - @dagdebug nothing :spawn_datadeps "Sync with writer via $ainfo -> $other_ainfo" - push!(syncdeps, ThunkSyncdep(other_task)) - end -end -function _get_read_deps!(state::DataDepsState{DataDepsAliasingState}, ainfo::AbstractAliasing, task, write_num, syncdeps) - astate = state.alias_state - ainfo.inner isa NoAliasing && return - for other_ainfo in astate.ainfos_overlaps[ainfo] - @dagdebug nothing :spawn_datadeps "Considering sync with reader via $ainfo -> $other_ainfo" - other_tasks = astate.ainfos_readers[other_ainfo] - for (other_task, other_write_num) in other_tasks - write_num == other_write_num && continue - @dagdebug nothing :spawn_datadeps "Sync with reader via $ainfo -> $other_ainfo" - push!(syncdeps, ThunkSyncdep(other_task)) - end - end -end -function add_writer!(state::DataDepsState{DataDepsAliasingState}, ainfo::AbstractAliasing, task, write_num) - state.alias_state.ainfos_owner[ainfo] = task=>write_num - empty!(state.alias_state.ainfos_readers[ainfo]) - # Not necessary to assert a read, but conceptually it's true - add_reader!(state, ainfo, task, write_num) -end -function add_reader!(state::DataDepsState{DataDepsAliasingState}, ainfo::AbstractAliasing, task, write_num) - push!(state.alias_state.ainfos_readers[ainfo], task=>write_num) -end - -function _get_write_deps!(state::DataDepsState{DataDepsNonAliasingState}, arg, task, write_num, syncdeps) - other_task_write_num = state.alias_state.args_owner[arg] - if other_task_write_num !== nothing - other_task, other_write_num = other_task_write_num - if write_num != other_write_num - push!(syncdeps, ThunkSyncdep(other_task)) - end - end -end -function _get_read_deps!(state::DataDepsState{DataDepsNonAliasingState}, arg, task, write_num, syncdeps) - for (other_task, other_write_num) in state.alias_state.args_readers[arg] - if write_num != other_write_num - push!(syncdeps, ThunkSyncdep(other_task)) - end - end -end -function add_writer!(state::DataDepsState{DataDepsNonAliasingState}, arg, task, write_num) - state.alias_state.args_owner[arg] = task=>write_num - empty!(state.alias_state.args_readers[arg]) - # Not necessary to assert a read, but conceptually it's true - add_reader!(state, arg, task, write_num) -end -function add_reader!(state::DataDepsState{DataDepsNonAliasingState}, arg, task, write_num) - push!(state.alias_state.args_readers[arg], task=>write_num) -end - -# Make a copy of each piece of data on each worker -# memory_space => {arg => copy_of_arg} -isremotehandle(x) = false -isremotehandle(x::DTask) = true -isremotehandle(x::Chunk) = true -function generate_slot!(state::DataDepsState, dest_space, data) - if data isa DTask - data = fetch(data; raw=true) - end - orig_space = memory_space(data) - to_proc = first(processors(dest_space)) - from_proc = first(processors(orig_space)) - dest_space_args = get!(IdDict{Any,Any}, state.remote_args, dest_space) - if orig_space == dest_space && (data isa Chunk || !isremotehandle(data)) - # Fast path for local data or data already in a Chunk - data_chunk = tochunk(data, from_proc) - dest_space_args[data] = data_chunk - @assert processor(data_chunk) in processors(dest_space) || data isa Chunk && processor(data) isa Dagger.OSProc - @assert memory_space(data_chunk) == orig_space - else - to_w = root_worker_id(dest_space) - ctx = Sch.eager_context() - id = rand(Int) - dest_space_args[data] = remotecall_fetch(to_w, from_proc, to_proc, data) do from_proc, to_proc, data - timespan_start(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data)) - data_converted = move(from_proc, to_proc, data) - data_chunk = tochunk(data_converted, to_proc) - @assert processor(data_chunk) in processors(dest_space) - @assert memory_space(data_converted) == memory_space(data_chunk) "space mismatch! $(memory_space(data_converted)) != $(memory_space(data_chunk)) ($(typeof(data_converted)) vs. $(typeof(data_chunk))), spaces ($orig_space -> $dest_space)" - if orig_space != dest_space - @assert orig_space != memory_space(data_chunk) "space preserved! $orig_space != $(memory_space(data_chunk)) ($(typeof(data)) vs. $(typeof(data_chunk))), spaces ($orig_space -> $dest_space)" - end - return data_chunk - end - timespan_finish(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data=dest_space_args[data])) - end - return dest_space_args[data] -end - -struct DataDepsSchedulerState - task_to_spec::Dict{DTask,DTaskSpec} - assignments::Dict{DTask,MemorySpace} - dependencies::Dict{DTask,Set{DTask}} - task_completions::Dict{DTask,UInt64} - space_completions::Dict{MemorySpace,UInt64} - capacities::Dict{MemorySpace,Int} - - function DataDepsSchedulerState() - return new(Dict{DTask,DTaskSpec}(), - Dict{DTask,MemorySpace}(), - Dict{DTask,Set{DTask}}(), - Dict{DTask,UInt64}(), - Dict{MemorySpace,UInt64}(), - Dict{MemorySpace,Int}()) - end -end - -function distribute_tasks!(queue::DataDepsTaskQueue) - #= TODO: Improvements to be made: - # - Support for copying non-AbstractArray arguments - # - Parallelize read copies - # - Unreference unused slots - # - Reuse memory when possible - # - Account for differently-sized data - =# - - # Get the set of all processors to be scheduled on - all_procs = Processor[] - scope = get_compute_scope() - for w in procs() - append!(all_procs, get_processors(OSProc(w))) - end - filter!(proc->!isa(constrain(ExactScope(proc), scope), - InvalidScope), - all_procs) - if isempty(all_procs) - throw(Sch.SchedulingException("No processors available, try widening scope")) - end - exec_spaces = unique(vcat(map(proc->collect(memory_spaces(proc)), all_procs)...)) - if !all(space->space isa CPURAMMemorySpace, exec_spaces) && !all(space->root_worker_id(space) == myid(), exec_spaces) - @warn "Datadeps support for multi-GPU, multi-worker is currently broken\nPlease be prepared for incorrect results or errors" maxlog=1 - end - - # Round-robin assign tasks to processors - upper_queue = get_options(:task_queue) - - traversal = queue.traversal - if traversal == :inorder - # As-is - task_order = Colon() - elseif traversal == :bfs - # BFS - task_order = Int[1] - to_walk = Int[1] - seen = Set{Int}([1]) - while !isempty(to_walk) - # N.B. next_root has already been seen - next_root = popfirst!(to_walk) - for v in outneighbors(queue.g, next_root) - if !(v in seen) - push!(task_order, v) - push!(seen, v) - push!(to_walk, v) - end - end - end - elseif traversal == :dfs - # DFS (modified with backtracking) - task_order = Int[] - to_walk = Int[1] - seen = Set{Int}() - while length(task_order) < length(queue.seen_tasks) && !isempty(to_walk) - next_root = popfirst!(to_walk) - if !(next_root in seen) - iv = inneighbors(queue.g, next_root) - if all(v->v in seen, iv) - push!(task_order, next_root) - push!(seen, next_root) - ov = outneighbors(queue.g, next_root) - prepend!(to_walk, ov) - else - push!(to_walk, next_root) - end - end - end - else - throw(ArgumentError("Invalid traversal mode: $traversal")) - end - - state = DataDepsState(queue.aliasing) - astate = state.alias_state - sstate = DataDepsSchedulerState() - for proc in all_procs - space = only(memory_spaces(proc)) - get!(()->0, sstate.capacities, space) - sstate.capacities[space] += 1 - end - - # Start launching tasks and necessary copies - write_num = 1 - proc_idx = 1 - pressures = Dict{Processor,Int}() - proc_to_scope_lfu = BasicLFUCache{Processor,AbstractScope}(1024) - for (spec, task) in queue.seen_tasks[task_order] - # Populate all task dependencies - populate_task_info!(state, spec, task) - - task_scope = @something(spec.options.compute_scope, spec.options.scope, DefaultScope()) - scheduler = queue.scheduler - if scheduler == :naive - raw_args = map(arg->tochunk(value(arg)), spec.fargs) - our_proc = remotecall_fetch(1, all_procs, raw_args) do all_procs, raw_args - Sch.init_eager() - sch_state = Sch.EAGER_STATE[] - - @lock sch_state.lock begin - # Calculate costs per processor and select the most optimal - # FIXME: This should consider any already-allocated slots, - # whether they are up-to-date, and if not, the cost of moving - # data to them - procs, costs = Sch.estimate_task_costs(sch_state, all_procs, nothing, raw_args) - return first(procs) - end - end - elseif scheduler == :smart - raw_args = map(filter(arg->haskey(astate.data_locality, value(arg)), spec.fargs)) do arg - arg_chunk = tochunk(last(arg)) - # Only the owned slot is valid - # FIXME: Track up-to-date copies and pass all of those - return arg_chunk => data_locality[arg] - end - f_chunk = tochunk(value(f)) - our_proc, task_pressure = remotecall_fetch(1, all_procs, pressures, f_chunk, raw_args) do all_procs, pressures, f, chunks_locality - Sch.init_eager() - sch_state = Sch.EAGER_STATE[] - - @lock sch_state.lock begin - tx_rate = sch_state.transfer_rate[] - - costs = Dict{Processor,Float64}() - for proc in all_procs - # Filter out chunks that are already local - chunks_filt = Iterators.filter(((chunk, space)=chunk_locality)->!(proc in processors(space)), chunks_locality) - - # Estimate network transfer costs based on data size - # N.B. `affinity(x)` really means "data size of `x`" - # N.B. We treat same-worker transfers as having zero transfer cost - tx_cost = Sch.impute_sum(affinity(chunk)[2] for chunk in chunks_filt) - - # Estimate total cost to move data and get task running after currently-scheduled tasks - est_time_util = get(pressures, proc, UInt64(0)) - costs[proc] = est_time_util + (tx_cost/tx_rate) - end - - # Look up estimated task cost - sig = Sch.signature(sch_state, f, map(first, chunks_locality)) - task_pressure = get(sch_state.signature_time_cost, sig, 1000^3) - - # Shuffle procs around, so equally-costly procs are equally considered - P = randperm(length(all_procs)) - procs = getindex.(Ref(all_procs), P) - - # Sort by lowest cost first - sort!(procs, by=p->costs[p]) - - best_proc = first(procs) - return best_proc, task_pressure - end - end - # FIXME: Pressure should be decreased by pressure of syncdeps on same processor - pressures[our_proc] = get(pressures, our_proc, UInt64(0)) + task_pressure - elseif scheduler == :ultra - args = Base.mapany(spec.fargs) do arg - pos, data = arg - data, _ = unwrap_inout(data) - if data isa DTask - data = fetch(data; raw=true) - end - return pos => tochunk(data) - end - f_chunk = tochunk(value(f)) - task_time = remotecall_fetch(1, f_chunk, args) do f, args - Sch.init_eager() - sch_state = Sch.EAGER_STATE[] - return @lock sch_state.lock begin - sig = Sch.signature(sch_state, f, args) - return get(sch_state.signature_time_cost, sig, 1000^3) - end - end - - # FIXME: Copy deps are computed eagerly - deps = get(Set{Any}, spec.options, :syncdeps) - - # Find latest time-to-completion of all syncdeps - deps_completed = UInt64(0) - for dep in deps - haskey(sstate.task_completions, dep) || continue # copy deps aren't recorded - deps_completed = max(deps_completed, sstate.task_completions[dep]) - end - - # Find latest time-to-completion of each memory space - # FIXME: Figure out space completions based on optimal packing - spaces_completed = Dict{MemorySpace,UInt64}() - for space in exec_spaces - completed = UInt64(0) - for (task, other_space) in sstate.assignments - space == other_space || continue - completed = max(completed, sstate.task_completions[task]) - end - spaces_completed[space] = completed - end - - # Choose the earliest-available memory space and processor - # FIXME: Consider move time - move_time = UInt64(0) - local our_space_completed - while true - our_space_completed, our_space = findmin(spaces_completed) - our_space_procs = filter(proc->proc in all_procs, processors(our_space)) - if isempty(our_space_procs) - delete!(spaces_completed, our_space) - continue - end - our_proc = rand(our_space_procs) - break - end - - sstate.task_to_spec[task] = spec - sstate.assignments[task] = our_space - sstate.task_completions[task] = our_space_completed + move_time + task_time - elseif scheduler == :roundrobin - our_proc = all_procs[proc_idx] - if task_scope == scope - # all_procs is already limited to scope - else - if isa(constrain(task_scope, scope), InvalidScope) - throw(Sch.SchedulingException("Scopes are not compatible: $(scope), $(task_scope)")) - end - while !proc_in_scope(our_proc, task_scope) - proc_idx = mod1(proc_idx + 1, length(all_procs)) - our_proc = all_procs[proc_idx] - end - end - else - error("Invalid scheduler: $sched") - end - @assert our_proc in all_procs - our_space = only(memory_spaces(our_proc)) - - # Find the scope for this task (and its copies) - if task_scope == scope - # Optimize for the common case, cache the proc=>scope mapping - our_scope = get!(proc_to_scope_lfu, our_proc) do - our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) - return constrain(UnionScope(map(ExactScope, our_procs)...), scope) - end - else - # Use the provided scope and constrain it to the available processors - our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) - our_scope = constrain(UnionScope(map(ExactScope, our_procs)...), task_scope) - end - if our_scope isa InvalidScope - throw(Sch.SchedulingException("Scopes are not compatible: $(our_scope.x), $(our_scope.y)")) - end - - f = spec.fargs[1] - f.value = move(ThreadProc(myid(), 1), our_proc, value(f)) - @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) Scheduling: $our_proc ($our_space)" - - # Copy raw task arguments for analysis - task_args = map(copy, spec.fargs) - - # Copy args from local to remote - for (idx, _arg) in enumerate(task_args) - # Is the data writeable? - arg, deps = unwrap_inout(value(_arg)) - arg = arg isa DTask ? fetch(arg; raw=true) : arg - if !type_may_alias(typeof(arg)) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Skipped copy-to (unwritten)" - spec.fargs[idx].value = arg - continue - end - if !supports_inplace_move(state, arg) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Skipped copy-to (non-writeable)" - spec.fargs[idx].value = arg - continue - end - - # Is the source of truth elsewhere? - arg_remote = get!(get!(IdDict{Any,Any}, state.remote_args, our_space), arg) do - generate_slot!(state, our_space, arg) - end - if queue.aliasing - for (dep_mod, _, _) in deps - ainfo = aliasing(astate, arg, dep_mod) - data_space = astate.data_locality[ainfo] - nonlocal = our_space != data_space - if nonlocal - # Add copy-to operation (depends on latest owner of arg) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] Enqueueing copy-to: $data_space => $our_space" - arg_local = get!(get!(IdDict{Any,Any}, state.remote_args, data_space), arg) do - generate_slot!(state, data_space, arg) - end - copy_to_scope = our_scope - copy_to_syncdeps = Set{ThunkSyncdep}() - get_write_deps!(state, ainfo, task, write_num, copy_to_syncdeps) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] $(length(copy_to_syncdeps)) syncdeps" - copy_to = Dagger.@spawn scope=copy_to_scope exec_scope=copy_to_scope syncdeps=copy_to_syncdeps meta=true Dagger.move!(dep_mod, our_space, data_space, arg_remote, arg_local) - add_writer!(state, ainfo, copy_to, write_num) - - astate.data_locality[ainfo] = our_space - else - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] Skipped copy-to (local): $data_space" - end - end - else - data_space = astate.data_locality[arg] - nonlocal = our_space != data_space - if nonlocal - # Add copy-to operation (depends on latest owner of arg) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Enqueueing copy-to: $data_space => $our_space" - arg_local = get!(get!(IdDict{Any,Any}, state.remote_args, data_space), arg) do - generate_slot!(state, data_space, arg) - end - copy_to_scope = our_scope - copy_to_syncdeps = Set{ThunkSyncdep}() - get_write_deps!(state, arg, task, write_num, copy_to_syncdeps) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] $(length(copy_to_syncdeps)) syncdeps" - copy_to = Dagger.@spawn scope=copy_to_scope exec_scope=copy_to_scope syncdeps=copy_to_syncdeps Dagger.move!(identity, our_space, data_space, arg_remote, arg_local) - add_writer!(state, arg, copy_to, write_num) - - astate.data_locality[arg] = our_space - else - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Skipped copy-to (local): $data_space" - end - end - spec.fargs[idx].value = arg_remote - end - write_num += 1 - - # Validate that we're not accidentally performing a copy - for (idx, _arg) in enumerate(spec.fargs) - _, deps = unwrap_inout(value(task_args[idx])) - # N.B. We only do this check when the argument supports in-place - # moves, because for the moment, we are not guaranteeing updates or - # write-back of results - arg = value(_arg) - if is_writedep(arg, deps, task) && supports_inplace_move(state, arg) - arg_space = memory_space(arg) - @assert arg_space == our_space "($(repr(value(f))))[$idx] Tried to pass $(typeof(arg)) from $arg_space to $our_space" - end - end - - # Calculate this task's syncdeps - if spec.options.syncdeps === nothing - spec.options.syncdeps = Set{ThunkSyncdep}() - end - syncdeps = spec.options.syncdeps - for (idx, (_, arg)) in enumerate(task_args) - arg, deps = unwrap_inout(arg) - arg = arg isa DTask ? fetch(arg; raw=true) : arg - type_may_alias(typeof(arg)) || continue - supports_inplace_move(state, arg) || continue - if queue.aliasing - for (dep_mod, _, writedep) in deps - ainfo = aliasing(astate, arg, dep_mod) - if writedep - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] Syncing as writer" - get_write_deps!(state, ainfo, task, write_num, syncdeps) - else - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] Syncing as reader" - get_read_deps!(state, ainfo, task, write_num, syncdeps) - end - end - else - if is_writedep(arg, deps, task) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Syncing as writer" - get_write_deps!(state, arg, task, write_num, syncdeps) - else - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Syncing as reader" - get_read_deps!(state, arg, task, write_num, syncdeps) - end - end - end - @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) $(length(syncdeps)) syncdeps" - - # Launch user's task - spec.options.scope = our_scope - spec.options.exec_scope = our_scope - enqueue!(upper_queue, spec=>task) - - # Update read/write tracking for arguments - for (idx, (_, arg)) in enumerate(task_args) - arg, deps = unwrap_inout(arg) - arg = arg isa DTask ? fetch(arg; raw=true) : arg - type_may_alias(typeof(arg)) || continue - if queue.aliasing - for (dep_mod, _, writedep) in deps - ainfo = aliasing(astate, arg, dep_mod) - if writedep - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] Set as owner" - add_writer!(state, ainfo, task, write_num) - else - add_reader!(state, ainfo, task, write_num) - end - end - else - if is_writedep(arg, deps, task) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Set as owner" - add_writer!(state, arg, task, write_num) - else - add_reader!(state, arg, task, write_num) - end - end - end - - # Update tracking for return value - populate_return_info!(state, task, our_space) - - write_num += 1 - proc_idx = mod1(proc_idx + 1, length(all_procs)) - end - - # Copy args from remote to local - if queue.aliasing - # We need to replay the writes from all tasks in-order (skipping any - # outdated write owners), to ensure that overlapping writes are applied - # in the correct order - - # First, find the latest owners of each live ainfo - arg_writes = IdDict{Any,Vector{Tuple{AliasingWrapper,<:Any,MemorySpace}}}() - for (task, taskdeps) in state.dependencies - for (_, writedep, ainfo, dep_mod, arg) in taskdeps - writedep || continue - haskey(astate.data_locality, ainfo) || continue - @assert haskey(astate.ainfos_owner, ainfo) "Missing ainfo: $ainfo ($dep_mod($(typeof(arg))))" - - # Skip virtual writes from task result aliasing - # FIXME: Make this less bad - if arg isa DTask && dep_mod === identity && ainfo.inner isa UnknownAliasing - continue - end - - # Skip non-writeable arguments - if !supports_inplace_move(state, arg) - @dagdebug nothing :spawn_datadeps "Skipped copy-from (non-writeable)" - continue - end - - # Get the set of writers - ainfo_writes = get!(Vector{Tuple{AliasingWrapper,<:Any,MemorySpace}}, arg_writes, arg) - - #= FIXME: If we fully overlap any writer, evict them - idxs = findall(ainfo_write->overlaps_all(ainfo, ainfo_write[1]), ainfo_writes) - deleteat!(ainfo_writes, idxs) - =# - - # Make ourselves the latest writer - push!(ainfo_writes, (ainfo, dep_mod, astate.data_locality[ainfo])) - end - end - - # Then, replay the writes from each owner in-order - # FIXME: write_num should advance across overlapping ainfo's, as - # writes must be ordered sequentially - for (arg, ainfo_writes) in arg_writes - if length(ainfo_writes) > 1 - # FIXME: Remove me - deleteat!(ainfo_writes, 1:length(ainfo_writes)-1) - end - for (ainfo, dep_mod, data_remote_space) in ainfo_writes - # Is the source of truth elsewhere? - data_local_space = astate.data_origin[ainfo] - if data_local_space != data_remote_space - # Add copy-from operation - @dagdebug nothing :spawn_datadeps "[$dep_mod] Enqueueing copy-from: $data_remote_space => $data_local_space" - arg_local = get!(get!(IdDict{Any,Any}, state.remote_args, data_local_space), arg) do - generate_slot!(state, data_local_space, arg) - end - arg_remote = state.remote_args[data_remote_space][arg] - @assert arg_remote !== arg_local - data_local_proc = first(processors(data_local_space)) - copy_from_scope = UnionScope(map(ExactScope, collect(processors(data_local_space)))...) - copy_from_syncdeps = Set{ThunkSyncdep}() - get_write_deps!(state, ainfo, nothing, write_num, copy_from_syncdeps) - @dagdebug nothing :spawn_datadeps "$(length(copy_from_syncdeps)) syncdeps" - copy_from = Dagger.@spawn scope=copy_from_scope exec_scope=copy_from_scope syncdeps=copy_from_syncdeps meta=true Dagger.move!(dep_mod, data_local_space, data_remote_space, arg_local, arg_remote) - else - @dagdebug nothing :spawn_datadeps "[$dep_mod] Skipped copy-from (local): $data_remote_space" - end - end - end - else - for arg in keys(astate.data_origin) - # Is the data previously written? - arg, deps = unwrap_inout(arg) - if !type_may_alias(typeof(arg)) - @dagdebug nothing :spawn_datadeps "Skipped copy-from (immutable)" - end - - # Can the data be written back to? - if !supports_inplace_move(state, arg) - @dagdebug nothing :spawn_datadeps "Skipped copy-from (non-writeable)" - end - - # Is the source of truth elsewhere? - data_remote_space = astate.data_locality[arg] - data_local_space = astate.data_origin[arg] - if data_local_space != data_remote_space - # Add copy-from operation - @dagdebug nothing :spawn_datadeps "Enqueueing copy-from: $data_remote_space => $data_local_space" - arg_local = state.remote_args[data_local_space][arg] - arg_remote = state.remote_args[data_remote_space][arg] - @assert arg_remote !== arg_local - data_local_proc = first(processors(data_local_space)) - copy_from_scope = ExactScope(data_local_proc) - copy_from_syncdeps = Set{ThunkSyncdep}() - get_write_deps!(state, arg, nothing, write_num, copy_from_syncdeps) - @dagdebug nothing :spawn_datadeps "$(length(copy_from_syncdeps)) syncdeps" - copy_from = Dagger.@spawn scope=copy_from_scope exec_scope=copy_from_scope syncdeps=copy_from_syncdeps meta=true Dagger.move!(identity, data_local_space, data_remote_space, arg_local, arg_remote) - else - @dagdebug nothing :spawn_datadeps "Skipped copy-from (local): $data_remote_space" - end - end - end -end - -""" - spawn_datadeps(f::Base.Callable; traversal::Symbol=:inorder) - -Constructs a "datadeps" (data dependencies) region and calls `f` within it. -Dagger tasks launched within `f` may wrap their arguments with `In`, `Out`, or -`InOut` to indicate whether the task will read, write, or read+write that -argument, respectively. These argument dependencies will be used to specify -which tasks depend on each other based on the following rules: - -- Dependencies across unrelated arguments are independent; only dependencies on arguments which overlap in memory synchronize with each other -- `InOut` is the same as `In` and `Out` applied simultaneously, and synchronizes with the union of the `In` and `Out` effects -- Any two or more `In` dependencies do not synchronize with each other, and may execute in parallel -- An `Out` dependency synchronizes with any previous `In` and `Out` dependencies -- An `In` dependency synchronizes with any previous `Out` dependencies -- If unspecified, an `In` dependency is assumed - -In general, the result of executing tasks following the above rules will be -equivalent to simply executing tasks sequentially and in order of submission. -Of course, if dependencies are incorrectly specified, undefined behavior (and -unexpected results) may occur. - -Unlike other Dagger tasks, tasks executed within a datadeps region are allowed -to write to their arguments when annotated with `Out` or `InOut` -appropriately. - -At the end of executing `f`, `spawn_datadeps` will wait for all launched tasks -to complete, rethrowing the first error, if any. The result of `f` will be -returned from `spawn_datadeps`. - -The keyword argument `traversal` controls the order that tasks are launched by -the scheduler, and may be set to `:bfs` or `:dfs` for Breadth-First Scheduling -or Depth-First Scheduling, respectively. All traversal orders respect the -dependencies and ordering of the launched tasks, but may provide better or -worse performance for a given set of datadeps tasks. This argument is -experimental and subject to change. -""" -function spawn_datadeps(f::Base.Callable; static::Bool=true, - traversal::Symbol=:inorder, - scheduler::Union{Symbol,Nothing}=nothing, - aliasing::Bool=true, - launch_wait::Union{Bool,Nothing}=nothing) - if !static - throw(ArgumentError("Dynamic scheduling is no longer available")) - end - wait_all(; check_errors=true) do - scheduler = something(scheduler, DATADEPS_SCHEDULER[], :roundrobin)::Symbol - launch_wait = something(launch_wait, DATADEPS_LAUNCH_WAIT[], false)::Bool - if launch_wait - result = spawn_bulk() do - queue = DataDepsTaskQueue(get_options(:task_queue); - traversal, scheduler, aliasing) - with_options(f; task_queue=queue) - distribute_tasks!(queue) - end - else - queue = DataDepsTaskQueue(get_options(:task_queue); - traversal, scheduler, aliasing) - result = with_options(f; task_queue=queue) - distribute_tasks!(queue) - end - return result - end -end -const DATADEPS_SCHEDULER = ScopedValue{Union{Symbol,Nothing}}(nothing) -const DATADEPS_LAUNCH_WAIT = ScopedValue{Union{Bool,Nothing}}(nothing) diff --git a/src/datadeps/aliasing.jl b/src/datadeps/aliasing.jl new file mode 100644 index 000000000..260af39ce --- /dev/null +++ b/src/datadeps/aliasing.jl @@ -0,0 +1,852 @@ +import Graphs: SimpleDiGraph, add_edge!, add_vertex!, inneighbors, outneighbors, nv + +export In, Out, InOut, Deps, spawn_datadeps + +#= +============================================================================== + DATADEPS ALIASING AND DATA MOVEMENT SYSTEM +============================================================================== + +This file implements the data dependencies system for Dagger tasks, which allows +tasks to access their arguments in a controlled manner. The system maintains +data coherency across distributed workers by tracking aliasing relationships +and orchestrating data movement operations. + +OVERVIEW: +--------- +The datadeps system enables parallel execution of tasks that modify shared data +by analyzing memory aliasing relationships and scheduling appropriate data +transfers. The core challenge is maintaining coherency when aliased data (e.g., +an array and its views) needs to be accessed by tasks running on different workers. + +KEY CONCEPTS: +------------- + +1. ALIASING ANALYSIS: + - Every mutable argument is analyzed for its memory access pattern + - Memory spans are computed to determine which bytes in memory are accessed + - Arguments that access overlapping memory spans are considered "aliasing" + - Examples: An array A and view(A, 2:3, 2:3) alias each other + +2. DATA LOCALITY TRACKING: + - The system tracks where the "source of truth" for each piece of data lives + - As tasks execute and modify data, the source of truth may move between workers + - Each argument can have its own independent source of truth location + +3. ALIASED OBJECT MANAGEMENT: + - When copying arguments between workers, the system tracks "aliased objects" + - This ensures that if both an array and its view need to be copied to a worker, + only one copy of the underlying array is made, with the view pointing to it + - The aliased_object!() and move_rewrap() functions manage this sharing + +ALIASING INFO: +-------------- + +The system uses different types of aliasing info to represent different types of +aliasing relationships: + +- ContiguousAliasing: Single contiguous memory region (e.g., full array) +- StridedAliasing: Multiple non-contiguous regions (e.g., SubArray) +- DiagonalAliasing: Diagonal elements only (e.g., Diagonal(A)) +- TriangularAliasing: Triangular regions (e.g., UpperTriangular(A)) + +Any two aliasing objects can be compared using the will_alias function to +determine if they overlap. Additionally, any aliasing object can be converted to +a vector of memory spans, which represents the contiguous regions of memory that +the aliasing object covers. + +DATA MOVEMENT FUNCTIONS: +------------------------ + +move!(dep_mod, to_space, from_space, to, from): +- The core in-place data movement function +- dep_mod specifies which part of the data to copy (identity, UpperTriangular, etc.) +- Supports partial copies via RemainderAliasing dependency modifiers + +move_rewrap(...): +- Handles copying of wrapped objects (SubArrays, ChunkViews) +- Ensures aliased objects are reused on destination worker + +read/write_remainder!(...): +- Read/write a span of memory from an object to/from a buffer +- Used by move! to copy the remainder of an aliased object + +THE DISTRIBUTED ALIASING PROBLEM: +--------------------------------- + +In a multithreaded environment, aliasing "just works" because all tasks operate +on the user-provided memory. However, in a distributed environment, arguments +must be copied between workers, which breaks aliasing relationships if care is +not taken. + +Consider this scenario: +```julia +A = rand(4, 4) +vA = view(A, 2:3, 2:3) + +Dagger.spawn_datadeps() do + Dagger.@spawn inc!(InOut(A), 1) # Task 1: increment all of A + Dagger.@spawn inc!(InOut(vA), 2) # Task 2: increment view of A +end +``` + +MULTITHREADED BEHAVIOR (WORKS): +- Both tasks run on the same worker +- They operate on the same memory, with proper dependency tracking +- Task dependencies ensure correct ordering (e.g., Task 1 then Task 2) + +DISTRIBUTED BEHAVIOR (THE PROBLEM): +- Each argument must be copied to the destination worker +- Without special handling, we would copy A and vA independently to another worker +- This creates two separate arrays, breaking the aliasing relationship between A and vA + +THE SOLUTION - PARTIAL DATA MOVEMENT: +------------------------------------- + +The datadeps system solves this by: + +1. UNIFIED ALLOCATION: + - When copying aliased objects, ensure only one underlying array exists per worker + - Use aliased_object!() to detect and reuse existing allocations + - Views on the destination worker point to the shared underlying array + +2. PARTIAL DATA TRANSFER: + - Instead of copying entire objects, only transfer the "dirty" regions + - This prevents overwrites of data that has already been updated by another task + - This also minimizes network traffic and overall copy time + - Uses the move!(dep_mod, ...) function with RemainderAliasing dependency modifiers + +3. REMAINDER TRACKING: + - When a task needs the full object, copy partial regions as needed + - When a partial region is updated, track what parts still need updating + - This preserves all updates while avoiding overwrites + +EXAMPLE EXECUTION FLOW: +----------------------- + +Given: A = 4x4 array, vA = view(A, 2:3, 2:3) +Tasks: T1 modifies InOut(A), T2 modifies InOut(vA) + +1. INITIAL STATE: + - A and vA both exist on worker0 (main worker) + - A's data_locality = worker0, vA's data_locality = worker0 + +2. T1 SCHEDULED ON WORKER1: + - Copy A from worker0 to worker1 + - T1 executes, modifying all of A on worker1 + - Update: A's data_locality = worker1, A is now "dirty" on worker1 + +3. T2 SCHEDULED ON WORKER2: + - T2 needs vA, but vA aliases with A (which was modified by T1) + - Copy vA-region of A from worker1 to worker2 + - This is a PARTIAL copy - only the 2:3, 2:3 region + - Create vA on worker2 pointing to the appropriate region of A + - T2 executes, modifying vA region on worker2 + - Update: vA's data_locality = worker2 + +4. FINAL SYNCHRONIZATION: + - Need to copy-back A and vA to worker0 + - A needs to be assembled from: worker1 (non-vA regions of A) + worker2 (vA region of A) + - REMAINDER COPY: Copy non-vA regions from worker1 to worker0 + - REMAINDER COPY: Copy vA region from worker2 to worker0 + +REMAINDER COMPUTATION: +---------------------- + +Remainder computation involves: +1. Computing memory spans for all overlapping aliasing objects +2. Finding the set difference: full_object_spans - updated_spans +3. Creating a RemainderAliasing object representing the difference between spans +4. Performing one or more move! calls with this RemainderAliasing object to copy only needed data +=# + +"Specifies a read-only dependency." +struct In{T} + x::T +end +"Specifies a write-only dependency." +struct Out{T} + x::T +end +"Specifies a read-write dependency." +struct InOut{T} + x::T +end +"Specifies one or more dependencies." +struct Deps{T,DT<:Tuple} + x::T + deps::DT +end +Deps(x, deps...) = Deps(x, deps) + +chunktype(::In{T}) where T = T +chunktype(::Out{T}) where T = T +chunktype(::InOut{T}) where T = T +chunktype(::Deps{T,DT}) where {T,DT} = T + +function unwrap_inout(arg) + readdep = false + writedep = false + if arg isa In + readdep = true + arg = arg.x + elseif arg isa Out + writedep = true + arg = arg.x + elseif arg isa InOut + readdep = true + writedep = true + arg = arg.x + elseif arg isa Deps + alldeps = Tuple[] + for dep in arg.deps + dep_mod, inner_deps = unwrap_inout(dep) + for (_, readdep, writedep) in inner_deps + push!(alldeps, (dep_mod, readdep, writedep)) + end + end + arg = arg.x + return arg, alldeps + else + readdep = true + end + return arg, Tuple[(identity, readdep, writedep)] +end + +_identity_hash(arg, h::UInt=UInt(0)) = ismutable(arg) ? objectid(arg) : hash(arg, h) +_identity_hash(arg::Chunk, h::UInt=UInt(0)) = hash(arg.handle, hash(Chunk, h)) +_identity_hash(arg::SubArray, h::UInt=UInt(0)) = hash(arg.indices, hash(arg.offset1, hash(arg.stride1, _identity_hash(arg.parent, h)))) +_identity_hash(arg::CartesianIndices, h::UInt=UInt(0)) = hash(arg.indices, hash(typeof(arg), h)) + +struct ArgumentWrapper + arg + dep_mod + hash::UInt + + function ArgumentWrapper(arg, dep_mod) + h = hash(dep_mod) + h = _identity_hash(arg, h) + return new(arg, dep_mod, h) + end +end +Base.hash(aw::ArgumentWrapper) = hash(ArgumentWrapper, aw.hash) +Base.:(==)(aw1::ArgumentWrapper, aw2::ArgumentWrapper) = + aw1.hash == aw2.hash +Base.isequal(aw1::ArgumentWrapper, aw2::ArgumentWrapper) = + aw1.hash == aw2.hash + +struct HistoryEntry + ainfo::AliasingWrapper + space::MemorySpace + write_num::Int +end + +struct AliasedObjectCacheStore + keys::Vector{AbstractAliasing} + derived::Dict{AbstractAliasing,AbstractAliasing} + stored::Dict{MemorySpace,Set{AbstractAliasing}} + values::Dict{MemorySpace,Dict{AbstractAliasing,Chunk}} +end +AliasedObjectCacheStore() = + AliasedObjectCacheStore(Vector{AbstractAliasing}(), + Dict{AbstractAliasing,AbstractAliasing}(), + Dict{MemorySpace,Set{AbstractAliasing}}(), + Dict{MemorySpace,Dict{AbstractAliasing,Chunk}}()) + +function is_stored(cache::AliasedObjectCacheStore, space::MemorySpace, ainfo::AbstractAliasing) + if !haskey(cache.stored, space) + return false + end + if !haskey(cache.derived, ainfo) + return false + end + key = cache.derived[ainfo] + return key in cache.stored[space] +end +function get_stored(cache::AliasedObjectCacheStore, space::MemorySpace, ainfo::AbstractAliasing) + @assert is_stored(cache, space, ainfo) "Cache does not have key $ainfo" + key = cache.derived[ainfo] + return cache.values[space][key] +end +function set_stored!(cache::AliasedObjectCacheStore, dest_space::MemorySpace, value::Chunk, ainfo::AbstractAliasing, orig_space::MemorySpace) + @assert !is_stored(cache, dest_space, ainfo) "Cache already has key $ainfo" + if !haskey(cache.derived, ainfo) + push!(cache.keys, ainfo) + cache.derived[ainfo] = ainfo + push!(get!(Set{AbstractAliasing}, cache.stored, orig_space), ainfo) + key = ainfo + else + key = cache.derived[ainfo] + end + value_ainfo = aliasing(value, identity) + cache.derived[value_ainfo] = key + push!(get!(Set{AbstractAliasing}, cache.stored, dest_space), key) + values_dict = get!(Dict{AbstractAliasing,Chunk}, cache.values, dest_space) + values_dict[key] = value + return +end + +struct AliasedObjectCache + space::MemorySpace + chunk::Chunk +end +function is_stored(cache::AliasedObjectCache, ainfo::AbstractAliasing) + wid = root_worker_id(cache.chunk) + if wid != myid() + return remotecall_fetch(is_stored, wid, cache, ainfo) + end + cache_raw = unwrap(cache.chunk)::AliasedObjectCacheStore + return is_stored(cache_raw, cache.space, ainfo) +end +function get_stored(cache::AliasedObjectCache, ainfo::AbstractAliasing) + wid = root_worker_id(cache.chunk) + if wid != myid() + return remotecall_fetch(get_stored, wid, cache, ainfo) + end + cache_raw = unwrap(cache.chunk)::AliasedObjectCacheStore + return get_stored(cache_raw, cache.space, ainfo) +end +function set_stored!(cache::AliasedObjectCache, value::Chunk, ainfo::AbstractAliasing, orig_space::MemorySpace) + wid = root_worker_id(cache.chunk) + if wid != myid() + return remotecall_fetch(set_stored!, wid, cache, value, ainfo, orig_space) + end + cache_raw = unwrap(cache.chunk)::AliasedObjectCacheStore + set_stored!(cache_raw, cache.space, value, ainfo, orig_space) + return +end +function aliased_object!(f, cache::AliasedObjectCache, x; ainfo=aliasing(x, identity)) + if is_stored(cache, ainfo) + return get_stored(cache, ainfo) + else + y = f(x) + @assert y isa Chunk "Didn't get a Chunk from functor" + @assert memory_space(y) == cache.space "Space mismatch! $(memory_space(y)) != $(cache.space)" + if memory_space(x) != cache.space + @assert ainfo != aliasing(y, identity) "Aliasing mismatch! $ainfo == $(aliasing(y, identity))" + end + set_stored!(cache, y, ainfo, memory_space(x)) + return y + end +end + +struct DataDepsState + # The mapping of original raw argument to its Chunk + raw_arg_to_chunk::IdDict{Any,Chunk} + + # The origin memory space of each argument + # Used to track the original location of an argument, for final copy-from + arg_origin::IdDict{Any,MemorySpace} + + # The mapping of memory space to argument to remote argument copies + # Used to replace an argument with its remote copy + remote_args::Dict{MemorySpace,IdDict{Any,Chunk}} + + # The mapping of remote argument to original argument + remote_arg_to_original::IdDict{Any,Any} + + # The mapping of original argument wrapper to remote argument wrapper + remote_arg_w::Dict{ArgumentWrapper,Dict{MemorySpace,ArgumentWrapper}} + + # The mapping of ainfo to argument and dep_mod + # Used to lookup which argument and dep_mod a given ainfo is generated from + # N.B. This is a mapping for remote argument copies + ainfo_arg::Dict{AliasingWrapper,Set{ArgumentWrapper}} + + # The history of writes (direct or indirect) to each argument and dep_mod, in terms of ainfos directly written to, and the memory space they were written to + # Updated when a new write happens on an overlapping ainfo + # Used by remainder copies to track which portions of an argument and dep_mod were written to elsewhere, through another argument + arg_history::Dict{ArgumentWrapper,Vector{HistoryEntry}} + + # The mapping of memory space and argument to the memory space of the last direct write + # Used by remainder copies to lookup the "backstop" if any portion of the target ainfo is not updated by the remainder + arg_owner::Dict{ArgumentWrapper,MemorySpace} + + # The overlap of each argument with every other argument, based on the ainfo overlaps + # Incrementally updated as new ainfos are created + # Used for fast history updates + arg_overlaps::Dict{ArgumentWrapper,Set{ArgumentWrapper}} + + # The mapping of, for a given memory space, the backing Chunks that an ainfo references + # Used by slot generation to replace the backing Chunks during move + ainfo_backing_chunk::Chunk{AliasedObjectCacheStore} + + # Cache of argument's supports_inplace_move query result + supports_inplace_cache::IdDict{Any,Bool} + + # Cache of argument and dep_mod to ainfo + # N.B. This is a mapping for remote argument copies + ainfo_cache::Dict{ArgumentWrapper,AliasingWrapper} + + # The oracle for aliasing lookups + # Used to populate ainfos_overlaps efficiently + ainfos_lookup::AliasingLookup + + # The overlapping ainfos for each ainfo + # Incrementally updated as new ainfos are created + # Used for fast will_alias lookups + ainfos_overlaps::Dict{AliasingWrapper,Set{AliasingWrapper}} + + # Track writers ("owners") and readers + # Updated as new writer and reader tasks are launched + # Used by task dependency tracking to calculate syncdeps and ensure correct launch ordering + ainfos_owner::Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}} + ainfos_readers::Dict{AliasingWrapper,Vector{Pair{DTask,Int}}} + + function DataDepsState(aliasing::Bool) + if !aliasing + @warn "aliasing=false is no longer supported, aliasing is now always enabled" maxlog=1 + end + + arg_to_chunk = IdDict{Any,Chunk}() + arg_origin = IdDict{Any,MemorySpace}() + remote_args = Dict{MemorySpace,IdDict{Any,Any}}() + remote_arg_to_original = IdDict{Any,Any}() + remote_arg_w = Dict{ArgumentWrapper,Dict{MemorySpace,ArgumentWrapper}}() + ainfo_arg = Dict{AliasingWrapper,Set{ArgumentWrapper}}() + arg_history = Dict{ArgumentWrapper,Vector{HistoryEntry}}() + arg_owner = Dict{ArgumentWrapper,MemorySpace}() + arg_overlaps = Dict{ArgumentWrapper,Set{ArgumentWrapper}}() + ainfo_backing_chunk = tochunk(AliasedObjectCacheStore()) + + supports_inplace_cache = IdDict{Any,Bool}() + ainfo_cache = Dict{ArgumentWrapper,AliasingWrapper}() + + ainfos_lookup = AliasingLookup() + ainfos_overlaps = Dict{AliasingWrapper,Set{AliasingWrapper}}() + + ainfos_owner = Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}}() + ainfos_readers = Dict{AliasingWrapper,Vector{Pair{DTask,Int}}}() + + return new(arg_to_chunk, arg_origin, remote_args, remote_arg_to_original, remote_arg_w, ainfo_arg, arg_history, arg_owner, arg_overlaps, ainfo_backing_chunk, + supports_inplace_cache, ainfo_cache, ainfos_lookup, ainfos_overlaps, ainfos_owner, ainfos_readers) + end +end + +function supports_inplace_move(state::DataDepsState, arg) + return get!(state.supports_inplace_cache, arg) do + return supports_inplace_move(arg) + end +end + +# Determine which arguments could be written to, and thus need tracking +"Whether `arg` is written to by `task`." +function is_writedep(arg, deps, task::DTask) + return any(dep->dep[3], deps) +end + +# Aliasing state setup +function populate_task_info!(state::DataDepsState, task_args, spec::DTaskSpec, task::DTask) + # Track the task's arguments and access patterns + return map_or_ntuple(task_args) do idx + _arg = task_args[idx] + + # Unwrap the argument + _arg_with_deps = value(_arg) + pos = _arg.pos + + # Unwrap In/InOut/Out wrappers and record dependencies + arg_pre_unwrap, deps = unwrap_inout(_arg_with_deps) + + # Unwrap the Chunk underlying any DTask arguments + arg = arg_pre_unwrap isa DTask ? fetch(arg_pre_unwrap; raw=true) : arg_pre_unwrap + + # Skip non-aliasing arguments or arguments that don't support in-place move + may_alias = type_may_alias(typeof(arg)) + inplace_move = may_alias && supports_inplace_move(state, arg) + if !may_alias || !inplace_move + arg_w = ArgumentWrapper(arg, identity) + if is_typed(spec) + return TypedDataDepsTaskArgument(arg, pos, may_alias, inplace_move, (DataDepsTaskDependency(arg_w, false, false),)) + else + return DataDepsTaskArgument(arg, pos, may_alias, inplace_move, [DataDepsTaskDependency(arg_w, false, false)]) + end + end + + # Generate a Chunk for the argument if necessary + if haskey(state.raw_arg_to_chunk, arg) + arg_chunk = state.raw_arg_to_chunk[arg] + else + if !(arg isa Chunk) + arg_chunk = tochunk(arg) + state.raw_arg_to_chunk[arg] = arg_chunk + else + state.raw_arg_to_chunk[arg] = arg + arg_chunk = arg + end + end + + # Track the origin space of the argument + origin_space = memory_space(arg_chunk) + state.arg_origin[arg_chunk] = origin_space + state.remote_arg_to_original[arg_chunk] = arg_chunk + + # Populate argument info for all aliasing dependencies + # And return the argument, dependencies, and ArgumentWrappers + if is_typed(spec) + deps = Tuple(DataDepsTaskDependency(arg_chunk, dep) for dep in deps) + map_or_ntuple(deps) do dep_idx + dep = deps[dep_idx] + # Populate argument info + populate_argument_info!(state, dep.arg_w, origin_space) + end + return TypedDataDepsTaskArgument(arg_chunk, pos, may_alias, inplace_move, deps) + else + deps = [DataDepsTaskDependency(arg_chunk, dep) for dep in deps] + map_or_ntuple(deps) do dep_idx + dep = deps[dep_idx] + # Populate argument info + populate_argument_info!(state, dep.arg_w, origin_space) + end + return DataDepsTaskArgument(arg_chunk, pos, may_alias, inplace_move, deps) + end + end +end +function populate_argument_info!(state::DataDepsState, arg_w::ArgumentWrapper, origin_space::MemorySpace) + # Initialize ownership and history + if !haskey(state.arg_owner, arg_w) + # N.B. This is valid (even if the backing data is up-to-date elsewhere), + # because we only use this to track the "backstop" if any portion of the + # target ainfo is not updated by the remainder (at which point, this + # is thus the correct owner). + state.arg_owner[arg_w] = origin_space + + # Initialize the overlap set + state.arg_overlaps[arg_w] = Set{ArgumentWrapper}() + end + if !haskey(state.arg_history, arg_w) + state.arg_history[arg_w] = Vector{HistoryEntry}() + end + + # Calculate the ainfo (which will populate ainfo structures and merge history) + aliasing!(state, origin_space, arg_w) +end +# N.B. arg_w must be the original argument wrapper, not a remote copy +function aliasing!(state::DataDepsState, target_space::MemorySpace, arg_w::ArgumentWrapper) + if haskey(state.remote_arg_w, arg_w) && haskey(state.remote_arg_w[arg_w], target_space) + remote_arg_w = @inbounds state.remote_arg_w[arg_w][target_space] + remote_arg = remote_arg_w.arg + else + # Grab the remote copy of the argument, and calculate the ainfo + remote_arg = get_or_generate_slot!(state, target_space, arg_w.arg) + remote_arg_w = ArgumentWrapper(remote_arg, arg_w.dep_mod) + get!(Dict{MemorySpace,ArgumentWrapper}, state.remote_arg_w, arg_w)[target_space] = remote_arg_w + end + + # Check if we already have the result cached + if haskey(state.ainfo_cache, remote_arg_w) + return state.ainfo_cache[remote_arg_w] + end + + # Calculate the ainfo + ainfo = AliasingWrapper(aliasing(remote_arg, arg_w.dep_mod)) + + # Cache the result + state.ainfo_cache[remote_arg_w] = ainfo + + # Update the mapping of ainfo to argument and dep_mod + if !haskey(state.ainfo_arg, ainfo) + state.ainfo_arg[ainfo] = Set{ArgumentWrapper}([remote_arg_w]) + end + push!(state.ainfo_arg[ainfo], remote_arg_w) + + # Populate info for the new ainfo + populate_ainfo!(state, arg_w, ainfo, target_space) + + return ainfo +end +function populate_ainfo!(state::DataDepsState, original_arg_w::ArgumentWrapper, target_ainfo::AliasingWrapper, target_space::MemorySpace) + if !haskey(state.ainfos_owner, target_ainfo) + # Add ourselves to the lookup oracle + ainfo_idx = push!(state.ainfos_lookup, target_ainfo) + + # Find overlapping ainfos + overlaps = Set{AliasingWrapper}() + push!(overlaps, target_ainfo) + for other_ainfo in intersect(state.ainfos_lookup, target_ainfo; ainfo_idx) + target_ainfo == other_ainfo && continue + # Mark us and them as overlapping + push!(overlaps, other_ainfo) + push!(state.ainfos_overlaps[other_ainfo], target_ainfo) + + # Add overlapping history to our own + for other_remote_arg_w in state.ainfo_arg[other_ainfo] + other_arg = state.remote_arg_to_original[other_remote_arg_w.arg] + other_arg_w = ArgumentWrapper(other_arg, other_remote_arg_w.dep_mod) + push!(state.arg_overlaps[original_arg_w], other_arg_w) + push!(state.arg_overlaps[other_arg_w], original_arg_w) + merge_history!(state, original_arg_w, other_arg_w) + end + end + state.ainfos_overlaps[target_ainfo] = overlaps + + # Initialize owner and readers + state.ainfos_owner[target_ainfo] = nothing + state.ainfos_readers[target_ainfo] = Pair{DTask,Int}[] + end +end +function merge_history!(state::DataDepsState, arg_w::ArgumentWrapper, other_arg_w::ArgumentWrapper) + history = state.arg_history[arg_w] + @opcounter :merge_history + @opcounter :merge_history_complexity length(history) + origin_space = state.arg_origin[other_arg_w.arg] + for other_entry in state.arg_history[other_arg_w] + write_num_tuple = HistoryEntry(AliasingWrapper(NoAliasing()), origin_space, other_entry.write_num) + range = searchsorted(history, write_num_tuple; by=x->x.write_num) + if !isempty(range) + # Find and skip duplicates + match = false + for source_idx in range + source_entry = history[source_idx] + if source_entry.ainfo == other_entry.ainfo && + source_entry.space == other_entry.space && + source_entry.write_num == other_entry.write_num + match = true + break + end + end + match && continue + + # Insert at the first position + idx = first(range) + else + # Insert at the last position + idx = length(history) + 1 + end + insert!(history, idx, other_entry) + end +end +function truncate_history!(state::DataDepsState, arg_w::ArgumentWrapper) + # FIXME: Do this continuously if possible + if haskey(state.arg_history, arg_w) && length(state.arg_history[arg_w]) > 100000 + origin_space = state.arg_origin[arg_w.arg] + @opcounter :truncate_history + _, last_idx = compute_remainder_for_arg!(state, origin_space, arg_w, 0; compute_syncdeps=false) + if last_idx > 0 + @opcounter :truncate_history_removed last_idx + deleteat!(state.arg_history[arg_w], 1:last_idx) + end + end +end + +""" + supports_inplace_move(x) -> Bool + +Returns `false` if `x` doesn't support being copied into from another object +like `x`, via `move!`. This is used in `spawn_datadeps` to prevent attempting +to copy between values which don't support mutation or otherwise don't have an +implemented `move!` and want to skip in-place copies. When this returns +`false`, datadeps will instead perform out-of-place copies for each non-local +use of `x`, and the data in `x` will not be updated when the `spawn_datadeps` +region returns. +""" +supports_inplace_move(x) = true +supports_inplace_move(t::DTask) = supports_inplace_move(fetch(t; raw=true)) +function supports_inplace_move(c::Chunk) + # FIXME: Use MemPool.access_ref + pid = root_worker_id(c.processor) + if pid == myid() + return supports_inplace_move(poolget(c.handle)) + else + return remotecall_fetch(supports_inplace_move, pid, c) + end +end +supports_inplace_move(::Function) = false + +# Read/write dependency management +function get_write_deps!(state::DataDepsState, dest_space::MemorySpace, ainfo::AbstractAliasing, write_num, syncdeps) + # We need to sync with both writers and readers + _get_write_deps!(state, dest_space, ainfo, write_num, syncdeps) + _get_read_deps!(state, dest_space, ainfo, write_num, syncdeps) +end +function get_read_deps!(state::DataDepsState, dest_space::MemorySpace, ainfo::AbstractAliasing, write_num, syncdeps) + # We only need to sync with writers, not readers + _get_write_deps!(state, dest_space, ainfo, write_num, syncdeps) +end + +function _get_write_deps!(state::DataDepsState, dest_space::MemorySpace, ainfo::AbstractAliasing, write_num, syncdeps) + ainfo.inner isa NoAliasing && return + for other_ainfo in state.ainfos_overlaps[ainfo] + other_task_write_num = state.ainfos_owner[other_ainfo] + @dagdebug nothing :spawn_datadeps_sync "Considering sync with writer via $ainfo -> $other_ainfo" + other_task_write_num === nothing && continue + other_task, other_write_num = other_task_write_num + write_num == other_write_num && continue + @dagdebug nothing :spawn_datadeps_sync "Sync with writer via $ainfo -> $other_ainfo" + push!(syncdeps, ThunkSyncdep(other_task)) + end +end +function _get_read_deps!(state::DataDepsState, dest_space::MemorySpace, ainfo::AbstractAliasing, write_num, syncdeps) + ainfo.inner isa NoAliasing && return + for other_ainfo in state.ainfos_overlaps[ainfo] + @dagdebug nothing :spawn_datadeps_sync "Considering sync with reader via $ainfo -> $other_ainfo" + other_tasks = state.ainfos_readers[other_ainfo] + for (other_task, other_write_num) in other_tasks + write_num == other_write_num && continue + @dagdebug nothing :spawn_datadeps_sync "Sync with reader via $ainfo -> $other_ainfo" + push!(syncdeps, ThunkSyncdep(other_task)) + end + end +end +function add_writer!(state::DataDepsState, arg_w::ArgumentWrapper, dest_space::MemorySpace, ainfo::AbstractAliasing, task, write_num) + state.ainfos_owner[ainfo] = task=>write_num + empty!(state.ainfos_readers[ainfo]) + + # Clear the history for this target, since this is a new write event + empty!(state.arg_history[arg_w]) + + # Add our own history + push!(state.arg_history[arg_w], HistoryEntry(ainfo, dest_space, write_num)) + + # Find overlapping arguments and update their history + for other_arg_w in state.arg_overlaps[arg_w] + other_arg_w == arg_w && continue + push!(state.arg_history[other_arg_w], HistoryEntry(ainfo, dest_space, write_num)) + end + + # Record the last place we were fully written to + state.arg_owner[arg_w] = dest_space + + # Not necessary to assert a read, but conceptually it's true + add_reader!(state, arg_w, dest_space, ainfo, task, write_num) +end +function add_reader!(state::DataDepsState, arg_w::ArgumentWrapper, dest_space::MemorySpace, ainfo::AbstractAliasing, task, write_num) + push!(state.ainfos_readers[ainfo], task=>write_num) +end + +# Make a copy of each piece of data on each worker +# memory_space => {arg => copy_of_arg} +isremotehandle(x) = false +isremotehandle(x::DTask) = true +isremotehandle(x::Chunk) = true +function generate_slot!(state::DataDepsState, dest_space, data) + # N.B. We do not perform any sync/copy with the current owner of the data, + # because all we want here is to make a copy of some version of the data, + # even if the data is not up to date. + orig_space = memory_space(data) + to_proc = first(processors(dest_space)) + from_proc = first(processors(orig_space)) + dest_space_args = get!(IdDict{Any,Any}, state.remote_args, dest_space) + aliased_object_cache = AliasedObjectCache(dest_space, state.ainfo_backing_chunk) + ctx = Sch.eager_context() + id = rand(Int) + @maybelog ctx timespan_start(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data)) + data_chunk = move_rewrap(aliased_object_cache, from_proc, to_proc, orig_space, dest_space, data) + @maybelog ctx timespan_finish(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data=data_chunk)) + @assert memory_space(data_chunk) == dest_space "space mismatch! $dest_space (dest) != $(memory_space(data_chunk)) (actual) ($(typeof(data)) (data) vs. $(typeof(data_chunk)) (chunk)), spaces ($orig_space -> $dest_space)" + dest_space_args[data] = data_chunk + state.remote_arg_to_original[data_chunk] = data + + return dest_space_args[data] +end +function get_or_generate_slot!(state, dest_space, data) + @assert !(data isa ArgumentWrapper) + if !haskey(state.remote_args, dest_space) + state.remote_args[dest_space] = IdDict{Any,Any}() + end + if !haskey(state.remote_args[dest_space], data) + return generate_slot!(state, dest_space, data) + end + return state.remote_args[dest_space][data] +end +function remotecall_endpoint(f, from_proc, to_proc, from_space, to_space, data) + to_w = root_worker_id(to_proc) + if to_w == myid() + data_converted = f(move(from_proc, to_proc, data)) + return tochunk(data_converted, to_proc) + end + return remotecall_fetch(to_w, from_proc, to_proc, to_space, data) do from_proc, to_proc, to_space, data + data_converted = f(move(from_proc, to_proc, data)) + return tochunk(data_converted, to_proc) + end +end +function rewrap_aliased_object!(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x) + return aliased_object!(cache, x) do x + return remotecall_endpoint(identity, from_proc, to_proc, from_space, to_space, x) + end +end +function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, data::Chunk) + # Unwrap so that we hit the right dispatch + wid = root_worker_id(data) + if wid != myid() + return remotecall_fetch(move_rewrap, wid, cache, from_proc, to_proc, from_space, to_space, data) + end + data_raw = unwrap(data) + return move_rewrap(cache, from_proc, to_proc, from_space, to_space, data_raw) +end +function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, data) + # For generic data + return aliased_object!(cache, data) do data + return remotecall_endpoint(identity, from_proc, to_proc, from_space, to_space, data) + end +end +function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::SubArray) + to_w = root_worker_id(to_proc) + p_chunk = rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, parent(v)) + inds = parentindices(v) + return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk, inds) do from_proc, to_proc, from_space, to_space, p_chunk, inds + p_new = move(from_proc, to_proc, p_chunk) + v_new = view(p_new, inds...) + return tochunk(v_new, to_proc) + end +end +# FIXME: Do this programmatically via recursive dispatch +for wrapper in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular) + @eval function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::$(wrapper)) + to_w = root_worker_id(to_proc) + p_chunk = rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, parent(v)) + return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk) do from_proc, to_proc, from_space, to_space, p_chunk + p_new = move(from_proc, to_proc, p_chunk) + v_new = $(wrapper)(p_new) + return tochunk(v_new, to_proc) + end + end +end +function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::Base.RefValue) + to_w = root_worker_id(to_proc) + p_chunk = rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, v[]) + return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk) do from_proc, to_proc, from_space, to_space, p_chunk + p_new = move(from_proc, to_proc, p_chunk) + v_new = Ref(p_new) + return tochunk(v_new, to_proc) + end +end +#= FIXME: Make this work so we can automatically move-rewrap recursive objects +function move_rewrap_recursive(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::T) where T + if isstructtype(T) + # Check all object fields (recursive) + for field in fieldnames(T) + value = getfield(x, field) + new_value = aliased_object!(cache, value) do value + return move_rewrap_recursive(cache, from_proc, to_proc, from_space, to_space, value) + end + setfield!(x, field, new_value) + end + return x + else + @warn "Cannot move-rewrap object of type $T" + return x + end +end +move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::String) = x # FIXME: Not necessarily true +move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::Symbol) = x +move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::Type) = x +=# + +struct DataDepsSchedulerState + task_to_spec::Dict{DTask,DTaskSpec} + assignments::Dict{DTask,MemorySpace} + dependencies::Dict{DTask,Set{DTask}} + task_completions::Dict{DTask,UInt64} + space_completions::Dict{MemorySpace,UInt64} + capacities::Dict{MemorySpace,Int} + + function DataDepsSchedulerState() + return new(Dict{DTask,DTaskSpec}(), + Dict{DTask,MemorySpace}(), + Dict{DTask,Set{DTask}}(), + Dict{DTask,UInt64}(), + Dict{MemorySpace,UInt64}(), + Dict{MemorySpace,Int}()) + end +end \ No newline at end of file diff --git a/src/datadeps/chunkview.jl b/src/datadeps/chunkview.jl new file mode 100644 index 000000000..42f32cca9 --- /dev/null +++ b/src/datadeps/chunkview.jl @@ -0,0 +1,58 @@ +struct ChunkView{N} + chunk::Chunk + slices::NTuple{N, Union{Int, AbstractRange{Int}, Colon}} +end + +function Base.view(c::Chunk, slices...) + if c.domain isa ArrayDomain + nd, sz = ndims(c.domain), size(c.domain) + nd == length(slices) || throw(DimensionMismatch("Expected $nd slices, got $(length(slices))")) + + for (i, s) in enumerate(slices) + if s isa Int + 1 ≤ s ≤ sz[i] || throw(ArgumentError("Index $s out of bounds for dimension $i (size $(sz[i]))")) + elseif s isa AbstractRange + isempty(s) && continue + 1 ≤ first(s) ≤ last(s) ≤ sz[i] || throw(ArgumentError("Range $s out of bounds for dimension $i (size $(sz[i]))")) + elseif s === Colon() + continue + else + throw(ArgumentError("Invalid slice type $(typeof(s)) at dimension $i, Expected Type of Int, AbstractRange, or Colon")) + end + end + end + + return ChunkView(c, slices) +end + +Base.view(c::DTask, slices...) = view(fetch(c; raw=true), slices...) + +function aliasing(x::ChunkView{N}) where N + return remotecall_fetch(root_worker_id(x.chunk.processor), x.chunk, x.slices) do x, slices + x = unwrap(x) + v = view(x, slices...) + return aliasing(v) + end +end +memory_space(x::ChunkView) = memory_space(x.chunk) +isremotehandle(x::ChunkView) = true + +function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, slice::ChunkView) + to_w = root_worker_id(to_proc) + p_chunk = rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, slice.chunk) + return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk, slice.slices) do from_proc, to_proc, from_space, to_space, p_chunk, inds + p_new = move(from_proc, to_proc, p_chunk) + v_new = view(p_new, inds...) + return tochunk(v_new, to_proc) + end +end +function move(from_proc::Processor, to_proc::Processor, slice::ChunkView) + to_w = root_worker_id(to_proc) + return remotecall_fetch(to_w, from_proc, to_proc, slice.chunk, slice.slices) do from_proc, to_proc, chunk, slices + chunk_new = move(from_proc, to_proc, chunk) + v_new = view(chunk_new, slices...) + return tochunk(v_new, to_proc) + end +end + +Base.fetch(slice::ChunkView) = view(fetch(slice.chunk), slice.slices...) \ No newline at end of file diff --git a/src/datadeps/queue.jl b/src/datadeps/queue.jl new file mode 100644 index 000000000..12e773d7a --- /dev/null +++ b/src/datadeps/queue.jl @@ -0,0 +1,560 @@ +struct DataDepsTaskQueue <: AbstractTaskQueue + # The queue above us + upper_queue::AbstractTaskQueue + # The set of tasks that have already been seen + seen_tasks::Union{Vector{DTaskPair},Nothing} + # The data-dependency graph of all tasks + g::Union{SimpleDiGraph{Int},Nothing} + # The mapping from task to graph ID + task_to_id::Union{Dict{DTask,Int},Nothing} + # How to traverse the dependency graph when launching tasks + traversal::Symbol + # Which scheduler to use to assign tasks to processors + scheduler::Symbol + + # Whether aliasing across arguments is possible + # The fields following only apply when aliasing==true + aliasing::Bool + + function DataDepsTaskQueue(upper_queue; + traversal::Symbol=:inorder, + scheduler::Symbol=:naive, + aliasing::Bool=true) + seen_tasks = DTaskPair[] + g = SimpleDiGraph() + task_to_id = Dict{DTask,Int}() + return new(upper_queue, seen_tasks, g, task_to_id, traversal, scheduler, + aliasing) + end +end + +function enqueue!(queue::DataDepsTaskQueue, pair::DTaskPair) + push!(queue.seen_tasks, pair) +end +function enqueue!(queue::DataDepsTaskQueue, pairs::Vector{DTaskPair}) + append!(queue.seen_tasks, pairs) +end + +""" + spawn_datadeps(f::Base.Callable; traversal::Symbol=:inorder) + +Constructs a "datadeps" (data dependencies) region and calls `f` within it. +Dagger tasks launched within `f` may wrap their arguments with `In`, `Out`, or +`InOut` to indicate whether the task will read, write, or read+write that +argument, respectively. These argument dependencies will be used to specify +which tasks depend on each other based on the following rules: + +- Dependencies across unrelated arguments are independent; only dependencies on arguments which overlap in memory synchronize with each other +- `InOut` is the same as `In` and `Out` applied simultaneously, and synchronizes with the union of the `In` and `Out` effects +- Any two or more `In` dependencies do not synchronize with each other, and may execute in parallel +- An `Out` dependency synchronizes with any previous `In` and `Out` dependencies +- An `In` dependency synchronizes with any previous `Out` dependencies +- If unspecified, an `In` dependency is assumed + +In general, the result of executing tasks following the above rules will be +equivalent to simply executing tasks sequentially and in order of submission. +Of course, if dependencies are incorrectly specified, undefined behavior (and +unexpected results) may occur. + +Unlike other Dagger tasks, tasks executed within a datadeps region are allowed +to write to their arguments when annotated with `Out` or `InOut` +appropriately. + +At the end of executing `f`, `spawn_datadeps` will wait for all launched tasks +to complete, rethrowing the first error, if any. The result of `f` will be +returned from `spawn_datadeps`. + +The keyword argument `traversal` controls the order that tasks are launched by +the scheduler, and may be set to `:bfs` or `:dfs` for Breadth-First Scheduling +or Depth-First Scheduling, respectively. All traversal orders respect the +dependencies and ordering of the launched tasks, but may provide better or +worse performance for a given set of datadeps tasks. This argument is +experimental and subject to change. +""" +function spawn_datadeps(f::Base.Callable; static::Bool=true, + traversal::Symbol=:inorder, + scheduler::Union{Symbol,Nothing}=nothing, + aliasing::Bool=true, + launch_wait::Union{Bool,Nothing}=nothing) + if !static + throw(ArgumentError("Dynamic scheduling is no longer available")) + end + if !aliasing + throw(ArgumentError("Aliasing analysis is no longer optional")) + end + wait_all(; check_errors=true) do + scheduler = something(scheduler, DATADEPS_SCHEDULER[], :roundrobin)::Symbol + launch_wait = something(launch_wait, DATADEPS_LAUNCH_WAIT[], false)::Bool + if launch_wait + result = spawn_bulk() do + queue = DataDepsTaskQueue(get_options(:task_queue); + traversal, scheduler, aliasing) + with_options(f; task_queue=queue) + distribute_tasks!(queue) + end + else + queue = DataDepsTaskQueue(get_options(:task_queue); + traversal, scheduler, aliasing) + result = with_options(f; task_queue=queue) + distribute_tasks!(queue) + end + return result + end +end +const DATADEPS_SCHEDULER = ScopedValue{Union{Symbol,Nothing}}(nothing) +const DATADEPS_LAUNCH_WAIT = ScopedValue{Union{Bool,Nothing}}(nothing) + +function distribute_tasks!(queue::DataDepsTaskQueue) + #= TODO: Improvements to be made: + # - Support for copying non-AbstractArray arguments + # - Parallelize read copies + # - Unreference unused slots + # - Reuse memory when possible + # - Account for differently-sized data + =# + + # Get the set of all processors to be scheduled on + all_procs = Processor[] + scope = get_compute_scope() + for w in procs() + append!(all_procs, get_processors(OSProc(w))) + end + filter!(proc->proc_in_scope(proc, scope), all_procs) + if isempty(all_procs) + throw(Sch.SchedulingException("No processors available, try widening scope")) + end + scope = UnionScope(map(ExactScope, all_procs)) + exec_spaces = unique(vcat(map(proc->collect(memory_spaces(proc)), all_procs)...)) + if !all(space->space isa CPURAMMemorySpace, exec_spaces) && !all(space->root_worker_id(space) == myid(), exec_spaces) + @warn "Datadeps support for multi-GPU, multi-worker is currently broken\nPlease be prepared for incorrect results or errors" maxlog=1 + end + + # Round-robin assign tasks to processors + upper_queue = get_options(:task_queue) + + traversal = queue.traversal + if traversal == :inorder + # As-is + task_order = Colon() + elseif traversal == :bfs + # BFS + task_order = Int[1] + to_walk = Int[1] + seen = Set{Int}([1]) + while !isempty(to_walk) + # N.B. next_root has already been seen + next_root = popfirst!(to_walk) + for v in outneighbors(queue.g, next_root) + if !(v in seen) + push!(task_order, v) + push!(seen, v) + push!(to_walk, v) + end + end + end + elseif traversal == :dfs + # DFS (modified with backtracking) + task_order = Int[] + to_walk = Int[1] + seen = Set{Int}() + while length(task_order) < length(queue.seen_tasks) && !isempty(to_walk) + next_root = popfirst!(to_walk) + if !(next_root in seen) + iv = inneighbors(queue.g, next_root) + if all(v->v in seen, iv) + push!(task_order, next_root) + push!(seen, next_root) + ov = outneighbors(queue.g, next_root) + prepend!(to_walk, ov) + else + push!(to_walk, next_root) + end + end + end + else + throw(ArgumentError("Invalid traversal mode: $traversal")) + end + + state = DataDepsState(queue.aliasing) + sstate = DataDepsSchedulerState() + for proc in all_procs + space = only(memory_spaces(proc)) + get!(()->0, sstate.capacities, space) + sstate.capacities[space] += 1 + end + + # Start launching tasks and necessary copies + write_num = 1 + proc_idx = 1 + #pressures = Dict{Processor,Int}() + proc_to_scope_lfu = BasicLFUCache{Processor,AbstractScope}(1024) + for pair in queue.seen_tasks[task_order] + spec = pair.spec + task = pair.task + write_num, proc_idx = distribute_task!(queue, state, all_procs, scope, spec, task, spec.fargs, proc_to_scope_lfu, write_num, proc_idx) + end + + # Copy args from remote to local + # N.B. We sort the keys to ensure a deterministic order for uniformity + for arg_w in sort(collect(keys(state.arg_owner)); by=arg_w->arg_w.hash) + arg = arg_w.arg + origin_space = state.arg_origin[arg] + remainder, _ = compute_remainder_for_arg!(state, origin_space, arg_w, write_num) + if remainder isa MultiRemainderAliasing + origin_scope = UnionScope(map(ExactScope, collect(processors(origin_space)))...) + enqueue_remainder_copy_from!(state, origin_space, arg_w, remainder, origin_scope, write_num) + elseif remainder isa FullCopy + origin_scope = UnionScope(map(ExactScope, collect(processors(origin_space)))...) + enqueue_copy_from!(state, origin_space, arg_w, origin_scope, write_num) + else + @assert remainder isa NoAliasing "Expected NoAliasing, got $(typeof(remainder))" + @dagdebug nothing :spawn_datadeps "Skipped copy-from (up-to-date): $origin_space" + ctx = Sch.eager_context() + id = rand(UInt) + @maybelog ctx timespan_start(ctx, :datadeps_copy_skip, (;id), (;)) + @maybelog ctx timespan_finish(ctx, :datadeps_copy_skip, (;id), (;thunk_id=0, from_space=origin_space, to_space=origin_space, arg_w, from_arg=arg, to_arg=arg)) + end + end +end +struct DataDepsTaskDependency + arg_w::ArgumentWrapper + readdep::Bool + writedep::Bool +end +DataDepsTaskDependency(arg, dep) = + DataDepsTaskDependency(ArgumentWrapper(arg, dep[1]), dep[2], dep[3]) +struct DataDepsTaskArgument + arg + pos::ArgPosition + may_alias::Bool + inplace_move::Bool + deps::Vector{DataDepsTaskDependency} +end +struct TypedDataDepsTaskArgument{T,N} + arg::T + pos::ArgPosition + may_alias::Bool + inplace_move::Bool + deps::NTuple{N,DataDepsTaskDependency} +end +map_or_ntuple(f, xs::Vector) = map(f, 1:length(xs)) +@inline map_or_ntuple(@specialize(f), xs::NTuple{N,T}) where {N,T} = ntuple(f, Val(N)) +function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_procs, scope, spec::DTaskSpec{typed}, task::DTask, fargs, proc_to_scope_lfu, write_num::Int, proc_idx::Int) where typed + @specialize spec fargs + + if typed + fargs::Tuple + else + fargs::Vector{Argument} + end + + task_scope = @something(spec.options.compute_scope, spec.options.scope, DefaultScope()) + scheduler = queue.scheduler + if scheduler == :naive + raw_args = map(arg->tochunk(value(arg)), spec.fargs) + our_proc = remotecall_fetch(1, all_procs, raw_args) do all_procs, raw_args + Sch.init_eager() + sch_state = Sch.EAGER_STATE[] + + @lock sch_state.lock begin + # Calculate costs per processor and select the most optimal + # FIXME: This should consider any already-allocated slots, + # whether they are up-to-date, and if not, the cost of moving + # data to them + procs, costs = Sch.estimate_task_costs(sch_state, all_procs, nothing, raw_args) + return first(procs) + end + end + elseif scheduler == :smart + raw_args = map(filter(arg->haskey(state.data_locality, value(arg)), spec.fargs)) do arg + arg_chunk = tochunk(value(arg)) + # Only the owned slot is valid + # FIXME: Track up-to-date copies and pass all of those + return arg_chunk => data_locality[arg] + end + f_chunk = tochunk(value(spec.fargs[1])) + our_proc, task_pressure = remotecall_fetch(1, all_procs, pressures, f_chunk, raw_args) do all_procs, pressures, f, chunks_locality + Sch.init_eager() + sch_state = Sch.EAGER_STATE[] + + @lock sch_state.lock begin + tx_rate = sch_state.transfer_rate[] + + costs = Dict{Processor,Float64}() + for proc in all_procs + # Filter out chunks that are already local + chunks_filt = Iterators.filter(((chunk, space)=chunk_locality)->!(proc in processors(space)), chunks_locality) + + # Estimate network transfer costs based on data size + # N.B. `affinity(x)` really means "data size of `x`" + # N.B. We treat same-worker transfers as having zero transfer cost + tx_cost = Sch.impute_sum(affinity(chunk)[2] for chunk in chunks_filt) + + # Estimate total cost to move data and get task running after currently-scheduled tasks + est_time_util = get(pressures, proc, UInt64(0)) + costs[proc] = est_time_util + (tx_cost/tx_rate) + end + + # Look up estimated task cost + sig = Sch.signature(sch_state, f, map(first, chunks_locality)) + task_pressure = get(sch_state.signature_time_cost, sig, 1000^3) + + # Shuffle procs around, so equally-costly procs are equally considered + P = randperm(length(all_procs)) + procs = getindex.(Ref(all_procs), P) + + # Sort by lowest cost first + sort!(procs, by=p->costs[p]) + + best_proc = first(procs) + return best_proc, task_pressure + end + end + # FIXME: Pressure should be decreased by pressure of syncdeps on same processor + pressures[our_proc] = get(pressures, our_proc, UInt64(0)) + task_pressure + elseif scheduler == :ultra + args = Base.mapany(spec.fargs) do arg + pos, data = arg + data, _ = unwrap_inout(data) + if data isa DTask + data = fetch(data; move_value=false, unwrap=false) + end + return pos => tochunk(data) + end + f_chunk = tochunk(value(spec.fargs[1])) + task_time = remotecall_fetch(1, f_chunk, args) do f, args + Sch.init_eager() + sch_state = Sch.EAGER_STATE[] + return @lock sch_state.lock begin + sig = Sch.signature(sch_state, f, args) + return get(sch_state.signature_time_cost, sig, 1000^3) + end + end + + # FIXME: Copy deps are computed eagerly + deps = @something(spec.options.syncdeps, Set{Any}()) + + # Find latest time-to-completion of all syncdeps + deps_completed = UInt64(0) + for dep in deps + haskey(sstate.task_completions, dep) || continue # copy deps aren't recorded + deps_completed = max(deps_completed, sstate.task_completions[dep]) + end + + # Find latest time-to-completion of each memory space + # FIXME: Figure out space completions based on optimal packing + spaces_completed = Dict{MemorySpace,UInt64}() + for space in exec_spaces + completed = UInt64(0) + for (task, other_space) in sstate.assignments + space == other_space || continue + completed = max(completed, sstate.task_completions[task]) + end + spaces_completed[space] = completed + end + + # Choose the earliest-available memory space and processor + # FIXME: Consider move time + move_time = UInt64(0) + local our_space_completed + while true + our_space_completed, our_space = findmin(spaces_completed) + our_space_procs = filter(proc->proc in all_procs, processors(our_space)) + if isempty(our_space_procs) + delete!(spaces_completed, our_space) + continue + end + our_proc = rand(our_space_procs) + break + end + + sstate.task_to_spec[task] = spec + sstate.assignments[task] = our_space + sstate.task_completions[task] = our_space_completed + move_time + task_time + elseif scheduler == :roundrobin + our_proc = all_procs[proc_idx] + if task_scope == scope + # all_procs is already limited to scope + else + if isa(constrain(task_scope, scope), InvalidScope) + throw(Sch.SchedulingException("Scopes are not compatible: $(scope), $(task_scope)")) + end + while !proc_in_scope(our_proc, task_scope) + proc_idx = mod1(proc_idx + 1, length(all_procs)) + our_proc = all_procs[proc_idx] + end + end + else + error("Invalid scheduler: $sched") + end + @assert our_proc in all_procs + our_space = only(memory_spaces(our_proc)) + + # Find the scope for this task (and its copies) + task_scope = @something(spec.options.compute_scope, spec.options.scope, DefaultScope()) + if task_scope == scope + # Optimize for the common case, cache the proc=>scope mapping + our_scope = get!(proc_to_scope_lfu, our_proc) do + our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) + return constrain(UnionScope(map(ExactScope, our_procs)...), scope) + end + else + # Use the provided scope and constrain it to the available processors + our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) + our_scope = constrain(UnionScope(map(ExactScope, our_procs)...), task_scope) + end + if our_scope isa InvalidScope + throw(Sch.SchedulingException("Scopes are not compatible: $(our_scope.x), $(our_scope.y)")) + end + + f = spec.fargs[1] + tid = task.uid + # FIXME: May not be correct to move this under uniformity + #f.value = move(default_processor(), our_proc, value(f)) + @dagdebug tid :spawn_datadeps "($(repr(value(f)))) Scheduling: $our_proc ($our_space)" + + # Copy raw task arguments for analysis + # N.B. Used later for checking dependencies + task_args = map_or_ntuple(idx->copy(spec.fargs[idx]), spec.fargs) + + # Populate all task dependencies + task_arg_ws = populate_task_info!(state, task_args, spec, task) + + # Truncate the history for each argument + map_or_ntuple(task_arg_ws) do idx + arg_ws = task_arg_ws[idx] + map_or_ntuple(arg_ws.deps) do dep_idx + dep = arg_ws.deps[dep_idx] + truncate_history!(state, dep.arg_w) + end + return + end + + # Copy args from local to remote + remote_args = map_or_ntuple(task_arg_ws) do idx + arg_ws = task_arg_ws[idx] + arg = arg_ws.arg + pos = raw_position(arg_ws.pos) + + # Is the data written previously or now? + if !arg_ws.may_alias + @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (immutable)" + return arg + end + + # Is the data writeable? + if !arg_ws.inplace_move + @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (non-writeable)" + return arg + end + + # Is the source of truth elsewhere? + arg_remote = get_or_generate_slot!(state, our_space, arg) + map_or_ntuple(arg_ws.deps) do dep_idx + dep = arg_ws.deps[dep_idx] + arg_w = dep.arg_w + dep_mod = arg_w.dep_mod + remainder, _ = compute_remainder_for_arg!(state, our_space, arg_w, write_num) + if remainder isa MultiRemainderAliasing + enqueue_remainder_copy_to!(state, our_space, arg_w, remainder, value(f), idx, our_scope, task, write_num) + elseif remainder isa FullCopy + enqueue_copy_to!(state, our_space, arg_w, value(f), idx, our_scope, task, write_num) + else + @assert remainder isa NoAliasing "Expected NoAliasing, got $(typeof(remainder))" + @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Skipped copy-to (up-to-date): $our_space" + end + end + return arg_remote + end + write_num += 1 + + # Validate that we're not accidentally performing a copy + map_or_ntuple(task_arg_ws) do idx + arg_ws = task_arg_ws[idx] + arg = remote_args[idx] + + # Get the dependencies again as (dep_mod, readdep, writedep) + deps = map_or_ntuple(arg_ws.deps) do dep_idx + dep = arg_ws.deps[dep_idx] + (dep.arg_w.dep_mod, dep.readdep, dep.writedep) + end + + # Check that any mutable and written arguments are already in the correct space + # N.B. We only do this check when the argument supports in-place + # moves, because for the moment, we are not guaranteeing updates or + # write-back of results + if is_writedep(arg, deps, task) && arg_ws.may_alias && arg_ws.inplace_move + arg_space = memory_space(arg) + @assert arg_space == our_space "($(repr(value(f))))[$(idx-1)] Tried to pass $(typeof(arg)) from $arg_space to $our_space" + end + end + + # Calculate this task's syncdeps + if spec.options.syncdeps === nothing + spec.options.syncdeps = Set{ThunkSyncdep}() + end + syncdeps = spec.options.syncdeps + map_or_ntuple(task_arg_ws) do idx + arg_ws = task_arg_ws[idx] + arg = arg_ws.arg + arg_ws.may_alias || return + arg_ws.inplace_move || return + map_or_ntuple(arg_ws.deps) do dep_idx + dep = arg_ws.deps[dep_idx] + arg_w = dep.arg_w + ainfo = aliasing!(state, our_space, arg_w) + dep_mod = arg_w.dep_mod + if dep.writedep + @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as writer" + get_write_deps!(state, our_space, ainfo, write_num, syncdeps) + else + @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as reader" + get_read_deps!(state, our_space, ainfo, write_num, syncdeps) + end + end + return + end + @dagdebug tid :spawn_datadeps "($(repr(value(f)))) Task has $(length(syncdeps)) syncdeps" + + # Launch user's task + new_fargs = map_or_ntuple(task_arg_ws) do idx + if is_typed(spec) + return TypedArgument(task_arg_ws[idx].pos, remote_args[idx]) + else + return Argument(task_arg_ws[idx].pos, remote_args[idx]) + end + end + new_spec = DTaskSpec(new_fargs, spec.options) + new_spec.options.scope = our_scope + new_spec.options.exec_scope = our_scope + new_spec.options.occupancy = Dict(Any=>0) + ctx = Sch.eager_context() + @maybelog ctx timespan_start(ctx, :datadeps_execute, (;thunk_id=task.uid), (;)) + enqueue!(queue.upper_queue, DTaskPair(new_spec, task)) + @maybelog ctx timespan_finish(ctx, :datadeps_execute, (;thunk_id=task.uid), (;space=our_space, deps=task_arg_ws, args=remote_args)) + + # Update read/write tracking for arguments + map_or_ntuple(task_arg_ws) do idx + arg_ws = task_arg_ws[idx] + arg = arg_ws.arg + arg_ws.may_alias || return + arg_ws.inplace_move || return + for dep in arg_ws.deps + arg_w = dep.arg_w + ainfo = aliasing!(state, our_space, arg_w) + dep_mod = arg_w.dep_mod + if dep.writedep + @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Task set as writer" + add_writer!(state, arg_w, our_space, ainfo, task, write_num) + else + add_reader!(state, arg_w, our_space, ainfo, task, write_num) + end + end + return + end + + write_num += 1 + proc_idx = mod1(proc_idx + 1, length(all_procs)) + + return write_num, proc_idx +end diff --git a/src/datadeps/remainders.jl b/src/datadeps/remainders.jl new file mode 100644 index 000000000..ed0a9fc1c --- /dev/null +++ b/src/datadeps/remainders.jl @@ -0,0 +1,518 @@ +# Remainder tracking and computation functions + +""" + RemainderAliasing{S<:MemorySpace} <: AbstractAliasing + +Represents the memory spans that remain after subtracting some regions from a base aliasing object. +This is used to perform partial data copies that only update the "remainder" regions. +""" +struct RemainderAliasing{S<:MemorySpace} <: AbstractAliasing + space::S + spans::Vector{Tuple{LocalMemorySpan,LocalMemorySpan}} + syncdeps::Set{ThunkSyncdep} +end +RemainderAliasing(space::S, spans::Vector{Tuple{LocalMemorySpan,LocalMemorySpan}}, syncdeps::Set{ThunkSyncdep}) where S = + RemainderAliasing{S}(space, spans, syncdeps) + +memory_spans(ra::RemainderAliasing) = ra.spans + +Base.hash(ra::RemainderAliasing, h::UInt) = hash(ra.spans, hash(RemainderAliasing, h)) +Base.:(==)(ra1::RemainderAliasing, ra2::RemainderAliasing) = ra1.spans == ra2.spans + +# Add will_alias support for RemainderAliasing +function will_alias(x::RemainderAliasing, y::AbstractAliasing) + return will_alias(memory_spans(x), memory_spans(y)) +end + +function will_alias(x::AbstractAliasing, y::RemainderAliasing) + return will_alias(memory_spans(x), memory_spans(y)) +end + +function will_alias(x::RemainderAliasing, y::RemainderAliasing) + return will_alias(memory_spans(x), memory_spans(y)) +end + +struct MultiRemainderAliasing <: AbstractAliasing + remainders::Vector{<:RemainderAliasing} +end +MultiRemainderAliasing() = MultiRemainderAliasing(RemainderAliasing[]) + +memory_spans(mra::MultiRemainderAliasing) = vcat(memory_spans.(mra.remainders)...) + +Base.hash(mra::MultiRemainderAliasing, h::UInt) = hash(mra.remainders, hash(MultiRemainderAliasing, h)) +Base.:(==)(mra1::MultiRemainderAliasing, mra2::MultiRemainderAliasing) = mra1.remainders == mra2.remainders + +struct FullCopy end + +""" + compute_remainder_for_arg!(state::DataDepsState, + target_space::MemorySpace, + arg_w::ArgumentWrapper) + +Computes what remainder regions need to be copied to `target_space` before a task can access `arg_w`. +Returns a `MultiRemainderAliasing` object representing the remainder, or `NoAliasing()` if no remainder needed. + +The algorithm starts by collecting the memory spans of `arg_w` in `target_space` - this is the "remainder". +When this remainder is empty, the algorithm will be finished. +Additionally, a dictionary is created to store the source and destination +memory spans (for each source memory space) that will be used to create the +`MultiRemainderAliasing` object - this is the "tracker". + +The algorithm walks backwards through the `arg_history` vector for `arg_w` +(which is an ordered list of all overlapping ainfos that were directy written to (potentially in a different memory space than `target_space`) +since the last time this `arg_w` was written to). If this ainfo is in `target_space`, +then it is not under consideration; it is simply subtraced from the remainder with `subtract_remainder!`, +and the algorithm goes to the next ainfo. Otherwise, the algorithm will consider this ainfo for tracking. + +For each overlapping ainfo (which lives in a different memory space than `target_space`) to be tracked, there exists a corresponding "mirror" ainfo in +`target_space`, which is the equivalent of the overlapping ainfo, but in +`target_space`. This mirror ainfo is assumed to have an identical number of memory spans as the overlapping ainfo, +and each memory span is assumed to be identical in size, but not necessarily identical in address. + +These three sets of memory spans (from the remainder, the overlapping ainfo, and the mirror ainfo) are then passed to `schedule_aliasing!`. +This call will subtract the spans of the mirror ainfo from the remainder (as the two live in the same memory space and thus can be directly compared), +and will update the remainder accordingly. +Additionaly, it will also use this subtraction to update the tracker, by adding the equivalent spans (mapped from mirror ainfo to overlapping ainfo) to the tracker as the source, +and the spans of the remainder as the destination. + +If the history is exhausted without the remainder becoming empty, then the +remaining data in `target_space` is assumed to be up-to-date (as the latest write +to `arg_w` is the furthest back we need to consider). + +Finally, the tracker is converted into a `MultiRemainderAliasing` object, +and returned. +""" +function compute_remainder_for_arg!(state::DataDepsState, + target_space::MemorySpace, + arg_w::ArgumentWrapper, + write_num::Int; compute_syncdeps::Bool=true) + spaces_set = Set{MemorySpace}() + push!(spaces_set, target_space) + owner_space = state.arg_owner[arg_w] + push!(spaces_set, owner_space) + + @label restart + + # Determine all memory spaces of the history + for entry in state.arg_history[arg_w] + push!(spaces_set, entry.space) + end + spaces = collect(spaces_set) + N = length(spaces) + + # Lookup all memory spans for arg_w in these spaces + target_ainfos = Vector{Vector{LocalMemorySpan}}() + for space in spaces + target_space_ainfo = aliasing!(state, space, arg_w) + spans = memory_spans(target_space_ainfo) + push!(target_ainfos, LocalMemorySpan.(spans)) + end + nspans = length(first(target_ainfos)) + @assert all(==(nspans), length.(target_ainfos)) "Aliasing info for $(typeof(arg_w.arg))[$(arg_w.dep_mod)] has different number of spans in different memory spaces" + + # FIXME: This is a hack to ensure that we don't miss any history generated by aliasing(...) + for entry in state.arg_history[arg_w] + if !in(entry.space, spaces) + @opcounter :compute_remainder_for_arg_restart + @goto restart + end + end + + # We may only need to schedule a full copy from the origin space to the + # target space if this is the first time we've written to `arg_w` + if isempty(state.arg_history[arg_w]) + if owner_space != target_space + return FullCopy(), 0 + else + return NoAliasing(), 0 + end + end + + # Create our remainder as an interval tree over all target ainfos + VERIFY_SPAN_CURRENT_OBJECT[] = arg_w.arg + remainder = IntervalTree{ManyMemorySpan{N}}(ManyMemorySpan{N}(ntuple(i -> target_ainfos[i][j], N)) for j in 1:nspans) + for span in remainder + verify_span(span) + end + + # Create our tracker + tracker = Dict{MemorySpace,Tuple{Vector{Tuple{LocalMemorySpan,LocalMemorySpan}},Set{ThunkSyncdep}}}() + + # Walk backwards through the history of writes to this target + # other_ainfo is the overlapping ainfo that was written to + # other_space is the memory space of the overlapping ainfo + last_idx = length(state.arg_history[arg_w]) + for idx in length(state.arg_history[arg_w]):-1:0 + if isempty(remainder) + # All done! + last_idx = idx + break + end + + if idx > 0 + other_entry = state.arg_history[arg_w][idx] + other_ainfo = other_entry.ainfo + other_space = other_entry.space + else + # If we've reached the end of the history, evaluate ourselves + other_ainfo = aliasing!(state, owner_space, arg_w) + other_space = owner_space + end + + # Lookup all memory spans for arg_w in these spaces + other_remote_arg_w = first(collect(state.ainfo_arg[other_ainfo])) + other_arg_w = ArgumentWrapper(state.remote_arg_to_original[other_remote_arg_w.arg], other_remote_arg_w.dep_mod) + other_ainfos = Vector{Vector{LocalMemorySpan}}() + for space in spaces + other_space_ainfo = aliasing!(state, space, other_arg_w) + spans = memory_spans(other_space_ainfo) + push!(other_ainfos, LocalMemorySpan.(spans)) + end + nspans = length(first(other_ainfos)) + other_many_spans = [ManyMemorySpan{N}(ntuple(i -> other_ainfos[i][j], N)) for j in 1:nspans] + foreach(other_many_spans) do span + verify_span(span) + end + + if other_space == target_space + # Only subtract, this data is already up-to-date in target_space + # N.B. We don't add to syncdeps here, because we'll see this ainfo + # in get_write_deps! + @opcounter :compute_remainder_for_arg_subtract + subtract_spans!(remainder, other_many_spans) + continue + end + + # Subtract from remainder and schedule copy in tracker + other_space_idx = something(findfirst(==(other_space), spaces)) + target_space_idx = something(findfirst(==(target_space), spaces)) + tracker_other_space = get!(tracker, other_space) do + (Vector{Tuple{LocalMemorySpan,LocalMemorySpan}}(), Set{ThunkSyncdep}()) + end + @opcounter :compute_remainder_for_arg_schedule + has_overlap = schedule_remainder!(tracker_other_space[1], other_space_idx, target_space_idx, remainder, other_many_spans) + if compute_syncdeps && has_overlap + @assert haskey(state.ainfos_owner, other_ainfo) "[idx $idx] ainfo $(typeof(other_ainfo)) has no owner" + get_read_deps!(state, other_space, other_ainfo, write_num, tracker_other_space[2]) + end + end + VERIFY_SPAN_CURRENT_OBJECT[] = nothing + + if isempty(tracker) || all(tracked->isempty(tracked[1]), values(tracker)) + return NoAliasing(), 0 + end + + # Return scheduled copies and the index of the last ainfo we considered + mra = MultiRemainderAliasing() + for space in spaces + if haskey(tracker, space) + spans, syncdeps = tracker[space] + if !isempty(spans) + push!(mra.remainders, RemainderAliasing(space, spans, syncdeps)) + end + end + end + @assert !isempty(mra.remainders) "Expected at least one remainder (spaces: $spaces, tracker spaces: $(collect(keys(tracker))))" + return mra, last_idx +end + +### Memory Span Set Operations for Remainder Computation + +""" + schedule_remainder!(tracker::Vector, source_space_idx::Int, dest_space_idx::Int, remainder::IntervalTree, other_many_spans::Vector{ManyMemorySpan{N}}) + +Calculates the difference between `remainder` and `other_many_spans`, subtracts +it from `remainder`, and then adds that difference to `tracker` as a scheduled +copy from `other_many_spans` to the subtraced portion of `remainder`. +""" +function schedule_remainder!(tracker::Vector, source_space_idx::Int, dest_space_idx::Int, remainder::IntervalTree, other_many_spans::Vector{ManyMemorySpan{N}}) where N + diff = Vector{ManyMemorySpan{N}}() + subtract_spans!(remainder, other_many_spans, diff) + for span in diff + source_span = span.spans[source_space_idx] + dest_span = span.spans[dest_space_idx] + @assert span_len(source_span) == span_len(dest_span) "Source and dest spans are not the same size: $(span_len(source_span)) != $(span_len(dest_span))" + push!(tracker, (source_span, dest_span)) + end + return !isempty(diff) +end + +### Remainder copy functions + +""" + enqueue_remainder_copy_to!(state::DataDepsState, f, target_ainfo::AliasingWrapper, remainder_aliasing, dep_mod, arg, idx, + our_space::MemorySpace, our_scope, task::DTask, write_num::Int) + +Enqueues a copy operation to update the remainder regions of an object before a task runs. +""" +function enqueue_remainder_copy_to!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, remainder_aliasing::MultiRemainderAliasing, + f, idx, dest_scope, task, write_num::Int) + for remainder in remainder_aliasing.remainders + @assert !isempty(remainder.spans) + enqueue_remainder_copy_to!(state, dest_space, arg_w, remainder, f, idx, dest_scope, task, write_num) + end +end +function enqueue_remainder_copy_to!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, remainder_aliasing::RemainderAliasing, + f, idx, dest_scope, task, write_num::Int) + dep_mod = arg_w.dep_mod + + # Find the source space for the remainder data + # We need to find where the best version of the target data lives that hasn't been + # overwritten by more recent partial updates + source_space = remainder_aliasing.space + + @dagdebug task.uid :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Enqueueing remainder copy-to for $(typeof(arg_w.arg))[$(arg_w.dep_mod)]: $source_space => $dest_space" + + # Get the source and destination arguments + arg_dest = state.remote_args[dest_space][arg_w.arg] + arg_source = get_or_generate_slot!(state, source_space, arg_w.arg) + + # Create a copy task for the remainder + remainder_syncdeps = Set{Any}() + target_ainfo = aliasing!(state, dest_space, arg_w) + for syncdep in remainder_aliasing.syncdeps + push!(remainder_syncdeps, syncdep) + end + empty!(remainder_aliasing.syncdeps) # We can't bring these to move! + get_write_deps!(state, dest_space, target_ainfo, write_num, remainder_syncdeps) + + @dagdebug task.uid :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Remainder copy-to has $(length(remainder_syncdeps)) syncdeps" + + # Launch the remainder copy task + ctx = Sch.eager_context() + id = rand(UInt) + @maybelog ctx timespan_start(ctx, :datadeps_copy, (;id), (;)) + copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=remainder_syncdeps meta=true Dagger.move!(remainder_aliasing, dest_space, source_space, arg_dest, arg_source) + @maybelog ctx timespan_finish(ctx, :datadeps_copy, (;id), (;thunk_id=copy_task.uid, from_space=source_space, to_space=dest_space, arg_w, from_arg=arg_source, to_arg=arg_dest)) + + # This copy task becomes a new writer for the target region + add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) +end +""" + enqueue_remainder_copy_from!(state::DataDepsState, target_ainfo::AliasingWrapper, arg, remainder_aliasing, + origin_space::MemorySpace, origin_scope, write_num::Int) + +Enqueues a copy operation to update the remainder regions of an object back to the original space. +""" +function enqueue_remainder_copy_from!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, remainder_aliasing::MultiRemainderAliasing, + dest_scope, write_num::Int) + for remainder in remainder_aliasing.remainders + @assert !isempty(remainder.spans) + enqueue_remainder_copy_from!(state, dest_space, arg_w, remainder, dest_scope, write_num) + end +end +function enqueue_remainder_copy_from!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, remainder_aliasing::RemainderAliasing, + dest_scope, write_num::Int) + dep_mod = arg_w.dep_mod + + # Find the source space for the remainder data + # We need to find where the best version of the target data lives that hasn't been + # overwritten by more recent partial updates + source_space = remainder_aliasing.space + + @dagdebug nothing :spawn_datadeps "($(typeof(arg_w.arg)))[$dep_mod] Enqueueing remainder copy-from for: $source_space => $dest_space" + + # Get the source and destination arguments + arg_dest = state.remote_args[dest_space][arg_w.arg] + arg_source = get_or_generate_slot!(state, source_space, arg_w.arg) + + # Create a copy task for the remainder + remainder_syncdeps = Set{Any}() + target_ainfo = aliasing!(state, dest_space, arg_w) + for syncdep in remainder_aliasing.syncdeps + push!(remainder_syncdeps, syncdep) + end + empty!(remainder_aliasing.syncdeps) # We can't bring these to move! + get_write_deps!(state, dest_space, target_ainfo, write_num, remainder_syncdeps) + + @dagdebug nothing :spawn_datadeps "($(typeof(arg_w.arg)))[$dep_mod] Remainder copy-from has $(length(remainder_syncdeps)) syncdeps" + + # Launch the remainder copy task + ctx = Sch.eager_context() + id = rand(UInt) + @maybelog ctx timespan_start(ctx, :datadeps_copy, (;id), (;)) + copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=remainder_syncdeps meta=true Dagger.move!(remainder_aliasing, dest_space, source_space, arg_dest, arg_source) + @maybelog ctx timespan_finish(ctx, :datadeps_copy, (;id), (;thunk_id=copy_task.uid, from_space=source_space, to_space=dest_space, arg_w, from_arg=arg_source, to_arg=arg_dest)) + + # This copy task becomes a new writer for the target region + add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) +end + +# FIXME: Document me +function enqueue_copy_to!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, + f, idx, dest_scope, task, write_num::Int) + dep_mod = arg_w.dep_mod + source_space = state.arg_owner[arg_w] + target_ainfo = aliasing!(state, dest_space, arg_w) + + @dagdebug task.uid :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Enqueueing full copy-to for $(typeof(arg_w.arg))[$(arg_w.dep_mod)]: $source_space => $dest_space" + + # Get the source and destination arguments + arg_dest = state.remote_args[dest_space][arg_w.arg] + arg_source = get_or_generate_slot!(state, source_space, arg_w.arg) + + # Create a copy task for the remainder + copy_syncdeps = Set{Any}() + source_ainfo = aliasing!(state, source_space, arg_w) + target_ainfo = aliasing!(state, dest_space, arg_w) + get_read_deps!(state, source_space, source_ainfo, write_num, copy_syncdeps) + get_write_deps!(state, dest_space, target_ainfo, write_num, copy_syncdeps) + + @dagdebug task.uid :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Full copy-to has $(length(copy_syncdeps)) syncdeps" + + # Launch the remainder copy task + ctx = Sch.eager_context() + id = rand(UInt) + @maybelog ctx timespan_start(ctx, :datadeps_copy, (;id), (;)) + copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=copy_syncdeps meta=true Dagger.move!(dep_mod, dest_space, source_space, arg_dest, arg_source) + @maybelog ctx timespan_finish(ctx, :datadeps_copy, (;id), (;thunk_id=copy_task.uid, from_space=source_space, to_space=dest_space, arg_w, from_arg=arg_source, to_arg=arg_dest)) + + # This copy task becomes a new writer for the target region + add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) +end +function enqueue_copy_from!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, + dest_scope, write_num::Int) + dep_mod = arg_w.dep_mod + source_space = state.arg_owner[arg_w] + target_ainfo = aliasing!(state, dest_space, arg_w) + + @dagdebug nothing :spawn_datadeps "($(typeof(arg_w.arg)))[$dep_mod] Enqueueing full copy-from: $source_space => $dest_space" + + # Get the source and destination arguments + arg_dest = state.remote_args[dest_space][arg_w.arg] + arg_source = get_or_generate_slot!(state, source_space, arg_w.arg) + + # Create a copy task for the remainder + copy_syncdeps = Set{Any}() + source_ainfo = aliasing!(state, source_space, arg_w) + target_ainfo = aliasing!(state, dest_space, arg_w) + get_read_deps!(state, source_space, source_ainfo, write_num, copy_syncdeps) + get_write_deps!(state, dest_space, target_ainfo, write_num, copy_syncdeps) + + @dagdebug nothing :spawn_datadeps "($(typeof(arg_w.arg)))[$dep_mod] Full copy-from has $(length(copy_syncdeps)) syncdeps" + + # Launch the remainder copy task + ctx = Sch.eager_context() + id = rand(UInt) + @maybelog ctx timespan_start(ctx, :datadeps_copy, (;id), (;)) + copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=copy_syncdeps meta=true Dagger.move!(dep_mod, dest_space, source_space, arg_dest, arg_source) + @maybelog ctx timespan_finish(ctx, :datadeps_copy, (;id), (;thunk_id=copy_task.uid, from_space=source_space, to_space=dest_space, arg_w, from_arg=arg_source, to_arg=arg_dest)) + + # This copy task becomes a new writer for the target region + add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) +end + +# Main copy function for RemainderAliasing +function move!(dep_mod::RemainderAliasing{S}, to_space::MemorySpace, from_space::MemorySpace, to::Chunk, from::Chunk) where S + # TODO: Support direct copy between GPU memory spaces + + # Copy the data from the source object + copies = remotecall_fetch(root_worker_id(from_space), from_space, dep_mod, from) do from_space, dep_mod, from + len = sum(span_tuple->span_len(span_tuple[1]), dep_mod.spans) + copies = Vector{UInt8}(undef, len) + from_raw = unwrap(from) + offset = UInt64(1) + with_context!(from_space) + GC.@preserve copies begin + for (from_span, _) in dep_mod.spans + read_remainder!(copies, offset, from_raw, from_span.ptr, from_span.len) + offset += from_span.len + end + end + @assert offset == len+UInt64(1) + return copies + end + + # Copy the data into the destination object + offset = UInt64(1) + to_raw = unwrap(to) + GC.@preserve copies begin + for (_, to_span) in dep_mod.spans + write_remainder!(copies, offset, to_raw, to_span.ptr, to_span.len) + offset += to_span.len + end + @assert offset == length(copies)+UInt64(1) + end + + # Ensure that the data is visible + Core.Intrinsics.atomic_fence(:release) + + return +end + +function read_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, from::Array, from_ptr::UInt64, len::UInt64) + elsize = sizeof(eltype(from)) + @assert len / elsize == round(UInt64, len / elsize) "Span length is not an integer multiple of the element size: $(len) / $(elsize) = $(len / elsize) (elsize: $elsize)" + n = UInt64(len / elsize) + from_offset_n = UInt64((from_ptr - UInt64(pointer(from))) / elsize) + UInt64(1) + from_vec = reshape(from, prod(size(from)))::DenseVector{eltype(from)} + # unsafe_wrap(Array, ...) doesn't like unaligned memory + unsafe_copyto!(Ptr{eltype(from)}(pointer(copies, copies_offset)), pointer(from_vec, from_offset_n), n) +end +function read_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, from::DenseArray, from_ptr::UInt64, len::UInt64) + elsize = sizeof(eltype(from)) + @assert len / elsize == round(UInt64, len / elsize) "Span length is not an integer multiple of the element size: $(len) / $(elsize) = $(len / elsize) (elsize: $elsize)" + n = UInt64(len / elsize) + from_offset_n = UInt64((from_ptr - UInt64(pointer(from))) / elsize) + UInt64(1) + from_vec = reshape(from, prod(size(from)))::DenseVector{eltype(from)} + copies_typed = unsafe_wrap(Vector{eltype(from)}, Ptr{eltype(from)}(pointer(copies, copies_offset)), n) + copyto!(copies_typed, 1, from_vec, Int(from_offset_n), Int(n)) +end +function read_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, from, from_ptr::UInt64, n::UInt64) + real_from = find_object_holding_ptr(from, from_ptr) + return read_remainder!(copies, copies_offset, real_from, from_ptr, n) +end + +function write_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, to::Array, to_ptr::UInt64, len::UInt64) + elsize = sizeof(eltype(to)) + @assert len / elsize == round(UInt64, len / elsize) "Span length is not an integer multiple of the element size: $(len) / $(elsize) = $(len / elsize) (elsize: $elsize)" + n = UInt64(len / elsize) + to_offset_n = UInt64((to_ptr - UInt64(pointer(to))) / elsize) + UInt64(1) + to_vec = reshape(to, prod(size(to)))::DenseVector{eltype(to)} + # unsafe_wrap(Array, ...) doesn't like unaligned memory + unsafe_copyto!(pointer(to_vec, to_offset_n), Ptr{eltype(to)}(pointer(copies, copies_offset)), n) +end +function write_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, to::DenseArray, to_ptr::UInt64, len::UInt64) + elsize = sizeof(eltype(to)) + @assert len / elsize == round(UInt64, len / elsize) "Span length is not an integer multiple of the element size: $(len) / $(elsize) = $(len / elsize) (elsize: $elsize)" + n = UInt64(len / elsize) + to_offset_n = UInt64((to_ptr - UInt64(pointer(to))) / elsize) + UInt64(1) + to_vec = reshape(to, prod(size(to)))::DenseVector{eltype(to)} + copies_typed = unsafe_wrap(Vector{eltype(to)}, Ptr{eltype(to)}(pointer(copies, copies_offset)), n) + copyto!(to_vec, Int(to_offset_n), copies_typed, 1, Int(n)) +end +function write_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, to, to_ptr::UInt64, n::UInt64) + real_to = find_object_holding_ptr(to, to_ptr) + return write_remainder!(copies, copies_offset, real_to, to_ptr, n) +end + +# Remainder copies for common objects +for wrapper in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular, SubArray) + @eval function read_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, from::$wrapper, from_ptr::UInt64, n::UInt64) + read_remainder!(copies, copies_offset, parent(from), from_ptr, n) + end + @eval function write_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, to::$wrapper, to_ptr::UInt64, n::UInt64) + write_remainder!(copies, copies_offset, parent(to), to_ptr, n) + end +end +# N.B. We don't handle pointer aliasing in remainder copies +function read_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, from::Base.RefValue, from_ptr::UInt64, n::UInt64) + read_remainder!(copies, copies_offset, from[], from_ptr, n) +end +function write_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, to::Base.RefValue, to_ptr::UInt64, n::UInt64) + write_remainder!(copies, copies_offset, to[], to_ptr, n) +end + +function find_object_holding_ptr(A::SparseMatrixCSC, ptr::UInt64) + span = LocalMemorySpan(pointer(A.nzval), length(A.nzval)*sizeof(eltype(A.nzval))) + if span_start(span) <= ptr <= span_end(span) + return A.nzval + end + span = LocalMemorySpan(pointer(A.colptr), length(A.colptr)*sizeof(eltype(A.colptr))) + if span_start(span) <= ptr <= span_end(span) + return A.colptr + end + span = LocalMemorySpan(pointer(A.rowval), length(A.rowval)*sizeof(eltype(A.rowval))) + @assert span_start(span) <= ptr <= span_end(span) "Pointer $ptr not found in SparseMatrixCSC" + return A.rowval +end \ No newline at end of file diff --git a/src/gpu.jl b/src/gpu.jl index 06d749543..fa93f8076 100644 --- a/src/gpu.jl +++ b/src/gpu.jl @@ -100,4 +100,7 @@ function gpu_synchronize(kind::Symbol) gpu_synchronize(Val(kind)) end end -gpu_synchronize(::Val{:CPU}) = nothing \ No newline at end of file +gpu_synchronize(::Val{:CPU}) = nothing + +with_context!(proc::Processor) = nothing +with_context!(space::MemorySpace) = nothing \ No newline at end of file diff --git a/src/memory-spaces.jl b/src/memory-spaces.jl index 9f65a1a21..d39e665cc 100644 --- a/src/memory-spaces.jl +++ b/src/memory-spaces.jl @@ -1,8 +1,7 @@ -abstract type MemorySpace end - struct CPURAMMemorySpace <: MemorySpace owner::Int end +CPURAMMemorySpace() = CPURAMMemorySpace(myid()) root_worker_id(space::CPURAMMemorySpace) = space.owner memory_space(x) = CPURAMMemorySpace(myid()) @@ -30,7 +29,7 @@ processors(space::CPURAMMemorySpace) = ### In-place Data Movement function unwrap(x::Chunk) - @assert root_worker_id(x.processor) == myid() + @assert x.handle.owner == myid() MemPool.poolget(x.handle) end move!(dep_mod, to_space::MemorySpace, from_space::MemorySpace, to::T, from::F) where {T,F} = @@ -89,44 +88,20 @@ function type_may_alias(::Type{T}) where T return false end -may_alias(::MemorySpace, ::MemorySpace) = true +may_alias(::MemorySpace, ::MemorySpace) = false +may_alias(space1::M, space2::M) where M<:MemorySpace = space1 == space2 may_alias(space1::CPURAMMemorySpace, space2::CPURAMMemorySpace) = space1.owner == space2.owner -struct RemotePtr{T,S<:MemorySpace} <: Ref{T} - addr::UInt - space::S -end -RemotePtr{T}(addr::UInt, space::S) where {T,S} = RemotePtr{T,S}(addr, space) -RemotePtr{T}(ptr::Ptr{V}, space::S) where {T,V,S} = RemotePtr{T,S}(UInt(ptr), space) -RemotePtr{T}(ptr::Ptr{V}) where {T,V} = RemotePtr{T}(UInt(ptr), CPURAMMemorySpace(myid())) -Base.convert(::Type{RemotePtr}, x::Ptr{T}) where T = - RemotePtr(UInt(x), CPURAMMemorySpace(myid())) -Base.convert(::Type{<:RemotePtr{V}}, x::Ptr{T}) where {V,T} = - RemotePtr{V}(UInt(x), CPURAMMemorySpace(myid())) -Base.:+(ptr::RemotePtr{T}, offset::Integer) where T = RemotePtr{T}(ptr.addr + offset, ptr.space) -Base.:-(ptr::RemotePtr{T}, offset::Integer) where T = RemotePtr{T}(ptr.addr - offset, ptr.space) -function Base.isless(ptr1::RemotePtr, ptr2::RemotePtr) - @assert ptr1.space == ptr2.space - return ptr1.addr < ptr2.addr -end - -struct MemorySpan{S} - ptr::RemotePtr{Cvoid,S} - len::UInt -end -MemorySpan(ptr::RemotePtr{Cvoid,S}, len::Integer) where S = - MemorySpan{S}(ptr, UInt(len)) - abstract type AbstractAliasing end memory_spans(::T) where T<:AbstractAliasing = throw(ArgumentError("Must define `memory_spans` for `$T`")) memory_spans(x) = memory_spans(aliasing(x)) memory_spans(x, T) = memory_spans(aliasing(x, T)) +### Type-generic aliasing info wrapper -struct AliasingWrapper <: AbstractAliasing +mutable struct AliasingWrapper <: AbstractAliasing inner::AbstractAliasing hash::UInt64 - AliasingWrapper(inner::AbstractAliasing) = new(inner, hash(inner)) end memory_spans(x::AliasingWrapper) = memory_spans(x.inner) @@ -135,8 +110,204 @@ equivalent_structure(x::AliasingWrapper, y::AliasingWrapper) = Base.hash(x::AliasingWrapper, h::UInt64) = hash(x.hash, h) Base.isequal(x::AliasingWrapper, y::AliasingWrapper) = x.hash == y.hash Base.:(==)(x::AliasingWrapper, y::AliasingWrapper) = x.hash == y.hash -will_alias(x::AliasingWrapper, y::AliasingWrapper) = - will_alias(x.inner, y.inner) +will_alias(x::AliasingWrapper, y::AliasingWrapper) = will_alias(x.inner, y.inner) + +### Small dictionary type + +struct SmallDict{K,V} <: AbstractDict{K,V} + keys::Vector{K} + vals::Vector{V} +end +SmallDict{K,V}() where {K,V} = SmallDict{K,V}(Vector{K}(), Vector{V}()) +function Base.getindex(d::SmallDict{K,V}, key) where {K,V} + key_idx = findfirst(==(convert(K, key)), d.keys) + if key_idx === nothing + throw(KeyError(key)) + end + return @inbounds d.vals[key_idx] +end +function Base.setindex!(d::SmallDict{K,V}, val, key) where {K,V} + key_conv = convert(K, key) + key_idx = findfirst(==(key_conv), d.keys) + if key_idx === nothing + push!(d.keys, key_conv) + push!(d.vals, convert(V, val)) + else + d.vals[key_idx] = convert(V, val) + end + return val +end +Base.haskey(d::SmallDict{K,V}, key) where {K,V} = in(convert(K, key), d.keys) +Base.keys(d::SmallDict) = d.keys +Base.length(d::SmallDict) = length(d.keys) +Base.iterate(d::SmallDict) = iterate(d, 1) +Base.iterate(d::SmallDict, state) = state > length(d.keys) ? nothing : (d.keys[state] => d.vals[state], state+1) + +### Type-stable lookup structure for AliasingWrappers + +struct AliasingLookup + # The set of memory spaces that are being tracked + spaces::Vector{MemorySpace} + # The set of AliasingWrappers that are being tracked + # One entry for each AliasingWrapper + ainfos::Vector{AliasingWrapper} + # The memory spaces for each AliasingWrapper + # One entry for each AliasingWrapper + ainfos_spaces::Vector{Vector{Int}} + # The spans for each AliasingWrapper in each memory space + # One entry for each AliasingWrapper + spans::Vector{SmallDict{Int,Vector{LocalMemorySpan}}} + # The set of AliasingWrappers that only exist in a single memory space + # One entry for each AliasingWrapper + ainfos_only_space::Vector{Int} + # The bounding span for each AliasingWrapper in each memory space + # One entry for each AliasingWrapper + bounding_spans::Vector{SmallDict{Int,LocalMemorySpan}} + # The interval tree of the bounding spans for each AliasingWrapper + # One entry for each MemorySpace + bounding_spans_tree::Vector{IntervalTree{LocatorMemorySpan{Int},UInt64}} + + AliasingLookup() = new(MemorySpace[], + AliasingWrapper[], + Vector{Int}[], + SmallDict{Int,Vector{LocalMemorySpan}}[], + Int[], + SmallDict{Int,LocalMemorySpan}[], + IntervalTree{LocatorMemorySpan{Int},UInt64}[]) +end +function Base.push!(lookup::AliasingLookup, ainfo::AliasingWrapper) + # Update the set of memory spaces and spans, + # and find the bounding spans for this AliasingWrapper + spaces_set = Set{MemorySpace}(lookup.spaces) + self_spaces_set = Set{Int}() + spans = SmallDict{Int,Vector{LocalMemorySpan}}() + for span in memory_spans(ainfo) + space = span.ptr.space + if !in(space, spaces_set) + push!(spaces_set, space) + push!(lookup.spaces, space) + push!(lookup.bounding_spans_tree, IntervalTree{LocatorMemorySpan{Int}}()) + end + space_idx = findfirst(==(space), lookup.spaces) + push!(self_spaces_set, space_idx) + spans_in_space = get!(Vector{LocalMemorySpan}, spans, space_idx) + push!(spans_in_space, LocalMemorySpan(span)) + end + push!(lookup.ainfos_spaces, collect(self_spaces_set)) + push!(lookup.spans, spans) + + # Update the set of AliasingWrappers + push!(lookup.ainfos, ainfo) + ainfo_idx = length(lookup.ainfos) + + # Check if the AliasingWrapper only exists in a single memory space + if length(self_spaces_set) == 1 + space_idx = only(self_spaces_set) + push!(lookup.ainfos_only_space, space_idx) + else + push!(lookup.ainfos_only_space, 0) + end + + # Add the bounding spans for this AliasingWrapper + bounding_spans = SmallDict{Int,LocalMemorySpan}() + for space_idx in keys(spans) + space_spans = spans[space_idx] + bound_start = minimum(span_start, space_spans) + bound_end = maximum(span_end, space_spans) + bounding_span = LocalMemorySpan(bound_start, bound_end - bound_start) + bounding_spans[space_idx] = bounding_span + insert!(lookup.bounding_spans_tree[space_idx], LocatorMemorySpan(bounding_span, ainfo_idx)) + end + push!(lookup.bounding_spans, bounding_spans) + + return ainfo_idx +end +struct AliasingLookupFinder + lookup::AliasingLookup + ainfo::AliasingWrapper + ainfo_idx::Int + spaces_idx::Vector{Int} + to_consider::Vector{Int} +end +Base.eltype(::AliasingLookupFinder) = AliasingWrapper +Base.IteratorSize(::AliasingLookupFinder) = Base.SizeUnknown() +# FIXME: We should use a Dict{UInt,Int} to find the ainfo_idx instead of linear search +function Base.intersect(lookup::AliasingLookup, ainfo::AliasingWrapper; ainfo_idx=nothing) + if ainfo_idx === nothing + ainfo_idx = something(findfirst(==(ainfo), lookup.ainfos)) + end + spaces_idx = lookup.ainfos_spaces[ainfo_idx] + to_consider_spans = LocatorMemorySpan{Int}[] + for space_idx in spaces_idx + bounding_spans_tree = lookup.bounding_spans_tree[space_idx] + self_bounding_span = LocatorMemorySpan(lookup.bounding_spans[ainfo_idx][space_idx], 0) + find_overlapping!(bounding_spans_tree, self_bounding_span, to_consider_spans; exact=false) + end + to_consider = Int[locator.owner for locator in to_consider_spans] + @assert all(to_consider .> 0) + return AliasingLookupFinder(lookup, ainfo, ainfo_idx, spaces_idx, to_consider) +end +Base.iterate(finder::AliasingLookupFinder) = iterate(finder, 1) +function Base.iterate(finder::AliasingLookupFinder, cursor_ainfo_idx) + ainfo_spaces = nothing + cursor_space_idx = 1 + + # New ainfos enter here + @label ainfo_restart + + # Check if we've exhausted all ainfos + if cursor_ainfo_idx > length(finder.to_consider) + return nothing + end + ainfo_idx = finder.to_consider[cursor_ainfo_idx] + + # Find the appropriate memory spaces for this ainfo + if ainfo_spaces === nothing + ainfo_spaces = finder.lookup.ainfos_spaces[ainfo_idx] + end + + # New memory spaces (for the same ainfo) enter here + @label space_restart + + # Check if we've exhausted all memory spaces for this ainfo, and need to move to the next ainfo + if cursor_space_idx > length(ainfo_spaces) + cursor_ainfo_idx += 1 + ainfo_spaces = nothing + cursor_space_idx = 1 + @goto ainfo_restart + end + + # Find the currently considered memory space for this ainfo + space_idx = ainfo_spaces[cursor_space_idx] + + # Check if this memory space is part of our target ainfo's spaces + if !(space_idx in finder.spaces_idx) + cursor_space_idx += 1 + @goto space_restart + end + + # Check if this ainfo's bounding span is part of our target ainfo's bounding span in this memory space + other_ainfo_bounding_span = finder.lookup.bounding_spans[ainfo_idx][space_idx] + self_bounding_span = finder.lookup.bounding_spans[finder.ainfo_idx][space_idx] + if !spans_overlap(other_ainfo_bounding_span, self_bounding_span) + cursor_space_idx += 1 + @goto space_restart + end + + # We have a overlapping bounds in the same memory space, so check if the ainfos are aliasing + # This is the slow path! + other_ainfo = finder.lookup.ainfos[ainfo_idx] + aliasing = will_alias(finder.ainfo, other_ainfo) + if !aliasing + cursor_ainfo_idx += 1 + ainfo_spaces = nothing + cursor_space_idx = 1 + @goto ainfo_restart + end + + # We overlap, so return the ainfo and the next ainfo index + return other_ainfo, cursor_ainfo_idx+1 +end struct NoAliasing <: AbstractAliasing end memory_spans(::NoAliasing) = MemorySpan{CPURAMMemorySpace}[] @@ -151,8 +322,11 @@ struct CombinedAliasing <: AbstractAliasing end function memory_spans(ca::CombinedAliasing) # FIXME: Don't hardcode CPURAMMemorySpace - all_spans = MemorySpan{CPURAMMemorySpace}[] - for sub_a in ca.sub_ainfos + if length(ca.sub_ainfos) == 0 + return MemorySpan{CPURAMMemorySpace}[] + end + all_spans = memory_spans(ca.sub_ainfos[1]) + for sub_a in ca.sub_ainfos[2:end] append!(all_spans, memory_spans(sub_a)) end return all_spans @@ -213,8 +387,14 @@ end aliasing(::String) = NoAliasing() # FIXME: Not necessarily true aliasing(::Symbol) = NoAliasing() aliasing(::Type) = NoAliasing() -aliasing(x::Chunk, T) = remotecall_fetch(root_worker_id(x.processor), x, T) do x, T - aliasing(unwrap(x), T) +function aliasing(x::Chunk, T) + @assert x.handle isa DRef + if root_worker_id(x.processor) == myid() + return aliasing(unwrap(x), T) + end + return remotecall_fetch(root_worker_id(x.processor), x, T) do x, T + aliasing(unwrap(x), T) + end end aliasing(x::Chunk) = remotecall_fetch(root_worker_id(x.processor), x) do x aliasing(unwrap(x)) @@ -273,13 +453,22 @@ function _memory_spans(a::StridedAliasing{T,N,S}, spans, ptr, dim) where {T,N,S} return spans end -function aliasing(x::SubArray{T,N,A}) where {T,N,A<:Array} +function aliasing(x::SubArray{T,N}) where {T,N} if isbitstype(T) - S = CPURAMMemorySpace - return StridedAliasing{T,ndims(x),S}(RemotePtr{Cvoid}(pointer(parent(x))), - RemotePtr{Cvoid}(pointer(x)), - parentindices(x), - size(x), strides(parent(x))) + p = parent(x) + space = memory_space(p) + S = typeof(space) + parent_ptr = RemotePtr{Cvoid}(UInt64(pointer(p)), space) + ptr = RemotePtr{Cvoid}(UInt64(pointer(x)), space) + NA = ndims(p) + raw_inds = parentindices(x) + inds = ntuple(i->raw_inds[i] isa Integer ? (raw_inds[i]:raw_inds[i]) : UnitRange(raw_inds[i]), NA) + sz = ntuple(i->length(inds[i]), NA) + return StridedAliasing{T,NA,S}(parent_ptr, + ptr, + inds, + sz, + strides(p)) else # FIXME: Also ContiguousAliasing of container #return IteratedAliasing(x) @@ -396,76 +585,8 @@ end function will_alias(x_span::MemorySpan, y_span::MemorySpan) may_alias(x_span.ptr.space, y_span.ptr.space) || return false # FIXME: Allow pointer conversion instead of just failing - @assert x_span.ptr.space == y_span.ptr.space + @assert x_span.ptr.space == y_span.ptr.space "Memory spans are in different spaces: $(x_span.ptr.space) vs. $(y_span.ptr.space)" x_end = x_span.ptr + x_span.len - 1 y_end = y_span.ptr + y_span.len - 1 return x_span.ptr <= y_end && y_span.ptr <= x_end end - -struct ChunkView{N} - chunk::Chunk - slices::NTuple{N, Union{Int, AbstractRange{Int}, Colon}} -end - -function Base.view(c::Chunk, slices...) - if c.domain isa ArrayDomain - nd, sz = ndims(c.domain), size(c.domain) - nd == length(slices) || throw(DimensionMismatch("Expected $nd slices, got $(length(slices))")) - - for (i, s) in enumerate(slices) - if s isa Int - 1 ≤ s ≤ sz[i] || throw(ArgumentError("Index $s out of bounds for dimension $i (size $(sz[i]))")) - elseif s isa AbstractRange - isempty(s) && continue - 1 ≤ first(s) ≤ last(s) ≤ sz[i] || throw(ArgumentError("Range $s out of bounds for dimension $i (size $(sz[i]))")) - elseif s === Colon() - continue - else - throw(ArgumentError("Invalid slice type $(typeof(s)) at dimension $i, Expected Type of Int, AbstractRange, or Colon")) - end - end - end - - return ChunkView(c, slices) -end - -Base.view(c::DTask, slices...) = view(fetch(c; raw=true), slices...) - -function aliasing(x::ChunkView{N}) where N - remotecall_fetch(root_worker_id(x.chunk.processor), x.chunk, x.slices) do x, slices - x = unwrap(x) - v = view(x, slices...) - return aliasing(v) - end -end -memory_space(x::ChunkView) = memory_space(x.chunk) -isremotehandle(x::ChunkView) = true - -#= -function move!(dep_mod, to_space::MemorySpace, from_space::MemorySpace, to::ChunkView, from::ChunkView) - to_w = root_worker_id(to_space) - @assert to_w == myid() - to_raw = unwrap(to.chunk) - from_w = root_worker_id(from_space) - from_raw = to_w == from_w ? unwrap(from.chunk) : remotecall_fetch(f->copy(unwrap(f)), from_w, from.chunk) - from_view = view(from_raw, from.slices...) - to_view = view(to_raw, to.slices...) - move!(dep_mod, to_space, from_space, to_view, from_view) - return -end -=# - -function move(from_proc::Processor, to_proc::Processor, slice::ChunkView) - if from_proc == to_proc - return view(unwrap(slice.chunk), slice.slices...) - else - # Need to copy the underlying data, so collapse the view - from_w = root_worker_id(from_proc) - data = remotecall_fetch(from_w, slice.chunk, slice.slices) do chunk, slices - copy(view(unwrap(chunk), slices...)) - end - return move(from_proc, to_proc, data) - end -end - -Base.fetch(slice::ChunkView) = view(fetch(slice.chunk), slice.slices...) \ No newline at end of file diff --git a/src/queue.jl b/src/queue.jl index c8c6007ec..37947a0ac 100644 --- a/src/queue.jl +++ b/src/queue.jl @@ -1,32 +1,63 @@ -mutable struct DTaskSpec - fargs::Vector{Argument} +mutable struct DTaskSpec{typed,FA<:Tuple} + _fargs::Vector{Argument} + _typed_fargs::FA options::Options end +DTaskSpec(fargs::Vector{Argument}, options::Options) = + DTaskSpec{false, Tuple{}}(fargs, (), options) +DTaskSpec(fargs::FA, options::Options) where FA = + DTaskSpec{true, FA}(Argument[], fargs, options) +is_typed(spec::DTaskSpec{typed}) where typed = typed +function Base.getproperty(spec::DTaskSpec{typed}, field::Symbol) where typed + if field === :fargs + if typed + return getfield(spec, :_typed_fargs) + else + return getfield(spec, :_fargs) + end + else + return getfield(spec, field) + end +end + +struct DTaskPair + spec::DTaskSpec + task::DTask +end +is_typed(pair::DTaskPair) = is_typed(pair.spec) +Base.iterate(pair::DTaskPair) = (pair.spec, true) +function Base.iterate(pair::DTaskPair, state::Bool) + if state + return (pair.task, false) + else + return nothing + end +end abstract type AbstractTaskQueue end function enqueue! end struct DefaultTaskQueue <: AbstractTaskQueue end -enqueue!(::DefaultTaskQueue, spec::Pair{DTaskSpec,DTask}) = - eager_launch!(spec) -enqueue!(::DefaultTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}) = - eager_launch!(specs) +enqueue!(::DefaultTaskQueue, pair::DTaskPair) = + eager_launch!(pair) +enqueue!(::DefaultTaskQueue, pairs::Vector{DTaskPair}) = + eager_launch!(pairs) -enqueue!(spec::Pair{DTaskSpec,DTask}) = - enqueue!(get_options(:task_queue, DefaultTaskQueue()), spec) -enqueue!(specs::Vector{Pair{DTaskSpec,DTask}}) = - enqueue!(get_options(:task_queue, DefaultTaskQueue()), specs) +enqueue!(pair::DTaskPair) = + enqueue!(get_options(:task_queue, DefaultTaskQueue()), pair) +enqueue!(pairs::Vector{DTaskPair}) = + enqueue!(get_options(:task_queue, DefaultTaskQueue()), pairs) struct LazyTaskQueue <: AbstractTaskQueue - tasks::Vector{Pair{DTaskSpec,DTask}} - LazyTaskQueue() = new(Pair{DTaskSpec,DTask}[]) + tasks::Vector{DTaskPair} + LazyTaskQueue() = new(DTaskPair[]) end -function enqueue!(queue::LazyTaskQueue, spec::Pair{DTaskSpec,DTask}) - push!(queue.tasks, spec) +function enqueue!(queue::LazyTaskQueue, pair::DTaskPair) + push!(queue.tasks, pair) end -function enqueue!(queue::LazyTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}) - append!(queue.tasks, specs) +function enqueue!(queue::LazyTaskQueue, pairs::Vector{DTaskPair}) + append!(queue.tasks, pairs) end function spawn_bulk(f::Base.Callable) queue = LazyTaskQueue() @@ -50,25 +81,25 @@ function _add_prev_deps!(queue::InOrderTaskQueue, spec::DTaskSpec) push!(syncdeps, ThunkSyncdep(task)) end end -function enqueue!(queue::InOrderTaskQueue, spec::Pair{DTaskSpec,DTask}) +function enqueue!(queue::InOrderTaskQueue, pair::DTaskPair) if length(queue.prev_tasks) > 0 - _add_prev_deps!(queue, first(spec)) + _add_prev_deps!(queue, pair.spec) empty!(queue.prev_tasks) end - push!(queue.prev_tasks, last(spec)) - enqueue!(queue.upper_queue, spec) + push!(queue.prev_tasks, pair.task) + enqueue!(queue.upper_queue, pair) end -function enqueue!(queue::InOrderTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}) +function enqueue!(queue::InOrderTaskQueue, pairs::Vector{DTaskPair}) if length(queue.prev_tasks) > 0 - for (spec, task) in specs - _add_prev_deps!(queue, spec) + for pair in pairs + _add_prev_deps!(queue, pair.spec) end empty!(queue.prev_tasks) end - for (spec, task) in specs - push!(queue.prev_tasks, task) + for pair in pairs + push!(queue.prev_tasks, pair.task) end - enqueue!(queue.upper_queue, specs) + enqueue!(queue.upper_queue, pairs) end function spawn_sequential(f::Base.Callable) queue = InOrderTaskQueue(get_options(:task_queue, DefaultTaskQueue())) @@ -79,15 +110,15 @@ struct WaitAllQueue <: AbstractTaskQueue upper_queue::AbstractTaskQueue tasks::Vector{DTask} end -function enqueue!(queue::WaitAllQueue, spec::Pair{DTaskSpec,DTask}) - push!(queue.tasks, spec[2]) - enqueue!(queue.upper_queue, spec) +function enqueue!(queue::WaitAllQueue, pair::DTaskPair) + push!(queue.tasks, pair.task) + enqueue!(queue.upper_queue, pair) end -function enqueue!(queue::WaitAllQueue, specs::Vector{Pair{DTaskSpec,DTask}}) - for (_, task) in specs - push!(queue.tasks, task) +function enqueue!(queue::WaitAllQueue, pairs::Vector{DTaskPair}) + for pair in pairs + push!(queue.tasks, pair.task) end - enqueue!(queue.upper_queue, specs) + enqueue!(queue.upper_queue, pairs) end function wait_all(f; check_errors::Bool=false) queue = WaitAllQueue(get_options(:task_queue, DefaultTaskQueue()), DTask[]) diff --git a/src/sch/util.jl b/src/sch/util.jl index 3f9d7b2f6..d3b7a4804 100644 --- a/src/sch/util.jl +++ b/src/sch/util.jl @@ -238,7 +238,7 @@ function set_failed!(state, origin, thunk=origin) @dagdebug thunk :finish "Setting as failed" filter!(x -> x !== thunk, state.ready) # N.B. If origin === thunk, we assume that the caller has already set the error - if origin !== thunk + if origin !== thunk && !has_result(state, thunk) origin_ex = load_result(state, origin) if origin_ex isa RemoteException origin_ex = origin_ex.captured diff --git a/src/scopes.jl b/src/scopes.jl index ba291bc2b..79190c292 100644 --- a/src/scopes.jl +++ b/src/scopes.jl @@ -40,6 +40,9 @@ struct UnionScope <: AbstractScope push!(scope_set, scope) end end + if isempty(scope_set) + throw(ArgumentError("Cannot construct UnionScope with no inner scopes")) + end return new((collect(scope_set)...,)) end end diff --git a/src/stream.jl b/src/stream.jl index bf1ea4537..c5e6641f6 100644 --- a/src/stream.jl +++ b/src/stream.jl @@ -372,21 +372,21 @@ function migrate_stream!(stream::Stream, w::Integer=myid()) end struct StreamingTaskQueue <: AbstractTaskQueue - tasks::Vector{Pair{DTaskSpec,DTask}} + tasks::Vector{DTaskPair} self_streams::Dict{UInt,Any} - StreamingTaskQueue() = new(Pair{DTaskSpec,DTask}[], + StreamingTaskQueue() = new(DTaskPair[], Dict{UInt,Any}()) end -function enqueue!(queue::StreamingTaskQueue, spec::Pair{DTaskSpec,DTask}) - push!(queue.tasks, spec) - initialize_streaming!(queue.self_streams, spec...) +function enqueue!(queue::StreamingTaskQueue, pair::DTaskPair) + push!(queue.tasks, pair) + initialize_streaming!(queue.self_streams, pair.spec, pair.task) end -function enqueue!(queue::StreamingTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}) - append!(queue.tasks, specs) - for (spec, task) in specs - initialize_streaming!(queue.self_streams, spec, task) +function enqueue!(queue::StreamingTaskQueue, pairs::Vector{DTaskPair}) + append!(queue.tasks, pairs) + for pair in pairs + initialize_streaming!(queue.self_streams, pair.spec, pair.task) end end @@ -458,7 +458,7 @@ function spawn_streaming(f::Base.Callable; teardown::Bool=true) if teardown # Start teardown monitor - dtasks = map(last, queue.tasks)::Vector{DTask} + dtasks = map(pair->pair.task, queue.tasks)::Vector{DTask} Sch.errormonitor_tracked("streaming teardown", Threads.@spawn begin # Wait for any task to finish waitany(dtasks) @@ -663,10 +663,12 @@ function task_to_stream(uid::UInt) end end -function finalize_streaming!(tasks::Vector{Pair{DTaskSpec,DTask}}, self_streams) +function finalize_streaming!(tasks::Vector{DTaskPair}, self_streams) stream_waiter_changes = Dict{UInt,Vector{Pair{UInt,Any}}}() - for (spec, task) in tasks + for pair in tasks + spec = pair.spec + task = pair.task @assert haskey(self_streams, task.uid) our_stream = self_streams[task.uid] diff --git a/src/submission.jl b/src/submission.jl index 2e7b1c836..4ff4f2294 100644 --- a/src/submission.jl +++ b/src/submission.jl @@ -268,24 +268,29 @@ function eager_process_elem_submission_to_local!(id_map, arg::Argument) arg.value = Sch.ThunkID(id_map[value(arg).uid], value(arg).thunk_ref) end end -function eager_process_args_submission_to_local!(id_map, spec_pair::Pair{DTaskSpec,DTask}) - spec, task = spec_pair +function eager_process_elem_submission_to_local(id_map, arg::TypedArgument{T}) where T + @assert !(T <: Thunk) "Cannot use `Thunk`s in `@spawn`/`spawn`" + if T <: DTask && haskey(id_map, (value(arg)::DTask).uid) + #=FIXME:UNIQUE=# + return Sch.ThunkID(id_map[value(arg).uid], value(arg).thunk_ref) + end + return arg +end +function eager_process_args_submission_to_local!(id_map, spec::DTaskSpec{false}) for arg in spec.fargs eager_process_elem_submission_to_local!(id_map, arg) end end -function eager_process_args_submission_to_local!(id_map, spec_pairs::Vector{Pair{DTaskSpec,DTask}}) - for spec_pair in spec_pairs - eager_process_args_submission_to_local!(id_map, spec_pair) - end +function eager_process_args_submission_to_local(id_map, spec::DTaskSpec{true}) + return ntuple(i->eager_process_elem_submission_to_local(id_map, spec.fargs[i]), length(spec.fargs)) end -function DTaskMetadata(spec::DTaskSpec) - f = value(spec.fargs[1]) +DTaskMetadata(spec::DTaskSpec) = DTaskMetadata(eager_metadata(spec.fargs)) +function eager_metadata(fargs) + f = value(fargs[1]) f = f isa StreamingFunction ? f.f : f - arg_types = ntuple(i->chunktype(value(spec.fargs[i+1])), length(spec.fargs)-1) - return_type = Base.promote_op(f, arg_types...) - return DTaskMetadata(return_type) + arg_types = ntuple(i->chunktype(value(fargs[i+1])), length(fargs)-1) + return Base.promote_op(f, arg_types...) end function eager_spawn(spec::DTaskSpec) @@ -298,48 +303,64 @@ end chunktype(t::DTask) = t.metadata.return_type -function eager_launch!((spec, task)::Pair{DTaskSpec,DTask}) +function eager_launch!(pair::DTaskPair) + spec = pair.spec + task = pair.task + # Assign a name, if specified eager_assign_name!(spec, task) # Lookup DTask -> ThunkID - lock(Sch.EAGER_ID_MAP) do id_map - eager_process_args_submission_to_local!(id_map, spec=>task) + fargs = lock(Sch.EAGER_ID_MAP) do id_map + if is_typed(spec) + return Argument[map(Argument, eager_process_args_submission_to_local(id_map, spec))...] + else + eager_process_args_submission_to_local!(id_map, spec) + return spec.fargs + end end # Submit the task #=FIXME:REALLOC=# thunk_id = eager_submit!(PayloadOne(task.uid, task.future, - spec.fargs, spec.options, true)) + fargs, spec.options, true)) task.thunk_ref = thunk_id.ref end -function eager_launch!(specs::Vector{Pair{DTaskSpec,DTask}}) - ntasks = length(specs) +# FIXME: Don't convert Tuple to Vector{Argument} +function eager_launch!(pairs::Vector{DTaskPair}) + ntasks = length(pairs) # Assign a name, if specified - for (spec, task) in specs - eager_assign_name!(spec, task) + for pair in pairs + eager_assign_name!(pair.spec, pair.task) end #=FIXME:REALLOC_N=# - uids = [task.uid for (_, task) in specs] - futures = [task.future for (_, task) in specs] + uids = [pair.task.uid for pair in pairs] + futures = [pair.task.future for pair in pairs] # Get all functions, args/kwargs, and options #=FIXME:REALLOC_N=# all_fargs = lock(Sch.EAGER_ID_MAP) do id_map # Lookup DTask -> ThunkID - eager_process_args_submission_to_local!(id_map, specs) - [spec.fargs for (spec, _) in specs] + return map(pairs) do pair + spec = pair.spec + if is_typed(spec) + return Argument[map(Argument, eager_process_args_submission_to_local(id_map, spec))...] + else + eager_process_args_submission_to_local!(id_map, spec) + return spec.fargs + end + end end - all_options = Options[spec.options for (spec, _) in specs] + all_options = Options[pair.spec.options for pair in pairs] # Submit the tasks #=FIXME:REALLOC=# thunk_ids = eager_submit!(PayloadMulti(ntasks, uids, futures, all_fargs, all_options, true)) for i in 1:ntasks - task = specs[i][2] + task = pairs[i].task task.thunk_ref = thunk_ids[i].ref end end diff --git a/src/thunk.jl b/src/thunk.jl index 482d66209..e13e299f0 100644 --- a/src/thunk.jl +++ b/src/thunk.jl @@ -17,8 +17,6 @@ function unset!(spec::ThunkSpec, _) spec.id = 0 spec.cache_ref = nothing spec.affinity = nothing - compute_scope = DefaultScope() - result_scope = AnyScope() spec.options = nothing end @@ -186,21 +184,19 @@ function args_kwargs_to_arguments(f, args, kwargs) end return args_kwargs end -function args_kwargs_to_arguments(f, args) - @nospecialize f args - args_kwargs = Argument[] - push!(args_kwargs, Argument(ArgPosition(true, 0, :NULL), f)) - pos_ctr = 1 - for idx in 1:length(args) - pos, arg = args[idx]::Pair - if pos === nothing - push!(args_kwargs, Argument(pos_ctr, arg)) - pos_ctr += 1 +function args_kwargs_to_typedarguments(f, args, kwargs) + nargs = 1 + length(args) + length(kwargs) + return ntuple(nargs) do idx + if idx == 1 + return TypedArgument(ArgPosition(true, 0, :NULL), f) + elseif idx in 2:(1+length(args)) + arg = args[idx-1] + return TypedArgument(idx, arg) else - push!(args_kwargs, Argument(pos, arg)) + kw, value = kwargs[idx-length(args)-1] + return TypedArgument(kw, value) end end - return args_kwargs end """ @@ -491,7 +487,11 @@ function _par(mod, ex::Expr; lazy=true, recur=true, opts=()) @gensym result return quote let - $result = $spawn($f, $Options(;$(opts...)), $(args...); $(kwargs...)) + $result = if $get_task_typed() + $typed_spawn($f, $Options(;$(opts...)), $(args...); $(kwargs...)) + else + $spawn($f, $Options(;$(opts...)), $(args...); $(kwargs...)) + end if $(Expr(:islocal, sync_var)) put!($sync_var, schedule(Task(()->fetch($result; raw=true)))) end @@ -516,6 +516,9 @@ function _setindex!_return_value(A, value, idxs...) return value end +const TASK_TYPED = ScopedValue{Bool}(false) +get_task_typed() = TASK_TYPED[] + """ Dagger.spawn(f, args...; kwargs...) -> DTask @@ -526,6 +529,36 @@ Spawns a `DTask` that will call `f(args...; kwargs...)`. Also supports passing a function spawn(f, args...; kwargs...) @nospecialize f args kwargs + # Merge all passed options + if length(args) >= 1 && first(args) isa Options + # N.B. Make a defensive copy in case user aliases Options struct + task_options = copy(first(args)::Options) + args = args[2:end] + else + task_options = Options() + end + + # Process the args and kwargs into Argument form + args_kwargs = args_kwargs_to_arguments(f, args, kwargs) + + return _spawn(args_kwargs, task_options) +end +function typed_spawn(f, args...; kwargs...) + # Merge all passed options + if length(args) >= 1 && first(args) isa Options + # N.B. Make a defensive copy in case user aliases Options struct + task_options = copy(first(args)::Options) + args = args[2:end] + else + task_options = Options() + end + + # Process the args and kwargs into Tuple of TypedArgument form + args_kwargs = args_kwargs_to_typedarguments(f, args, kwargs) + + return _spawn(args_kwargs, task_options) +end +function _spawn(args_kwargs, task_options) # Get all scoped options and determine which propagate beyond this task scoped_options = get_options()::NamedTuple if haskey(scoped_options, :propagates) @@ -539,20 +572,9 @@ function spawn(f, args...; kwargs...) end append!(propagates, keys(scoped_options)::NTuple{N,Symbol} where N) - # Merge all passed options - if length(args) >= 1 && first(args) isa Options - # N.B. Make a defensive copy in case user aliases Options struct - task_options = copy(first(args)::Options) - args = args[2:end] - else - task_options = Options() - end # N.B. Merges into task_options options_merge!(task_options, scoped_options; override=false) - # Process the args and kwargs into Pair form - args_kwargs = args_kwargs_to_arguments(f, args, kwargs) - # Get task queue, and don't let it propagate task_queue = get(scoped_options, :task_queue, DefaultTaskQueue())::AbstractTaskQueue filter!(prop -> prop != :task_queue, propagates) @@ -568,7 +590,7 @@ function spawn(f, args...; kwargs...) task = eager_spawn(spec) # Enqueue the task into the task queue - enqueue!(task_queue, spec=>task) + enqueue!(task_queue, DTaskPair(spec, task)) return task end diff --git a/src/utils/chunks.jl b/src/utils/chunks.jl index 400b49332..9f0c3b487 100644 --- a/src/utils/chunks.jl +++ b/src/utils/chunks.jl @@ -174,6 +174,9 @@ function tochunk(x::Chunk, proc=nothing, scope=nothing; rewrap=false, kwargs...) end tochunk(x::Thunk, proc=nothing, scope=nothing; kwargs...) = x +root_worker_id(chunk::Chunk) = root_worker_id(chunk.handle) +root_worker_id(dref::DRef) = dref.owner # FIXME: Migration + function savechunk(data, dir, f) sz = open(joinpath(dir, f), "w") do io serialize(io, MemPool.MMWrap(data)) diff --git a/src/utils/dagdebug.jl b/src/utils/dagdebug.jl index 615030400..873e47e79 100644 --- a/src/utils/dagdebug.jl +++ b/src/utils/dagdebug.jl @@ -35,3 +35,28 @@ macro dagdebug(thunk, category, msg, args...) end end) end + +# FIXME: Calculate fast-growth based on clock time, not iteration +const OPCOUNTER_CATEGORIES = Symbol[] +const OPCOUNTER_FAST_GROWTH_THRESHOLD = Ref(10_000_000) +struct OpCounter + value::Threads.Atomic{Int} +end +OpCounter() = OpCounter(Threads.Atomic{Int}(0)) +macro opcounter(category, count=1) + cat_sym = category.value + @gensym old + opcounter_sym = Symbol(:OPCOUNTER_, cat_sym) + if !isdefined(__module__, opcounter_sym) + __module__.eval(:(#=const=# $opcounter_sym = OpCounter())) + end + esc(quote + if $(QuoteNode(cat_sym)) in $OPCOUNTER_CATEGORIES + $old = Threads.atomic_add!($opcounter_sym.value, Int($count)) + if $old > 1 && (mod1($old, $OPCOUNTER_FAST_GROWTH_THRESHOLD[]) == 1 || $count > $OPCOUNTER_FAST_GROWTH_THRESHOLD[]) + println("Fast-growing counter: $($(QuoteNode(cat_sym))) = $($old)") + end + end + end) +end +opcounter(mod::Module, category::Symbol) = getfield(mod, Symbol(:OPCOUNTER_, category)).value[] \ No newline at end of file diff --git a/src/utils/haloarray.jl b/src/utils/haloarray.jl index 1fadbeeb6..e34cf2b05 100644 --- a/src/utils/haloarray.jl +++ b/src/utils/haloarray.jl @@ -99,4 +99,42 @@ Adapt.adapt_structure(to, H::Dagger.HaloArray) = HaloArray(Adapt.adapt(to, H.center), Adapt.adapt.(Ref(to), H.edges), Adapt.adapt.(Ref(to), H.corners), - H.halo_width) \ No newline at end of file + H.halo_width) + +function aliasing(A::HaloArray) + return CombinedAliasing([aliasing(A.center), map(aliasing, A.edges)..., map(aliasing, A.corners)...]) +end +memory_space(A::HaloArray) = memory_space(A.center) +function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, A::HaloArray) + center_chunk = rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, A.center) + edge_chunks = ntuple(i->rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, A.edges[i]), length(A.edges)) + corner_chunks = ntuple(i->rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, A.corners[i]), length(A.corners)) + halo_width = A.halo_width + to_w = root_worker_id(to_proc) + return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, center_chunk, edge_chunks, corner_chunks, halo_width) do from_proc, to_proc, from_space, to_space, center_chunk, edge_chunks, corner_chunks, halo_width + center_new = move(from_proc, to_proc, center_chunk) + edges_new = ntuple(i->move(from_proc, to_proc, edge_chunks[i]), length(edge_chunks)) + corners_new = ntuple(i->move(from_proc, to_proc, corner_chunks[i]), length(corner_chunks)) + return tochunk(HaloArray(center_new, edges_new, corners_new, halo_width), to_proc) + end +end +function find_object_holding_ptr(object::HaloArray, ptr::UInt64) + for i in 1:length(object.edges) + edge = object.edges[i] + span = LocalMemorySpan(pointer(edge), length(edge)*sizeof(eltype(edge))) + if span_start(span) <= ptr <= span_end(span) + return edge + end + end + for i in 1:length(object.corners) + corner = object.corners[i] + span = LocalMemorySpan(pointer(corner), length(corner)*sizeof(eltype(corner))) + if span_start(span) <= ptr <= span_end(span) + return corner + end + end + center = object.center + span = LocalMemorySpan(pointer(center), length(center)*sizeof(eltype(center))) + @assert span_start(span) <= ptr <= span_end(span) "Pointer $ptr not found in HaloArray" + return center +end \ No newline at end of file diff --git a/src/utils/interval_tree.jl b/src/utils/interval_tree.jl new file mode 100644 index 000000000..8960dd69d --- /dev/null +++ b/src/utils/interval_tree.jl @@ -0,0 +1,452 @@ +mutable struct IntervalNode{M,E} + span::M + max_end::E # Maximum end value in this subtree + left::Union{IntervalNode{M,E}, Nothing} + right::Union{IntervalNode{M,E}, Nothing} + + IntervalNode(span::M) where M <: MemorySpan = new{M,UInt64}(span, span_end(span), nothing, nothing) + IntervalNode(span::LocalMemorySpan) = new{LocalMemorySpan,UInt64}(span, span_end(span), nothing, nothing) + IntervalNode(span::ManyMemorySpan{N}) where N = new{ManyMemorySpan{N},ManyPair{N}}(span, span_end(span), nothing, nothing) + IntervalNode(span::LocatorMemorySpan{T}) where T = new{LocatorMemorySpan{T},UInt64}(span, span_end(span), nothing, nothing) +end + +mutable struct IntervalTree{M,E} + root::Union{IntervalNode{M,E}, Nothing} + + IntervalTree{M}() where M<:MemorySpan = new{M,UInt64}(nothing) + IntervalTree{LocalMemorySpan}() = new{LocalMemorySpan,UInt64}(nothing) + IntervalTree{ManyMemorySpan{N}}() where N = new{ManyMemorySpan{N},ManyPair{N}}(nothing) + IntervalTree{LocatorMemorySpan{T}}() where T = new{LocatorMemorySpan{T},UInt64}(nothing) +end + +# Construct interval tree from unsorted set of spans +function IntervalTree{M}(spans) where M + tree = IntervalTree{M}() + for span in spans + insert!(tree, span) + end + verify_spans(tree) + return tree +end +IntervalTree(spans::Vector{M}) where M = IntervalTree{M}(spans) + +function Base.show(io::IO, tree::IntervalTree) + println(io, "$(typeof(tree)) (with $(length(tree)) spans):") + for (i, span) in enumerate(tree) + println(io, " $i: [$(span_start(span)), $(span_end(span))) (len=$(span_len(span)))") + end +end + +function Base.collect(tree::IntervalTree{M}) where M + result = M[] + for span in tree + push!(result, span) + end + return result +end + +# Useful for debugging when spans get misaligned +function verify_spans(tree::IntervalTree{ManyMemorySpan{N}}) where N + for span in tree + verify_span(span) + end +end + +function Base.iterate(tree::IntervalTree{M}) where M + state = Vector{M}() + if tree.root === nothing + return nothing + end + return iterate(tree.root) +end +function Base.iterate(tree::IntervalTree, state) + return iterate(tree.root, state) +end +function Base.iterate(root::IntervalNode{M,E}) where {M,E} + state = Vector{IntervalNode{M,E}}() + push!(state, root) + return iterate(root, state) +end +function Base.iterate(root::IntervalNode, state) + if isempty(state) + return nothing + end + current = popfirst!(state) + if current.right !== nothing + pushfirst!(state, current.right) + end + if current.left !== nothing + pushfirst!(state, current.left) + end + return current.span, state +end + +function Base.length(tree::IntervalTree) + result = 0 + for _ in tree + result += 1 + end + return result +end + +# Update max_end value for a node based on its children +function update_max_end!(node::IntervalNode) + max_end = span_end(node.span) + if node.left !== nothing + max_end = max(max_end, node.left.max_end) + end + if node.right !== nothing + max_end = max(max_end, node.right.max_end) + end + node.max_end = max_end +end + +# Insert a span into the interval tree +function Base.insert!(tree::IntervalTree{M,E}, span::M) where {M,E} + if !isempty(span) + if tree.root === nothing + tree.root = IntervalNode(span) + update_max_end!(tree.root) + return span + end + #tree.root = insert_node!(tree.root, span) + to_update = Vector{IntervalNode{M,E}}() + prev_node = tree.root + cur_node = tree.root + while cur_node !== nothing + if span_start(span) <= span_start(cur_node.span) + cur_node = cur_node.left + else + cur_node = cur_node.right + end + if cur_node !== nothing + prev_node = cur_node + push!(to_update, cur_node) + end + end + if prev_node.left === nothing + prev_node.left = IntervalNode(span) + else + prev_node.right = IntervalNode(span) + end + for node_idx in eachindex(to_update) + node = to_update[node_idx] + update_max_end!(node) + end + end + return span +end + +function insert_node!(::Nothing, span::M) where M + return IntervalNode(span) +end +function insert_node!(root::IntervalNode{M,E}, span::M) where {M,E} + # Use a queue to track the path for updating max_end after insertion + path = Vector{IntervalNode{M,E}}() + current = root + + # Traverse to find the insertion point + while current !== nothing + push!(path, current) + if span_start(span) <= span_start(current.span) + if current.left === nothing + current.left = IntervalNode(span) + break + end + current = current.left + else + if current.right === nothing + current.right = IntervalNode(span) + break + end + current = current.right + end + end + + # Update max_end for all ancestors (process in reverse order) + while !isempty(path) + node = pop!(path) + update_max_end!(node) + end + + return root +end + +# Remove a specific span from the tree (split as needed) +function Base.delete!(tree::IntervalTree{M}, span::M) where M + if !isempty(span) + tree.root = delete_node!(tree.root, span) + end + return span +end + +function delete_node!(::Nothing, span::M) where M + return nothing +end +function delete_node!(root::IntervalNode{M,E}, span::M) where {M,E} + # Track the path to the target node: (node, direction_to_child) + path = Vector{Tuple{IntervalNode{M,E}, Symbol}}() + current = root + target = nothing + target_type = :none # :exact or :overlap + + # Phase 1: Search for target node + while current !== nothing + is_exact = span_start(current.span) == span_start(span) && span_len(current.span) == span_len(span) + is_overlap = !is_exact && spans_overlap(current.span, span) + + if is_exact + target = current + target_type = :exact + break + elseif is_overlap + target = current + target_type = :overlap + break + elseif span_start(span) <= span_start(current.span) + push!(path, (current, :left)) + current = current.left + else + push!(path, (current, :right)) + current = current.right + end + end + + if target === nothing + return root + end + + # Phase 2: Compute replacement for target node + original_span = target.span + succ_path = Vector{IntervalNode{M,E}}() # Path to successor (for max_end updates) + local replacement::Union{IntervalNode{M,E}, Nothing} + + if target.left === nothing && target.right === nothing + # Leaf node + replacement = nothing + elseif target.left === nothing + # Only right child + replacement = target.right + elseif target.right === nothing + # Only left child + replacement = target.left + else + # Two children - find and remove inorder successor + successor = find_min(target.right) + + if target.right === successor + # Successor is direct right child + target.right = successor.right + else + # Track path to successor for max_end updates + succ_parent = target.right + push!(succ_path, succ_parent) + while succ_parent.left !== successor + succ_parent = succ_parent.left + push!(succ_path, succ_parent) + end + # Remove successor by replacing with its right child + succ_parent.left = successor.right + end + + target.span = successor.span + replacement = target + end + + # Phase 3: Handle overlap case - add remaining portions + if target_type == :overlap + original_start = span_start(original_span) + original_end = span_end(original_span) + del_start = span_start(span) + del_end = span_end(span) + verify_span(span) + + # Left portion: exists if original starts before deleted span + if original_start < del_start + left_end = min(original_end, del_start) + if left_end > original_start + left_span = M(original_start, left_end - original_start) + if !isempty(left_span) + replacement = insert_node!(replacement, left_span) + end + end + end + + # Right portion: exists if original extends beyond deleted span + if original_end > del_end + right_start = max(original_start, del_end) + if original_end > right_start + right_span = M(right_start, original_end - right_start) + if !isempty(right_span) + replacement = insert_node!(replacement, right_span) + end + end + end + end + + # Phase 4: Update parent's child pointer + if isempty(path) + root = replacement + else + parent, dir = path[end] + if dir == :left + parent.left = replacement + else + parent.right = replacement + end + end + + # Phase 5: Update max_end in correct order (bottom-up) + # First: successor path (if any) + for i in length(succ_path):-1:1 + update_max_end!(succ_path[i]) + end + # Second: target node (if it wasn't removed) + if replacement === target + update_max_end!(target) + end + # Third: main path (ancestors of target) + for i in length(path):-1:1 + update_max_end!(path[i][1]) + end + + return root +end + +function find_min(node::IntervalNode) + while node.left !== nothing + node = node.left + end + return node +end + +# Find all spans that overlap with the given query span +function find_overlapping(tree::IntervalTree{M}, query::M; exact::Bool=true) where M + result = M[] + find_overlapping!(tree.root, query, result; exact) + return result +end +function find_overlapping!(tree::IntervalTree{M}, query::M, result::Vector{M}; exact::Bool=true) where M + find_overlapping!(tree.root, query, result; exact) + return result +end + +function find_overlapping!(::Nothing, query::M, result::Vector{M}; exact::Bool=true) where M + return +end +function find_overlapping!(node::IntervalNode{M,E}, query::M, result::Vector{M}; exact::Bool=true) where {M,E} + # Use a queue for breadth-first traversal + queue = Vector{IntervalNode{M,E}}() + push!(queue, node) + + while !isempty(queue) + current = popfirst!(queue) + + # Check if current node overlaps with query + if spans_overlap(current.span, query) + if exact + # Get the overlapping portion of the span + overlap = span_diff(current.span, query) + if !isempty(overlap) + push!(result, overlap) + end + else + push!(result, current.span) + end + end + + # Enqueue left subtree if it might contain overlapping intervals + if current.left !== nothing && current.left.max_end > span_start(query) + push!(queue, current.left) + end + + # Enqueue right subtree if query extends beyond current node's start + if current.right !== nothing && span_end(query) > span_start(current.span) + push!(queue, current.right) + end + end +end + +# ============================================================================ +# MAIN SUBTRACTION ALGORITHM +# ============================================================================ + +""" + subtract_spans!(minuend_tree::IntervalTree{M}, subtrahend_spans::Vector{M}, diff=nothing) where M + +Subtract all spans in subtrahend_spans from the minuend_tree in-place. +The minuend_tree is modified to contain only the portions that remain after subtraction. + +Time Complexity: O(M log N + M*K) where M = |subtrahend_spans|, N = |minuend nodes|, + K = average overlaps per subtrahend span +Space Complexity: O(1) additional space (modifies tree in-place) + +If `diff` is provided, add the overlapping spans to `diff`. +""" +function subtract_spans!(minuend_tree::IntervalTree{M}, subtrahend_spans::Vector{M}, diff=nothing) where M + for sub_span in subtrahend_spans + subtract_single_span!(minuend_tree, sub_span, diff) + end +end + +""" + subtract_single_span!(tree::IntervalTree, sub_span::MemorySpan, diff=nothing) + +Subtract a single span from the interval tree. This function: +1. Finds all overlapping spans in the tree +2. Removes each overlapping span +3. Adds back the non-overlapping portions (left and/or right remnants) +4. If diff is provided, add the overlapping span to diff +""" +function subtract_single_span!(tree::IntervalTree{M}, sub_span::M, diff=nothing) where M + # Find all spans that overlap with the subtrahend + overlapping_spans = find_overlapping(tree, sub_span) + + # Process each overlapping span + for overlap_span in overlapping_spans + # Remove the overlapping span from the tree + delete!(tree, overlap_span) + + # Calculate and add back the portions that should remain + add_remaining_portions!(tree, overlap_span, sub_span) + + if diff !== nothing && !isempty(overlap_span) + push!(diff, overlap_span) + end + end +end + +""" + add_remaining_portions!(tree::IntervalTree, original::MemorySpan, subtracted::MemorySpan) + +After removing an overlapping span, add back the portions that don't overlap with the subtracted span. +There can be up to two remaining portions: left and right of the subtracted region. +""" +function add_remaining_portions!(tree::IntervalTree{M}, original::M, subtracted::M) where M + original_start = span_start(original) + original_end = span_end(original) + sub_start = span_start(subtracted) + sub_end = span_end(subtracted) + + # Left portion: exists if original starts before subtracted + if original_start < sub_start + left_end = min(original_end, sub_start) + if left_end > original_start + left_span = M(original_start, left_end - original_start) + if !isempty(left_span) + insert!(tree, left_span) + end + end + end + + # Right portion: exists if original extends beyond subtracted + if original_end > sub_end + right_start = max(original_start, sub_end) + if original_end > right_start + right_span = M(right_start, original_end - right_start) + if !isempty(right_span) + insert!(tree, right_span) + end + end + end +end diff --git a/src/utils/memory-span.jl b/src/utils/memory-span.jl new file mode 100644 index 000000000..e6140c726 --- /dev/null +++ b/src/utils/memory-span.jl @@ -0,0 +1,144 @@ +### Remote pointer type + +struct RemotePtr{T,S<:MemorySpace} <: Ref{T} + addr::UInt + space::S +end +RemotePtr{T}(addr::UInt, space::S) where {T,S} = RemotePtr{T,S}(addr, space) +RemotePtr{T}(ptr::Ptr{V}, space::S) where {T,V,S} = RemotePtr{T,S}(UInt(ptr), space) +RemotePtr{T}(ptr::Ptr{V}) where {T,V} = RemotePtr{T}(UInt(ptr), CPURAMMemorySpace(myid())) +# FIXME: Don't hardcode CPURAMMemorySpace +RemotePtr(addr::UInt) = RemotePtr{Cvoid}(addr, CPURAMMemorySpace(myid())) +Base.convert(::Type{RemotePtr}, x::Ptr{T}) where T = + RemotePtr(UInt(x), CPURAMMemorySpace(myid())) +Base.convert(::Type{<:RemotePtr{V}}, x::Ptr{T}) where {V,T} = + RemotePtr{V}(UInt(x), CPURAMMemorySpace(myid())) +Base.convert(::Type{UInt}, ptr::RemotePtr) = ptr.addr +Base.:+(ptr::RemotePtr{T}, offset::Integer) where T = RemotePtr{T}(ptr.addr + offset, ptr.space) +Base.:-(ptr::RemotePtr{T}, offset::Integer) where T = RemotePtr{T}(ptr.addr - offset, ptr.space) +function Base.isless(ptr1::RemotePtr, ptr2::RemotePtr) + @assert ptr1.space == ptr2.space + return ptr1.addr < ptr2.addr +end + +### Generic memory spans + +struct MemorySpan{S} + ptr::RemotePtr{Cvoid,S} + len::UInt +end +MemorySpan(ptr::RemotePtr{Cvoid,S}, len::Integer) where S = + MemorySpan{S}(ptr, UInt(len)) +MemorySpan{S}(addr::UInt, len::Integer) where S = + MemorySpan{S}(RemotePtr{Cvoid,S}(addr), UInt(len)) +Base.isless(a::MemorySpan, b::MemorySpan) = a.ptr < b.ptr +Base.isempty(x::MemorySpan) = x.len == 0 +span_start(span::MemorySpan) = span.ptr.addr +span_len(span::MemorySpan) = span.len +span_end(span::MemorySpan) = span.ptr.addr + span.len +spans_overlap(span1::MemorySpan, span2::MemorySpan) = + span_start(span1) < span_end(span2) && span_start(span2) < span_end(span1) +function span_diff(span1::MemorySpan, span2::MemorySpan) + @assert span1.ptr.space == span2.ptr.space + start = max(span_start(span1), span_start(span2)) + stop = min(span_end(span1), span_end(span2)) + start_ptr = RemotePtr(start, span1.ptr.space) + if start < stop + len = stop - start + return MemorySpan(start_ptr, len) + else + return MemorySpan(start_ptr, 0) + end +end + +### More space-efficient memory spans + +struct LocalMemorySpan + ptr::UInt + len::UInt +end +LocalMemorySpan(span::MemorySpan) = LocalMemorySpan(span.ptr.addr, span.len) +Base.isempty(x::LocalMemorySpan) = x.len == 0 +span_start(span::LocalMemorySpan) = span.ptr +span_len(span::LocalMemorySpan) = span.len +span_end(span::LocalMemorySpan) = span.ptr + span.len +spans_overlap(span1::LocalMemorySpan, span2::LocalMemorySpan) = + span_start(span1) < span_end(span2) && span_start(span2) < span_end(span1) +function span_diff(span1::LocalMemorySpan, span2::LocalMemorySpan) + start = max(span_start(span1), span_start(span2)) + stop = min(span_end(span1), span_end(span2)) + if start < stop + len = stop - start + return LocalMemorySpan(start, len) + else + return LocalMemorySpan(start, 0) + end +end + +# FIXME: Store the length separately, since it's shared by all spans +struct ManyMemorySpan{N} + spans::NTuple{N,LocalMemorySpan} +end +Base.isempty(x::ManyMemorySpan) = all(isempty, x.spans) +span_start(span::ManyMemorySpan{N}) where N = ManyPair(ntuple(i -> span_start(span.spans[i]), N)) +span_len(span::ManyMemorySpan{N}) where N = ManyPair(ntuple(i -> span_len(span.spans[i]), N)) +span_end(span::ManyMemorySpan{N}) where N = ManyPair(ntuple(i -> span_end(span.spans[i]), N)) +spans_overlap(span1::ManyMemorySpan{N}, span2::ManyMemorySpan{N}) where N = + # N.B. The spans are assumed to be the same length and relative offset + spans_overlap(span1.spans[1], span2.spans[1]) +function span_diff(span1::ManyMemorySpan{N}, span2::ManyMemorySpan{N}) where N + verify_span(span1) + verify_span(span2) + span = ManyMemorySpan(ntuple(i -> span_diff(span1.spans[i], span2.spans[i]), N)) + matches = ntuple(i->span1.spans[i].ptr == span2.spans[i].ptr, Val(N)) + @assert !(any(matches) && !all(matches)) "Spans only partially match:\n Span1: $span1\n Span2: $span2\n Result: $span" + @assert allequal(span_len, span.spans) "Uneven span_diff result:\n Span1: $span1\n Span2: $span2\n Result: $span" + verify_span(span) + return span +end +const VERIFY_SPAN_CURRENT_OBJECT = TaskLocalValue{Any}(()->nothing) +function verify_span(span::ManyMemorySpan{N}) where N + @assert allequal(span_len, span.spans) "All spans must be the same: $(map(span_len, span.spans))\nWhile processing $(typeof(VERIFY_SPAN_CURRENT_OBJECT[]))" +end + +struct ManyPair{N} <: Unsigned + pairs::NTuple{N,UInt} +end +Base.promote_rule(::Type{ManyPair}, ::Type{T}) where {T<:Integer} = ManyPair +Base.convert(::Type{ManyPair{N}}, pair::ManyPair{N}) where N = pair +Base.convert(::Type{ManyPair{N}}, x::T) where {T<:Integer,N} = ManyPair(ntuple(i -> x, N)) +Base.convert(::Type{ManyPair}, x::ManyPair) = x +Base.:+(x::ManyPair{N}, y::ManyPair{N}) where N = ManyPair(ntuple(i -> x.pairs[i] + y.pairs[i], N)) +Base.:-(x::ManyPair{N}, y::ManyPair{N}) where N = ManyPair(ntuple(i -> x.pairs[i] - y.pairs[i], N)) +Base.:-(x::ManyPair) = error("Can't negate a ManyPair") +Base.:(==)(x::ManyPair, y::ManyPair) = x.pairs == y.pairs +Base.isless(x::ManyPair, y::ManyPair) = x.pairs[1] < y.pairs[1] +Base.:(<)(x::ManyPair, y::ManyPair) = x.pairs[1] < y.pairs[1] +Base.string(x::ManyPair) = "ManyPair($(x.pairs))" +Base.show(io::IO, x::ManyPair) = print(io, string(x)) + +ManyMemorySpan{N}(start::ManyPair{N}, len::ManyPair{N}) where N = + ManyMemorySpan{N}(ntuple(i -> LocalMemorySpan(start.pairs[i], len.pairs[i]), N)) + +### Memory spans with ownership info + +struct LocatorMemorySpan{T} + span::LocalMemorySpan + owner::T +end +LocatorMemorySpan{T}(start::UInt64, len::UInt64) where T = # For interval tree + LocatorMemorySpan{T}(LocalMemorySpan(start, len), 0) +Base.isempty(x::LocatorMemorySpan) = span_len(x.span) == 0 +span_start(x::LocatorMemorySpan) = span_start(x.span) +span_end(x::LocatorMemorySpan) = span_end(x.span) +span_len(x::LocatorMemorySpan) = span_len(x.span) +spans_overlap(span1::LocatorMemorySpan{T}, span2::LocatorMemorySpan{T}) where T = + spans_overlap(span1.span, span2.span) +function span_diff(span1::LocatorMemorySpan{T}, span2::LocatorMemorySpan{T}) where T + span = LocatorMemorySpan(span_diff(span1.span, span2.span), 0) + verify_span(span) + return span +end +function verify_span(span::LocatorMemorySpan{T}) where T + verify_span(span.span) +end \ No newline at end of file diff --git a/test/datadeps.jl b/test/datadeps.jl index cd83be95f..faf33b9b9 100644 --- a/test/datadeps.jl +++ b/test/datadeps.jl @@ -1,4 +1,5 @@ -import Dagger: ChunkView, Chunk +import Dagger: ChunkView, Chunk, AbstractAliasing, MemorySpace, ArgumentWrapper +import Dagger: aliasing, memory_space using LinearAlgebra, Graphs @testset "Memory Aliasing" begin @@ -82,7 +83,7 @@ end end function with_logs(f) - Dagger.enable_logging!(;taskdeps=true, taskargs=true) + Dagger.enable_logging!(;taskdeps=true, taskargs=true, timeline=true) try f() return Dagger.fetch_logs!() @@ -108,68 +109,296 @@ function taskdeps_for_task(logs::Dict{Int,<:Dict}, tid::Int) end error("Task $tid not found in logs") end -function test_task_dominators(logs::Dict, tid::Int, doms::Vector; all_tids::Vector=[], nondom_check::Bool=false) - g = SimpleDiGraph() - tid_to_v = Dict{Int,Int}() +function all_tasks_in_logs(logs::Dict) + all_tids = Int[] + for w in keys(logs) + _logs = logs[w] + for idx in 1:length(_logs[:core]) + core_log = _logs[:core][idx] + id_log = _logs[:id][idx] + if core_log.category == :add_thunk && core_log.kind == :finish + tid = id_log.thunk_id::Int + push!(all_tids, tid) + end + end + end + return all_tids +end +mutable struct FlowEntry + kind::Symbol + tid::Int + ainfo::AbstractAliasing + to_ainfo::AbstractAliasing + from_space::MemorySpace + to_space::MemorySpace + read::Bool + write::Bool +end +struct FlowCheck + read::Bool + write::Bool + arg_w::ArgumentWrapper + orig_ainfo::AbstractAliasing + orig_space::MemorySpace + function FlowCheck(kind, arg, dep_mod=identity) + if kind == :read + read = true + write = false + elseif kind == :write + read = false + write = true + elseif kind == :readwrite + read = true + write = true + else + error("Invalid kind: $kind") + end + arg_w = maybe_rewrap_arg_w(ArgumentWrapper(arg, dep_mod)) + return new(read, write, arg_w, aliasing(arg, dep_mod), memory_space(arg)) + end +end +struct FlowGraph + g::SimpleDiGraph + tid_to_v::Dict{Int,Int} + FlowGraph() = new(SimpleDiGraph(), Dict{Int,Int}()) +end +struct FlowState + flows::Dict{ArgumentWrapper,Vector{FlowEntry}} + graph::FlowGraph + FlowState() = new(Dict{ArgumentWrapper,Vector{FlowEntry}}(), FlowGraph()) +end +function maybe_rewrap_arg_w(arg_w::ArgumentWrapper) + arg = arg_w.arg + if arg isa DTask + arg = fetch(arg; raw=true) + end + if arg isa Chunk && Dagger.root_worker_id(arg) == myid() + arg = Dagger.unwrap(arg) + end + return ArgumentWrapper(arg, arg_w.dep_mod) +end +function build_dataflow(logs::Dict; verbose::Bool=false) + state = FlowState() + orig_ainfos = Dict{AbstractAliasing,AbstractAliasing}() + ainfo_arg_w = Dict{AbstractAliasing,ArgumentWrapper}() + + function add_execute!(arg_w, orig_ainfo, ainfo, tid, space, read, write) + ainfo_flows = get!(Vector{FlowEntry}, state.flows, arg_w) + # Skip duplicates (same arg 2+ times to same task) + dup_idx = findfirst(flow->flow.tid == tid, ainfo_flows) + if dup_idx === nothing + if !haskey(orig_ainfos, ainfo) + orig_ainfos[ainfo] = orig_ainfo + end + if !haskey(ainfo_arg_w, ainfo) + ainfo_arg_w[ainfo] = arg_w + end + verbose && println("Adding execute flow (tid $tid, space $space, read $read, write $write):\n $orig_ainfo ->\n $ainfo") + verbose && println(" $(arg_w.dep_mod), $(arg_w.arg)") + push!(ainfo_flows, FlowEntry(:execute, tid, ainfo, ainfo, space, space, read, write)) + else + # Union read and write fields + ainfo_flows[dup_idx].read |= read + ainfo_flows[dup_idx].write |= write + end + end + function add_copy!(arg_w, from_arg, to_arg, tid, from_space, to_space) + dep_mod = arg_w.dep_mod + from_ainfo = aliasing(from_arg, dep_mod) + to_ainfo = aliasing(to_arg, dep_mod) + if !haskey(orig_ainfos, from_ainfo) + orig_ainfos[from_ainfo] = from_ainfo + end + if !haskey(ainfo_arg_w, from_ainfo) + ainfo_arg_w[from_ainfo] = arg_w + end + if !haskey(ainfo_arg_w, to_ainfo) + ainfo_arg_w[to_ainfo] = arg_w + end + orig_ainfo = orig_ainfos[from_ainfo] + orig_ainfos[to_ainfo] = orig_ainfo + arg_flows = get!(Vector{FlowEntry}, state.flows, arg_w) + verbose && println("Adding copy flow (tid $tid, from_space $from_space, to_space $to_space):\n $orig_ainfo ->\n $to_ainfo") + verbose && println(" $(arg_w.dep_mod), $(arg_w.arg)") + push!(arg_flows, FlowEntry(:copy, tid, from_ainfo, to_ainfo, from_space, to_space, true, true)) + end + + # Populate graph from syncdeps seen = Set{Int}() - to_visit = copy(all_tids) + to_visit = all_tasks_in_logs(logs) while !isempty(to_visit) this_tid = popfirst!(to_visit) this_tid in seen && continue push!(seen, this_tid) - if !(this_tid in keys(tid_to_v)) - add_vertex!(g); tid_to_v[this_tid] = nv(g) + if !(this_tid in keys(state.graph.tid_to_v)) + add_vertex!(state.graph.g); state.graph.tid_to_v[this_tid] = nv(state.graph.g) end # Add syncdeps deps = taskdeps_for_task(logs, this_tid) for dep in deps - if !(dep in keys(tid_to_v)) - add_vertex!(g); tid_to_v[dep] = nv(g) + if !(dep in keys(state.graph.tid_to_v)) + add_vertex!(state.graph.g); state.graph.tid_to_v[dep] = nv(state.graph.g) end - add_edge!(g, tid_to_v[this_tid], tid_to_v[dep]) + add_edge!(state.graph.g, state.graph.tid_to_v[this_tid], state.graph.tid_to_v[dep]) push!(to_visit, dep) end end - state = dijkstra_shortest_paths(g, tid_to_v[tid]) - any_failed = false - @test !has_edge(g, tid_to_v[tid], tid_to_v[tid]) - any_failed |= has_edge(g, tid_to_v[tid], tid_to_v[tid]) - for dom in doms - @test state.pathcounts[tid_to_v[dom]] > 0 - if state.pathcounts[tid_to_v[dom]] == 0 - println("Expected dominance for $dom of $tid") - any_failed = true - end - end - if nondom_check - for nondom in all_tids - nondom == tid && continue - nondom in doms && continue - @test state.pathcounts[tid_to_v[nondom]] == 0 - if state.pathcounts[tid_to_v[nondom]] > 0 - println("Expected non-dominance for $nondom of $tid") - any_failed = true + + # Populate flows and graphs from datadeps logs + for w in keys(logs) + _logs = logs[w] + for idx in 1:length(_logs[:core]) + core_log = _logs[:core][idx] + id_log = _logs[:id][idx] + tl_log = _logs[:timeline][idx] + if core_log.category == :datadeps_execute && core_log.kind == :finish + tid = id_log.thunk_id + for (remote_arg, depset) in zip(tl_log.args, tl_log.deps) + for dep in depset.deps + arg_w = maybe_rewrap_arg_w(dep.arg_w) + orig_ainfo = aliasing(arg_w.arg, arg_w.dep_mod) + remote_ainfo = aliasing(remote_arg, arg_w.dep_mod) + space = memory_space(remote_arg) + add_execute!(arg_w, orig_ainfo, remote_ainfo, tid, space, dep.readdep, dep.writedep) + end + end + elseif (core_log.category == :datadeps_copy || core_log.category == :datadeps_copy_skip) && core_log.kind == :finish + tid = tl_log.thunk_id + from_space = tl_log.from_space + to_space = tl_log.to_space + from_arg = tl_log.from_arg + to_arg = tl_log.to_arg + arg_w = maybe_rewrap_arg_w(tl_log.arg_w) + add_copy!(arg_w, from_arg, to_arg, tid, from_space, to_space) end end end - # For debugging purposes - if any_failed - println("Failure detected!") - println("Root: $tid") - println("Exp. doms: $doms") - println("All: $all_tids") - e_vs = collect(edges(g)) - e_tids = map(e->Edge(only(filter(tv->tv[2]==src(e), tid_to_v))[1], - only(filter(tv->tv[2]==dst(e), tid_to_v))[1]), - e_vs) - sort!(e_tids) - for e in e_tids - s_tid, d_tid = src(e), dst(e) - println("Edge: $s_tid -(up)> $d_tid") + return state +end +function test_dataflow(state::FlowState, checks...; verbose::Bool=true) + # Check that each ainfo starts and ends in the same space + for arg_w in keys(state.flows) + ainfo = aliasing(arg_w.arg, arg_w.dep_mod) + arg_flows = state.flows[arg_w] + orig_space = memory_space(arg_w.arg) #arg_flows[1].from_space + #=if ainfo != arg_flows[1].ainfo + verbose && println("Ainfo key $(ainfo) is not the same as the first flow's ainfo $(ainfo_flows[1].ainfo)") + return false + end=# + final_space = arg_flows[end].to_space + # FIXME: will_alias doesn't check across spaces + any_writes = any(flows->Dagger.will_alias(flows[1], ainfo) && any(flow->flow.write, flows[2]), state.flows) + if orig_space != final_space + if verbose + println("Arg ($(arg_w.dep_mod), $(arg_w.arg)) starts in $(orig_space) but ends in $(final_space)") + for flow in arg_flows + println(" $(flow.kind) $(flow.tid) $(flow.from_space) -> $(flow.to_space)") + end + end + return false end end + + # Check each flow against the previous flow, ensuring that the previous flow is a dominator of the current flow + # FIXME: Validate non-dominance when unnecessary? + for arg_w in keys(state.flows) + arg_flows = state.flows[arg_w] + for (idx, flow) in enumerate(arg_flows) + if idx > 1 + prev_flow = arg_flows[idx-1] + if !prev_flow.write && !flow.write + # R->R don't depend on each other + continue + end + if !prev_flow.write && flow.write && prev_flow.kind == :execute && flow.kind == :copy && prev_flow.ainfo != flow.to_ainfo + # Copy only writes to a different ainfo, so don't depend on each other + continue + end + if flow.tid == 0 + # Ignore copy skip flows + continue + end + v = state.graph.tid_to_v[flow.tid] + prev_v = state.graph.tid_to_v[prev_flow.tid] + path_state = dijkstra_shortest_paths(state.graph.g, v; allpaths=true) + if path_state.pathcounts[prev_v] == 0 + if verbose + println("Flow $(idx-1) (tid $(prev_flow.tid), $(prev_flow.kind), R:$(prev_flow.read), W:$(prev_flow.write)) is not a dominator of flow $(idx) (tid $(flow.tid), $(flow.kind), R:$(flow.read), W:$(flow.write))") + @show length(state.flows[arg_w]) + for flow in state.flows[arg_w] + println(" $(flow.kind) $(flow.tid) $(flow.from_space) -> $(flow.to_space) (R:$(flow.read), W:$(flow.write))") + end + for flow in state.flows[arg_w] + println(" May write to: $(flow.to_ainfo)") + end + e_vs = collect(edges(state.graph.g)) + e_tids = map(e->Edge(only(filter(tv->tv[2]==src(e), state.graph.tid_to_v))[1], + only(filter(tv->tv[2]==dst(e), state.graph.tid_to_v))[1]), + e_vs) + sort!(e_tids) + for e in e_tids + s_tid, d_tid = src(e), dst(e) + println("Edge: $s_tid -(up)> $d_tid") + end + end + return false + end + end + end + end + + # Walk through each check, ensuring that the current state of the flow matches the check + arg_locations = Dict{ArgumentWrapper,MemorySpace}() + flow_idxs = Dict{ArgumentWrapper,Int}(arg_w=>1 for arg_w in keys(state.flows)) + for (idx, check) in enumerate(checks) + # Record the original location of the ainfo + if !haskey(arg_locations, check.arg_w) + arg_locations[check.arg_w] = check.orig_space + end + + # Try to advance a flow + if !haskey(flow_idxs, check.arg_w) + if verbose + @warn "Didn't encounter argument ($(check.arg_w.dep_mod), $(check.arg_w.arg))" + println("Seen arguments:") + for arg_w in keys(state.flows) + println(" ($(arg_w.dep_mod), $(arg_w.arg))") + end + return false + end + end + flow_idx = flow_idxs[check.arg_w] + while true + if flow_idx > length(state.flows[check.arg_w]) + verbose && println("Exhausted all tasks while trying to find $(check.arg_w)") + return false + end + flow = state.flows[check.arg_w][flow_idx] + if flow.kind == :execute + # The current flow state must match the check + if flow.read == check.read && flow.write == check.write + # Match, move on to next check + flow_idx += 1 + break + else + verbose && println("Expected ($(check.read), $(check.write)), got ($(flow.read), $(flow.write))") + return false + end + elseif flow.kind == :copy + # We need to advance our ainfo location + # FIXME: Assert proper data progression (requires more complex tracking of other arguments) + #@assert flow.from_space == arg_locations[check.arg_w] + arg_locations[check.arg_w] = flow.to_space + flow_idx += 1 + end + end + + flow_idxs[check.arg_w] = flow_idx + end + + return true end @everywhere do_nothing(Xs...) = nothing @@ -177,16 +406,15 @@ end @everywhere mut_V!(V) = (V .= 1;) function test_datadeps(;args_chunks::Bool, args_thunks::Bool, - args_loc::Int, - aliasing::Bool) + args_loc::Int) # Returns last value - @test Dagger.spawn_datadeps(;aliasing) do + @test Dagger.spawn_datadeps() do 42 end == 42 # Tasks are started and finished as spawn_datadeps returns ts = [] - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do for i in 1:5 t = Dagger.@spawn sleep(0.1) @test !istaskstarted(t) @@ -195,7 +423,7 @@ function test_datadeps(;args_chunks::Bool, @test all(istaskdone, ts) # Rethrows any task exceptions - @test_throws Exception Dagger.spawn_datadeps(;aliasing) do + @test_throws Exception Dagger.spawn_datadeps() do Dagger.@spawn error("Test") end @@ -206,10 +434,13 @@ function test_datadeps(;args_chunks::Bool, A = Dagger.@spawn scope=Dagger.scope(worker=args_loc) copy(A) end + @warn "Negative-test the test_dataflow helper" + # Task return values can be tracked ts = [] + local t1 logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do t1 = Dagger.@spawn fill(42, 1) push!(ts, t1) push!(ts, Dagger.@spawn copyto!(Out(A), In(t1))) @@ -217,273 +448,277 @@ function test_datadeps(;args_chunks::Bool, end tid_1, tid_2 = task_id.(ts) @test fetch(A)[1] == 42.0 - test_task_dominators(logs, tid_1, []; all_tids=[tid_1, tid_2]) + state = build_dataflow(logs) + # FIXME: We don't record the task as a syncdep, but instead internally `fetch` the chunk - test_task_dominators(logs, tid_2, [#=tid_1=#]; all_tids=[tid_1, tid_2]) + # We don't see the :readwrite because we don't see the use of t1 + #@test test_dataflow(state, FlowCheck(:readwrite, t1)) + @test test_dataflow(state, FlowCheck(:read, t1), FlowCheck(:write, A)) # R->R Non-Aliasing ts = [] logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do push!(ts, Dagger.@spawn do_nothing(In(A))) push!(ts, Dagger.@spawn do_nothing(In(A))) end end tid_1, tid_2 = task_id.(ts) - test_task_dominators(logs, tid_1, []; all_tids=[tid_1, tid_2]) - test_task_dominators(logs, tid_2, []; all_tids=[tid_1, tid_2], nondom_check=false) + state = build_dataflow(logs) + test_dataflow(state, FlowCheck(:read, A), FlowCheck(:read, A)) # R->W Aliasing ts = [] logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do push!(ts, Dagger.@spawn do_nothing(In(A))) push!(ts, Dagger.@spawn do_nothing(Out(A))) end end tid_1, tid_2 = task_id.(ts) - test_task_dominators(logs, tid_1, []; all_tids=[tid_1, tid_2]) - test_task_dominators(logs, tid_2, [tid_1]; all_tids=[tid_1, tid_2]) + state = build_dataflow(logs) + @test test_dataflow(state, FlowCheck(:read, A), FlowCheck(:write, A)) # W->W Aliasing ts = [] logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do push!(ts, Dagger.@spawn do_nothing(Out(A))) push!(ts, Dagger.@spawn do_nothing(Out(A))) end end tid_1, tid_2 = task_id.(ts) - test_task_dominators(logs, tid_1, []; all_tids=[tid_1, tid_2]) - test_task_dominators(logs, tid_2, [tid_1]; all_tids=[tid_1, tid_2]) + state = build_dataflow(logs) + @test test_dataflow(state, FlowCheck(:write, A), FlowCheck(:write, A)) # R->R Non-Self-Aliasing ts = [] logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do push!(ts, Dagger.@spawn do_nothing(In(A), In(A))) push!(ts, Dagger.@spawn do_nothing(In(A), In(A))) end end tid_1, tid_2 = task_id.(ts) - test_task_dominators(logs, tid_1, []; all_tids=[tid_1, tid_2]) - test_task_dominators(logs, tid_2, []; all_tids=[tid_1, tid_2]) + state = build_dataflow(logs) + @test test_dataflow(state, FlowCheck(:read, A), FlowCheck(:read, A)) # R->W Self-Aliasing ts = [] logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do push!(ts, Dagger.@spawn do_nothing(In(A), In(A))) push!(ts, Dagger.@spawn do_nothing(Out(A), Out(A))) end end tid_1, tid_2 = task_id.(ts) - test_task_dominators(logs, tid_1, []; all_tids=[tid_1, tid_2]) - test_task_dominators(logs, tid_2, [tid_1]; all_tids=[tid_1, tid_2]) + state = build_dataflow(logs) + @test test_dataflow(state, FlowCheck(:read, A), FlowCheck(:write, A)) # W->W Self-Aliasing ts = [] logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do push!(ts, Dagger.@spawn do_nothing(Out(A), Out(A))) push!(ts, Dagger.@spawn do_nothing(Out(A), Out(A))) end end tid_1, tid_2 = task_id.(ts) - test_task_dominators(logs, tid_1, []; all_tids=[tid_1, tid_2]) - test_task_dominators(logs, tid_2, [tid_1]; all_tids=[tid_1, tid_2]) - - if aliasing - function wrap_chunk_thunk(f, args...) - if args_thunks || args_chunks - result = Dagger.@spawn scope=Dagger.scope(worker=args_loc) f(args...) - if args_thunks - return result - elseif args_chunks - return fetch(result; raw=true) - end - else - # N.B. We don't allocate remotely for raw data - return f(args...) + state = build_dataflow(logs) + @test test_dataflow(state, FlowCheck(:write, A), FlowCheck(:write, A)) + + function wrap_chunk_thunk(f, args...) + if args_thunks || args_chunks + result = Dagger.@spawn scope=Dagger.scope(worker=args_loc) f(args...) + if args_thunks + return result + elseif args_chunks + return fetch(result; raw=true) end + else + # N.B. We don't allocate remotely for raw data + return f(args...) end - B = wrap_chunk_thunk(rand, 4, 4) - - # Views - B_ul = wrap_chunk_thunk(view, B, 1:2, 1:2) - B_ur = wrap_chunk_thunk(view, B, 1:2, 3:4) - B_ll = wrap_chunk_thunk(view, B, 3:4, 1:2) - B_lr = wrap_chunk_thunk(view, B, 3:4, 3:4) - B_mid = wrap_chunk_thunk(view, B, 2:3, 2:3) - for (B_name, B_view) in ( - (:B_ul, B_ul), - (:B_ur, B_ur), - (:B_ll, B_ll), - (:B_lr, B_lr), - (:B_mid, B_mid)) - @test Dagger.will_alias(Dagger.aliasing(B), Dagger.aliasing(B_view)) - B_view === B_mid && continue - @test Dagger.will_alias(Dagger.aliasing(B_mid), Dagger.aliasing(B_view)) - end - local t_A, t_B, t_ul, t_ur, t_ll, t_lr, t_mid - local t_ul2, t_ur2, t_ll2, t_lr2 - logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do - t_A = Dagger.@spawn do_nothing(InOut(A)) - t_B = Dagger.@spawn do_nothing(InOut(B)) - t_ul = Dagger.@spawn do_nothing(InOut(B_ul)) - t_ur = Dagger.@spawn do_nothing(InOut(B_ur)) - t_ll = Dagger.@spawn do_nothing(InOut(B_ll)) - t_lr = Dagger.@spawn do_nothing(InOut(B_lr)) - t_mid = Dagger.@spawn do_nothing(InOut(B_mid)) - t_ul2 = Dagger.@spawn do_nothing(InOut(B_ul)) - t_ur2 = Dagger.@spawn do_nothing(InOut(B_ur)) - t_ll2 = Dagger.@spawn do_nothing(InOut(B_ll)) - t_lr2 = Dagger.@spawn do_nothing(InOut(B_lr)) - end + end + B = wrap_chunk_thunk(rand, 4, 4) + + # Views + B_ul = wrap_chunk_thunk(view, B, 1:2, 1:2) + B_ur = wrap_chunk_thunk(view, B, 1:2, 3:4) + B_ll = wrap_chunk_thunk(view, B, 3:4, 1:2) + B_lr = wrap_chunk_thunk(view, B, 3:4, 3:4) + B_mid = wrap_chunk_thunk(view, B, 2:3, 2:3) + for (B_name, B_view) in ( + (:B_ul, B_ul), + (:B_ur, B_ur), + (:B_ll, B_ll), + (:B_lr, B_lr), + (:B_mid, B_mid)) + @test Dagger.will_alias(Dagger.aliasing(B), Dagger.aliasing(B_view)) + B_view === B_mid && continue + @test Dagger.will_alias(Dagger.aliasing(B_mid), Dagger.aliasing(B_view)) + end + local t_A, t_B, t_ul, t_ur, t_ll, t_lr, t_mid + local t_ul2, t_ur2, t_ll2, t_lr2 + logs = with_logs() do + Dagger.spawn_datadeps() do + t_A = Dagger.@spawn do_nothing(InOut(A)) + t_B = Dagger.@spawn do_nothing(InOut(B)) + t_ul = Dagger.@spawn do_nothing(InOut(B_ul)) + t_ur = Dagger.@spawn do_nothing(InOut(B_ur)) + t_ll = Dagger.@spawn do_nothing(InOut(B_ll)) + t_lr = Dagger.@spawn do_nothing(InOut(B_lr)) + t_mid = Dagger.@spawn do_nothing(InOut(B_mid)) + t_ul2 = Dagger.@spawn do_nothing(InOut(B_ul)) + t_ur2 = Dagger.@spawn do_nothing(InOut(B_ur)) + t_ll2 = Dagger.@spawn do_nothing(InOut(B_ll)) + t_lr2 = Dagger.@spawn do_nothing(InOut(B_lr)) end - tid_A, tid_B, tid_ul, tid_ur, tid_ll, tid_lr, tid_mid = - task_id.([t_A, t_B, t_ul, t_ur, t_ll, t_lr, t_mid]) - tid_ul2, tid_ur2, tid_ll2, tid_lr2 = - task_id.([t_ul2, t_ur2, t_ll2, t_lr2]) - tids_all = [tid_A, tid_B, tid_ul, tid_ur, tid_ll, tid_lr, tid_mid, - tid_ul2, tid_ur2, tid_ll2, tid_lr2] - test_task_dominators(logs, tid_A, []; all_tids=tids_all) - test_task_dominators(logs, tid_B, []; all_tids=tids_all) - test_task_dominators(logs, tid_ul, [tid_B]; all_tids=tids_all) - test_task_dominators(logs, tid_ur, [tid_B]; all_tids=tids_all) - test_task_dominators(logs, tid_ll, [tid_B]; all_tids=tids_all) - test_task_dominators(logs, tid_lr, [tid_B]; all_tids=tids_all) - test_task_dominators(logs, tid_mid, [tid_B, tid_ul, tid_ur, tid_ll, tid_lr]; all_tids=tids_all) - test_task_dominators(logs, tid_ul2, [tid_B, tid_mid, tid_ul]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_ur2, [tid_B, tid_mid, tid_ur]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_ll2, [tid_B, tid_mid, tid_ll]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_lr2, [tid_B, tid_mid, tid_lr]; all_tids=tids_all, nondom_check=false) - - # (Unit)Upper/LowerTriangular and Diagonal - B_upper = wrap_chunk_thunk(UpperTriangular, B) - B_unitupper = wrap_chunk_thunk(UnitUpperTriangular, B) - B_lower = wrap_chunk_thunk(LowerTriangular, B) - B_unitlower = wrap_chunk_thunk(UnitLowerTriangular, B) - for (B_name, B_view) in ( - (:B_upper, B_upper), - (:B_unitupper, B_unitupper), - (:B_lower, B_lower), - (:B_unitlower, B_unitlower)) - @test Dagger.will_alias(Dagger.aliasing(B), Dagger.aliasing(B_view)) - end - @test Dagger.will_alias(Dagger.aliasing(B_upper), Dagger.aliasing(B_lower)) - @test !Dagger.will_alias(Dagger.aliasing(B_unitupper), Dagger.aliasing(B_unitlower)) - @test Dagger.will_alias(Dagger.aliasing(B_upper), Dagger.aliasing(B_unitupper)) - @test Dagger.will_alias(Dagger.aliasing(B_lower), Dagger.aliasing(B_unitlower)) - - @test Dagger.will_alias(Dagger.aliasing(B_upper), Dagger.aliasing(B, Diagonal)) - @test Dagger.will_alias(Dagger.aliasing(B_lower), Dagger.aliasing(B, Diagonal)) - @test !Dagger.will_alias(Dagger.aliasing(B_unitupper), Dagger.aliasing(B, Diagonal)) - @test !Dagger.will_alias(Dagger.aliasing(B_unitlower), Dagger.aliasing(B, Diagonal)) - - local t_A, t_B, t_upper, t_unitupper, t_lower, t_unitlower, t_diag - local t_upper2, t_unitupper2, t_lower2, t_unitlower2 - logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do - t_A = Dagger.@spawn do_nothing(InOut(A)) - t_B = Dagger.@spawn do_nothing(InOut(B)) - t_upper = Dagger.@spawn do_nothing(InOut(B_upper)) - t_unitupper = Dagger.@spawn do_nothing(InOut(B_unitupper)) - t_lower = Dagger.@spawn do_nothing(InOut(B_lower)) - t_unitlower = Dagger.@spawn do_nothing(InOut(B_unitlower)) - t_diag = Dagger.@spawn do_nothing(Deps(B, InOut(Diagonal))) - t_unitlower2 = Dagger.@spawn do_nothing(InOut(B_unitlower)) - t_lower2 = Dagger.@spawn do_nothing(InOut(B_lower)) - t_unitupper2 = Dagger.@spawn do_nothing(InOut(B_unitupper)) - t_upper2 = Dagger.@spawn do_nothing(InOut(B_upper)) - end + end + tid_A, tid_B, tid_ul, tid_ur, tid_ll, tid_lr, tid_mid = + task_id.([t_A, t_B, t_ul, t_ur, t_ll, t_lr, t_mid]) + tid_ul2, tid_ur2, tid_ll2, tid_lr2 = + task_id.([t_ul2, t_ur2, t_ll2, t_lr2]) + tids_all = [tid_A, tid_B, tid_ul, tid_ur, tid_ll, tid_lr, tid_mid, + tid_ul2, tid_ur2, tid_ll2, tid_lr2] + state = build_dataflow(logs) + @test test_dataflow(state, FlowCheck(:readwrite, A)) + @test test_dataflow(state, FlowCheck(:readwrite, B)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_ul)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_ur)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_ll)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_lr)) + for arg in [B_ul, B_ur, B_ll, B_lr] + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, arg), FlowCheck(:readwrite, B_mid), FlowCheck(:readwrite, arg)) + end + + # (Unit)Upper/LowerTriangular and Diagonal + B_upper = wrap_chunk_thunk(UpperTriangular, B) + B_unitupper = wrap_chunk_thunk(UnitUpperTriangular, B) + B_lower = wrap_chunk_thunk(LowerTriangular, B) + B_unitlower = wrap_chunk_thunk(UnitLowerTriangular, B) + for (B_name, B_view) in ( + (:B_upper, B_upper), + (:B_unitupper, B_unitupper), + (:B_lower, B_lower), + (:B_unitlower, B_unitlower)) + @test Dagger.will_alias(Dagger.aliasing(B), Dagger.aliasing(B_view)) + end + @test Dagger.will_alias(Dagger.aliasing(B_upper), Dagger.aliasing(B_lower)) + @test !Dagger.will_alias(Dagger.aliasing(B_unitupper), Dagger.aliasing(B_unitlower)) + @test Dagger.will_alias(Dagger.aliasing(B_upper), Dagger.aliasing(B_unitupper)) + @test Dagger.will_alias(Dagger.aliasing(B_lower), Dagger.aliasing(B_unitlower)) + + @test Dagger.will_alias(Dagger.aliasing(B_upper), Dagger.aliasing(B, Diagonal)) + @test Dagger.will_alias(Dagger.aliasing(B_lower), Dagger.aliasing(B, Diagonal)) + @test !Dagger.will_alias(Dagger.aliasing(B_unitupper), Dagger.aliasing(B, Diagonal)) + @test !Dagger.will_alias(Dagger.aliasing(B_unitlower), Dagger.aliasing(B, Diagonal)) + + local t_A, t_B, t_upper, t_unitupper, t_lower, t_unitlower, t_diag + local t_upper2, t_unitupper2, t_lower2, t_unitlower2 + logs = with_logs() do + Dagger.spawn_datadeps() do + t_A = Dagger.@spawn do_nothing(InOut(A)) + t_B = Dagger.@spawn do_nothing(InOut(B)) + t_upper = Dagger.@spawn do_nothing(InOut(B_upper)) + t_unitupper = Dagger.@spawn do_nothing(InOut(B_unitupper)) + t_lower = Dagger.@spawn do_nothing(InOut(B_lower)) + t_unitlower = Dagger.@spawn do_nothing(InOut(B_unitlower)) + t_diag = Dagger.@spawn do_nothing(Deps(B, InOut(Diagonal))) + t_unitlower2 = Dagger.@spawn do_nothing(InOut(B_unitlower)) + t_lower2 = Dagger.@spawn do_nothing(InOut(B_lower)) + t_unitupper2 = Dagger.@spawn do_nothing(InOut(B_unitupper)) + t_upper2 = Dagger.@spawn do_nothing(InOut(B_upper)) end - tid_A, tid_B, tid_upper, tid_unitupper, tid_lower, tid_unitlower, tid_diag = - task_id.([t_A, t_B, t_upper, t_unitupper, t_lower, t_unitlower, t_diag]) - tid_upper2, tid_unitupper2, tid_lower2, tid_unitlower2 = - task_id.([t_upper2, t_unitupper2, t_lower2, t_unitlower2]) - tids_all = [tid_A, tid_B, tid_upper, tid_unitupper, tid_lower, tid_unitlower, tid_diag, - tid_upper2, tid_unitupper2, tid_lower2, tid_unitlower2] - test_task_dominators(logs, tid_A, []; all_tids=tids_all) - test_task_dominators(logs, tid_B, []; all_tids=tids_all) - # FIXME: Proper non-dominance checks - test_task_dominators(logs, tid_upper, [tid_B]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_unitupper, [tid_B, tid_upper]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_lower, [tid_B, tid_upper]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_unitlower, [tid_B, tid_lower]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_diag, [tid_B, tid_upper, tid_lower]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_unitlower2, [tid_B, tid_lower, tid_unitlower]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_lower2, [tid_B, tid_lower, tid_unitlower, tid_diag, tid_unitlower2]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_unitupper2, [tid_B, tid_upper, tid_unitupper]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_upper2, [tid_B, tid_upper, tid_unitupper, tid_diag, tid_unitupper2]; all_tids=tids_all, nondom_check=false) - - # Additional aliasing tests - views_overlap(x, y) = Dagger.will_alias(Dagger.aliasing(x), Dagger.aliasing(y)) - - A = wrap_chunk_thunk(identity, B) - - A_r1 = wrap_chunk_thunk(view, A, 1:1, 1:4) - A_r2 = wrap_chunk_thunk(view, A, 2:2, 1:4) - B_r1 = wrap_chunk_thunk(view, B, 1:1, 1:4) - B_r2 = wrap_chunk_thunk(view, B, 2:2, 1:4) - - A_c1 = wrap_chunk_thunk(view, A, 1:4, 1:1) - A_c2 = wrap_chunk_thunk(view, A, 1:4, 2:2) - B_c1 = wrap_chunk_thunk(view, B, 1:4, 1:1) - B_c2 = wrap_chunk_thunk(view, B, 1:4, 2:2) - - A_mid = wrap_chunk_thunk(view, A, 2:3, 2:3) - B_mid = wrap_chunk_thunk(view, B, 2:3, 2:3) - - @test views_overlap(A_r1, A_r1) - @test views_overlap(B_r1, B_r1) - @test views_overlap(A_c1, A_c1) - @test views_overlap(B_c1, B_c1) - - @test views_overlap(A_r1, B_r1) - @test views_overlap(A_r2, B_r2) - @test views_overlap(A_c1, B_c1) - @test views_overlap(A_c2, B_c2) - - @test !views_overlap(A_r1, A_r2) - @test !views_overlap(B_r1, B_r2) - @test !views_overlap(A_c1, A_c2) - @test !views_overlap(B_c1, B_c2) - - @test views_overlap(A_r1, A_c1) - @test views_overlap(A_r1, B_c1) - @test views_overlap(A_r2, A_c2) - @test views_overlap(A_r2, B_c2) - - for (name, mid) in ((:A_mid, A_mid), (:B_mid, B_mid)) - @test !views_overlap(A_r1, mid) - @test !views_overlap(B_r1, mid) - @test !views_overlap(A_c1, mid) - @test !views_overlap(B_c1, mid) - - @test views_overlap(A_r2, mid) - @test views_overlap(B_r2, mid) - @test views_overlap(A_c2, mid) - @test views_overlap(B_c2, mid) - end - - @test views_overlap(A_mid, A_mid) - @test views_overlap(A_mid, B_mid) - - # SubArray hashing - V = zeros(3) - Dagger.spawn_datadeps(;aliasing) do - Dagger.@spawn mut_V!(InOut(view(V, 1:2))) - Dagger.@spawn mut_V!(InOut(view(V, 2:3))) - end - @test fetch(V) == [1, 1, 1] end + tid_A, tid_B, tid_upper, tid_unitupper, tid_lower, tid_unitlower, tid_diag = + task_id.([t_A, t_B, t_upper, t_unitupper, t_lower, t_unitlower, t_diag]) + tid_upper2, tid_unitupper2, tid_lower2, tid_unitlower2 = + task_id.([t_upper2, t_unitupper2, t_lower2, t_unitlower2]) + tids_all = [tid_A, tid_B, tid_upper, tid_unitupper, tid_lower, tid_unitlower, tid_diag, + tid_upper2, tid_unitupper2, tid_lower2, tid_unitlower2] + state = build_dataflow(logs) + @test test_dataflow(state, FlowCheck(:readwrite, A)) + @test test_dataflow(state, FlowCheck(:readwrite, B)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_upper), FlowCheck(:readwrite, B_unitupper)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_upper), FlowCheck(:readwrite, B_lower)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_lower), FlowCheck(:readwrite, B_unitlower)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_upper), FlowCheck(:readwrite, B_lower), + FlowCheck(:readwrite, B, Diagonal)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_lower), FlowCheck(:readwrite, B_unitlower), + FlowCheck(:readwrite, B, Diagonal), FlowCheck(:readwrite, B_unitlower)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_lower), FlowCheck(:readwrite, B_unitlower), + FlowCheck(:readwrite, B, Diagonal), FlowCheck(:readwrite, B_unitlower), FlowCheck(:readwrite, B_lower)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_upper), FlowCheck(:readwrite, B_unitupper), + FlowCheck(:readwrite, B_unitupper)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_upper), FlowCheck(:readwrite, B_unitupper), + FlowCheck(:readwrite, B, Diagonal), FlowCheck(:readwrite, B_unitupper), FlowCheck(:readwrite, B_upper)) + + # Additional aliasing tests + views_overlap(x, y) = Dagger.will_alias(Dagger.aliasing(x), Dagger.aliasing(y)) + + A = wrap_chunk_thunk(identity, B) + + A_r1 = wrap_chunk_thunk(view, A, 1:1, 1:4) + A_r2 = wrap_chunk_thunk(view, A, 2:2, 1:4) + B_r1 = wrap_chunk_thunk(view, B, 1:1, 1:4) + B_r2 = wrap_chunk_thunk(view, B, 2:2, 1:4) + + A_c1 = wrap_chunk_thunk(view, A, 1:4, 1:1) + A_c2 = wrap_chunk_thunk(view, A, 1:4, 2:2) + B_c1 = wrap_chunk_thunk(view, B, 1:4, 1:1) + B_c2 = wrap_chunk_thunk(view, B, 1:4, 2:2) + + A_mid = wrap_chunk_thunk(view, A, 2:3, 2:3) + B_mid = wrap_chunk_thunk(view, B, 2:3, 2:3) + + @test views_overlap(A_r1, A_r1) + @test views_overlap(B_r1, B_r1) + @test views_overlap(A_c1, A_c1) + @test views_overlap(B_c1, B_c1) + + @test views_overlap(A_r1, B_r1) + @test views_overlap(A_r2, B_r2) + @test views_overlap(A_c1, B_c1) + @test views_overlap(A_c2, B_c2) + + @test !views_overlap(A_r1, A_r2) + @test !views_overlap(B_r1, B_r2) + @test !views_overlap(A_c1, A_c2) + @test !views_overlap(B_c1, B_c2) + + @test views_overlap(A_r1, A_c1) + @test views_overlap(A_r1, B_c1) + @test views_overlap(A_r2, A_c2) + @test views_overlap(A_r2, B_c2) + + for (name, mid) in ((:A_mid, A_mid), (:B_mid, B_mid)) + @test !views_overlap(A_r1, mid) + @test !views_overlap(B_r1, mid) + @test !views_overlap(A_c1, mid) + @test !views_overlap(B_c1, mid) + + @test views_overlap(A_r2, mid) + @test views_overlap(B_r2, mid) + @test views_overlap(A_c2, mid) + @test views_overlap(B_c2, mid) + end + + @test views_overlap(A_mid, A_mid) + @test views_overlap(A_mid, B_mid) + + # SubArray hashing + V = zeros(3) + Dagger.spawn_datadeps() do + Dagger.@spawn mut_V!(InOut(view(V, 1:2))) + Dagger.@spawn mut_V!(InOut(view(V, 2:3))) + end + @test fetch(V) == [1, 1, 1] # FIXME: Deps # Outer Scope - exec_procs = fetch.(Dagger.spawn_datadeps(;aliasing) do + exec_procs = fetch.(Dagger.spawn_datadeps() do [Dagger.@spawn Dagger.task_processor() for i in 1:10] end) unique!(exec_procs) @@ -499,7 +734,7 @@ function test_datadeps(;args_chunks::Bool, end # Inner Scope - @test_throws Dagger.Sch.SchedulingException Dagger.spawn_datadeps(;aliasing) do + @test_throws Dagger.Sch.SchedulingException Dagger.spawn_datadeps() do Dagger.@spawn scope=Dagger.ExactScope(Dagger.ThreadProc(1, 5000)) 1+1 end @@ -528,7 +763,7 @@ function test_datadeps(;args_chunks::Bool, C = Dagger.@spawn scope=Dagger.scope(worker=args_loc) copy(C) D = Dagger.@spawn scope=Dagger.scope(worker=args_loc) copy(D) end - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do Dagger.@spawn add!(InOut(B), In(A)) Dagger.@spawn add!(InOut(C), In(A)) Dagger.@spawn add!(InOut(C), In(B)) @@ -545,7 +780,7 @@ function test_datadeps(;args_chunks::Bool, elseif args_thunks As = map(A->(Dagger.@spawn scope=Dagger.scope(worker=args_loc) copy(A)), As) end - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do to_reduce = Vector[] push!(to_reduce, As) while !isempty(to_reduce) @@ -576,7 +811,7 @@ function test_datadeps(;args_chunks::Bool, elseif args_thunks M = map(m->(Dagger.@spawn scope=Dagger.scope(worker=args_loc) copy(m)), M) end - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do for k in range(1, mt) Dagger.@spawn LAPACK.potrf!('L', InOut(M[k, k])) for _m in range(k+1, mt) @@ -596,18 +831,16 @@ function test_datadeps(;args_chunks::Bool, @test isapprox(M_dense, expected) end -@testset "$(aliasing ? "With" : "Without") Aliasing Support" for aliasing in (true, false) - @testset "$args_mode Data" for args_mode in (:Raw, :Chunk, :Thunk) - args_chunks = args_mode == :Chunk - args_thunks = args_mode == :Thunk - for nw in (1, 2) - args_loc = nw == 2 ? 2 : 1 - for nt in (1, 2) - if nprocs() >= nw && Threads.nthreads() >= nt - @testset "$nw Workers, $nt Threads" begin - Dagger.with_options(;scope=Dagger.scope(workers=1:nw, threads=1:nt)) do - test_datadeps(;args_chunks, args_thunks, args_loc, aliasing) - end +@testset "$args_mode Data" for args_mode in (:Raw, :Chunk, :Thunk) + args_chunks = args_mode == :Chunk + args_thunks = args_mode == :Thunk + for nw in (1, 2) + args_loc = nw == 2 ? 2 : 1 + for nt in (1, 2) + if nprocs() >= nw && Threads.nthreads() >= nt + @testset "$nw Workers, $nt Threads" begin + Dagger.with_options(;scope=Dagger.scope(workers=1:nw, threads=1:nt)) do + test_datadeps(;args_chunks, args_thunks, args_loc) end end end diff --git a/test/scopes.jl b/test/scopes.jl index fa5bf1135..55d15b349 100644 --- a/test/scopes.jl +++ b/test/scopes.jl @@ -123,8 +123,8 @@ us_es1_multi_ch = Dagger.tochunk(nothing, OSProc(), UnionScope(es1, es1)) @test fetch(Dagger.@spawn exact_scope_test(us_es1_multi_ch)) == es1.processor - # No inner scopes - @test UnionScope() isa UnionScope + # No inner scopes (disallowed) + @test_throws ArgumentError UnionScope() # Same inner scope @test fetch(Dagger.@spawn exact_scope_test(us_es1_ch, us_es1_ch)) == es1.processor @@ -165,7 +165,7 @@ @test Dagger.scope(:any) isa AnyScope @test Dagger.scope(:default) == DefaultScope() @test_throws ArgumentError Dagger.scope(:blah) - @test Dagger.scope(()) == UnionScope() + @test_throws ArgumentError Dagger.scope(()) @test Dagger.scope(worker=wid1) == Dagger.scope(workers=[wid1]) diff --git a/test/task-affinity.jl b/test/task-affinity.jl index f1e26295a..ce898b476 100644 --- a/test/task-affinity.jl +++ b/test/task-affinity.jl @@ -135,7 +135,7 @@ @testset "Chunk function, scope, compute_scope and result_scope" begin @everywhere g(x, y) = x * 2 + y * 3 - n = cld(numscopes, 3) + n = fld(numscopes, 3) shuffle!(availscopes) scope_a = availscopes[1:n]