From cadbe02711dcb5d50888a0e081d2aa1a21a11d94 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Sun, 10 Aug 2025 18:10:14 +0000 Subject: [PATCH 01/28] datadeps: Fix views and implement remainder copies This commit fixes two major issues in Datadeps that caused incorrect results with views and ChunkView. First, when presented with arguments that alias (such as an `Array` and a view of that array), `generate_slot` would separately `move` these values onto the destination processor (without considering how they alias with each other), which could break the aliasing that they previously had on their originating processor. This meant that certain algorithms which used views together with the underlying arrays as arguments would get incorrect results in a distributed setting, because `generate_slot` would break aliasing and cause data to not be updated correctly during copies. This commit adds helpers (specifically `aliased_object!`) which allows objects like views and `ChunkView` to declare the underlying parent array as an object that may need to be tracked separately from the surrounding structure; this helper keeps track of other such declared objects that have been allocated on the destination processor, and replaces the source object with the destination object during `move`. By default, all arguments are now provided directly to `aliased_object!` to perform this replacement, but this can be customized by overloading `move_rewrap` (which `SubArray` and `ChunkView` now overload). Secondly, even with objects now properly aliasing on remote processors, Datadeps did not have a clear way to copy only the changed portions of an argument. For example, when only a view of an array is updated on a remote processor, and the next task will then need the full parent array on the same remote processor, how does Datadeps copy over only the portions of the parent array that aren't yet up-to-date on the remote? The answer is that it didn't; it would do a full copy of the parent array to the remote, which would then destroy the changes made to the underlying view. This commit overhauls the copying machinery to properly calculate this difference (termed the "remainder"), based on the target ainfo and all previously-updated ainfos, and schedules a "remainder copy" to copy only the exact bytes that are not yet updated on the remote. Additionally, it may schedule copies from multiple other remote processors to the "target" remote processor as necessary, in case portions of an aliased object exist on multiple distinct processors. This machinery is driven by a new interval tree implementation, which allows efficient calculation of differences between sets of memory spans, and uses `unsafe_copyto!` to handle arbitrary data. --- src/Dagger.jl | 8 +- src/datadeps.jl | 1082 --------------------------------- src/datadeps/aliasing.jl | 689 +++++++++++++++++++++ src/datadeps/chunkview.jl | 64 ++ src/datadeps/interval_tree.jl | 349 +++++++++++ src/datadeps/queue.jl | 500 +++++++++++++++ src/datadeps/remainders.jl | 407 +++++++++++++ src/memory-spaces.jl | 96 +-- src/utils/dagdebug.jl | 25 + 9 files changed, 2074 insertions(+), 1146 deletions(-) delete mode 100644 src/datadeps.jl create mode 100644 src/datadeps/aliasing.jl create mode 100644 src/datadeps/chunkview.jl create mode 100644 src/datadeps/interval_tree.jl create mode 100644 src/datadeps/queue.jl create mode 100644 src/datadeps/remainders.jl diff --git a/src/Dagger.jl b/src/Dagger.jl index fa30c7c1a..987963b34 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -83,7 +83,13 @@ include("utils/caching.jl") include("sch/Sch.jl"); using .Sch # Data dependency task queue -include("datadeps.jl") +include("datadeps/aliasing.jl") +include("datadeps/chunkview.jl") +include("datadeps/interval_tree.jl") +include("datadeps/remainders.jl") +include("datadeps/queue.jl") + +# Stencils include("utils/haloarray.jl") include("stencil.jl") diff --git a/src/datadeps.jl b/src/datadeps.jl deleted file mode 100644 index d20bda647..000000000 --- a/src/datadeps.jl +++ /dev/null @@ -1,1082 +0,0 @@ -import Graphs: SimpleDiGraph, add_edge!, add_vertex!, inneighbors, outneighbors, nv - -export In, Out, InOut, Deps, spawn_datadeps - -"Specifies a read-only dependency." -struct In{T} - x::T -end -"Specifies a write-only dependency." -struct Out{T} - x::T -end -"Specifies a read-write dependency." -struct InOut{T} - x::T -end -"Specifies one or more dependencies." -struct Deps{T,DT<:Tuple} - x::T - deps::DT -end -Deps(x, deps...) = Deps(x, deps) - -struct DataDepsTaskQueue <: AbstractTaskQueue - # The queue above us - upper_queue::AbstractTaskQueue - # The set of tasks that have already been seen - seen_tasks::Union{Vector{Pair{DTaskSpec,DTask}},Nothing} - # The data-dependency graph of all tasks - g::Union{SimpleDiGraph{Int},Nothing} - # The mapping from task to graph ID - task_to_id::Union{Dict{DTask,Int},Nothing} - # How to traverse the dependency graph when launching tasks - traversal::Symbol - # Which scheduler to use to assign tasks to processors - scheduler::Symbol - - # Whether aliasing across arguments is possible - # The fields following only apply when aliasing==true - aliasing::Bool - - function DataDepsTaskQueue(upper_queue; - traversal::Symbol=:inorder, - scheduler::Symbol=:naive, - aliasing::Bool=true) - seen_tasks = Pair{DTaskSpec,DTask}[] - g = SimpleDiGraph() - task_to_id = Dict{DTask,Int}() - return new(upper_queue, seen_tasks, g, task_to_id, traversal, scheduler, - aliasing) - end -end - -function unwrap_inout(arg) - readdep = false - writedep = false - if arg isa In - readdep = true - arg = arg.x - elseif arg isa Out - writedep = true - arg = arg.x - elseif arg isa InOut - readdep = true - writedep = true - arg = arg.x - elseif arg isa Deps - alldeps = Tuple[] - for dep in arg.deps - dep_mod, inner_deps = unwrap_inout(dep) - for (_, readdep, writedep) in inner_deps - push!(alldeps, (dep_mod, readdep, writedep)) - end - end - arg = arg.x - return arg, alldeps - else - readdep = true - end - return arg, Tuple[(identity, readdep, writedep)] -end - -function enqueue!(queue::DataDepsTaskQueue, spec::Pair{DTaskSpec,DTask}) - push!(queue.seen_tasks, spec) -end -function enqueue!(queue::DataDepsTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}) - append!(queue.seen_tasks, specs) -end - -_identity_hash(arg, h::UInt=UInt(0)) = ismutable(arg) ? objectid(arg) : hash(arg, h) -_identity_hash(arg::SubArray, h::UInt=UInt(0)) = hash(arg.indices, hash(arg.offset1, hash(arg.stride1, _identity_hash(arg.parent, h)))) -_identity_hash(arg::CartesianIndices, h::UInt=UInt(0)) = hash(arg.indices, hash(typeof(arg), h)) - -struct ArgumentWrapper - arg - dep_mod - hash::UInt - - function ArgumentWrapper(arg, dep_mod) - h = hash(dep_mod) - h = _identity_hash(arg, h) - return new(arg, dep_mod, h) - end -end -Base.hash(aw::ArgumentWrapper) = hash(ArgumentWrapper, aw.hash) -Base.:(==)(aw1::ArgumentWrapper, aw2::ArgumentWrapper) = - aw1.hash == aw2.hash - -struct DataDepsAliasingState - # Track original and current data locations - # We track data => space - data_origin::Dict{AliasingWrapper,MemorySpace} - data_locality::Dict{AliasingWrapper,MemorySpace} - - # Track writers ("owners") and readers - ainfos_owner::Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}} - ainfos_readers::Dict{AliasingWrapper,Vector{Pair{DTask,Int}}} - ainfos_overlaps::Dict{AliasingWrapper,Set{AliasingWrapper}} - - # Cache ainfo lookups - ainfo_cache::Dict{ArgumentWrapper,AliasingWrapper} - - function DataDepsAliasingState() - data_origin = Dict{AliasingWrapper,MemorySpace}() - data_locality = Dict{AliasingWrapper,MemorySpace}() - - ainfos_owner = Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}}() - ainfos_readers = Dict{AliasingWrapper,Vector{Pair{DTask,Int}}}() - ainfos_overlaps = Dict{AliasingWrapper,Set{AliasingWrapper}}() - - ainfo_cache = Dict{ArgumentWrapper,AliasingWrapper}() - - return new(data_origin, data_locality, - ainfos_owner, ainfos_readers, ainfos_overlaps, - ainfo_cache) - end -end -struct DataDepsNonAliasingState - # Track original and current data locations - # We track data => space - data_origin::IdDict{Any,MemorySpace} - data_locality::IdDict{Any,MemorySpace} - - # Track writers ("owners") and readers - args_owner::IdDict{Any,Union{Pair{DTask,Int},Nothing}} - args_readers::IdDict{Any,Vector{Pair{DTask,Int}}} - - function DataDepsNonAliasingState() - data_origin = IdDict{Any,MemorySpace}() - data_locality = IdDict{Any,MemorySpace}() - - args_owner = IdDict{Any,Union{Pair{DTask,Int},Nothing}}() - args_readers = IdDict{Any,Vector{Pair{DTask,Int}}}() - - return new(data_origin, data_locality, - args_owner, args_readers) - end -end -struct DataDepsState{State<:Union{DataDepsAliasingState,DataDepsNonAliasingState}} - # Whether aliasing is being analyzed - aliasing::Bool - - # The ordered list of tasks and their read/write dependencies - dependencies::Vector{Pair{DTask,Vector{Tuple{Bool,Bool,AliasingWrapper,<:Any,<:Any}}}} - - # The mapping of memory space to remote argument copies - remote_args::Dict{MemorySpace,IdDict{Any,Any}} - - # Cache of whether arguments supports in-place move - supports_inplace_cache::IdDict{Any,Bool} - - # The aliasing analysis state - alias_state::State - - function DataDepsState(aliasing::Bool) - dependencies = Pair{DTask,Vector{Tuple{Bool,Bool,AliasingWrapper,<:Any,<:Any}}}[] - remote_args = Dict{MemorySpace,IdDict{Any,Any}}() - supports_inplace_cache = IdDict{Any,Bool}() - if aliasing - state = DataDepsAliasingState() - else - state = DataDepsNonAliasingState() - end - return new{typeof(state)}(aliasing, dependencies, remote_args, supports_inplace_cache, state) - end -end - -function aliasing(astate::DataDepsAliasingState, arg, dep_mod) - aw = ArgumentWrapper(arg, dep_mod) - get!(astate.ainfo_cache, aw) do - return AliasingWrapper(aliasing(arg, dep_mod)) - end -end - -function supports_inplace_move(state::DataDepsState, arg) - return get!(state.supports_inplace_cache, arg) do - return supports_inplace_move(arg) - end -end - -# Determine which arguments could be written to, and thus need tracking - -"Whether `arg` has any writedep in this datadeps region." -function has_writedep(state::DataDepsState{DataDepsNonAliasingState}, arg, deps) - # Check if we are writing to this memory - writedep = any(dep->dep[3], deps) - if writedep - arg_has_writedep[arg] = true - return true - end - - # Check if another task is writing to this memory - for (_, taskdeps) in state.dependencies - for (_, other_arg_writedep, _, _, other_arg) in taskdeps - other_arg_writedep || continue - if arg === other_arg - return true - end - end - end - - return false -end -""" -Whether `arg` has any writedep at or before executing `task` in this -datadeps region. -""" -function has_writedep(state::DataDepsState, arg, deps, task::DTask) - is_writedep(arg, deps, task) && return true - if state.aliasing - for (other_task, other_taskdeps) in state.dependencies - for (readdep, writedep, other_ainfo, _, _) in other_taskdeps - writedep || continue - for (dep_mod, _, _) in deps - ainfo = aliasing(state.alias_state, arg, dep_mod) - if will_alias(ainfo, other_ainfo) - return true - end - end - end - if task === other_task - return false - end - end - else - for (other_task, other_taskdeps) in state.dependencies - for (readdep, writedep, _, _, other_arg) in other_taskdeps - writedep || continue - if arg === other_arg - return true - end - end - if task === other_task - return false - end - end - end - error("Task isn't in argdeps set") -end -"Whether `arg` is written to by `task`." -function is_writedep(arg, deps, task::DTask) - return any(dep->dep[3], deps) -end - -# Aliasing state setup -function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask) - # Populate task dependencies - dependencies_to_add = Vector{Tuple{Bool,Bool,AliasingWrapper,<:Any,<:Any}}() - - # Track the task's arguments and access patterns - for (idx, _arg) in enumerate(spec.fargs) - # Unwrap In/InOut/Out wrappers and record dependencies - arg, deps = unwrap_inout(value(_arg)) - - # Unwrap the Chunk underlying any DTask arguments - arg = arg isa DTask ? fetch(arg; raw=true) : arg - - # Skip non-aliasing arguments - type_may_alias(typeof(arg)) || continue - - # Add all aliasing dependencies - for (dep_mod, readdep, writedep) in deps - if state.aliasing - ainfo = aliasing(state.alias_state, arg, dep_mod) - else - ainfo = AliasingWrapper(UnknownAliasing()) - end - push!(dependencies_to_add, (readdep, writedep, ainfo, dep_mod, arg)) - end - - # Populate argument write info - populate_argument_info!(state, arg, deps) - end - - # Track the task result too - # N.B. We state no readdep/writedep because, while we can't model the aliasing info for the task result yet, we don't want to synchronize because of this - push!(dependencies_to_add, (false, false, AliasingWrapper(UnknownAliasing()), identity, task)) - - # Record argument/result dependencies - push!(state.dependencies, task => dependencies_to_add) -end -function populate_argument_info!(state::DataDepsState{DataDepsAliasingState}, arg, deps) - astate = state.alias_state - for (dep_mod, readdep, writedep) in deps - ainfo = aliasing(astate, arg, dep_mod) - - # Initialize owner and readers - if !haskey(astate.ainfos_owner, ainfo) - overlaps = Set{AliasingWrapper}() - push!(overlaps, ainfo) - for other_ainfo in keys(astate.ainfos_owner) - ainfo == other_ainfo && continue - if will_alias(ainfo, other_ainfo) - push!(overlaps, other_ainfo) - push!(astate.ainfos_overlaps[other_ainfo], ainfo) - end - end - astate.ainfos_overlaps[ainfo] = overlaps - astate.ainfos_owner[ainfo] = nothing - astate.ainfos_readers[ainfo] = Pair{DTask,Int}[] - end - - # Assign data owner and locality - if !haskey(astate.data_locality, ainfo) - astate.data_locality[ainfo] = memory_space(arg) - astate.data_origin[ainfo] = memory_space(arg) - end - end -end -function populate_argument_info!(state::DataDepsState{DataDepsNonAliasingState}, arg, deps) - astate = state.alias_state - # Initialize owner and readers - if !haskey(astate.args_owner, arg) - astate.args_owner[arg] = nothing - astate.args_readers[arg] = DTask[] - end - - # Assign data owner and locality - if !haskey(astate.data_locality, arg) - astate.data_locality[arg] = memory_space(arg) - astate.data_origin[arg] = memory_space(arg) - end -end -function populate_return_info!(state::DataDepsState{DataDepsAliasingState}, task, space) - astate = state.alias_state - @assert !haskey(astate.data_locality, task) - # FIXME: We don't yet know about ainfos for this task -end -function populate_return_info!(state::DataDepsState{DataDepsNonAliasingState}, task, space) - astate = state.alias_state - @assert !haskey(astate.data_locality, task) - astate.data_locality[task] = space - astate.data_origin[task] = space -end - -""" - supports_inplace_move(x) -> Bool - -Returns `false` if `x` doesn't support being copied into from another object -like `x`, via `move!`. This is used in `spawn_datadeps` to prevent attempting -to copy between values which don't support mutation or otherwise don't have an -implemented `move!` and want to skip in-place copies. When this returns -`false`, datadeps will instead perform out-of-place copies for each non-local -use of `x`, and the data in `x` will not be updated when the `spawn_datadeps` -region returns. -""" -supports_inplace_move(x) = true -supports_inplace_move(t::DTask) = supports_inplace_move(fetch(t; raw=true)) -function supports_inplace_move(c::Chunk) - # FIXME: Use MemPool.access_ref - pid = root_worker_id(c.processor) - if pid == myid() - return supports_inplace_move(poolget(c.handle)) - else - return remotecall_fetch(supports_inplace_move, pid, c) - end -end -supports_inplace_move(::Function) = false - -# Read/write dependency management -function get_write_deps!(state::DataDepsState, ainfo_or_arg, task, write_num, syncdeps) - _get_write_deps!(state, ainfo_or_arg, task, write_num, syncdeps) - _get_read_deps!(state, ainfo_or_arg, task, write_num, syncdeps) -end -function get_read_deps!(state::DataDepsState, ainfo_or_arg, task, write_num, syncdeps) - _get_write_deps!(state, ainfo_or_arg, task, write_num, syncdeps) -end - -function _get_write_deps!(state::DataDepsState{DataDepsAliasingState}, ainfo::AbstractAliasing, task, write_num, syncdeps) - astate = state.alias_state - ainfo.inner isa NoAliasing && return - for other_ainfo in astate.ainfos_overlaps[ainfo] - other_task_write_num = astate.ainfos_owner[other_ainfo] - @dagdebug nothing :spawn_datadeps "Considering sync with writer via $ainfo -> $other_ainfo" - other_task_write_num === nothing && continue - other_task, other_write_num = other_task_write_num - write_num == other_write_num && continue - @dagdebug nothing :spawn_datadeps "Sync with writer via $ainfo -> $other_ainfo" - push!(syncdeps, ThunkSyncdep(other_task)) - end -end -function _get_read_deps!(state::DataDepsState{DataDepsAliasingState}, ainfo::AbstractAliasing, task, write_num, syncdeps) - astate = state.alias_state - ainfo.inner isa NoAliasing && return - for other_ainfo in astate.ainfos_overlaps[ainfo] - @dagdebug nothing :spawn_datadeps "Considering sync with reader via $ainfo -> $other_ainfo" - other_tasks = astate.ainfos_readers[other_ainfo] - for (other_task, other_write_num) in other_tasks - write_num == other_write_num && continue - @dagdebug nothing :spawn_datadeps "Sync with reader via $ainfo -> $other_ainfo" - push!(syncdeps, ThunkSyncdep(other_task)) - end - end -end -function add_writer!(state::DataDepsState{DataDepsAliasingState}, ainfo::AbstractAliasing, task, write_num) - state.alias_state.ainfos_owner[ainfo] = task=>write_num - empty!(state.alias_state.ainfos_readers[ainfo]) - # Not necessary to assert a read, but conceptually it's true - add_reader!(state, ainfo, task, write_num) -end -function add_reader!(state::DataDepsState{DataDepsAliasingState}, ainfo::AbstractAliasing, task, write_num) - push!(state.alias_state.ainfos_readers[ainfo], task=>write_num) -end - -function _get_write_deps!(state::DataDepsState{DataDepsNonAliasingState}, arg, task, write_num, syncdeps) - other_task_write_num = state.alias_state.args_owner[arg] - if other_task_write_num !== nothing - other_task, other_write_num = other_task_write_num - if write_num != other_write_num - push!(syncdeps, ThunkSyncdep(other_task)) - end - end -end -function _get_read_deps!(state::DataDepsState{DataDepsNonAliasingState}, arg, task, write_num, syncdeps) - for (other_task, other_write_num) in state.alias_state.args_readers[arg] - if write_num != other_write_num - push!(syncdeps, ThunkSyncdep(other_task)) - end - end -end -function add_writer!(state::DataDepsState{DataDepsNonAliasingState}, arg, task, write_num) - state.alias_state.args_owner[arg] = task=>write_num - empty!(state.alias_state.args_readers[arg]) - # Not necessary to assert a read, but conceptually it's true - add_reader!(state, arg, task, write_num) -end -function add_reader!(state::DataDepsState{DataDepsNonAliasingState}, arg, task, write_num) - push!(state.alias_state.args_readers[arg], task=>write_num) -end - -# Make a copy of each piece of data on each worker -# memory_space => {arg => copy_of_arg} -isremotehandle(x) = false -isremotehandle(x::DTask) = true -isremotehandle(x::Chunk) = true -function generate_slot!(state::DataDepsState, dest_space, data) - if data isa DTask - data = fetch(data; raw=true) - end - orig_space = memory_space(data) - to_proc = first(processors(dest_space)) - from_proc = first(processors(orig_space)) - dest_space_args = get!(IdDict{Any,Any}, state.remote_args, dest_space) - if orig_space == dest_space && (data isa Chunk || !isremotehandle(data)) - # Fast path for local data or data already in a Chunk - data_chunk = tochunk(data, from_proc) - dest_space_args[data] = data_chunk - @assert processor(data_chunk) in processors(dest_space) || data isa Chunk && processor(data) isa Dagger.OSProc - @assert memory_space(data_chunk) == orig_space - else - to_w = root_worker_id(dest_space) - ctx = Sch.eager_context() - id = rand(Int) - dest_space_args[data] = remotecall_fetch(to_w, from_proc, to_proc, data) do from_proc, to_proc, data - timespan_start(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data)) - data_converted = move(from_proc, to_proc, data) - data_chunk = tochunk(data_converted, to_proc) - @assert processor(data_chunk) in processors(dest_space) - @assert memory_space(data_converted) == memory_space(data_chunk) "space mismatch! $(memory_space(data_converted)) != $(memory_space(data_chunk)) ($(typeof(data_converted)) vs. $(typeof(data_chunk))), spaces ($orig_space -> $dest_space)" - if orig_space != dest_space - @assert orig_space != memory_space(data_chunk) "space preserved! $orig_space != $(memory_space(data_chunk)) ($(typeof(data)) vs. $(typeof(data_chunk))), spaces ($orig_space -> $dest_space)" - end - return data_chunk - end - timespan_finish(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data=dest_space_args[data])) - end - return dest_space_args[data] -end - -struct DataDepsSchedulerState - task_to_spec::Dict{DTask,DTaskSpec} - assignments::Dict{DTask,MemorySpace} - dependencies::Dict{DTask,Set{DTask}} - task_completions::Dict{DTask,UInt64} - space_completions::Dict{MemorySpace,UInt64} - capacities::Dict{MemorySpace,Int} - - function DataDepsSchedulerState() - return new(Dict{DTask,DTaskSpec}(), - Dict{DTask,MemorySpace}(), - Dict{DTask,Set{DTask}}(), - Dict{DTask,UInt64}(), - Dict{MemorySpace,UInt64}(), - Dict{MemorySpace,Int}()) - end -end - -function distribute_tasks!(queue::DataDepsTaskQueue) - #= TODO: Improvements to be made: - # - Support for copying non-AbstractArray arguments - # - Parallelize read copies - # - Unreference unused slots - # - Reuse memory when possible - # - Account for differently-sized data - =# - - # Get the set of all processors to be scheduled on - all_procs = Processor[] - scope = get_compute_scope() - for w in procs() - append!(all_procs, get_processors(OSProc(w))) - end - filter!(proc->!isa(constrain(ExactScope(proc), scope), - InvalidScope), - all_procs) - if isempty(all_procs) - throw(Sch.SchedulingException("No processors available, try widening scope")) - end - exec_spaces = unique(vcat(map(proc->collect(memory_spaces(proc)), all_procs)...)) - if !all(space->space isa CPURAMMemorySpace, exec_spaces) && !all(space->root_worker_id(space) == myid(), exec_spaces) - @warn "Datadeps support for multi-GPU, multi-worker is currently broken\nPlease be prepared for incorrect results or errors" maxlog=1 - end - - # Round-robin assign tasks to processors - upper_queue = get_options(:task_queue) - - traversal = queue.traversal - if traversal == :inorder - # As-is - task_order = Colon() - elseif traversal == :bfs - # BFS - task_order = Int[1] - to_walk = Int[1] - seen = Set{Int}([1]) - while !isempty(to_walk) - # N.B. next_root has already been seen - next_root = popfirst!(to_walk) - for v in outneighbors(queue.g, next_root) - if !(v in seen) - push!(task_order, v) - push!(seen, v) - push!(to_walk, v) - end - end - end - elseif traversal == :dfs - # DFS (modified with backtracking) - task_order = Int[] - to_walk = Int[1] - seen = Set{Int}() - while length(task_order) < length(queue.seen_tasks) && !isempty(to_walk) - next_root = popfirst!(to_walk) - if !(next_root in seen) - iv = inneighbors(queue.g, next_root) - if all(v->v in seen, iv) - push!(task_order, next_root) - push!(seen, next_root) - ov = outneighbors(queue.g, next_root) - prepend!(to_walk, ov) - else - push!(to_walk, next_root) - end - end - end - else - throw(ArgumentError("Invalid traversal mode: $traversal")) - end - - state = DataDepsState(queue.aliasing) - astate = state.alias_state - sstate = DataDepsSchedulerState() - for proc in all_procs - space = only(memory_spaces(proc)) - get!(()->0, sstate.capacities, space) - sstate.capacities[space] += 1 - end - - # Start launching tasks and necessary copies - write_num = 1 - proc_idx = 1 - pressures = Dict{Processor,Int}() - proc_to_scope_lfu = BasicLFUCache{Processor,AbstractScope}(1024) - for (spec, task) in queue.seen_tasks[task_order] - # Populate all task dependencies - populate_task_info!(state, spec, task) - - task_scope = @something(spec.options.compute_scope, spec.options.scope, DefaultScope()) - scheduler = queue.scheduler - if scheduler == :naive - raw_args = map(arg->tochunk(value(arg)), spec.fargs) - our_proc = remotecall_fetch(1, all_procs, raw_args) do all_procs, raw_args - Sch.init_eager() - sch_state = Sch.EAGER_STATE[] - - @lock sch_state.lock begin - # Calculate costs per processor and select the most optimal - # FIXME: This should consider any already-allocated slots, - # whether they are up-to-date, and if not, the cost of moving - # data to them - procs, costs = Sch.estimate_task_costs(sch_state, all_procs, nothing, raw_args) - return first(procs) - end - end - elseif scheduler == :smart - raw_args = map(filter(arg->haskey(astate.data_locality, value(arg)), spec.fargs)) do arg - arg_chunk = tochunk(last(arg)) - # Only the owned slot is valid - # FIXME: Track up-to-date copies and pass all of those - return arg_chunk => data_locality[arg] - end - f_chunk = tochunk(value(f)) - our_proc, task_pressure = remotecall_fetch(1, all_procs, pressures, f_chunk, raw_args) do all_procs, pressures, f, chunks_locality - Sch.init_eager() - sch_state = Sch.EAGER_STATE[] - - @lock sch_state.lock begin - tx_rate = sch_state.transfer_rate[] - - costs = Dict{Processor,Float64}() - for proc in all_procs - # Filter out chunks that are already local - chunks_filt = Iterators.filter(((chunk, space)=chunk_locality)->!(proc in processors(space)), chunks_locality) - - # Estimate network transfer costs based on data size - # N.B. `affinity(x)` really means "data size of `x`" - # N.B. We treat same-worker transfers as having zero transfer cost - tx_cost = Sch.impute_sum(affinity(chunk)[2] for chunk in chunks_filt) - - # Estimate total cost to move data and get task running after currently-scheduled tasks - est_time_util = get(pressures, proc, UInt64(0)) - costs[proc] = est_time_util + (tx_cost/tx_rate) - end - - # Look up estimated task cost - sig = Sch.signature(sch_state, f, map(first, chunks_locality)) - task_pressure = get(sch_state.signature_time_cost, sig, 1000^3) - - # Shuffle procs around, so equally-costly procs are equally considered - P = randperm(length(all_procs)) - procs = getindex.(Ref(all_procs), P) - - # Sort by lowest cost first - sort!(procs, by=p->costs[p]) - - best_proc = first(procs) - return best_proc, task_pressure - end - end - # FIXME: Pressure should be decreased by pressure of syncdeps on same processor - pressures[our_proc] = get(pressures, our_proc, UInt64(0)) + task_pressure - elseif scheduler == :ultra - args = Base.mapany(spec.fargs) do arg - pos, data = arg - data, _ = unwrap_inout(data) - if data isa DTask - data = fetch(data; raw=true) - end - return pos => tochunk(data) - end - f_chunk = tochunk(value(f)) - task_time = remotecall_fetch(1, f_chunk, args) do f, args - Sch.init_eager() - sch_state = Sch.EAGER_STATE[] - return @lock sch_state.lock begin - sig = Sch.signature(sch_state, f, args) - return get(sch_state.signature_time_cost, sig, 1000^3) - end - end - - # FIXME: Copy deps are computed eagerly - deps = get(Set{Any}, spec.options, :syncdeps) - - # Find latest time-to-completion of all syncdeps - deps_completed = UInt64(0) - for dep in deps - haskey(sstate.task_completions, dep) || continue # copy deps aren't recorded - deps_completed = max(deps_completed, sstate.task_completions[dep]) - end - - # Find latest time-to-completion of each memory space - # FIXME: Figure out space completions based on optimal packing - spaces_completed = Dict{MemorySpace,UInt64}() - for space in exec_spaces - completed = UInt64(0) - for (task, other_space) in sstate.assignments - space == other_space || continue - completed = max(completed, sstate.task_completions[task]) - end - spaces_completed[space] = completed - end - - # Choose the earliest-available memory space and processor - # FIXME: Consider move time - move_time = UInt64(0) - local our_space_completed - while true - our_space_completed, our_space = findmin(spaces_completed) - our_space_procs = filter(proc->proc in all_procs, processors(our_space)) - if isempty(our_space_procs) - delete!(spaces_completed, our_space) - continue - end - our_proc = rand(our_space_procs) - break - end - - sstate.task_to_spec[task] = spec - sstate.assignments[task] = our_space - sstate.task_completions[task] = our_space_completed + move_time + task_time - elseif scheduler == :roundrobin - our_proc = all_procs[proc_idx] - if task_scope == scope - # all_procs is already limited to scope - else - if isa(constrain(task_scope, scope), InvalidScope) - throw(Sch.SchedulingException("Scopes are not compatible: $(scope), $(task_scope)")) - end - while !proc_in_scope(our_proc, task_scope) - proc_idx = mod1(proc_idx + 1, length(all_procs)) - our_proc = all_procs[proc_idx] - end - end - else - error("Invalid scheduler: $sched") - end - @assert our_proc in all_procs - our_space = only(memory_spaces(our_proc)) - - # Find the scope for this task (and its copies) - if task_scope == scope - # Optimize for the common case, cache the proc=>scope mapping - our_scope = get!(proc_to_scope_lfu, our_proc) do - our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) - return constrain(UnionScope(map(ExactScope, our_procs)...), scope) - end - else - # Use the provided scope and constrain it to the available processors - our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) - our_scope = constrain(UnionScope(map(ExactScope, our_procs)...), task_scope) - end - if our_scope isa InvalidScope - throw(Sch.SchedulingException("Scopes are not compatible: $(our_scope.x), $(our_scope.y)")) - end - - f = spec.fargs[1] - f.value = move(ThreadProc(myid(), 1), our_proc, value(f)) - @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) Scheduling: $our_proc ($our_space)" - - # Copy raw task arguments for analysis - task_args = map(copy, spec.fargs) - - # Copy args from local to remote - for (idx, _arg) in enumerate(task_args) - # Is the data writeable? - arg, deps = unwrap_inout(value(_arg)) - arg = arg isa DTask ? fetch(arg; raw=true) : arg - if !type_may_alias(typeof(arg)) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Skipped copy-to (unwritten)" - spec.fargs[idx].value = arg - continue - end - if !supports_inplace_move(state, arg) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Skipped copy-to (non-writeable)" - spec.fargs[idx].value = arg - continue - end - - # Is the source of truth elsewhere? - arg_remote = get!(get!(IdDict{Any,Any}, state.remote_args, our_space), arg) do - generate_slot!(state, our_space, arg) - end - if queue.aliasing - for (dep_mod, _, _) in deps - ainfo = aliasing(astate, arg, dep_mod) - data_space = astate.data_locality[ainfo] - nonlocal = our_space != data_space - if nonlocal - # Add copy-to operation (depends on latest owner of arg) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] Enqueueing copy-to: $data_space => $our_space" - arg_local = get!(get!(IdDict{Any,Any}, state.remote_args, data_space), arg) do - generate_slot!(state, data_space, arg) - end - copy_to_scope = our_scope - copy_to_syncdeps = Set{ThunkSyncdep}() - get_write_deps!(state, ainfo, task, write_num, copy_to_syncdeps) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] $(length(copy_to_syncdeps)) syncdeps" - copy_to = Dagger.@spawn scope=copy_to_scope exec_scope=copy_to_scope syncdeps=copy_to_syncdeps meta=true Dagger.move!(dep_mod, our_space, data_space, arg_remote, arg_local) - add_writer!(state, ainfo, copy_to, write_num) - - astate.data_locality[ainfo] = our_space - else - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] Skipped copy-to (local): $data_space" - end - end - else - data_space = astate.data_locality[arg] - nonlocal = our_space != data_space - if nonlocal - # Add copy-to operation (depends on latest owner of arg) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Enqueueing copy-to: $data_space => $our_space" - arg_local = get!(get!(IdDict{Any,Any}, state.remote_args, data_space), arg) do - generate_slot!(state, data_space, arg) - end - copy_to_scope = our_scope - copy_to_syncdeps = Set{ThunkSyncdep}() - get_write_deps!(state, arg, task, write_num, copy_to_syncdeps) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] $(length(copy_to_syncdeps)) syncdeps" - copy_to = Dagger.@spawn scope=copy_to_scope exec_scope=copy_to_scope syncdeps=copy_to_syncdeps Dagger.move!(identity, our_space, data_space, arg_remote, arg_local) - add_writer!(state, arg, copy_to, write_num) - - astate.data_locality[arg] = our_space - else - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Skipped copy-to (local): $data_space" - end - end - spec.fargs[idx].value = arg_remote - end - write_num += 1 - - # Validate that we're not accidentally performing a copy - for (idx, _arg) in enumerate(spec.fargs) - _, deps = unwrap_inout(value(task_args[idx])) - # N.B. We only do this check when the argument supports in-place - # moves, because for the moment, we are not guaranteeing updates or - # write-back of results - arg = value(_arg) - if is_writedep(arg, deps, task) && supports_inplace_move(state, arg) - arg_space = memory_space(arg) - @assert arg_space == our_space "($(repr(value(f))))[$idx] Tried to pass $(typeof(arg)) from $arg_space to $our_space" - end - end - - # Calculate this task's syncdeps - if spec.options.syncdeps === nothing - spec.options.syncdeps = Set{ThunkSyncdep}() - end - syncdeps = spec.options.syncdeps - for (idx, (_, arg)) in enumerate(task_args) - arg, deps = unwrap_inout(arg) - arg = arg isa DTask ? fetch(arg; raw=true) : arg - type_may_alias(typeof(arg)) || continue - supports_inplace_move(state, arg) || continue - if queue.aliasing - for (dep_mod, _, writedep) in deps - ainfo = aliasing(astate, arg, dep_mod) - if writedep - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] Syncing as writer" - get_write_deps!(state, ainfo, task, write_num, syncdeps) - else - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] Syncing as reader" - get_read_deps!(state, ainfo, task, write_num, syncdeps) - end - end - else - if is_writedep(arg, deps, task) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Syncing as writer" - get_write_deps!(state, arg, task, write_num, syncdeps) - else - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Syncing as reader" - get_read_deps!(state, arg, task, write_num, syncdeps) - end - end - end - @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) $(length(syncdeps)) syncdeps" - - # Launch user's task - spec.options.scope = our_scope - spec.options.exec_scope = our_scope - enqueue!(upper_queue, spec=>task) - - # Update read/write tracking for arguments - for (idx, (_, arg)) in enumerate(task_args) - arg, deps = unwrap_inout(arg) - arg = arg isa DTask ? fetch(arg; raw=true) : arg - type_may_alias(typeof(arg)) || continue - if queue.aliasing - for (dep_mod, _, writedep) in deps - ainfo = aliasing(astate, arg, dep_mod) - if writedep - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] Set as owner" - add_writer!(state, ainfo, task, write_num) - else - add_reader!(state, ainfo, task, write_num) - end - end - else - if is_writedep(arg, deps, task) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Set as owner" - add_writer!(state, arg, task, write_num) - else - add_reader!(state, arg, task, write_num) - end - end - end - - # Update tracking for return value - populate_return_info!(state, task, our_space) - - write_num += 1 - proc_idx = mod1(proc_idx + 1, length(all_procs)) - end - - # Copy args from remote to local - if queue.aliasing - # We need to replay the writes from all tasks in-order (skipping any - # outdated write owners), to ensure that overlapping writes are applied - # in the correct order - - # First, find the latest owners of each live ainfo - arg_writes = IdDict{Any,Vector{Tuple{AliasingWrapper,<:Any,MemorySpace}}}() - for (task, taskdeps) in state.dependencies - for (_, writedep, ainfo, dep_mod, arg) in taskdeps - writedep || continue - haskey(astate.data_locality, ainfo) || continue - @assert haskey(astate.ainfos_owner, ainfo) "Missing ainfo: $ainfo ($dep_mod($(typeof(arg))))" - - # Skip virtual writes from task result aliasing - # FIXME: Make this less bad - if arg isa DTask && dep_mod === identity && ainfo.inner isa UnknownAliasing - continue - end - - # Skip non-writeable arguments - if !supports_inplace_move(state, arg) - @dagdebug nothing :spawn_datadeps "Skipped copy-from (non-writeable)" - continue - end - - # Get the set of writers - ainfo_writes = get!(Vector{Tuple{AliasingWrapper,<:Any,MemorySpace}}, arg_writes, arg) - - #= FIXME: If we fully overlap any writer, evict them - idxs = findall(ainfo_write->overlaps_all(ainfo, ainfo_write[1]), ainfo_writes) - deleteat!(ainfo_writes, idxs) - =# - - # Make ourselves the latest writer - push!(ainfo_writes, (ainfo, dep_mod, astate.data_locality[ainfo])) - end - end - - # Then, replay the writes from each owner in-order - # FIXME: write_num should advance across overlapping ainfo's, as - # writes must be ordered sequentially - for (arg, ainfo_writes) in arg_writes - if length(ainfo_writes) > 1 - # FIXME: Remove me - deleteat!(ainfo_writes, 1:length(ainfo_writes)-1) - end - for (ainfo, dep_mod, data_remote_space) in ainfo_writes - # Is the source of truth elsewhere? - data_local_space = astate.data_origin[ainfo] - if data_local_space != data_remote_space - # Add copy-from operation - @dagdebug nothing :spawn_datadeps "[$dep_mod] Enqueueing copy-from: $data_remote_space => $data_local_space" - arg_local = get!(get!(IdDict{Any,Any}, state.remote_args, data_local_space), arg) do - generate_slot!(state, data_local_space, arg) - end - arg_remote = state.remote_args[data_remote_space][arg] - @assert arg_remote !== arg_local - data_local_proc = first(processors(data_local_space)) - copy_from_scope = UnionScope(map(ExactScope, collect(processors(data_local_space)))...) - copy_from_syncdeps = Set{ThunkSyncdep}() - get_write_deps!(state, ainfo, nothing, write_num, copy_from_syncdeps) - @dagdebug nothing :spawn_datadeps "$(length(copy_from_syncdeps)) syncdeps" - copy_from = Dagger.@spawn scope=copy_from_scope exec_scope=copy_from_scope syncdeps=copy_from_syncdeps meta=true Dagger.move!(dep_mod, data_local_space, data_remote_space, arg_local, arg_remote) - else - @dagdebug nothing :spawn_datadeps "[$dep_mod] Skipped copy-from (local): $data_remote_space" - end - end - end - else - for arg in keys(astate.data_origin) - # Is the data previously written? - arg, deps = unwrap_inout(arg) - if !type_may_alias(typeof(arg)) - @dagdebug nothing :spawn_datadeps "Skipped copy-from (immutable)" - end - - # Can the data be written back to? - if !supports_inplace_move(state, arg) - @dagdebug nothing :spawn_datadeps "Skipped copy-from (non-writeable)" - end - - # Is the source of truth elsewhere? - data_remote_space = astate.data_locality[arg] - data_local_space = astate.data_origin[arg] - if data_local_space != data_remote_space - # Add copy-from operation - @dagdebug nothing :spawn_datadeps "Enqueueing copy-from: $data_remote_space => $data_local_space" - arg_local = state.remote_args[data_local_space][arg] - arg_remote = state.remote_args[data_remote_space][arg] - @assert arg_remote !== arg_local - data_local_proc = first(processors(data_local_space)) - copy_from_scope = ExactScope(data_local_proc) - copy_from_syncdeps = Set{ThunkSyncdep}() - get_write_deps!(state, arg, nothing, write_num, copy_from_syncdeps) - @dagdebug nothing :spawn_datadeps "$(length(copy_from_syncdeps)) syncdeps" - copy_from = Dagger.@spawn scope=copy_from_scope exec_scope=copy_from_scope syncdeps=copy_from_syncdeps meta=true Dagger.move!(identity, data_local_space, data_remote_space, arg_local, arg_remote) - else - @dagdebug nothing :spawn_datadeps "Skipped copy-from (local): $data_remote_space" - end - end - end -end - -""" - spawn_datadeps(f::Base.Callable; traversal::Symbol=:inorder) - -Constructs a "datadeps" (data dependencies) region and calls `f` within it. -Dagger tasks launched within `f` may wrap their arguments with `In`, `Out`, or -`InOut` to indicate whether the task will read, write, or read+write that -argument, respectively. These argument dependencies will be used to specify -which tasks depend on each other based on the following rules: - -- Dependencies across unrelated arguments are independent; only dependencies on arguments which overlap in memory synchronize with each other -- `InOut` is the same as `In` and `Out` applied simultaneously, and synchronizes with the union of the `In` and `Out` effects -- Any two or more `In` dependencies do not synchronize with each other, and may execute in parallel -- An `Out` dependency synchronizes with any previous `In` and `Out` dependencies -- An `In` dependency synchronizes with any previous `Out` dependencies -- If unspecified, an `In` dependency is assumed - -In general, the result of executing tasks following the above rules will be -equivalent to simply executing tasks sequentially and in order of submission. -Of course, if dependencies are incorrectly specified, undefined behavior (and -unexpected results) may occur. - -Unlike other Dagger tasks, tasks executed within a datadeps region are allowed -to write to their arguments when annotated with `Out` or `InOut` -appropriately. - -At the end of executing `f`, `spawn_datadeps` will wait for all launched tasks -to complete, rethrowing the first error, if any. The result of `f` will be -returned from `spawn_datadeps`. - -The keyword argument `traversal` controls the order that tasks are launched by -the scheduler, and may be set to `:bfs` or `:dfs` for Breadth-First Scheduling -or Depth-First Scheduling, respectively. All traversal orders respect the -dependencies and ordering of the launched tasks, but may provide better or -worse performance for a given set of datadeps tasks. This argument is -experimental and subject to change. -""" -function spawn_datadeps(f::Base.Callable; static::Bool=true, - traversal::Symbol=:inorder, - scheduler::Union{Symbol,Nothing}=nothing, - aliasing::Bool=true, - launch_wait::Union{Bool,Nothing}=nothing) - if !static - throw(ArgumentError("Dynamic scheduling is no longer available")) - end - wait_all(; check_errors=true) do - scheduler = something(scheduler, DATADEPS_SCHEDULER[], :roundrobin)::Symbol - launch_wait = something(launch_wait, DATADEPS_LAUNCH_WAIT[], false)::Bool - if launch_wait - result = spawn_bulk() do - queue = DataDepsTaskQueue(get_options(:task_queue); - traversal, scheduler, aliasing) - with_options(f; task_queue=queue) - distribute_tasks!(queue) - end - else - queue = DataDepsTaskQueue(get_options(:task_queue); - traversal, scheduler, aliasing) - result = with_options(f; task_queue=queue) - distribute_tasks!(queue) - end - return result - end -end -const DATADEPS_SCHEDULER = ScopedValue{Union{Symbol,Nothing}}(nothing) -const DATADEPS_LAUNCH_WAIT = ScopedValue{Union{Bool,Nothing}}(nothing) diff --git a/src/datadeps/aliasing.jl b/src/datadeps/aliasing.jl new file mode 100644 index 000000000..833e3fa99 --- /dev/null +++ b/src/datadeps/aliasing.jl @@ -0,0 +1,689 @@ +import Graphs: SimpleDiGraph, add_edge!, add_vertex!, inneighbors, outneighbors, nv + +export In, Out, InOut, Deps, spawn_datadeps + +#= +============================================================================== + DATADEPS ALIASING AND DATA MOVEMENT SYSTEM +============================================================================== + +This file implements the data dependencies system for Dagger tasks, which allows +tasks to access their arguments in a controlled manner. The system maintains +data coherency across distributed workers by tracking aliasing relationships +and orchestrating data movement operations. + +OVERVIEW: +--------- +The datadeps system enables parallel execution of tasks that modify shared data +by analyzing memory aliasing relationships and scheduling appropriate data +transfers. The core challenge is maintaining coherency when aliased data (e.g., +an array and its views) needs to be accessed by tasks running on different workers. + +KEY CONCEPTS: +------------- + +1. ALIASING ANALYSIS: + - Every mutable argument is analyzed for its memory access pattern + - Memory spans are computed to determine which bytes in memory are accessed + - Objects that access overlapping memory spans are considered "aliasing" + - Examples: An array A and view(A, 2:3, 2:3) alias each other + +2. DATA LOCALITY TRACKING: + - The system tracks where the "source of truth" for each piece of data lives + - As tasks execute and modify data, the source of truth may move between workers + - Each aliasing region can have its own independent source of truth location + +3. ALIASED OBJECT MANAGEMENT: + - When copying arguments between workers, the system tracks "aliased objects" + - This ensures that if both an array and its view need to be copied to a worker, + only one copy of the underlying array is made, with the view pointing to it + - The aliased_object!() functions manage this sharing + +ALIASING INFO: +-------------- + +The system uses different types of aliasing info to represent different types of +aliasing relationships: + +- ContiguousAliasing: Single contiguous memory region (e.g., full array) +- StridedAliasing: Multiple non-contiguous regions (e.g., SubArray) +- DiagonalAliasing: Diagonal elements only (e.g., Diagonal(A)) +- TriangularAliasing: Triangular regions (e.g., UpperTriangular(A)) + +Any two aliasing objects can be compared using the will_alias function to +determine if they overlap. Additionally, any aliasing object can be converted to +a vector of memory spans, which represents the contiguous regions of memory that +the aliasing object covers. + +DATA MOVEMENT FUNCTIONS: +------------------------ + +move!(dep_mod, to_space, from_space, to, from): +- The core in-place data movement function +- dep_mod specifies which part of the data to copy (identity, UpperTriangular, etc.) +- Supports partial copies via RemainderAliasing dependency modifiers + +move_rewrap(...): +- Handles copying of wrapped objects (SubArrays, ChunkViews) +- Ensures aliased objects are reused on destination worker + +read/write_remainder!(...): +- Read/write a span of memory from an object to/from a buffer +- Used by move! to copy the remainder of an aliased object + +THE DISTRIBUTED ALIASING PROBLEM: +--------------------------------- + +In a multithreaded environment, aliasing "just works" because all tasks operate +on the user-provided memory. However, in a distributed environment, arguments +must be copied between workers, which breaks aliasing relationships if care is +not taken. + +Consider this scenario: +```julia +A = rand(4, 4) +vA = view(A, 2:3, 2:3) + +Dagger.spawn_datadeps() do + Dagger.@spawn inc!(InOut(A), 1) # Task 1: increment all of A + Dagger.@spawn inc!(InOut(vA), 2) # Task 2: increment view of A +end +``` + +MULTITHREADED BEHAVIOR (WORKS): +- Both tasks run on the same worker +- They operate on the same memory, with proper dependency tracking +- Task dependencies ensure correct ordering (e.g., Task 1 then Task 2) + +DISTRIBUTED BEHAVIOR (THE PROBLEM): +- Tasks may be scheduled on different workers +- Each argument must be copied to the destination worker +- Without special handling, we would copy A to worker1 and vA to worker2 +- This creates two separate arrays, breaking the aliasing relationship +- Updates to the view on worker2 don't affect the array on worker1 + +THE SOLUTION - PARTIAL DATA MOVEMENT: +------------------------------------- + +The datadeps system solves this by: + +1. UNIFIED ALLOCATION: + - When copying aliased objects, ensure only one underlying array exists per worker + - Use aliased_object!() to detect and reuse existing allocations + - Views on the destination worker point to the shared underlying array + +2. PARTIAL DATA TRANSFER: + - Instead of copying entire objects, only transfer the "dirty" regions + - This prevents overwrites of data that has already been updated by another task + - This also minimizes network traffic and overall copy time + - Uses the move!(dep_mod, ...) function with RemainderAliasing dependency modifiers + +3. REMAINDER TRACKING: + - When a task needs the full object, copy partial regions as needed + - When a partial region is updated, track what parts still need updating + - This preserves all updates while avoiding overwrites + +EXAMPLE EXECUTION FLOW: +----------------------- + +Given: A = 4x4 array, vA = view(A, 2:3, 2:3) +Tasks: T1 modifies InOut(A), T2 modifies InOut(vA) + +1. INITIAL STATE: + - A and vA both exist on worker0 (main worker) + - A's data_locality = worker0, vA's data_locality = worker0 + +2. T1 SCHEDULED ON WORKER1: + - Copy A from worker0 to worker1 + - T1 executes, modifying all of A on worker1 + - Update: A's data_locality = worker1, A is now "dirty" on worker1 + +3. T2 SCHEDULED ON WORKER2: + - T2 needs vA, but vA aliases with A (which was modified by T1) + - Copy vA-region of A from worker1 to worker2 + - This is a PARTIAL copy - only the 2:3, 2:3 region + - Create vA on worker2 pointing to the appropriate region of A + - T2 executes, modifying vA region on worker2 + - Update: vA's data_locality = worker2 + +4. FINAL SYNCHRONIZATION: + - Need to copy-back A and vA to worker0 + - A needs to be assembled from: worker1 (non-vA regions of A) + worker2 (vA region of A) + - REMAINDER COPY: Copy non-vA regions from worker1 to worker0 + - REMAINDER COPY: Copy vA region from worker2 to worker0 + +REMAINDER COMPUTATION: +---------------------- + +Remainder computation involves: +1. Computing memory spans for all overlapping aliasing objects +2. Finding the set difference: full_object_spans - updated_spans +3. Creating a RemainderAliasing object representing the difference between spans +4. Performing one or more move! calls with this RemainderAliasing object to copy only needed data +=# + +"Specifies a read-only dependency." +struct In{T} + x::T +end +"Specifies a write-only dependency." +struct Out{T} + x::T +end +"Specifies a read-write dependency." +struct InOut{T} + x::T +end +"Specifies one or more dependencies." +struct Deps{T,DT<:Tuple} + x::T + deps::DT +end +Deps(x, deps...) = Deps(x, deps) + +function unwrap_inout(arg) + readdep = false + writedep = false + if arg isa In + readdep = true + arg = arg.x + elseif arg isa Out + writedep = true + arg = arg.x + elseif arg isa InOut + readdep = true + writedep = true + arg = arg.x + elseif arg isa Deps + alldeps = Tuple[] + for dep in arg.deps + dep_mod, inner_deps = unwrap_inout(dep) + for (_, readdep, writedep) in inner_deps + push!(alldeps, (dep_mod, readdep, writedep)) + end + end + arg = arg.x + return arg, alldeps + else + readdep = true + end + return arg, Tuple[(identity, readdep, writedep)] +end + +_identity_hash(arg, h::UInt=UInt(0)) = ismutable(arg) ? objectid(arg) : hash(arg, h) +_identity_hash(arg::Chunk, h::UInt=UInt(0)) = hash(arg.handle, hash(Chunk, h)) +_identity_hash(arg::SubArray, h::UInt=UInt(0)) = hash(arg.indices, hash(arg.offset1, hash(arg.stride1, _identity_hash(arg.parent, h)))) +_identity_hash(arg::CartesianIndices, h::UInt=UInt(0)) = hash(arg.indices, hash(typeof(arg), h)) + +struct ArgumentWrapper + arg + dep_mod + hash::UInt + + function ArgumentWrapper(arg, dep_mod) + h = hash(dep_mod) + h = _identity_hash(arg, h) + return new(arg, dep_mod, h) + end +end +Base.hash(aw::ArgumentWrapper) = hash(ArgumentWrapper, aw.hash) +Base.:(==)(aw1::ArgumentWrapper, aw2::ArgumentWrapper) = + aw1.hash == aw2.hash +Base.isequal(aw1::ArgumentWrapper, aw2::ArgumentWrapper) = + aw1.hash == aw2.hash + +struct HistoryEntry + ainfo::AliasingWrapper + space::MemorySpace + write_num::Int +end + +struct DataDepsState + # The mapping of original raw argument to its Chunk + raw_arg_to_chunk::IdDict{Any,Chunk} + + # The origin memory space of each argument + # Used to track the original location of an argument, for final copy-from + arg_origin::IdDict{Any,MemorySpace} + + # The mapping of memory space to argument to remote argument copies + # Used to replace an argument with its remote copy + remote_args::Dict{MemorySpace,IdDict{Any,Chunk}} + + # The mapping of remote argument to original argument + remote_arg_to_original::IdDict{Any,Any} + + # The mapping of ainfo to argument and dep_mod + # Used to lookup which argument and dep_mod a given ainfo is generated from + # N.B. This is a mapping for remote argument copies + ainfo_arg::Dict{AliasingWrapper,ArgumentWrapper} + + # The history of writes (direct or indirect) to each argument and dep_mod, in terms of ainfos directly written to, and the memory space they were written to + # Updated when a new write happens on an overlapping ainfo + # Used by remainder copies to track which portions of an argument and dep_mod were written to elsewhere, through another argument + arg_history::Dict{ArgumentWrapper,Vector{HistoryEntry}} + + # The mapping of memory space and argument to the memory space of the last direct write + # Used by remainder copies to lookup the "backstop" if any portion of the target ainfo is not updated by the remainder + arg_owner::Dict{ArgumentWrapper,MemorySpace} + + # The overlap of each argument with every other argument, based on the ainfo overlaps + # Incrementally updated as new ainfos are created + # Used for fast history updates + arg_overlaps::Dict{ArgumentWrapper,Set{ArgumentWrapper}} + + # The mapping of, for a given memory space, the backing Chunks that an ainfo references + # Used by slot generation to replace the backing Chunks during move + ainfo_backing_chunk::Dict{MemorySpace,Dict{AbstractAliasing,Chunk}} + + # Cache of argument's supports_inplace_move query result + supports_inplace_cache::IdDict{Any,Bool} + + # Cache of argument and dep_mod to ainfo + # N.B. This is a mapping for remote argument copies + ainfo_cache::Dict{ArgumentWrapper,AliasingWrapper} + + # The overlapping ainfos for each ainfo + # Incrementally updated as new ainfos are created + # Used for fast will_alias lookups + ainfos_overlaps::Dict{AliasingWrapper,Set{AliasingWrapper}} + + # Track writers ("owners") and readers + # Updated as new writer and reader tasks are launched + # Used by task dependency tracking to calculate syncdeps and ensure correct launch ordering + ainfos_owner::Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}} + ainfos_readers::Dict{AliasingWrapper,Vector{Pair{DTask,Int}}} + + function DataDepsState(aliasing::Bool) + if !aliasing + @warn "aliasing=false is no longer supported, aliasing is now always enabled" maxlog=1 + end + + arg_to_chunk = IdDict{Any,Chunk}() + arg_origin = IdDict{Any,MemorySpace}() + remote_args = Dict{MemorySpace,IdDict{Any,Any}}() + remote_arg_to_original = IdDict{Any,Any}() + ainfo_arg = Dict{AliasingWrapper,ArgumentWrapper}() + arg_owner = Dict{ArgumentWrapper,MemorySpace}() + arg_overlaps = Dict{ArgumentWrapper,Set{ArgumentWrapper}}() + ainfo_backing_chunk = Dict{MemorySpace,Dict{AbstractAliasing,Chunk}}() + arg_history = Dict{ArgumentWrapper,Vector{HistoryEntry}}() + + supports_inplace_cache = IdDict{Any,Bool}() + ainfo_cache = Dict{ArgumentWrapper,AliasingWrapper}() + + ainfos_overlaps = Dict{AliasingWrapper,Set{AliasingWrapper}}() + + ainfos_owner = Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}}() + ainfos_readers = Dict{AliasingWrapper,Vector{Pair{DTask,Int}}}() + + return new(arg_to_chunk, arg_origin, remote_args, remote_arg_to_original, ainfo_arg, arg_owner, arg_overlaps, ainfo_backing_chunk, arg_history, + supports_inplace_cache, ainfo_cache, ainfos_overlaps, ainfos_owner, ainfos_readers) + end +end + +# N.B. arg_w must be the original argument wrapper, not a remote copy +function aliasing!(state::DataDepsState, target_space::MemorySpace, arg_w::ArgumentWrapper) + # Grab the remote copy of the argument, and calculate the ainfo + remote_arg = get_or_generate_slot!(state, target_space, arg_w.arg) + remote_arg_w = ArgumentWrapper(remote_arg, arg_w.dep_mod) + + # Check if we already have the result cached + if haskey(state.ainfo_cache, remote_arg_w) + return state.ainfo_cache[remote_arg_w] + end + + # Calculate the ainfo + ainfo = AliasingWrapper(aliasing(remote_arg, arg_w.dep_mod)) + + # Cache the result + state.ainfo_cache[remote_arg_w] = ainfo + + # Update the mapping of ainfo to argument and dep_mod + state.ainfo_arg[ainfo] = remote_arg_w + + # Populate info for the new ainfo + populate_ainfo!(state, arg_w, ainfo, target_space) + + return ainfo +end + +function supports_inplace_move(state::DataDepsState, arg) + return get!(state.supports_inplace_cache, arg) do + return supports_inplace_move(arg) + end +end + +# Determine which arguments could be written to, and thus need tracking +"Whether `arg` is written to by `task`." +function is_writedep(arg, deps, task::DTask) + return any(dep->dep[3], deps) +end + +# Aliasing state setup +function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask) + # Track the task's arguments and access patterns + for (idx, _arg) in enumerate(spec.fargs) + arg = value(_arg) + + # Unwrap In/InOut/Out wrappers and record dependencies + arg, deps = unwrap_inout(arg) + + # Unwrap the Chunk underlying any DTask arguments + arg = arg isa DTask ? fetch(arg; raw=true) : arg + + # Skip non-aliasing arguments + type_may_alias(typeof(arg)) || continue + + # Skip arguments not supporting in-place move + supports_inplace_move(state, arg) || continue + + # Generate a Chunk for the argument if necessary + if haskey(state.raw_arg_to_chunk, arg) + arg = state.raw_arg_to_chunk[arg] + else + if !(arg isa Chunk) + new_arg = tochunk(arg) + state.raw_arg_to_chunk[arg] = new_arg + arg = new_arg + else + state.raw_arg_to_chunk[arg] = arg + end + end + + # Track the origin space of the argument + origin_space = memory_space(arg) + state.arg_origin[arg] = origin_space + state.remote_arg_to_original[arg] = arg + + # Populate argument info for all aliasing dependencies + for (dep_mod, _, _) in deps + # Generate an ArgumentWrapper for the argument + aw = ArgumentWrapper(arg, dep_mod) + + # Populate argument info + populate_argument_info!(state, aw, origin_space) + end + end +end +function populate_argument_info!(state::DataDepsState, arg_w::ArgumentWrapper, origin_space::MemorySpace) + # Initialize ownership and history + if !haskey(state.arg_owner, arg_w) + # N.B. This is valid (even if the backing data is up-to-date elsewhere), + # because we only use this to track the "backstop" if any portion of the + # target ainfo is not updated by the remainder (at which point, this + # is thus the correct owner). + state.arg_owner[arg_w] = origin_space + + # Initialize the overlap set + state.arg_overlaps[arg_w] = Set{ArgumentWrapper}() + end + if !haskey(state.arg_history, arg_w) + state.arg_history[arg_w] = Vector{HistoryEntry}() + end + + # Calculate the ainfo (which will populate ainfo structures and merge history) + aliasing!(state, origin_space, arg_w) +end +function populate_ainfo!(state::DataDepsState, original_arg_w::ArgumentWrapper, target_ainfo::AliasingWrapper, target_space::MemorySpace) + # Initialize owner and readers + if !haskey(state.ainfos_owner, target_ainfo) + overlaps = Set{AliasingWrapper}() + push!(overlaps, target_ainfo) + for other_ainfo in keys(state.ainfos_owner) + target_ainfo == other_ainfo && continue + if will_alias(target_ainfo, other_ainfo) + # Mark us and them as overlapping + push!(overlaps, other_ainfo) + push!(state.ainfos_overlaps[other_ainfo], target_ainfo) + + # Add overlapping history to our own + other_remote_arg_w = state.ainfo_arg[other_ainfo] + other_arg = state.remote_arg_to_original[other_remote_arg_w.arg] + other_arg_w = ArgumentWrapper(other_arg, other_remote_arg_w.dep_mod) + push!(state.arg_overlaps[original_arg_w], other_arg_w) + push!(state.arg_overlaps[other_arg_w], original_arg_w) + merge_history!(state, original_arg_w, other_arg_w) + end + end + state.ainfos_overlaps[target_ainfo] = overlaps + state.ainfos_owner[target_ainfo] = nothing + state.ainfos_readers[target_ainfo] = Pair{DTask,Int}[] + end +end +function merge_history!(state::DataDepsState, arg_w::ArgumentWrapper, other_arg_w::ArgumentWrapper) + history = state.arg_history[arg_w] + @opcounter :merge_history + @opcounter :merge_history_complexity length(history) + largest_value_update!(length(history)) + origin_space = state.arg_origin[other_arg_w.arg] + for other_entry in state.arg_history[other_arg_w] + write_num_tuple = HistoryEntry(AliasingWrapper(NoAliasing()), origin_space, other_entry.write_num) + range = searchsorted(history, write_num_tuple; by=x->x.write_num) + if !isempty(range) + # Find and skip duplicates + match = false + for source_idx in range + source_entry = history[source_idx] + if source_entry.ainfo == other_entry.ainfo && + source_entry.space == other_entry.space && + source_entry.write_num == other_entry.write_num + match = true + break + end + end + match && continue + + # Insert at the first position + idx = first(range) + else + # Insert at the last position + idx = length(history) + 1 + end + insert!(history, idx, other_entry) + end +end +function truncate_history!(state::DataDepsState, arg_w::ArgumentWrapper) + # FIXME: Do this continuously if possible + if haskey(state.arg_history, arg_w) && length(state.arg_history[arg_w]) > 100000 + origin_space = state.arg_origin[arg_w.arg] + @opcounter :truncate_history + _, last_idx = compute_remainder_for_arg!(state, origin_space, arg_w, 0; compute_syncdeps=false) + if last_idx > 0 + @opcounter :truncate_history_removed last_idx + deleteat!(state.arg_history[arg_w], 1:last_idx) + end + end +end + +""" + supports_inplace_move(x) -> Bool + +Returns `false` if `x` doesn't support being copied into from another object +like `x`, via `move!`. This is used in `spawn_datadeps` to prevent attempting +to copy between values which don't support mutation or otherwise don't have an +implemented `move!` and want to skip in-place copies. When this returns +`false`, datadeps will instead perform out-of-place copies for each non-local +use of `x`, and the data in `x` will not be updated when the `spawn_datadeps` +region returns. +""" +supports_inplace_move(x) = true +supports_inplace_move(t::DTask) = supports_inplace_move(fetch(t; raw=true)) +function supports_inplace_move(c::Chunk) + # FIXME: Use MemPool.access_ref + pid = root_worker_id(c.processor) + if pid == myid() + return supports_inplace_move(poolget(c.handle)) + else + return remotecall_fetch(supports_inplace_move, pid, c) + end +end +supports_inplace_move(::Function) = false + +# Read/write dependency management +function get_write_deps!(state::DataDepsState, dest_space::MemorySpace, ainfo::AbstractAliasing, write_num, syncdeps) + # We need to sync with both writers and readers + _get_write_deps!(state, dest_space, ainfo, write_num, syncdeps) + _get_read_deps!(state, dest_space, ainfo, write_num, syncdeps) +end +function get_read_deps!(state::DataDepsState, dest_space::MemorySpace, ainfo::AbstractAliasing, write_num, syncdeps) + # We only need to sync with writers, not readers + _get_write_deps!(state, dest_space, ainfo, write_num, syncdeps) +end + +function _get_write_deps!(state::DataDepsState, dest_space::MemorySpace, ainfo::AbstractAliasing, write_num, syncdeps) + ainfo.inner isa NoAliasing && return + for other_ainfo in state.ainfos_overlaps[ainfo] + other_task_write_num = state.ainfos_owner[other_ainfo] + @dagdebug nothing :spawn_datadeps_sync "Considering sync with writer via $ainfo -> $other_ainfo" + other_task_write_num === nothing && continue + other_task, other_write_num = other_task_write_num + write_num == other_write_num && continue + @dagdebug nothing :spawn_datadeps_sync "Sync with writer via $ainfo -> $other_ainfo" + push!(syncdeps, ThunkSyncdep(other_task)) + end +end +function _get_read_deps!(state::DataDepsState, dest_space::MemorySpace, ainfo::AbstractAliasing, write_num, syncdeps) + ainfo.inner isa NoAliasing && return + for other_ainfo in state.ainfos_overlaps[ainfo] + @dagdebug nothing :spawn_datadeps_sync "Considering sync with reader via $ainfo -> $other_ainfo" + other_tasks = state.ainfos_readers[other_ainfo] + for (other_task, other_write_num) in other_tasks + write_num == other_write_num && continue + @dagdebug nothing :spawn_datadeps_sync "Sync with reader via $ainfo -> $other_ainfo" + push!(syncdeps, ThunkSyncdep(other_task)) + end + end +end +function add_writer!(state::DataDepsState, arg_w::ArgumentWrapper, dest_space::MemorySpace, ainfo::AbstractAliasing, task, write_num) + state.ainfos_owner[ainfo] = task=>write_num + empty!(state.ainfos_readers[ainfo]) + + # Clear the history for this target, since this is a new write event + empty!(state.arg_history[arg_w]) + + # Add our own history + push!(state.arg_history[arg_w], HistoryEntry(ainfo, dest_space, write_num)) + + # Find overlapping arguments and update their history + for other_arg_w in state.arg_overlaps[arg_w] + other_arg_w == arg_w && continue + push!(state.arg_history[other_arg_w], HistoryEntry(ainfo, dest_space, write_num)) + end + + # Record the last place we were fully written to + state.arg_owner[arg_w] = dest_space + + # Not necessary to assert a read, but conceptually it's true + add_reader!(state, arg_w, dest_space, ainfo, task, write_num) +end +function add_reader!(state::DataDepsState, arg_w::ArgumentWrapper, dest_space::MemorySpace, ainfo::AbstractAliasing, task, write_num) + push!(state.ainfos_readers[ainfo], task=>write_num) +end + +# Make a copy of each piece of data on each worker +# memory_space => {arg => copy_of_arg} +isremotehandle(x) = false +isremotehandle(x::DTask) = true +isremotehandle(x::Chunk) = true +function generate_slot!(state::DataDepsState, dest_space, data) + if data isa DTask + data = fetch(data; raw=true) + end + # N.B. We do not perform any sync/copy with the current owner of the data, + # because all we want here is to make a copy of some version of the data, + # even if the data is not up to date. + orig_space = memory_space(data) + to_proc = first(processors(dest_space)) + from_proc = first(processors(orig_space)) + dest_space_args = get!(IdDict{Any,Any}, state.remote_args, dest_space) + ALIASED_OBJECT_CACHE[] = get!(Dict{AbstractAliasing,Chunk}, state.ainfo_backing_chunk, dest_space) + if orig_space == dest_space && (data isa Chunk || !isremotehandle(data)) + # Fast path for local data that's already in a Chunk or not a remote handle needing rewrapping + data_chunk = tochunk(data, from_proc) + else + ctx = Sch.eager_context() + id = rand(Int) + @maybelog ctx timespan_start(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data)) + data_chunk = move_rewrap(from_proc, to_proc, data) + @maybelog ctx timespan_finish(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data=data_chunk)) + end + @assert memory_space(data_chunk) == dest_space "space mismatch! $dest_space (dest) != $(memory_space(data_chunk)) (actual) ($(typeof(data)) (data) vs. $(typeof(data_chunk)) (chunk)), spaces ($orig_space -> $dest_space)" + dest_space_args[data] = data_chunk + state.remote_arg_to_original[data_chunk] = data + + ALIASED_OBJECT_CACHE[] = nothing + + return dest_space_args[data] +end +function get_or_generate_slot!(state, dest_space, data) + @assert !(data isa ArgumentWrapper) + if !haskey(state.remote_args, dest_space) + state.remote_args[dest_space] = IdDict{Any,Any}() + end + if !haskey(state.remote_args[dest_space], data) + return generate_slot!(state, dest_space, data) + end + return state.remote_args[dest_space][data] +end +function move_rewrap(from_proc::Processor, to_proc::Processor, data) + return aliased_object!(data) do data + to_w = root_worker_id(to_proc) + return remotecall_fetch(to_w, from_proc, to_proc, data) do from_proc, to_proc, data + data_converted = move(from_proc, to_proc, data) + return tochunk(data_converted, to_proc) + end + end +end +const ALIASED_OBJECT_CACHE = TaskLocalValue{Union{Dict{AbstractAliasing,Chunk}, Nothing}}(()->nothing) +@warn "Document these public methods" maxlog=1 +# TODO: Use state to cache aliasing() results +function declare_aliased_object!(x; ainfo=aliasing(x, identity)) + cache = ALIASED_OBJECT_CACHE[] + cache[ainfo] = x +end +function aliased_object!(x; ainfo=aliasing(x, identity)) + cache = ALIASED_OBJECT_CACHE[] + if haskey(cache, ainfo) + y = cache[ainfo] + else + @assert x isa Chunk "x must be a Chunk\nUse functor form of aliased_object!" + cache[ainfo] = x + y = x + end + return y +end +function aliased_object!(f, x; ainfo=aliasing(x, identity)) + cache = ALIASED_OBJECT_CACHE[] + if haskey(cache, ainfo) + y = cache[ainfo] + else + y = f(x) + @assert y isa Chunk "Didn't get a Chunk from functor" + cache[ainfo] = y + end + return y +end +function aliased_object_unwrap!(x::Chunk) + y = unwrap(x) + ainfo = aliasing(y, identity) + return unwrap(aliased_object!(x; ainfo)) +end + +struct DataDepsSchedulerState + task_to_spec::Dict{DTask,DTaskSpec} + assignments::Dict{DTask,MemorySpace} + dependencies::Dict{DTask,Set{DTask}} + task_completions::Dict{DTask,UInt64} + space_completions::Dict{MemorySpace,UInt64} + capacities::Dict{MemorySpace,Int} + + function DataDepsSchedulerState() + return new(Dict{DTask,DTaskSpec}(), + Dict{DTask,MemorySpace}(), + Dict{DTask,Set{DTask}}(), + Dict{DTask,UInt64}(), + Dict{MemorySpace,UInt64}(), + Dict{MemorySpace,Int}()) + end +end \ No newline at end of file diff --git a/src/datadeps/chunkview.jl b/src/datadeps/chunkview.jl new file mode 100644 index 000000000..04b581c17 --- /dev/null +++ b/src/datadeps/chunkview.jl @@ -0,0 +1,64 @@ +struct ChunkView{N} + chunk::Chunk + slices::NTuple{N, Union{Int, AbstractRange{Int}, Colon}} +end + +function Base.view(c::Chunk, slices...) + if c.domain isa ArrayDomain + nd, sz = ndims(c.domain), size(c.domain) + nd == length(slices) || throw(DimensionMismatch("Expected $nd slices, got $(length(slices))")) + + for (i, s) in enumerate(slices) + if s isa Int + 1 ≤ s ≤ sz[i] || throw(ArgumentError("Index $s out of bounds for dimension $i (size $(sz[i]))")) + elseif s isa AbstractRange + isempty(s) && continue + 1 ≤ first(s) ≤ last(s) ≤ sz[i] || throw(ArgumentError("Range $s out of bounds for dimension $i (size $(sz[i]))")) + elseif s === Colon() + continue + else + throw(ArgumentError("Invalid slice type $(typeof(s)) at dimension $i, Expected Type of Int, AbstractRange, or Colon")) + end + end + end + + return ChunkView(c, slices) +end + +Base.view(c::DTask, slices...) = view(fetch(c; raw=true), slices...) + +aliasing(x::ChunkView) = + throw(ConcurrencyViolationError("Cannot query aliasing of a ChunkView directly")) +memory_space(x::ChunkView) = memory_space(x.chunk) +isremotehandle(x::ChunkView) = true + +# This definition is here because it's so similar to ChunkView +function move_rewrap(from_proc::Processor, to_proc::Processor, v::SubArray) + to_w = root_worker_id(to_proc) + p_chunk = aliased_object!(parent(v)) do p + return remotecall_fetch(to_w, from_proc, to_proc, p) do from_proc, to_proc, p + return tochunk(move(from_proc, to_proc, p), to_proc) + end + end + inds = parentindices(v) + return remotecall_fetch(to_w, from_proc, to_proc, p_chunk, inds) do from_proc, to_proc, p_chunk, inds + p_new = move(from_proc, to_proc, p_chunk) + v_new = view(p_new, inds...) + return tochunk(v_new, to_proc) + end +end +function move_rewrap(from_proc::Processor, to_proc::Processor, slice::ChunkView) + to_w = root_worker_id(to_proc) + p_chunk = aliased_object!(slice.chunk) do p_chunk + return remotecall_fetch(to_w, from_proc, to_proc, p_chunk) do from_proc, to_proc, p_chunk + return tochunk(move(from_proc, to_proc, p_chunk), to_proc) + end + end + return remotecall_fetch(to_w, from_proc, to_proc, p_chunk, slice.slices) do from_proc, to_proc, p_chunk, inds + p_new = move(from_proc, to_proc, p_chunk) + v_new = view(p_new, inds...) + return tochunk(v_new, to_proc) + end +end + +Base.fetch(slice::ChunkView) = view(fetch(slice.chunk), slice.slices...) \ No newline at end of file diff --git a/src/datadeps/interval_tree.jl b/src/datadeps/interval_tree.jl new file mode 100644 index 000000000..1075f5912 --- /dev/null +++ b/src/datadeps/interval_tree.jl @@ -0,0 +1,349 @@ +# Get the start address of a span +span_start(span::MemorySpan) = span.ptr.addr +span_start(span::LocalMemorySpan) = span.ptr +span_start(span::ManyMemorySpan{N}) where N = ManyPair(ntuple(i -> span_start(span.spans[i]), N)) +# Get the length of a span +span_len(span::MemorySpan) = span.len +span_len(span::LocalMemorySpan) = span.len +span_len(span::ManyMemorySpan{N}) where N = ManyPair(ntuple(i -> span_len(span.spans[i]), N)) + +# Get the end address of a span +span_end(span::MemorySpan) = span.ptr.addr + span.len +span_end(span::LocalMemorySpan) = span.ptr + span.len +span_end(span::ManyMemorySpan{N}) where N = ManyPair(ntuple(i -> span_end(span.spans[i]), N)) +mutable struct IntervalNode{M,E} + span::M + max_end::E # Maximum end value in this subtree + left::Union{IntervalNode{M,E}, Nothing} + right::Union{IntervalNode{M,E}, Nothing} + + IntervalNode(span::M) where M <: MemorySpan = new{M,UInt64}(span, span_end(span), nothing, nothing) + IntervalNode(span::LocalMemorySpan) = new{LocalMemorySpan,UInt64}(span, span_end(span), nothing, nothing) + IntervalNode(span::ManyMemorySpan{N}) where N = new{ManyMemorySpan{N},ManyPair{N}}(span, span_end(span), nothing, nothing) +end + +mutable struct IntervalTree{M,E} + root::Union{IntervalNode{M,E}, Nothing} + + IntervalTree{M}() where M<:MemorySpan = new{M,UInt64}(nothing) + IntervalTree{LocalMemorySpan}() = new{LocalMemorySpan,UInt64}(nothing) + IntervalTree{ManyMemorySpan{N}}() where N = new{ManyMemorySpan{N},ManyPair{N}}(nothing) +end + +# Construct interval tree from unsorted set of spans +function IntervalTree{M}(spans) where M + tree = IntervalTree{M}() + for span in spans + insert!(tree, span) + end + return tree +end +IntervalTree(spans::Vector{M}) where M = IntervalTree{M}(spans) + +function Base.show(io::IO, tree::IntervalTree) + println(io, "$(typeof(tree)) (with $(length(tree)) spans):") + for (i, span) in enumerate(tree) + println(io, " $i: [$(span_start(span)), $(span_end(span))) (len=$(span_len(span)))") + end +end + +function Base.collect(tree::IntervalTree{M}) where M + result = M[] + for span in tree + push!(result, span) + end + return result +end + +function Base.iterate(tree::IntervalTree{M}) where M + state = Vector{M}() + if tree.root === nothing + return nothing + end + return iterate(tree.root) +end +function Base.iterate(tree::IntervalTree, state) + return iterate(tree.root, state) +end +function Base.iterate(root::IntervalNode{M,E}) where {M,E} + state = Vector{IntervalNode{M,E}}() + push!(state, root) + return iterate(root, state) +end +function Base.iterate(root::IntervalNode, state) + if isempty(state) + return nothing + end + current = popfirst!(state) + if current.right !== nothing + pushfirst!(state, current.right) + end + if current.left !== nothing + pushfirst!(state, current.left) + end + return current.span, state +end + +function Base.length(tree::IntervalTree) + result = 0 + for _ in tree + result += 1 + end + return result +end + +# Update max_end value for a node based on its children +function update_max_end!(node::IntervalNode) + node.max_end = span_end(node.span) + if node.left !== nothing + node.max_end = max(node.max_end, node.left.max_end) + end + if node.right !== nothing + node.max_end = max(node.max_end, node.right.max_end) + end +end + +# Insert a span into the interval tree +function Base.insert!(tree::IntervalTree{M}, span::M) where M + if !isempty(span) + tree.root = insert_node!(tree.root, span) + end + return span +end + +function insert_node!(::Nothing, span::M) where M + return IntervalNode(span) +end +function insert_node!(node::IntervalNode{M,E}, span::M) where {M,E} + if span_start(span) <= span_start(node.span) + node.left = insert_node!(node.left, span) + else + node.right = insert_node!(node.right, span) + end + + update_max_end!(node) + return node +end + +# Remove a specific span from the tree (split as needed) +function Base.delete!(tree::IntervalTree{M}, span::M) where M + if !isempty(span) + tree.root = delete_node!(tree.root, span) + end + return span +end + +function delete_node!(::Nothing, span::M) where M + return nothing +end +function delete_node!(node::IntervalNode{M,E}, span::M) where {M,E} + # Check for exact match first + if span_start(node.span) == span_start(span) && span_len(node.span) == span_len(span) + # Exact match, remove the node + if node.left === nothing && node.right === nothing + return nothing + elseif node.left === nothing + return node.right + elseif node.right === nothing + return node.left + else + # Node has two children - replace with inorder successor + successor = find_min(node.right) + node.span = successor.span + node.right = delete_node!(node.right, successor.span) + end + # Check for overlap + elseif spans_overlap(node.span, span) + # Handle overlapping spans by removing current node and adding remainders + original_span = node.span + + # Remove the current node first (same logic as exact match) + if node.left === nothing && node.right === nothing + # Leaf node - remove it and create a new subtree with remainders + remaining_node = nothing + elseif node.left === nothing + remaining_node = node.right + elseif node.right === nothing + remaining_node = node.left + else + # Node has two children - replace with inorder successor + successor = find_min(node.right) + node.span = successor.span + node.right = delete_node!(node.right, successor.span) + remaining_node = node + end + + # Calculate and insert the remaining portions + original_start = span_start(original_span) + original_end = span_end(original_span) + del_start = span_start(span) + del_end = span_end(span) + + # Left portion: exists if original starts before deleted span + if original_start < del_start + left_end = min(original_end, del_start) + if left_end > original_start + left_span = M(original_start, left_end - original_start) + if !isempty(left_span) + remaining_node = insert_node!(remaining_node, left_span) + end + end + end + + # Right portion: exists if original extends beyond deleted span + if original_end > del_end + right_start = max(original_start, del_end) + if original_end > right_start + right_span = M(right_start, original_end - right_start) + if !isempty(right_span) + remaining_node = insert_node!(remaining_node, right_span) + end + end + end + + return remaining_node + elseif span_start(span) <= span_start(node.span) + node.left = delete_node!(node.left, span) + else + node.right = delete_node!(node.right, span) + end + + if node !== nothing + update_max_end!(node) + end + return node +end + +function find_min(node::IntervalNode) + while node.left !== nothing + node = node.left + end + return node +end + +# Check if two spans overlap +function spans_overlap(span1::MemorySpan, span2::MemorySpan) + return span_start(span1) < span_end(span2) && span_start(span2) < span_end(span1) +end +function spans_overlap(span1::LocalMemorySpan, span2::LocalMemorySpan) + return span_start(span1) < span_end(span2) && span_start(span2) < span_end(span1) +end +function spans_overlap(span1::ManyMemorySpan{N}, span2::ManyMemorySpan{N}) where N + # N.B. The spans are assumed to be the same length and relative offset + return spans_overlap(span1.spans[1], span2.spans[1]) +end + +# Find all spans that overlap with the given query span +function find_overlapping(tree::IntervalTree{M}, query::M) where M + result = M[] + find_overlapping!(tree.root, query, result) + return result +end + +function find_overlapping!(::Nothing, query::M, result::Vector{M}) where M + return +end +function find_overlapping!(node::IntervalNode{M,E}, query::M, result::Vector{M}) where {M,E} + # Check if current node overlaps with query + if spans_overlap(node.span, query) + # Get the overlapping portion of the span + overlap_start = max(span_start(node.span), span_start(query)) + overlap_end = min(span_end(node.span), span_end(query)) + overlap = M(overlap_start, overlap_end - overlap_start) + push!(result, overlap) + end + + # Recursively search left subtree if it might contain overlapping intervals + if node.left !== nothing && node.left.max_end > span_start(query) + find_overlapping!(node.left, query, result) + end + + # Recursively search right subtree if query extends beyond current node's start + if node.right !== nothing && span_end(query) > span_start(node.span) + find_overlapping!(node.right, query, result) + end +end + +# ============================================================================ +# MAIN SUBTRACTION ALGORITHM +# ============================================================================ + +""" + subtract_spans!(minuend_tree::IntervalTree{M}, subtrahend_spans::Vector{M}, diff=nothing) where M + +Subtract all spans in subtrahend_spans from the minuend_tree in-place. +The minuend_tree is modified to contain only the portions that remain after subtraction. + +Time Complexity: O(M log N + M*K) where M = |subtrahend_spans|, N = |minuend nodes|, + K = average overlaps per subtrahend span +Space Complexity: O(1) additional space (modifies tree in-place) + +If `diff` is provided, add the overlapping spans to `diff`. +""" +function subtract_spans!(minuend_tree::IntervalTree{M}, subtrahend_spans::Vector{M}, diff=nothing) where M + for sub_span in subtrahend_spans + subtract_single_span!(minuend_tree, sub_span, diff) + end +end + +""" + subtract_single_span!(tree::IntervalTree, sub_span::MemorySpan, diff=nothing) + +Subtract a single span from the interval tree. This function: +1. Finds all overlapping spans in the tree +2. Removes each overlapping span +3. Adds back the non-overlapping portions (left and/or right remnants) +4. If diff is provided, add the overlapping span to diff +""" +function subtract_single_span!(tree::IntervalTree{M}, sub_span::M, diff=nothing) where M + # Find all spans that overlap with the subtrahend + overlapping_spans = find_overlapping(tree, sub_span) + + # Process each overlapping span + for overlap_span in overlapping_spans + # Remove the overlapping span from the tree + delete!(tree, overlap_span) + + # Calculate and add back the portions that should remain + add_remaining_portions!(tree, overlap_span, sub_span) + + if diff !== nothing && !isempty(overlap_span) + push!(diff, overlap_span) + end + end +end + +""" + add_remaining_portions!(tree::IntervalTree, original::MemorySpan, subtracted::MemorySpan) + +After removing an overlapping span, add back the portions that don't overlap with the subtracted span. +There can be up to two remaining portions: left and right of the subtracted region. +""" +function add_remaining_portions!(tree::IntervalTree{M}, original::M, subtracted::M) where M + original_start = span_start(original) + original_end = span_end(original) + sub_start = span_start(subtracted) + sub_end = span_end(subtracted) + + # Left portion: exists if original starts before subtracted + if original_start < sub_start + left_end = min(original_end, sub_start) + if left_end > original_start + left_span = M(original_start, left_end - original_start) + if !isempty(left_span) + insert!(tree, left_span) + end + end + end + + # Right portion: exists if original extends beyond subtracted + if original_end > sub_end + right_start = max(original_start, sub_end) + if original_end > right_start + right_span = M(right_start, original_end - right_start) + if !isempty(right_span) + insert!(tree, right_span) + end + end + end +end \ No newline at end of file diff --git a/src/datadeps/queue.jl b/src/datadeps/queue.jl new file mode 100644 index 000000000..6fc85bd22 --- /dev/null +++ b/src/datadeps/queue.jl @@ -0,0 +1,500 @@ +struct DataDepsTaskQueue <: AbstractTaskQueue + # The queue above us + upper_queue::AbstractTaskQueue + # The set of tasks that have already been seen + seen_tasks::Union{Vector{Pair{DTaskSpec,DTask}},Nothing} + # The data-dependency graph of all tasks + g::Union{SimpleDiGraph{Int},Nothing} + # The mapping from task to graph ID + task_to_id::Union{Dict{DTask,Int},Nothing} + # How to traverse the dependency graph when launching tasks + traversal::Symbol + # Which scheduler to use to assign tasks to processors + scheduler::Symbol + + # Whether aliasing across arguments is possible + # The fields following only apply when aliasing==true + aliasing::Bool + + function DataDepsTaskQueue(upper_queue; + traversal::Symbol=:inorder, + scheduler::Symbol=:naive, + aliasing::Bool=true) + seen_tasks = Pair{DTaskSpec,DTask}[] + g = SimpleDiGraph() + task_to_id = Dict{DTask,Int}() + return new(upper_queue, seen_tasks, g, task_to_id, traversal, scheduler, + aliasing) + end +end + +function enqueue!(queue::DataDepsTaskQueue, spec::Pair{DTaskSpec,DTask}) + push!(queue.seen_tasks, spec) +end +function enqueue!(queue::DataDepsTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}) + append!(queue.seen_tasks, specs) +end + +""" + spawn_datadeps(f::Base.Callable; traversal::Symbol=:inorder) + +Constructs a "datadeps" (data dependencies) region and calls `f` within it. +Dagger tasks launched within `f` may wrap their arguments with `In`, `Out`, or +`InOut` to indicate whether the task will read, write, or read+write that +argument, respectively. These argument dependencies will be used to specify +which tasks depend on each other based on the following rules: + +- Dependencies across unrelated arguments are independent; only dependencies on arguments which overlap in memory synchronize with each other +- `InOut` is the same as `In` and `Out` applied simultaneously, and synchronizes with the union of the `In` and `Out` effects +- Any two or more `In` dependencies do not synchronize with each other, and may execute in parallel +- An `Out` dependency synchronizes with any previous `In` and `Out` dependencies +- An `In` dependency synchronizes with any previous `Out` dependencies +- If unspecified, an `In` dependency is assumed + +In general, the result of executing tasks following the above rules will be +equivalent to simply executing tasks sequentially and in order of submission. +Of course, if dependencies are incorrectly specified, undefined behavior (and +unexpected results) may occur. + +Unlike other Dagger tasks, tasks executed within a datadeps region are allowed +to write to their arguments when annotated with `Out` or `InOut` +appropriately. + +At the end of executing `f`, `spawn_datadeps` will wait for all launched tasks +to complete, rethrowing the first error, if any. The result of `f` will be +returned from `spawn_datadeps`. + +The keyword argument `traversal` controls the order that tasks are launched by +the scheduler, and may be set to `:bfs` or `:dfs` for Breadth-First Scheduling +or Depth-First Scheduling, respectively. All traversal orders respect the +dependencies and ordering of the launched tasks, but may provide better or +worse performance for a given set of datadeps tasks. This argument is +experimental and subject to change. +""" +function spawn_datadeps(f::Base.Callable; static::Bool=true, + traversal::Symbol=:inorder, + scheduler::Union{Symbol,Nothing}=nothing, + aliasing::Bool=true, + launch_wait::Union{Bool,Nothing}=nothing) + if !static + throw(ArgumentError("Dynamic scheduling is no longer available")) + end + wait_all(; check_errors=true) do + scheduler = something(scheduler, DATADEPS_SCHEDULER[], :roundrobin)::Symbol + launch_wait = something(launch_wait, DATADEPS_LAUNCH_WAIT[], false)::Bool + if launch_wait + result = spawn_bulk() do + queue = DataDepsTaskQueue(get_options(:task_queue); + traversal, scheduler, aliasing) + with_options(f; task_queue=queue) + distribute_tasks!(queue) + end + else + queue = DataDepsTaskQueue(get_options(:task_queue); + traversal, scheduler, aliasing) + result = with_options(f; task_queue=queue) + distribute_tasks!(queue) + end + return result + end +end +const DATADEPS_SCHEDULER = ScopedValue{Union{Symbol,Nothing}}(nothing) +const DATADEPS_LAUNCH_WAIT = ScopedValue{Union{Bool,Nothing}}(nothing) + +function distribute_tasks!(queue::DataDepsTaskQueue) + #= TODO: Improvements to be made: + # - Support for copying non-AbstractArray arguments + # - Parallelize read copies + # - Unreference unused slots + # - Reuse memory when possible + # - Account for differently-sized data + =# + + # Get the set of all processors to be scheduled on + all_procs = Processor[] + scope = get_compute_scope() + for w in procs() + append!(all_procs, get_processors(OSProc(w))) + end + filter!(proc->!isa(constrain(ExactScope(proc), scope), + InvalidScope), + all_procs) + if isempty(all_procs) + throw(Sch.SchedulingException("No processors available, try widening scope")) + end + exec_spaces = unique(vcat(map(proc->collect(memory_spaces(proc)), all_procs)...)) + if !all(space->space isa CPURAMMemorySpace, exec_spaces) && !all(space->root_worker_id(space) == myid(), exec_spaces) + @warn "Datadeps support for multi-GPU, multi-worker is currently broken\nPlease be prepared for incorrect results or errors" maxlog=1 + end + + # Round-robin assign tasks to processors + upper_queue = get_options(:task_queue) + + traversal = queue.traversal + if traversal == :inorder + # As-is + task_order = Colon() + elseif traversal == :bfs + # BFS + task_order = Int[1] + to_walk = Int[1] + seen = Set{Int}([1]) + while !isempty(to_walk) + # N.B. next_root has already been seen + next_root = popfirst!(to_walk) + for v in outneighbors(queue.g, next_root) + if !(v in seen) + push!(task_order, v) + push!(seen, v) + push!(to_walk, v) + end + end + end + elseif traversal == :dfs + # DFS (modified with backtracking) + task_order = Int[] + to_walk = Int[1] + seen = Set{Int}() + while length(task_order) < length(queue.seen_tasks) && !isempty(to_walk) + next_root = popfirst!(to_walk) + if !(next_root in seen) + iv = inneighbors(queue.g, next_root) + if all(v->v in seen, iv) + push!(task_order, next_root) + push!(seen, next_root) + ov = outneighbors(queue.g, next_root) + prepend!(to_walk, ov) + else + push!(to_walk, next_root) + end + end + end + else + throw(ArgumentError("Invalid traversal mode: $traversal")) + end + + state = DataDepsState(queue.aliasing) + sstate = DataDepsSchedulerState() + for proc in all_procs + space = only(memory_spaces(proc)) + get!(()->0, sstate.capacities, space) + sstate.capacities[space] += 1 + end + + # Start launching tasks and necessary copies + write_num = 1 + proc_idx = 1 + pressures = Dict{Processor,Int}() + proc_to_scope_lfu = BasicLFUCache{Processor,AbstractScope}(1024) + for (spec, task) in queue.seen_tasks[task_order] + # Populate all task dependencies + populate_task_info!(state, spec, task) + + task_scope = @something(spec.options.compute_scope, spec.options.scope, DefaultScope()) + scheduler = queue.scheduler + if scheduler == :naive + raw_args = map(arg->tochunk(value(arg)), spec.fargs) + our_proc = remotecall_fetch(1, all_procs, raw_args) do all_procs, raw_args + Sch.init_eager() + sch_state = Sch.EAGER_STATE[] + + @lock sch_state.lock begin + # Calculate costs per processor and select the most optimal + # FIXME: This should consider any already-allocated slots, + # whether they are up-to-date, and if not, the cost of moving + # data to them + procs, costs = Sch.estimate_task_costs(sch_state, all_procs, nothing, raw_args) + return first(procs) + end + end + elseif scheduler == :smart + raw_args = map(filter(arg->haskey(state.data_locality, value(arg)), spec.fargs)) do arg + arg_chunk = tochunk(value(arg)) + # Only the owned slot is valid + # FIXME: Track up-to-date copies and pass all of those + return arg_chunk => data_locality[arg] + end + f_chunk = tochunk(value(spec.fargs[1])) + our_proc, task_pressure = remotecall_fetch(1, all_procs, pressures, f_chunk, raw_args) do all_procs, pressures, f, chunks_locality + Sch.init_eager() + sch_state = Sch.EAGER_STATE[] + + @lock sch_state.lock begin + tx_rate = sch_state.transfer_rate[] + + costs = Dict{Processor,Float64}() + for proc in all_procs + # Filter out chunks that are already local + chunks_filt = Iterators.filter(((chunk, space)=chunk_locality)->!(proc in processors(space)), chunks_locality) + + # Estimate network transfer costs based on data size + # N.B. `affinity(x)` really means "data size of `x`" + # N.B. We treat same-worker transfers as having zero transfer cost + tx_cost = Sch.impute_sum(affinity(chunk)[2] for chunk in chunks_filt) + + # Estimate total cost to move data and get task running after currently-scheduled tasks + est_time_util = get(pressures, proc, UInt64(0)) + costs[proc] = est_time_util + (tx_cost/tx_rate) + end + + # Look up estimated task cost + sig = Sch.signature(sch_state, f, map(first, chunks_locality)) + task_pressure = get(sch_state.signature_time_cost, sig, 1000^3) + + # Shuffle procs around, so equally-costly procs are equally considered + P = randperm(length(all_procs)) + procs = getindex.(Ref(all_procs), P) + + # Sort by lowest cost first + sort!(procs, by=p->costs[p]) + + best_proc = first(procs) + return best_proc, task_pressure + end + end + # FIXME: Pressure should be decreased by pressure of syncdeps on same processor + pressures[our_proc] = get(pressures, our_proc, UInt64(0)) + task_pressure + elseif scheduler == :ultra + args = Base.mapany(spec.fargs) do arg + pos, data = arg + data, _ = unwrap_inout(data) + if data isa DTask + data = fetch(data; raw=true) + end + return pos => tochunk(data) + end + f_chunk = tochunk(value(spec.fargs[1])) + task_time = remotecall_fetch(1, f_chunk, args) do f, args + Sch.init_eager() + sch_state = Sch.EAGER_STATE[] + return @lock sch_state.lock begin + sig = Sch.signature(sch_state, f, args) + return get(sch_state.signature_time_cost, sig, 1000^3) + end + end + + # FIXME: Copy deps are computed eagerly + deps = @something(spec.options.syncdeps, Set{Any}()) + + # Find latest time-to-completion of all syncdeps + deps_completed = UInt64(0) + for dep in deps + haskey(sstate.task_completions, dep) || continue # copy deps aren't recorded + deps_completed = max(deps_completed, sstate.task_completions[dep]) + end + + # Find latest time-to-completion of each memory space + # FIXME: Figure out space completions based on optimal packing + spaces_completed = Dict{MemorySpace,UInt64}() + for space in exec_spaces + completed = UInt64(0) + for (task, other_space) in sstate.assignments + space == other_space || continue + completed = max(completed, sstate.task_completions[task]) + end + spaces_completed[space] = completed + end + + # Choose the earliest-available memory space and processor + # FIXME: Consider move time + move_time = UInt64(0) + local our_space_completed + while true + our_space_completed, our_space = findmin(spaces_completed) + our_space_procs = filter(proc->proc in all_procs, processors(our_space)) + if isempty(our_space_procs) + delete!(spaces_completed, our_space) + continue + end + our_proc = rand(our_space_procs) + break + end + + sstate.task_to_spec[task] = spec + sstate.assignments[task] = our_space + sstate.task_completions[task] = our_space_completed + move_time + task_time + elseif scheduler == :roundrobin + our_proc = all_procs[proc_idx] + if task_scope == scope + # all_procs is already limited to scope + else + if isa(constrain(task_scope, scope), InvalidScope) + throw(Sch.SchedulingException("Scopes are not compatible: $(scope), $(task_scope)")) + end + while !proc_in_scope(our_proc, task_scope) + proc_idx = mod1(proc_idx + 1, length(all_procs)) + our_proc = all_procs[proc_idx] + end + end + else + error("Invalid scheduler: $sched") + end + @assert our_proc in all_procs + our_space = only(memory_spaces(our_proc)) + + # Find the scope for this task (and its copies) + if task_scope == scope + # Optimize for the common case, cache the proc=>scope mapping + our_scope = get!(proc_to_scope_lfu, our_proc) do + our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) + return constrain(UnionScope(map(ExactScope, our_procs)...), scope) + end + else + # Use the provided scope and constrain it to the available processors + our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) + our_scope = constrain(UnionScope(map(ExactScope, our_procs)...), task_scope) + end + if our_scope isa InvalidScope + throw(Sch.SchedulingException("Scopes are not compatible: $(our_scope.x), $(our_scope.y)")) + end + + f = spec.fargs[1] + f.value = move(ThreadProc(myid(), 1), our_proc, value(f)) + @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) Scheduling: $our_proc ($our_space)" + + # Copy raw task arguments for analysis + task_args = map(copy, spec.fargs) + + # Generate a list of ArgumentWrappers for each task argument + task_arg_ws = map(task_args) do _arg + arg = value(_arg) + arg, deps = unwrap_inout(arg) + arg = arg isa DTask ? fetch(arg; raw=true) : arg + if !type_may_alias(typeof(arg)) || !supports_inplace_move(state, arg) + return [(ArgumentWrapper(arg, identity), false, false)] + end + + # Get the Chunk for the argument + arg = state.raw_arg_to_chunk[arg] + + arg_ws = Tuple{ArgumentWrapper,Bool,Bool}[] + for (dep_mod, readdep, writedep) in deps + push!(arg_ws, (ArgumentWrapper(arg, dep_mod), readdep, writedep)) + end + return arg_ws + end + task_arg_ws = task_arg_ws::Vector{Vector{Tuple{ArgumentWrapper,Bool,Bool}}} + + # Truncate the history for each argument + for arg_ws in task_arg_ws + for (arg_w, _, _) in arg_ws + truncate_history!(state, arg_w) + end + end + + # Copy args from local to remote + for (idx, arg_ws) in enumerate(task_arg_ws) + arg = first(arg_ws)[1].arg + pos = raw_position(task_args[idx]) + + # Is the data written previously or now? + if !type_may_alias(typeof(arg)) + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (immutable)" + spec.fargs[idx].value = arg + continue + end + + # Is the data writeable? + if !supports_inplace_move(state, arg) + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (non-writeable)" + spec.fargs[idx].value = arg + continue + end + + # Is the source of truth elsewhere? + arg_remote = get_or_generate_slot!(state, our_space, arg) + for (arg_w, _, _) in arg_ws + dep_mod = arg_w.dep_mod + remainder, _ = compute_remainder_for_arg!(state, our_space, arg_w, write_num) + if remainder isa MultiRemainderAliasing + enqueue_remainder_copy_to!(state, our_space, arg_w, remainder, value(f), idx, our_scope, task, write_num) + elseif remainder isa FullCopy + enqueue_copy_to!(state, our_space, arg_w, value(f), idx, our_scope, task, write_num) + else + @assert remainder isa NoAliasing "Expected NoAliasing, got $(typeof(remainder))" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Skipped copy-to (up-to-date): $our_space" + end + end + spec.fargs[idx].value = arg_remote + end + write_num += 1 + + # Validate that we're not accidentally performing a copy + for (idx, _arg) in enumerate(spec.fargs) + arg = value(_arg) + _, deps = unwrap_inout(value(task_args[idx])) + # N.B. We only do this check when the argument supports in-place + # moves, because for the moment, we are not guaranteeing updates or + # write-back of results + if is_writedep(arg, deps, task) && supports_inplace_move(state, arg) + arg_space = memory_space(arg) + @assert arg_space == our_space "($(repr(value(f))))[$(idx-1)] Tried to pass $(typeof(arg)) from $arg_space to $our_space" + end + end + + # Calculate this task's syncdeps + if spec.options.syncdeps === nothing + spec.options.syncdeps = Set{Any}() + end + syncdeps = spec.options.syncdeps + for (idx, arg_ws) in enumerate(task_arg_ws) + arg = first(arg_ws)[1].arg + type_may_alias(typeof(arg)) || continue + supports_inplace_move(state, arg) || continue + for (arg_w, _, writedep) in arg_ws + ainfo = aliasing!(state, our_space, arg_w) + dep_mod = arg_w.dep_mod + if writedep + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as writer" + get_write_deps!(state, our_space, ainfo, write_num, syncdeps) + else + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as reader" + get_read_deps!(state, our_space, ainfo, write_num, syncdeps) + end + end + end + @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) Task has $(length(syncdeps)) syncdeps" + + # Launch user's task + spec.options.scope = our_scope + spec.options.exec_scope = our_scope + enqueue!(upper_queue, spec=>task) + + # Update read/write tracking for arguments + for (idx, arg_ws) in enumerate(task_arg_ws) + arg = first(arg_ws)[1].arg + type_may_alias(typeof(arg)) || continue + for (arg_w, _, writedep) in arg_ws + ainfo = aliasing!(state, our_space, arg_w) + dep_mod = arg_w.dep_mod + if writedep + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Task set as writer" + add_writer!(state, arg_w, our_space, ainfo, task, write_num) + else + add_reader!(state, arg_w, our_space, ainfo, task, write_num) + end + end + end + + write_num += 1 + proc_idx = mod1(proc_idx + 1, length(all_procs)) + end + + # Copy args from remote to local + # N.B. We sort the keys to ensure a deterministic order for uniformity + for arg_w in sort(collect(keys(state.arg_owner)); by=arg_w->arg_w.hash) + arg = arg_w.arg + origin_space = state.arg_origin[arg] + remainder, _ = compute_remainder_for_arg!(state, origin_space, arg_w, write_num) + if remainder isa MultiRemainderAliasing + origin_scope = UnionScope(map(ExactScope, collect(processors(origin_space)))...) + enqueue_remainder_copy_from!(state, origin_space, arg_w, remainder, origin_scope, write_num) + elseif remainder isa FullCopy + origin_scope = UnionScope(map(ExactScope, collect(processors(origin_space)))...) + enqueue_copy_from!(state, origin_space, arg_w, origin_scope, write_num) + else + @assert remainder isa NoAliasing "Expected NoAliasing, got $(typeof(remainder))" + @dagdebug nothing :spawn_datadeps "Skipped copy-from (up-to-date): $origin_space" + end + end +end diff --git a/src/datadeps/remainders.jl b/src/datadeps/remainders.jl new file mode 100644 index 000000000..b6f3d3c51 --- /dev/null +++ b/src/datadeps/remainders.jl @@ -0,0 +1,407 @@ +# Remainder tracking and computation functions + +""" + RemainderAliasing{S<:MemorySpace} <: AbstractAliasing + +Represents the memory spans that remain after subtracting some regions from a base aliasing object. +This is used to perform partial data copies that only update the "remainder" regions. +""" +struct RemainderAliasing{S<:MemorySpace} <: AbstractAliasing + space::S + spans::Vector{Tuple{LocalMemorySpan,LocalMemorySpan}} + syncdeps::Set{ThunkSyncdep} +end +RemainderAliasing(space::S, spans::Vector{Tuple{LocalMemorySpan,LocalMemorySpan}}, syncdeps::Set{ThunkSyncdep}) where S = + RemainderAliasing{S}(space, spans, syncdeps) + +memory_spans(ra::RemainderAliasing) = ra.spans + +Base.hash(ra::RemainderAliasing, h::UInt) = hash(ra.spans, hash(RemainderAliasing, h)) +Base.:(==)(ra1::RemainderAliasing, ra2::RemainderAliasing) = ra1.spans == ra2.spans + +# Add will_alias support for RemainderAliasing +function will_alias(x::RemainderAliasing, y::AbstractAliasing) + return will_alias(memory_spans(x), memory_spans(y)) +end + +function will_alias(x::AbstractAliasing, y::RemainderAliasing) + return will_alias(memory_spans(x), memory_spans(y)) +end + +function will_alias(x::RemainderAliasing, y::RemainderAliasing) + return will_alias(memory_spans(x), memory_spans(y)) +end + +struct MultiRemainderAliasing <: AbstractAliasing + remainders::Vector{<:RemainderAliasing} +end +MultiRemainderAliasing() = MultiRemainderAliasing(RemainderAliasing[]) + +memory_spans(mra::MultiRemainderAliasing) = vcat(memory_spans.(mra.remainders)...) + +Base.hash(mra::MultiRemainderAliasing, h::UInt) = hash(mra.remainders, hash(MultiRemainderAliasing, h)) +Base.:(==)(mra1::MultiRemainderAliasing, mra2::MultiRemainderAliasing) = mra1.remainders == mra2.remainders + +struct FullCopy end + +""" + compute_remainder_for_arg!(state::DataDepsState, + target_space::MemorySpace, + arg_w::ArgumentWrapper) + +Computes what remainder regions need to be copied to `target_space` before a task can access `arg_w`. +Returns a `MultiRemainderAliasing` object representing the remainder, or `NoAliasing()` if no remainder needed. + +The algorithm starts by collecting the memory spans of `arg_w` in `target_space` - this is the "remainder". +When this remainder is empty, the algorithm will be finished. +Additionally, a dictionary is created to store the source and destination +memory spans (for each source memory space) that will be used to create the +`MultiRemainderAliasing` object - this is the "tracker". + +The algorithm walks backwards through the `arg_history` vector for `arg_w` +(which is an ordered list of all overlapping ainfos that were directy written to (potentially in a different memory space than `target_space`) +since the last time this `arg_w` was written to). If this ainfo is in `target_space`, +then it is not under consideration; it is simply subtraced from the remainder with `subtract_remainder!`, +and the algorithm goes to the next ainfo. Otherwise, the algorithm will consider this ainfo for tracking. + +For each overlapping ainfo (which lives in a different memory space than `target_space`) to be tracked, there exists a corresponding "mirror" ainfo in +`target_space`, which is the equivalent of the overlapping ainfo, but in +`target_space`. This mirror ainfo is assumed to have an identical number of memory spans as the overlapping ainfo, +and each memory span is assumed to be identical in size, but not necessarily identical in address. + +These three sets of memory spans (from the remainder, the overlapping ainfo, and the mirror ainfo) are then passed to `schedule_aliasing!`. +This call will subtract the spans of the mirror ainfo from the remainder (as the two live in the same memory space and thus can be directly compared), +and will update the remainder accordingly. +Additionaly, it will also use this subtraction to update the tracker, by adding the equivalent spans (mapped from mirror ainfo to overlapping ainfo) to the tracker as the source, +and the spans of the remainder as the destination. + +If the history is exhausted without the remainder becoming empty, then the +remaining data in `target_space` is assumed to be up-to-date (as the latest write +to `arg_w` is the furthest back we need to consider). + +Finally, the tracker is converted into a `MultiRemainderAliasing` object, +and returned. +""" +function compute_remainder_for_arg!(state::DataDepsState, + target_space::MemorySpace, + arg_w::ArgumentWrapper, + write_num::Int; compute_syncdeps::Bool=true) + @label restart + + # Determine all memory spaces of the history + spaces_set = Set{MemorySpace}() + push!(spaces_set, target_space) + owner_space = state.arg_owner[arg_w] + push!(spaces_set, owner_space) + for entry in state.arg_history[arg_w] + push!(spaces_set, entry.space) + end + spaces = collect(spaces_set) + N = length(spaces) + + # Lookup all memory spans for arg_w in these spaces + target_ainfos = Vector{Vector{LocalMemorySpan}}() + for space in spaces + target_space_ainfo = aliasing!(state, space, arg_w) + spans = memory_spans(target_space_ainfo) + push!(target_ainfos, LocalMemorySpan.(spans)) + end + nspans = length(first(target_ainfos)) + + # FIXME: This is a hack to ensure that we don't miss any history generated by aliasing(...) + for entry in state.arg_history[arg_w] + if !in(entry.space, spaces) + @opcounter :compute_remainder_for_arg_restart + @goto restart + end + end + + # We may only need to schedule a full copy from the origin space to the + # target space if this is the first time we've written to `arg_w` + if isempty(state.arg_history[arg_w]) + if owner_space != target_space + return FullCopy(), 0 + else + return NoAliasing(), 0 + end + end + + # Create our remainder as an interval tree over all target ainfos + remainder = IntervalTree{ManyMemorySpan{N}}(ManyMemorySpan{N}(ntuple(i -> target_ainfos[i][j], N)) for j in 1:nspans) + + # Create our tracker + tracker = Dict{MemorySpace,Tuple{Vector{Tuple{LocalMemorySpan,LocalMemorySpan}},Set{ThunkSyncdep}}}() + + # Walk backwards through the history of writes to this target + # other_ainfo is the overlapping ainfo that was written to + # other_space is the memory space of the overlapping ainfo + last_idx = length(state.arg_history[arg_w]) + for idx in length(state.arg_history[arg_w]):-1:0 + if isempty(remainder) + # All done! + last_idx = idx + break + end + + if idx > 0 + other_entry = state.arg_history[arg_w][idx] + other_ainfo = other_entry.ainfo + other_space = other_entry.space + else + # If we've reached the end of the history, evaluate ourselves + other_ainfo = aliasing!(state, owner_space, arg_w) + other_space = owner_space + end + + # Lookup all memory spans for arg_w in these spaces + other_remote_arg_w = state.ainfo_arg[other_ainfo] + other_arg_w = ArgumentWrapper(state.remote_arg_to_original[other_remote_arg_w.arg], other_remote_arg_w.dep_mod) + other_ainfos = Vector{Vector{LocalMemorySpan}}() + for space in spaces + other_space_ainfo = aliasing!(state, space, other_arg_w) + spans = memory_spans(other_space_ainfo) + push!(other_ainfos, LocalMemorySpan.(spans)) + end + nspans = length(first(other_ainfos)) + other_many_spans = [ManyMemorySpan{N}(ntuple(i -> other_ainfos[i][j], N)) for j in 1:nspans] + + if other_space == target_space + # Only subtract, this data is already up-to-date in target_space + # N.B. We don't add to syncdeps here, because we'll see this ainfo + # in get_write_deps! + @opcounter :compute_remainder_for_arg_subtract + subtract_spans!(remainder, other_many_spans) + continue + end + + # Subtract from remainder and schedule copy in tracker + other_space_idx = something(findfirst(==(other_space), spaces)) + target_space_idx = something(findfirst(==(target_space), spaces)) + tracker_other_space = get!(tracker, other_space) do + (Vector{Tuple{LocalMemorySpan,LocalMemorySpan}}(), Set{ThunkSyncdep}()) + end + @opcounter :compute_remainder_for_arg_schedule + schedule_remainder!(tracker_other_space[1], other_space_idx, target_space_idx, remainder, other_many_spans) + if compute_syncdeps + @assert haskey(state.ainfos_owner, other_ainfo) "[idx $idx] ainfo $(typeof(other_ainfo)) has no owner" + get_read_deps!(state, other_space, other_ainfo, write_num, tracker_other_space[2]) + end + end + + if isempty(tracker) + return NoAliasing(), 0 + end + + # Return scheduled copies and the index of the last ainfo we considered + mra = MultiRemainderAliasing() + for space in spaces + if haskey(tracker, space) + spans, syncdeps = tracker[space] + if !isempty(spans) + push!(mra.remainders, RemainderAliasing(space, spans, syncdeps)) + end + end + end + return mra, last_idx +end + +### Memory Span Set Operations for Remainder Computation + +""" + schedule_remainder!(tracker::Vector, source_space_idx::Int, dest_space_idx::Int, remainder::IntervalTree, other_many_spans::Vector{ManyMemorySpan{N}}) + +Calculates the difference between `remainder` and `other_many_spans`, subtracts +it from `remainder`, and then adds that difference to `tracker` as a scheduled +copy from `other_many_spans` to the subtraced portion of `remainder`. +""" +function schedule_remainder!(tracker::Vector, source_space_idx::Int, dest_space_idx::Int, remainder::IntervalTree, other_many_spans::Vector{ManyMemorySpan{N}}) where N + diff = Vector{ManyMemorySpan{N}}() + subtract_spans!(remainder, other_many_spans, diff) + + for span in diff + source_span = span.spans[source_space_idx] + dest_span = span.spans[dest_space_idx] + push!(tracker, (source_span, dest_span)) + end +end + +### Remainder copy functions + +""" + enqueue_remainder_copy_to!(state::DataDepsState, f, target_ainfo::AliasingWrapper, remainder_aliasing, dep_mod, arg, idx, + our_space::MemorySpace, our_scope, task::DTask, write_num::Int) + +Enqueues a copy operation to update the remainder regions of an object before a task runs. +""" +function enqueue_remainder_copy_to!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, remainder_aliasing::MultiRemainderAliasing, + f, idx, dest_scope, task, write_num::Int) + for remainder in remainder_aliasing.remainders + @assert !isempty(remainder.spans) + enqueue_remainder_copy_to!(state, dest_space, arg_w, remainder, f, idx, dest_scope, task, write_num) + end +end +function enqueue_remainder_copy_to!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, remainder_aliasing::RemainderAliasing, + f, idx, dest_scope, task, write_num::Int) + dep_mod = arg_w.dep_mod + + # Find the source space for the remainder data + # We need to find where the best version of the target data lives that hasn't been + # overwritten by more recent partial updates + source_space = remainder_aliasing.space + + @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Enqueueing remainder copy-to for $(typeof(arg_w.arg))[$(arg_w.dep_mod)]: $source_space => $dest_space" + + # Get the source and destination arguments + arg_dest = state.remote_args[dest_space][arg_w.arg] + arg_source = get_or_generate_slot!(state, source_space, arg_w.arg) + + # Create a copy task for the remainder + remainder_syncdeps = Set{Any}() + target_ainfo = aliasing!(state, dest_space, arg_w) + for syncdep in remainder_aliasing.syncdeps + push!(remainder_syncdeps, syncdep) + end + empty!(remainder_aliasing.syncdeps) # We can't bring these to move! + get_write_deps!(state, dest_space, target_ainfo, write_num, remainder_syncdeps) + + @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Remainder copy-to has $(length(remainder_syncdeps)) syncdeps" + + # Launch the remainder copy task + copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=remainder_syncdeps meta=true Dagger.move!(remainder_aliasing, dest_space, source_space, arg_dest, arg_source) + + # This copy task becomes a new writer for the target region + add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) +end +""" + enqueue_remainder_copy_from!(state::DataDepsState, target_ainfo::AliasingWrapper, arg, remainder_aliasing, + origin_space::MemorySpace, origin_scope, write_num::Int) + +Enqueues a copy operation to update the remainder regions of an object back to the original space. +""" +function enqueue_remainder_copy_from!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, remainder_aliasing::MultiRemainderAliasing, + dest_scope, write_num::Int) + for remainder in remainder_aliasing.remainders + @assert !isempty(remainder.spans) + enqueue_remainder_copy_from!(state, dest_space, arg_w, remainder, dest_scope, write_num) + end +end +function enqueue_remainder_copy_from!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, remainder_aliasing::RemainderAliasing, + dest_scope, write_num::Int) + dep_mod = arg_w.dep_mod + + # Find the source space for the remainder data + # We need to find where the best version of the target data lives that hasn't been + # overwritten by more recent partial updates + source_space = remainder_aliasing.space + + @dagdebug nothing :spawn_datadeps "($(typeof(arg_w.arg)))[$dep_mod] Enqueueing remainder copy-from for: $source_space => $dest_space" + + # Get the source and destination arguments + arg_dest = state.remote_args[dest_space][arg_w.arg] + arg_source = get_or_generate_slot!(state, source_space, arg_w.arg) + + # Create a copy task for the remainder + remainder_syncdeps = Set{Any}() + target_ainfo = aliasing!(state, dest_space, arg_w) + for syncdep in remainder_aliasing.syncdeps + push!(remainder_syncdeps, syncdep) + end + empty!(remainder_aliasing.syncdeps) # We can't bring these to move! + get_write_deps!(state, dest_space, target_ainfo, write_num, remainder_syncdeps) + + @dagdebug nothing :spawn_datadeps "($(typeof(arg_w.arg)))[$dep_mod] Remainder copy-from has $(length(remainder_syncdeps)) syncdeps" + + # Launch the remainder copy task + copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=remainder_syncdeps meta=true Dagger.move!(remainder_aliasing, dest_space, source_space, arg_dest, arg_source) + + # This copy task becomes a new writer for the target region + add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) +end + +# FIXME: Document me +function enqueue_copy_to!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, + f, idx, dest_scope, task, write_num::Int) + dep_mod = arg_w.dep_mod + source_space = state.arg_owner[arg_w] + target_ainfo = aliasing!(state, dest_space, arg_w) + + @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Enqueueing full copy-to for $(typeof(arg_w.arg))[$(arg_w.dep_mod)]: $source_space => $dest_space" + + # Get the source and destination arguments + arg_dest = state.remote_args[dest_space][arg_w.arg] + arg_source = get_or_generate_slot!(state, source_space, arg_w.arg) + + # Create a copy task for the remainder + copy_syncdeps = Set{Any}() + source_ainfo = aliasing!(state, source_space, arg_w) + target_ainfo = aliasing!(state, dest_space, arg_w) + get_read_deps!(state, source_space, source_ainfo, write_num, copy_syncdeps) + get_write_deps!(state, dest_space, target_ainfo, write_num, copy_syncdeps) + + @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Full copy-to has $(length(copy_syncdeps)) syncdeps" + + # Launch the remainder copy task + copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=copy_syncdeps meta=true Dagger.move!(dep_mod, dest_space, source_space, arg_dest, arg_source) + + # This copy task becomes a new writer for the target region + add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) +end +function enqueue_copy_from!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, + dest_scope, write_num::Int) + dep_mod = arg_w.dep_mod + source_space = state.arg_owner[arg_w] + target_ainfo = aliasing!(state, dest_space, arg_w) + + @dagdebug nothing :spawn_datadeps "($(typeof(arg_w.arg)))[$dep_mod] Enqueueing full copy-from: $source_space => $dest_space" + + # Get the source and destination arguments + arg_dest = state.remote_args[dest_space][arg_w.arg] + arg_source = get_or_generate_slot!(state, source_space, arg_w.arg) + + # Create a copy task for the remainder + copy_syncdeps = Set{Any}() + source_ainfo = aliasing!(state, source_space, arg_w) + target_ainfo = aliasing!(state, dest_space, arg_w) + get_read_deps!(state, source_space, source_ainfo, write_num, copy_syncdeps) + get_write_deps!(state, dest_space, target_ainfo, write_num, copy_syncdeps) + + @dagdebug nothing :spawn_datadeps "($(typeof(arg_w.arg)))[$dep_mod] Full copy-from has $(length(copy_syncdeps)) syncdeps" + + # Launch the remainder copy task + copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=copy_syncdeps meta=true Dagger.move!(dep_mod, dest_space, source_space, arg_dest, arg_source) + + # This copy task becomes a new writer for the target region + add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) +end + +# Main copy function for RemainderAliasing +function move!(dep_mod::RemainderAliasing{S}, to_space::MemorySpace, from_space::MemorySpace, to::Chunk, from::Chunk) where S + # Get the source data for each span + copies = remotecall_fetch(root_worker_id(from_space), dep_mod) do dep_mod + copies = Vector{UInt8}[] + for (from_span, _) in dep_mod.spans + copy = Vector{UInt8}(undef, from_span.len) + GC.@preserve copy begin + from_ptr = Ptr{UInt8}(from_span.ptr) + to_ptr = Ptr{UInt8}(pointer(copy)) + unsafe_copyto!(to_ptr, from_ptr, from_span.len) + end + push!(copies, copy) + end + return copies + end + + # Copy the data into the destination object + for (copy, (_, to_span)) in zip(copies, dep_mod.spans) + GC.@preserve copy begin + from_ptr = Ptr{UInt8}(pointer(copy)) + to_ptr = Ptr{UInt8}(to_span.ptr) + unsafe_copyto!(to_ptr, from_ptr, to_span.len) + end + end + + # Ensure that the data is visible + Core.Intrinsics.atomic_fence(:release) + + return +end diff --git a/src/memory-spaces.jl b/src/memory-spaces.jl index 9f65a1a21..b1ff40d8f 100644 --- a/src/memory-spaces.jl +++ b/src/memory-spaces.jl @@ -30,7 +30,7 @@ processors(space::CPURAMMemorySpace) = ### In-place Data Movement function unwrap(x::Chunk) - @assert root_worker_id(x.processor) == myid() + @assert x.handle.owner == myid() MemPool.poolget(x.handle) end move!(dep_mod, to_space::MemorySpace, from_space::MemorySpace, to::T, from::F) where {T,F} = @@ -99,10 +99,13 @@ end RemotePtr{T}(addr::UInt, space::S) where {T,S} = RemotePtr{T,S}(addr, space) RemotePtr{T}(ptr::Ptr{V}, space::S) where {T,V,S} = RemotePtr{T,S}(UInt(ptr), space) RemotePtr{T}(ptr::Ptr{V}) where {T,V} = RemotePtr{T}(UInt(ptr), CPURAMMemorySpace(myid())) +# FIXME: Don't hardcode CPURAMMemorySpace +RemotePtr(addr::UInt) = RemotePtr{Cvoid}(addr, CPURAMMemorySpace(myid())) Base.convert(::Type{RemotePtr}, x::Ptr{T}) where T = RemotePtr(UInt(x), CPURAMMemorySpace(myid())) Base.convert(::Type{<:RemotePtr{V}}, x::Ptr{T}) where {V,T} = RemotePtr{V}(UInt(x), CPURAMMemorySpace(myid())) +Base.convert(::Type{UInt}, ptr::RemotePtr) = ptr.addr Base.:+(ptr::RemotePtr{T}, offset::Integer) where T = RemotePtr{T}(ptr.addr + offset, ptr.space) Base.:-(ptr::RemotePtr{T}, offset::Integer) where T = RemotePtr{T}(ptr.addr - offset, ptr.space) function Base.isless(ptr1::RemotePtr, ptr2::RemotePtr) @@ -116,13 +119,15 @@ struct MemorySpan{S} end MemorySpan(ptr::RemotePtr{Cvoid,S}, len::Integer) where S = MemorySpan{S}(ptr, UInt(len)) - +MemorySpan{S}(addr::UInt, len::Integer) where S = + MemorySpan{S}(RemotePtr{Cvoid,S}(addr), UInt(len)) +Base.isless(a::MemorySpan, b::MemorySpan) = a.ptr < b.ptr +Base.isempty(x::MemorySpan) = x.len == 0 abstract type AbstractAliasing end memory_spans(::T) where T<:AbstractAliasing = throw(ArgumentError("Must define `memory_spans` for `$T`")) memory_spans(x) = memory_spans(aliasing(x)) memory_spans(x, T) = memory_spans(aliasing(x, T)) - struct AliasingWrapper <: AbstractAliasing inner::AbstractAliasing hash::UInt64 @@ -279,7 +284,7 @@ function aliasing(x::SubArray{T,N,A}) where {T,N,A<:Array} return StridedAliasing{T,ndims(x),S}(RemotePtr{Cvoid}(pointer(parent(x))), RemotePtr{Cvoid}(pointer(x)), parentindices(x), - size(x), strides(parent(x))) + size(x), strides(x)) else # FIXME: Also ContiguousAliasing of container #return IteratedAliasing(x) @@ -402,70 +407,35 @@ function will_alias(x_span::MemorySpan, y_span::MemorySpan) return x_span.ptr <= y_end && y_span.ptr <= x_end end -struct ChunkView{N} - chunk::Chunk - slices::NTuple{N, Union{Int, AbstractRange{Int}, Colon}} -end - -function Base.view(c::Chunk, slices...) - if c.domain isa ArrayDomain - nd, sz = ndims(c.domain), size(c.domain) - nd == length(slices) || throw(DimensionMismatch("Expected $nd slices, got $(length(slices))")) - - for (i, s) in enumerate(slices) - if s isa Int - 1 ≤ s ≤ sz[i] || throw(ArgumentError("Index $s out of bounds for dimension $i (size $(sz[i]))")) - elseif s isa AbstractRange - isempty(s) && continue - 1 ≤ first(s) ≤ last(s) ≤ sz[i] || throw(ArgumentError("Range $s out of bounds for dimension $i (size $(sz[i]))")) - elseif s === Colon() - continue - else - throw(ArgumentError("Invalid slice type $(typeof(s)) at dimension $i, Expected Type of Int, AbstractRange, or Colon")) - end - end - end +### More space-efficient memory spans - return ChunkView(c, slices) +struct LocalMemorySpan + ptr::UInt + len::UInt end +LocalMemorySpan(span::MemorySpan) = LocalMemorySpan(span.ptr.addr, span.len) +Base.isempty(x::LocalMemorySpan) = x.len == 0 -Base.view(c::DTask, slices...) = view(fetch(c; raw=true), slices...) - -function aliasing(x::ChunkView{N}) where N - remotecall_fetch(root_worker_id(x.chunk.processor), x.chunk, x.slices) do x, slices - x = unwrap(x) - v = view(x, slices...) - return aliasing(v) - end +# FIXME: Store the length separately, since it's shared by all spans +struct ManyMemorySpan{N} + spans::NTuple{N,LocalMemorySpan} end -memory_space(x::ChunkView) = memory_space(x.chunk) -isremotehandle(x::ChunkView) = true +Base.isempty(x::ManyMemorySpan) = all(isempty, x.spans) -#= -function move!(dep_mod, to_space::MemorySpace, from_space::MemorySpace, to::ChunkView, from::ChunkView) - to_w = root_worker_id(to_space) - @assert to_w == myid() - to_raw = unwrap(to.chunk) - from_w = root_worker_id(from_space) - from_raw = to_w == from_w ? unwrap(from.chunk) : remotecall_fetch(f->copy(unwrap(f)), from_w, from.chunk) - from_view = view(from_raw, from.slices...) - to_view = view(to_raw, to.slices...) - move!(dep_mod, to_space, from_space, to_view, from_view) - return +struct ManyPair{N} <: Unsigned + pairs::NTuple{N,UInt} end -=# +Base.promote_rule(::Type{ManyPair}, ::Type{T}) where {T<:Integer} = ManyPair +Base.convert(::Type{ManyPair{N}}, x::T) where {T<:Integer,N} = ManyPair(ntuple(i -> x, N)) +Base.convert(::Type{ManyPair}, x::ManyPair) = x +Base.:+(x::ManyPair{N}, y::ManyPair{N}) where N = ManyPair(ntuple(i -> x.pairs[i] + y.pairs[i], N)) +Base.:-(x::ManyPair{N}, y::ManyPair{N}) where N = ManyPair(ntuple(i -> x.pairs[i] - y.pairs[i], N)) +Base.:-(x::ManyPair) = error("Can't negate a ManyPair") +Base.:(==)(x::ManyPair, y::ManyPair) = x.pairs == y.pairs +Base.isless(x::ManyPair, y::ManyPair) = x.pairs[1] < y.pairs[1] +Base.:(<)(x::ManyPair, y::ManyPair) = x.pairs[1] < y.pairs[1] +Base.string(x::ManyPair) = "ManyPair($(x.pairs))" -function move(from_proc::Processor, to_proc::Processor, slice::ChunkView) - if from_proc == to_proc - return view(unwrap(slice.chunk), slice.slices...) - else - # Need to copy the underlying data, so collapse the view - from_w = root_worker_id(from_proc) - data = remotecall_fetch(from_w, slice.chunk, slice.slices) do chunk, slices - copy(view(unwrap(chunk), slices...)) - end - return move(from_proc, to_proc, data) - end -end +ManyMemorySpan{N}(start::ManyPair{N}, len::ManyPair{N}) where N = + ManyMemorySpan{N}(ntuple(i -> LocalMemorySpan(start.pairs[i], len.pairs[i]), N)) -Base.fetch(slice::ChunkView) = view(fetch(slice.chunk), slice.slices...) \ No newline at end of file diff --git a/src/utils/dagdebug.jl b/src/utils/dagdebug.jl index 615030400..873e47e79 100644 --- a/src/utils/dagdebug.jl +++ b/src/utils/dagdebug.jl @@ -35,3 +35,28 @@ macro dagdebug(thunk, category, msg, args...) end end) end + +# FIXME: Calculate fast-growth based on clock time, not iteration +const OPCOUNTER_CATEGORIES = Symbol[] +const OPCOUNTER_FAST_GROWTH_THRESHOLD = Ref(10_000_000) +struct OpCounter + value::Threads.Atomic{Int} +end +OpCounter() = OpCounter(Threads.Atomic{Int}(0)) +macro opcounter(category, count=1) + cat_sym = category.value + @gensym old + opcounter_sym = Symbol(:OPCOUNTER_, cat_sym) + if !isdefined(__module__, opcounter_sym) + __module__.eval(:(#=const=# $opcounter_sym = OpCounter())) + end + esc(quote + if $(QuoteNode(cat_sym)) in $OPCOUNTER_CATEGORIES + $old = Threads.atomic_add!($opcounter_sym.value, Int($count)) + if $old > 1 && (mod1($old, $OPCOUNTER_FAST_GROWTH_THRESHOLD[]) == 1 || $count > $OPCOUNTER_FAST_GROWTH_THRESHOLD[]) + println("Fast-growing counter: $($(QuoteNode(cat_sym))) = $($old)") + end + end + end) +end +opcounter(mod::Module, category::Symbol) = getfield(mod, Symbol(:OPCOUNTER_, category)).value[] \ No newline at end of file From 3733dcbe34989e5ad67cb14a55f458bc70655e2c Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Wed, 15 Oct 2025 19:49:44 -0700 Subject: [PATCH 02/28] Add type-stable spawn code paths --- src/argument.jl | 32 +- src/datadeps/aliasing.jl | 74 +++-- src/datadeps/queue.jl | 613 +++++++++++++++++++++------------------ src/queue.jl | 97 ++++--- src/stream.jl | 26 +- src/submission.jl | 71 +++-- src/thunk.jl | 74 +++-- 7 files changed, 585 insertions(+), 402 deletions(-) diff --git a/src/argument.jl b/src/argument.jl index 94246a75e..849486e03 100644 --- a/src/argument.jl +++ b/src/argument.jl @@ -20,6 +20,7 @@ function pos_kw(pos::ArgPosition) @assert pos.kw != :NULL return pos.kw end + mutable struct Argument pos::ArgPosition value @@ -41,6 +42,35 @@ function Base.iterate(arg::Argument, state::Bool) return nothing end end - Base.copy(arg::Argument) = Argument(ArgPosition(arg.pos), arg.value) chunktype(arg::Argument) = chunktype(value(arg)) + +mutable struct TypedArgument{T} + pos::ArgPosition + value::T +end +TypedArgument(pos::Integer, value::T) where T = TypedArgument{T}(ArgPosition(true, pos, :NULL), value) +TypedArgument(kw::Symbol, value::T) where T = TypedArgument{T}(ArgPosition(false, 0, kw), value) +Base.setproperty!(arg::TypedArgument, name::Symbol, value::T) where T = + throw(ArgumentError("Cannot set properties of TypedArgument")) +ispositional(arg::TypedArgument) = ispositional(arg.pos) +iskw(arg::TypedArgument) = iskw(arg.pos) +pos_idx(arg::TypedArgument) = pos_idx(arg.pos) +pos_kw(arg::TypedArgument) = pos_kw(arg.pos) +raw_position(arg::TypedArgument) = raw_position(arg.pos) +value(arg::TypedArgument) = arg.value +valuetype(arg::TypedArgument{T}) where T = T +Base.iterate(arg::TypedArgument) = (arg.pos, true) +function Base.iterate(arg::TypedArgument, state::Bool) + if state + return (arg.value, false) + else + return nothing + end +end +Base.copy(arg::TypedArgument{T}) where T = TypedArgument{T}(ArgPosition(arg.pos), arg.value) +chunktype(arg::TypedArgument) = chunktype(value(arg)) + +Argument(arg::TypedArgument) = Argument(arg.pos, arg.value) + +const AnyArgument = Union{Argument, TypedArgument} \ No newline at end of file diff --git a/src/datadeps/aliasing.jl b/src/datadeps/aliasing.jl index 833e3fa99..164944f13 100644 --- a/src/datadeps/aliasing.jl +++ b/src/datadeps/aliasing.jl @@ -181,6 +181,11 @@ struct Deps{T,DT<:Tuple} end Deps(x, deps...) = Deps(x, deps) +chunktype(::In{T}) where T = T +chunktype(::Out{T}) where T = T +chunktype(::InOut{T}) where T = T +chunktype(::Deps{T,DT}) where {T,DT} = T + function unwrap_inout(arg) readdep = false writedep = false @@ -361,48 +366,69 @@ function is_writedep(arg, deps, task::DTask) end # Aliasing state setup -function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask) +function populate_task_info!(state::DataDepsState, task_args, spec::DTaskSpec, task::DTask) # Track the task's arguments and access patterns - for (idx, _arg) in enumerate(spec.fargs) - arg = value(_arg) + return map_or_ntuple(task_args) do idx + _arg = task_args[idx] + + # Unwrap the argument + _arg_with_deps = value(_arg) + pos = _arg.pos # Unwrap In/InOut/Out wrappers and record dependencies - arg, deps = unwrap_inout(arg) + arg_pre_unwrap, deps = unwrap_inout(_arg_with_deps) # Unwrap the Chunk underlying any DTask arguments - arg = arg isa DTask ? fetch(arg; raw=true) : arg - - # Skip non-aliasing arguments - type_may_alias(typeof(arg)) || continue - - # Skip arguments not supporting in-place move - supports_inplace_move(state, arg) || continue + arg = arg_pre_unwrap isa DTask ? fetch(arg_pre_unwrap; raw=true) : arg_pre_unwrap + + # Skip non-aliasing arguments or arguments that don't support in-place move + may_alias = type_may_alias(typeof(arg)) + inplace_move = may_alias && supports_inplace_move(state, arg) + if !may_alias || !inplace_move + arg_w = ArgumentWrapper(arg, identity) + if is_typed(spec) + return TypedDataDepsTaskArgument(arg, pos, may_alias, inplace_move, (DataDepsTaskDependency(arg_w, false, false),)) + else + return DataDepsTaskArgument(arg, pos, may_alias, inplace_move, [DataDepsTaskDependency(arg_w, false, false)]) + end + end # Generate a Chunk for the argument if necessary if haskey(state.raw_arg_to_chunk, arg) - arg = state.raw_arg_to_chunk[arg] + arg_chunk = state.raw_arg_to_chunk[arg] else if !(arg isa Chunk) - new_arg = tochunk(arg) - state.raw_arg_to_chunk[arg] = new_arg - arg = new_arg + arg_chunk = tochunk(arg) + state.raw_arg_to_chunk[arg] = arg_chunk else state.raw_arg_to_chunk[arg] = arg + arg_chunk = arg end end # Track the origin space of the argument - origin_space = memory_space(arg) - state.arg_origin[arg] = origin_space - state.remote_arg_to_original[arg] = arg + origin_space = memory_space(arg_chunk) + state.arg_origin[arg_chunk] = origin_space + state.remote_arg_to_original[arg_chunk] = arg_chunk # Populate argument info for all aliasing dependencies - for (dep_mod, _, _) in deps - # Generate an ArgumentWrapper for the argument - aw = ArgumentWrapper(arg, dep_mod) - - # Populate argument info - populate_argument_info!(state, aw, origin_space) + # And return the argument, dependencies, and ArgumentWrappers + if is_typed(spec) + deps = Tuple(DataDepsTaskDependency(arg_chunk, dep) for dep in deps) + map_or_ntuple(deps) do dep_idx + dep = deps[dep_idx] + # Populate argument info + populate_argument_info!(state, dep.arg_w, origin_space) + end + return TypedDataDepsTaskArgument(arg_chunk, pos, may_alias, inplace_move, deps) + else + deps = [DataDepsTaskDependency(arg_chunk, dep) for dep in deps] + map_or_ntuple(deps) do dep_idx + dep = deps[dep_idx] + # Populate argument info + populate_argument_info!(state, dep.arg_w, origin_space) + end + return DataDepsTaskArgument(arg_chunk, pos, may_alias, inplace_move, deps) end end end diff --git a/src/datadeps/queue.jl b/src/datadeps/queue.jl index 6fc85bd22..f8f907741 100644 --- a/src/datadeps/queue.jl +++ b/src/datadeps/queue.jl @@ -2,7 +2,7 @@ struct DataDepsTaskQueue <: AbstractTaskQueue # The queue above us upper_queue::AbstractTaskQueue # The set of tasks that have already been seen - seen_tasks::Union{Vector{Pair{DTaskSpec,DTask}},Nothing} + seen_tasks::Union{Vector{DTaskPair},Nothing} # The data-dependency graph of all tasks g::Union{SimpleDiGraph{Int},Nothing} # The mapping from task to graph ID @@ -20,7 +20,7 @@ struct DataDepsTaskQueue <: AbstractTaskQueue traversal::Symbol=:inorder, scheduler::Symbol=:naive, aliasing::Bool=true) - seen_tasks = Pair{DTaskSpec,DTask}[] + seen_tasks = DTaskPair[] g = SimpleDiGraph() task_to_id = Dict{DTask,Int}() return new(upper_queue, seen_tasks, g, task_to_id, traversal, scheduler, @@ -28,11 +28,11 @@ struct DataDepsTaskQueue <: AbstractTaskQueue end end -function enqueue!(queue::DataDepsTaskQueue, spec::Pair{DTaskSpec,DTask}) - push!(queue.seen_tasks, spec) +function enqueue!(queue::DataDepsTaskQueue, pair::DTaskPair) + push!(queue.seen_tasks, pair) end -function enqueue!(queue::DataDepsTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}) - append!(queue.seen_tasks, specs) +function enqueue!(queue::DataDepsTaskQueue, pairs::Vector{DTaskPair}) + append!(queue.seen_tasks, pairs) end """ @@ -116,12 +116,11 @@ function distribute_tasks!(queue::DataDepsTaskQueue) for w in procs() append!(all_procs, get_processors(OSProc(w))) end - filter!(proc->!isa(constrain(ExactScope(proc), scope), - InvalidScope), - all_procs) + filter!(proc->proc_in_scope(proc, scope), all_procs) if isempty(all_procs) throw(Sch.SchedulingException("No processors available, try widening scope")) end + scope = UnionScope(map(ExactScope, all_procs)) exec_spaces = unique(vcat(map(proc->collect(memory_spaces(proc)), all_procs)...)) if !all(space->space isa CPURAMMemorySpace, exec_spaces) && !all(space->root_worker_id(space) == myid(), exec_spaces) @warn "Datadeps support for multi-GPU, multi-worker is currently broken\nPlease be prepared for incorrect results or errors" maxlog=1 @@ -184,317 +183,367 @@ function distribute_tasks!(queue::DataDepsTaskQueue) # Start launching tasks and necessary copies write_num = 1 proc_idx = 1 - pressures = Dict{Processor,Int}() + #pressures = Dict{Processor,Int}() proc_to_scope_lfu = BasicLFUCache{Processor,AbstractScope}(1024) - for (spec, task) in queue.seen_tasks[task_order] - # Populate all task dependencies - populate_task_info!(state, spec, task) - - task_scope = @something(spec.options.compute_scope, spec.options.scope, DefaultScope()) - scheduler = queue.scheduler - if scheduler == :naive - raw_args = map(arg->tochunk(value(arg)), spec.fargs) - our_proc = remotecall_fetch(1, all_procs, raw_args) do all_procs, raw_args - Sch.init_eager() - sch_state = Sch.EAGER_STATE[] - - @lock sch_state.lock begin - # Calculate costs per processor and select the most optimal - # FIXME: This should consider any already-allocated slots, - # whether they are up-to-date, and if not, the cost of moving - # data to them - procs, costs = Sch.estimate_task_costs(sch_state, all_procs, nothing, raw_args) - return first(procs) - end - end - elseif scheduler == :smart - raw_args = map(filter(arg->haskey(state.data_locality, value(arg)), spec.fargs)) do arg - arg_chunk = tochunk(value(arg)) - # Only the owned slot is valid - # FIXME: Track up-to-date copies and pass all of those - return arg_chunk => data_locality[arg] + for pair in queue.seen_tasks[task_order] + spec = pair.spec + task = pair.task + write_num, proc_idx = distribute_task!(queue, state, all_procs, scope, spec, task, spec.fargs, proc_to_scope_lfu, write_num, proc_idx) + end + + # Copy args from remote to local + # N.B. We sort the keys to ensure a deterministic order for uniformity + for arg_w in sort(collect(keys(state.arg_owner)); by=arg_w->arg_w.hash) + arg = arg_w.arg + origin_space = state.arg_origin[arg] + remainder, _ = compute_remainder_for_arg!(state, origin_space, arg_w, write_num) + if remainder isa MultiRemainderAliasing + origin_scope = UnionScope(map(ExactScope, collect(processors(origin_space)))...) + enqueue_remainder_copy_from!(state, origin_space, arg_w, remainder, origin_scope, write_num) + elseif remainder isa FullCopy + origin_scope = UnionScope(map(ExactScope, collect(processors(origin_space)))...) + enqueue_copy_from!(state, origin_space, arg_w, origin_scope, write_num) + else + @assert remainder isa NoAliasing "Expected NoAliasing, got $(typeof(remainder))" + @dagdebug nothing :spawn_datadeps "Skipped copy-from (up-to-date): $origin_space" + end + end +end +struct DataDepsTaskDependency + arg_w::ArgumentWrapper + readdep::Bool + writedep::Bool +end +DataDepsTaskDependency(arg, dep) = + DataDepsTaskDependency(ArgumentWrapper(arg, dep[1]), dep[2], dep[3]) +struct DataDepsTaskArgument + arg + pos::ArgPosition + may_alias::Bool + inplace_move::Bool + deps::Vector{DataDepsTaskDependency} +end +struct TypedDataDepsTaskArgument{T,N} + arg::T + pos::ArgPosition + may_alias::Bool + inplace_move::Bool + deps::NTuple{N,DataDepsTaskDependency} +end +map_or_ntuple(f, xs::Vector) = map(f, 1:length(xs)) +@inline map_or_ntuple(@specialize(f), xs::NTuple{N,T}) where {N,T} = ntuple(f, Val(N)) +function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_procs, scope, spec::DTaskSpec{typed}, task::DTask, fargs, proc_to_scope_lfu, write_num::Int, proc_idx::Int) where typed + @specialize spec fargs + + if typed + fargs::Tuple + else + fargs::Vector{Argument} + end + + task_scope = @something(spec.options.compute_scope, spec.options.scope, DefaultScope()) + scheduler = queue.scheduler + if scheduler == :naive + raw_args = map(arg->tochunk(value(arg)), spec.fargs) + our_proc = remotecall_fetch(1, all_procs, raw_args) do all_procs, raw_args + Sch.init_eager() + sch_state = Sch.EAGER_STATE[] + + @lock sch_state.lock begin + # Calculate costs per processor and select the most optimal + # FIXME: This should consider any already-allocated slots, + # whether they are up-to-date, and if not, the cost of moving + # data to them + procs, costs = Sch.estimate_task_costs(sch_state, all_procs, nothing, raw_args) + return first(procs) end - f_chunk = tochunk(value(spec.fargs[1])) - our_proc, task_pressure = remotecall_fetch(1, all_procs, pressures, f_chunk, raw_args) do all_procs, pressures, f, chunks_locality - Sch.init_eager() - sch_state = Sch.EAGER_STATE[] - - @lock sch_state.lock begin - tx_rate = sch_state.transfer_rate[] - - costs = Dict{Processor,Float64}() - for proc in all_procs - # Filter out chunks that are already local - chunks_filt = Iterators.filter(((chunk, space)=chunk_locality)->!(proc in processors(space)), chunks_locality) - - # Estimate network transfer costs based on data size - # N.B. `affinity(x)` really means "data size of `x`" - # N.B. We treat same-worker transfers as having zero transfer cost - tx_cost = Sch.impute_sum(affinity(chunk)[2] for chunk in chunks_filt) - - # Estimate total cost to move data and get task running after currently-scheduled tasks - est_time_util = get(pressures, proc, UInt64(0)) - costs[proc] = est_time_util + (tx_cost/tx_rate) - end - - # Look up estimated task cost - sig = Sch.signature(sch_state, f, map(first, chunks_locality)) - task_pressure = get(sch_state.signature_time_cost, sig, 1000^3) - - # Shuffle procs around, so equally-costly procs are equally considered - P = randperm(length(all_procs)) - procs = getindex.(Ref(all_procs), P) - - # Sort by lowest cost first - sort!(procs, by=p->costs[p]) - - best_proc = first(procs) - return best_proc, task_pressure + end + elseif scheduler == :smart + raw_args = map(filter(arg->haskey(state.data_locality, value(arg)), spec.fargs)) do arg + arg_chunk = tochunk(value(arg)) + # Only the owned slot is valid + # FIXME: Track up-to-date copies and pass all of those + return arg_chunk => data_locality[arg] + end + f_chunk = tochunk(value(spec.fargs[1])) + our_proc, task_pressure = remotecall_fetch(1, all_procs, pressures, f_chunk, raw_args) do all_procs, pressures, f, chunks_locality + Sch.init_eager() + sch_state = Sch.EAGER_STATE[] + + @lock sch_state.lock begin + tx_rate = sch_state.transfer_rate[] + + costs = Dict{Processor,Float64}() + for proc in all_procs + # Filter out chunks that are already local + chunks_filt = Iterators.filter(((chunk, space)=chunk_locality)->!(proc in processors(space)), chunks_locality) + + # Estimate network transfer costs based on data size + # N.B. `affinity(x)` really means "data size of `x`" + # N.B. We treat same-worker transfers as having zero transfer cost + tx_cost = Sch.impute_sum(affinity(chunk)[2] for chunk in chunks_filt) + + # Estimate total cost to move data and get task running after currently-scheduled tasks + est_time_util = get(pressures, proc, UInt64(0)) + costs[proc] = est_time_util + (tx_cost/tx_rate) end + + # Look up estimated task cost + sig = Sch.signature(sch_state, f, map(first, chunks_locality)) + task_pressure = get(sch_state.signature_time_cost, sig, 1000^3) + + # Shuffle procs around, so equally-costly procs are equally considered + P = randperm(length(all_procs)) + procs = getindex.(Ref(all_procs), P) + + # Sort by lowest cost first + sort!(procs, by=p->costs[p]) + + best_proc = first(procs) + return best_proc, task_pressure end - # FIXME: Pressure should be decreased by pressure of syncdeps on same processor - pressures[our_proc] = get(pressures, our_proc, UInt64(0)) + task_pressure - elseif scheduler == :ultra - args = Base.mapany(spec.fargs) do arg - pos, data = arg - data, _ = unwrap_inout(data) - if data isa DTask - data = fetch(data; raw=true) - end - return pos => tochunk(data) + end + # FIXME: Pressure should be decreased by pressure of syncdeps on same processor + pressures[our_proc] = get(pressures, our_proc, UInt64(0)) + task_pressure + elseif scheduler == :ultra + args = Base.mapany(spec.fargs) do arg + pos, data = arg + data, _ = unwrap_inout(data) + if data isa DTask + data = fetch(data; move_value=false, unwrap=false) end - f_chunk = tochunk(value(spec.fargs[1])) - task_time = remotecall_fetch(1, f_chunk, args) do f, args - Sch.init_eager() - sch_state = Sch.EAGER_STATE[] - return @lock sch_state.lock begin - sig = Sch.signature(sch_state, f, args) - return get(sch_state.signature_time_cost, sig, 1000^3) - end + return pos => tochunk(data) + end + f_chunk = tochunk(value(spec.fargs[1])) + task_time = remotecall_fetch(1, f_chunk, args) do f, args + Sch.init_eager() + sch_state = Sch.EAGER_STATE[] + return @lock sch_state.lock begin + sig = Sch.signature(sch_state, f, args) + return get(sch_state.signature_time_cost, sig, 1000^3) end + end - # FIXME: Copy deps are computed eagerly - deps = @something(spec.options.syncdeps, Set{Any}()) - - # Find latest time-to-completion of all syncdeps - deps_completed = UInt64(0) - for dep in deps - haskey(sstate.task_completions, dep) || continue # copy deps aren't recorded - deps_completed = max(deps_completed, sstate.task_completions[dep]) - end + # FIXME: Copy deps are computed eagerly + deps = @something(spec.options.syncdeps, Set{Any}()) - # Find latest time-to-completion of each memory space - # FIXME: Figure out space completions based on optimal packing - spaces_completed = Dict{MemorySpace,UInt64}() - for space in exec_spaces - completed = UInt64(0) - for (task, other_space) in sstate.assignments - space == other_space || continue - completed = max(completed, sstate.task_completions[task]) - end - spaces_completed[space] = completed - end + # Find latest time-to-completion of all syncdeps + deps_completed = UInt64(0) + for dep in deps + haskey(sstate.task_completions, dep) || continue # copy deps aren't recorded + deps_completed = max(deps_completed, sstate.task_completions[dep]) + end - # Choose the earliest-available memory space and processor - # FIXME: Consider move time - move_time = UInt64(0) - local our_space_completed - while true - our_space_completed, our_space = findmin(spaces_completed) - our_space_procs = filter(proc->proc in all_procs, processors(our_space)) - if isempty(our_space_procs) - delete!(spaces_completed, our_space) - continue - end - our_proc = rand(our_space_procs) - break + # Find latest time-to-completion of each memory space + # FIXME: Figure out space completions based on optimal packing + spaces_completed = Dict{MemorySpace,UInt64}() + for space in exec_spaces + completed = UInt64(0) + for (task, other_space) in sstate.assignments + space == other_space || continue + completed = max(completed, sstate.task_completions[task]) end + spaces_completed[space] = completed + end - sstate.task_to_spec[task] = spec - sstate.assignments[task] = our_space - sstate.task_completions[task] = our_space_completed + move_time + task_time - elseif scheduler == :roundrobin - our_proc = all_procs[proc_idx] - if task_scope == scope - # all_procs is already limited to scope - else - if isa(constrain(task_scope, scope), InvalidScope) - throw(Sch.SchedulingException("Scopes are not compatible: $(scope), $(task_scope)")) - end - while !proc_in_scope(our_proc, task_scope) - proc_idx = mod1(proc_idx + 1, length(all_procs)) - our_proc = all_procs[proc_idx] - end + # Choose the earliest-available memory space and processor + # FIXME: Consider move time + move_time = UInt64(0) + local our_space_completed + while true + our_space_completed, our_space = findmin(spaces_completed) + our_space_procs = filter(proc->proc in all_procs, processors(our_space)) + if isempty(our_space_procs) + delete!(spaces_completed, our_space) + continue end - else - error("Invalid scheduler: $sched") + our_proc = rand(our_space_procs) + break end - @assert our_proc in all_procs - our_space = only(memory_spaces(our_proc)) - # Find the scope for this task (and its copies) + sstate.task_to_spec[task] = spec + sstate.assignments[task] = our_space + sstate.task_completions[task] = our_space_completed + move_time + task_time + elseif scheduler == :roundrobin + our_proc = all_procs[proc_idx] if task_scope == scope - # Optimize for the common case, cache the proc=>scope mapping - our_scope = get!(proc_to_scope_lfu, our_proc) do - our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) - return constrain(UnionScope(map(ExactScope, our_procs)...), scope) - end + # all_procs is already limited to scope else - # Use the provided scope and constrain it to the available processors - our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) - our_scope = constrain(UnionScope(map(ExactScope, our_procs)...), task_scope) + if isa(constrain(task_scope, scope), InvalidScope) + throw(Sch.SchedulingException("Scopes are not compatible: $(scope), $(task_scope)")) + end + while !proc_in_scope(our_proc, task_scope) + proc_idx = mod1(proc_idx + 1, length(all_procs)) + our_proc = all_procs[proc_idx] + end end - if our_scope isa InvalidScope - throw(Sch.SchedulingException("Scopes are not compatible: $(our_scope.x), $(our_scope.y)")) + else + error("Invalid scheduler: $sched") + end + @assert our_proc in all_procs + our_space = only(memory_spaces(our_proc)) + + # Find the scope for this task (and its copies) + task_scope = @something(spec.options.compute_scope, spec.options.scope, DefaultScope()) + if task_scope == scope + # Optimize for the common case, cache the proc=>scope mapping + our_scope = get!(proc_to_scope_lfu, our_proc) do + our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) + return constrain(UnionScope(map(ExactScope, our_procs)...), scope) end + else + # Use the provided scope and constrain it to the available processors + our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) + our_scope = constrain(UnionScope(map(ExactScope, our_procs)...), task_scope) + end + if our_scope isa InvalidScope + throw(Sch.SchedulingException("Scopes are not compatible: $(our_scope.x), $(our_scope.y)")) + end - f = spec.fargs[1] - f.value = move(ThreadProc(myid(), 1), our_proc, value(f)) - @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) Scheduling: $our_proc ($our_space)" - - # Copy raw task arguments for analysis - task_args = map(copy, spec.fargs) + f = spec.fargs[1] + # FIXME: May not be correct to move this under uniformity + #f.value = move(default_processor(), our_proc, value(f)) + @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) Scheduling: $our_proc ($our_space)" - # Generate a list of ArgumentWrappers for each task argument - task_arg_ws = map(task_args) do _arg - arg = value(_arg) - arg, deps = unwrap_inout(arg) - arg = arg isa DTask ? fetch(arg; raw=true) : arg - if !type_may_alias(typeof(arg)) || !supports_inplace_move(state, arg) - return [(ArgumentWrapper(arg, identity), false, false)] - end + # Copy raw task arguments for analysis + # N.B. Used later for checking dependencies + task_args = map_or_ntuple(idx->copy(spec.fargs[idx]), spec.fargs) - # Get the Chunk for the argument - arg = state.raw_arg_to_chunk[arg] + # Populate all task dependencies + task_arg_ws = populate_task_info!(state, task_args, spec, task) - arg_ws = Tuple{ArgumentWrapper,Bool,Bool}[] - for (dep_mod, readdep, writedep) in deps - push!(arg_ws, (ArgumentWrapper(arg, dep_mod), readdep, writedep)) - end - return arg_ws + # Truncate the history for each argument + map_or_ntuple(task_arg_ws) do idx + arg_ws = task_arg_ws[idx] + map_or_ntuple(arg_ws.deps) do dep_idx + dep = arg_ws.deps[dep_idx] + truncate_history!(state, dep.arg_w) end - task_arg_ws = task_arg_ws::Vector{Vector{Tuple{ArgumentWrapper,Bool,Bool}}} + return + end - # Truncate the history for each argument - for arg_ws in task_arg_ws - for (arg_w, _, _) in arg_ws - truncate_history!(state, arg_w) - end + # Copy args from local to remote + remote_args = map_or_ntuple(task_arg_ws) do idx + arg_ws = task_arg_ws[idx] + arg = arg_ws.arg + pos = raw_position(arg_ws.pos) + + # Is the data written previously or now? + if !arg_ws.may_alias + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (immutable)" + return arg end - # Copy args from local to remote - for (idx, arg_ws) in enumerate(task_arg_ws) - arg = first(arg_ws)[1].arg - pos = raw_position(task_args[idx]) + # Is the data writeable? + if !arg_ws.inplace_move + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (non-writeable)" + return arg + end - # Is the data written previously or now? - if !type_may_alias(typeof(arg)) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (immutable)" - spec.fargs[idx].value = arg - continue + # Is the source of truth elsewhere? + arg_remote = get_or_generate_slot!(state, our_space, arg) + map_or_ntuple(arg_ws.deps) do dep_idx + dep = arg_ws.deps[dep_idx] + arg_w = dep.arg_w + dep_mod = arg_w.dep_mod + remainder, _ = compute_remainder_for_arg!(state, our_space, arg_w, write_num) + if remainder isa MultiRemainderAliasing + enqueue_remainder_copy_to!(state, our_space, arg_w, remainder, value(f), idx, our_scope, task, write_num) + elseif remainder isa FullCopy + enqueue_copy_to!(state, our_space, arg_w, value(f), idx, our_scope, task, write_num) + else + @assert remainder isa NoAliasing "Expected NoAliasing, got $(typeof(remainder))" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Skipped copy-to (up-to-date): $our_space" end + end + return arg_remote + end + write_num += 1 - # Is the data writeable? - if !supports_inplace_move(state, arg) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (non-writeable)" - spec.fargs[idx].value = arg - continue - end + # Validate that we're not accidentally performing a copy + map_or_ntuple(task_arg_ws) do idx + arg_ws = task_arg_ws[idx] + arg = remote_args[idx] - # Is the source of truth elsewhere? - arg_remote = get_or_generate_slot!(state, our_space, arg) - for (arg_w, _, _) in arg_ws - dep_mod = arg_w.dep_mod - remainder, _ = compute_remainder_for_arg!(state, our_space, arg_w, write_num) - if remainder isa MultiRemainderAliasing - enqueue_remainder_copy_to!(state, our_space, arg_w, remainder, value(f), idx, our_scope, task, write_num) - elseif remainder isa FullCopy - enqueue_copy_to!(state, our_space, arg_w, value(f), idx, our_scope, task, write_num) - else - @assert remainder isa NoAliasing "Expected NoAliasing, got $(typeof(remainder))" - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Skipped copy-to (up-to-date): $our_space" - end - end - spec.fargs[idx].value = arg_remote - end - write_num += 1 - - # Validate that we're not accidentally performing a copy - for (idx, _arg) in enumerate(spec.fargs) - arg = value(_arg) - _, deps = unwrap_inout(value(task_args[idx])) - # N.B. We only do this check when the argument supports in-place - # moves, because for the moment, we are not guaranteeing updates or - # write-back of results - if is_writedep(arg, deps, task) && supports_inplace_move(state, arg) - arg_space = memory_space(arg) - @assert arg_space == our_space "($(repr(value(f))))[$(idx-1)] Tried to pass $(typeof(arg)) from $arg_space to $our_space" - end + # Get the dependencies again as (dep_mod, readdep, writedep) + deps = map_or_ntuple(arg_ws.deps) do dep_idx + dep = arg_ws.deps[dep_idx] + (dep.arg_w.dep_mod, dep.readdep, dep.writedep) end - # Calculate this task's syncdeps - if spec.options.syncdeps === nothing - spec.options.syncdeps = Set{Any}() - end - syncdeps = spec.options.syncdeps - for (idx, arg_ws) in enumerate(task_arg_ws) - arg = first(arg_ws)[1].arg - type_may_alias(typeof(arg)) || continue - supports_inplace_move(state, arg) || continue - for (arg_w, _, writedep) in arg_ws - ainfo = aliasing!(state, our_space, arg_w) - dep_mod = arg_w.dep_mod - if writedep - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as writer" - get_write_deps!(state, our_space, ainfo, write_num, syncdeps) - else - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as reader" - get_read_deps!(state, our_space, ainfo, write_num, syncdeps) - end - end + # Check that any mutable and written arguments are already in the correct space + # N.B. We only do this check when the argument supports in-place + # moves, because for the moment, we are not guaranteeing updates or + # write-back of results + if is_writedep(arg, deps, task) && arg_ws.may_alias && arg_ws.inplace_move + arg_space = memory_space(arg) + @assert arg_space == our_space "($(repr(value(f))))[$(idx-1)] Tried to pass $(typeof(arg)) from $arg_space to $our_space" end - @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) Task has $(length(syncdeps)) syncdeps" - - # Launch user's task - spec.options.scope = our_scope - spec.options.exec_scope = our_scope - enqueue!(upper_queue, spec=>task) - - # Update read/write tracking for arguments - for (idx, arg_ws) in enumerate(task_arg_ws) - arg = first(arg_ws)[1].arg - type_may_alias(typeof(arg)) || continue - for (arg_w, _, writedep) in arg_ws - ainfo = aliasing!(state, our_space, arg_w) - dep_mod = arg_w.dep_mod - if writedep - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Task set as writer" - add_writer!(state, arg_w, our_space, ainfo, task, write_num) - else - add_reader!(state, arg_w, our_space, ainfo, task, write_num) - end + end + + # Calculate this task's syncdeps + if spec.options.syncdeps === nothing + spec.options.syncdeps = Set{ThunkSyncdep}() + end + syncdeps = spec.options.syncdeps + map_or_ntuple(task_arg_ws) do idx + arg_ws = task_arg_ws[idx] + arg = arg_ws.arg + arg_ws.may_alias || return + arg_ws.inplace_move || return + map_or_ntuple(arg_ws.deps) do dep_idx + dep = arg_ws.deps[dep_idx] + arg_w = dep.arg_w + ainfo = aliasing!(state, our_space, arg_w) + dep_mod = arg_w.dep_mod + if dep.writedep + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as writer" + get_write_deps!(state, our_space, ainfo, write_num, syncdeps) + else + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as reader" + get_read_deps!(state, our_space, ainfo, write_num, syncdeps) end end - - write_num += 1 - proc_idx = mod1(proc_idx + 1, length(all_procs)) + return end + @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) Task has $(length(syncdeps)) syncdeps" - # Copy args from remote to local - # N.B. We sort the keys to ensure a deterministic order for uniformity - for arg_w in sort(collect(keys(state.arg_owner)); by=arg_w->arg_w.hash) - arg = arg_w.arg - origin_space = state.arg_origin[arg] - remainder, _ = compute_remainder_for_arg!(state, origin_space, arg_w, write_num) - if remainder isa MultiRemainderAliasing - origin_scope = UnionScope(map(ExactScope, collect(processors(origin_space)))...) - enqueue_remainder_copy_from!(state, origin_space, arg_w, remainder, origin_scope, write_num) - elseif remainder isa FullCopy - origin_scope = UnionScope(map(ExactScope, collect(processors(origin_space)))...) - enqueue_copy_from!(state, origin_space, arg_w, origin_scope, write_num) + # Launch user's task + new_fargs = map_or_ntuple(task_arg_ws) do idx + if is_typed(spec) + return TypedArgument(task_arg_ws[idx].pos, remote_args[idx]) else - @assert remainder isa NoAliasing "Expected NoAliasing, got $(typeof(remainder))" - @dagdebug nothing :spawn_datadeps "Skipped copy-from (up-to-date): $origin_space" + return Argument(task_arg_ws[idx].pos, remote_args[idx]) end end + new_spec = DTaskSpec(new_fargs, spec.options) + new_spec.options.scope = our_scope + new_spec.options.exec_scope = our_scope + new_spec.options.occupancy = Dict(Any=>0) + enqueue!(queue.upper_queue, DTaskPair(new_spec, task)) + + # Update read/write tracking for arguments + map_or_ntuple(task_arg_ws) do idx + arg_ws = task_arg_ws[idx] + arg = arg_ws.arg + arg_ws.may_alias || return + arg_ws.inplace_move || return + for dep in arg_ws.deps + arg_w = dep.arg_w + ainfo = aliasing!(state, our_space, arg_w) + dep_mod = arg_w.dep_mod + if dep.writedep + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Task set as writer" + add_writer!(state, arg_w, our_space, ainfo, task, write_num) + else + add_reader!(state, arg_w, our_space, ainfo, task, write_num) + end + end + return + end + + write_num += 1 + proc_idx = mod1(proc_idx + 1, length(all_procs)) + + return write_num, proc_idx end diff --git a/src/queue.jl b/src/queue.jl index c8c6007ec..37947a0ac 100644 --- a/src/queue.jl +++ b/src/queue.jl @@ -1,32 +1,63 @@ -mutable struct DTaskSpec - fargs::Vector{Argument} +mutable struct DTaskSpec{typed,FA<:Tuple} + _fargs::Vector{Argument} + _typed_fargs::FA options::Options end +DTaskSpec(fargs::Vector{Argument}, options::Options) = + DTaskSpec{false, Tuple{}}(fargs, (), options) +DTaskSpec(fargs::FA, options::Options) where FA = + DTaskSpec{true, FA}(Argument[], fargs, options) +is_typed(spec::DTaskSpec{typed}) where typed = typed +function Base.getproperty(spec::DTaskSpec{typed}, field::Symbol) where typed + if field === :fargs + if typed + return getfield(spec, :_typed_fargs) + else + return getfield(spec, :_fargs) + end + else + return getfield(spec, field) + end +end + +struct DTaskPair + spec::DTaskSpec + task::DTask +end +is_typed(pair::DTaskPair) = is_typed(pair.spec) +Base.iterate(pair::DTaskPair) = (pair.spec, true) +function Base.iterate(pair::DTaskPair, state::Bool) + if state + return (pair.task, false) + else + return nothing + end +end abstract type AbstractTaskQueue end function enqueue! end struct DefaultTaskQueue <: AbstractTaskQueue end -enqueue!(::DefaultTaskQueue, spec::Pair{DTaskSpec,DTask}) = - eager_launch!(spec) -enqueue!(::DefaultTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}) = - eager_launch!(specs) +enqueue!(::DefaultTaskQueue, pair::DTaskPair) = + eager_launch!(pair) +enqueue!(::DefaultTaskQueue, pairs::Vector{DTaskPair}) = + eager_launch!(pairs) -enqueue!(spec::Pair{DTaskSpec,DTask}) = - enqueue!(get_options(:task_queue, DefaultTaskQueue()), spec) -enqueue!(specs::Vector{Pair{DTaskSpec,DTask}}) = - enqueue!(get_options(:task_queue, DefaultTaskQueue()), specs) +enqueue!(pair::DTaskPair) = + enqueue!(get_options(:task_queue, DefaultTaskQueue()), pair) +enqueue!(pairs::Vector{DTaskPair}) = + enqueue!(get_options(:task_queue, DefaultTaskQueue()), pairs) struct LazyTaskQueue <: AbstractTaskQueue - tasks::Vector{Pair{DTaskSpec,DTask}} - LazyTaskQueue() = new(Pair{DTaskSpec,DTask}[]) + tasks::Vector{DTaskPair} + LazyTaskQueue() = new(DTaskPair[]) end -function enqueue!(queue::LazyTaskQueue, spec::Pair{DTaskSpec,DTask}) - push!(queue.tasks, spec) +function enqueue!(queue::LazyTaskQueue, pair::DTaskPair) + push!(queue.tasks, pair) end -function enqueue!(queue::LazyTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}) - append!(queue.tasks, specs) +function enqueue!(queue::LazyTaskQueue, pairs::Vector{DTaskPair}) + append!(queue.tasks, pairs) end function spawn_bulk(f::Base.Callable) queue = LazyTaskQueue() @@ -50,25 +81,25 @@ function _add_prev_deps!(queue::InOrderTaskQueue, spec::DTaskSpec) push!(syncdeps, ThunkSyncdep(task)) end end -function enqueue!(queue::InOrderTaskQueue, spec::Pair{DTaskSpec,DTask}) +function enqueue!(queue::InOrderTaskQueue, pair::DTaskPair) if length(queue.prev_tasks) > 0 - _add_prev_deps!(queue, first(spec)) + _add_prev_deps!(queue, pair.spec) empty!(queue.prev_tasks) end - push!(queue.prev_tasks, last(spec)) - enqueue!(queue.upper_queue, spec) + push!(queue.prev_tasks, pair.task) + enqueue!(queue.upper_queue, pair) end -function enqueue!(queue::InOrderTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}) +function enqueue!(queue::InOrderTaskQueue, pairs::Vector{DTaskPair}) if length(queue.prev_tasks) > 0 - for (spec, task) in specs - _add_prev_deps!(queue, spec) + for pair in pairs + _add_prev_deps!(queue, pair.spec) end empty!(queue.prev_tasks) end - for (spec, task) in specs - push!(queue.prev_tasks, task) + for pair in pairs + push!(queue.prev_tasks, pair.task) end - enqueue!(queue.upper_queue, specs) + enqueue!(queue.upper_queue, pairs) end function spawn_sequential(f::Base.Callable) queue = InOrderTaskQueue(get_options(:task_queue, DefaultTaskQueue())) @@ -79,15 +110,15 @@ struct WaitAllQueue <: AbstractTaskQueue upper_queue::AbstractTaskQueue tasks::Vector{DTask} end -function enqueue!(queue::WaitAllQueue, spec::Pair{DTaskSpec,DTask}) - push!(queue.tasks, spec[2]) - enqueue!(queue.upper_queue, spec) +function enqueue!(queue::WaitAllQueue, pair::DTaskPair) + push!(queue.tasks, pair.task) + enqueue!(queue.upper_queue, pair) end -function enqueue!(queue::WaitAllQueue, specs::Vector{Pair{DTaskSpec,DTask}}) - for (_, task) in specs - push!(queue.tasks, task) +function enqueue!(queue::WaitAllQueue, pairs::Vector{DTaskPair}) + for pair in pairs + push!(queue.tasks, pair.task) end - enqueue!(queue.upper_queue, specs) + enqueue!(queue.upper_queue, pairs) end function wait_all(f; check_errors::Bool=false) queue = WaitAllQueue(get_options(:task_queue, DefaultTaskQueue()), DTask[]) diff --git a/src/stream.jl b/src/stream.jl index bf1ea4537..c5e6641f6 100644 --- a/src/stream.jl +++ b/src/stream.jl @@ -372,21 +372,21 @@ function migrate_stream!(stream::Stream, w::Integer=myid()) end struct StreamingTaskQueue <: AbstractTaskQueue - tasks::Vector{Pair{DTaskSpec,DTask}} + tasks::Vector{DTaskPair} self_streams::Dict{UInt,Any} - StreamingTaskQueue() = new(Pair{DTaskSpec,DTask}[], + StreamingTaskQueue() = new(DTaskPair[], Dict{UInt,Any}()) end -function enqueue!(queue::StreamingTaskQueue, spec::Pair{DTaskSpec,DTask}) - push!(queue.tasks, spec) - initialize_streaming!(queue.self_streams, spec...) +function enqueue!(queue::StreamingTaskQueue, pair::DTaskPair) + push!(queue.tasks, pair) + initialize_streaming!(queue.self_streams, pair.spec, pair.task) end -function enqueue!(queue::StreamingTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}) - append!(queue.tasks, specs) - for (spec, task) in specs - initialize_streaming!(queue.self_streams, spec, task) +function enqueue!(queue::StreamingTaskQueue, pairs::Vector{DTaskPair}) + append!(queue.tasks, pairs) + for pair in pairs + initialize_streaming!(queue.self_streams, pair.spec, pair.task) end end @@ -458,7 +458,7 @@ function spawn_streaming(f::Base.Callable; teardown::Bool=true) if teardown # Start teardown monitor - dtasks = map(last, queue.tasks)::Vector{DTask} + dtasks = map(pair->pair.task, queue.tasks)::Vector{DTask} Sch.errormonitor_tracked("streaming teardown", Threads.@spawn begin # Wait for any task to finish waitany(dtasks) @@ -663,10 +663,12 @@ function task_to_stream(uid::UInt) end end -function finalize_streaming!(tasks::Vector{Pair{DTaskSpec,DTask}}, self_streams) +function finalize_streaming!(tasks::Vector{DTaskPair}, self_streams) stream_waiter_changes = Dict{UInt,Vector{Pair{UInt,Any}}}() - for (spec, task) in tasks + for pair in tasks + spec = pair.spec + task = pair.task @assert haskey(self_streams, task.uid) our_stream = self_streams[task.uid] diff --git a/src/submission.jl b/src/submission.jl index 2e7b1c836..4ff4f2294 100644 --- a/src/submission.jl +++ b/src/submission.jl @@ -268,24 +268,29 @@ function eager_process_elem_submission_to_local!(id_map, arg::Argument) arg.value = Sch.ThunkID(id_map[value(arg).uid], value(arg).thunk_ref) end end -function eager_process_args_submission_to_local!(id_map, spec_pair::Pair{DTaskSpec,DTask}) - spec, task = spec_pair +function eager_process_elem_submission_to_local(id_map, arg::TypedArgument{T}) where T + @assert !(T <: Thunk) "Cannot use `Thunk`s in `@spawn`/`spawn`" + if T <: DTask && haskey(id_map, (value(arg)::DTask).uid) + #=FIXME:UNIQUE=# + return Sch.ThunkID(id_map[value(arg).uid], value(arg).thunk_ref) + end + return arg +end +function eager_process_args_submission_to_local!(id_map, spec::DTaskSpec{false}) for arg in spec.fargs eager_process_elem_submission_to_local!(id_map, arg) end end -function eager_process_args_submission_to_local!(id_map, spec_pairs::Vector{Pair{DTaskSpec,DTask}}) - for spec_pair in spec_pairs - eager_process_args_submission_to_local!(id_map, spec_pair) - end +function eager_process_args_submission_to_local(id_map, spec::DTaskSpec{true}) + return ntuple(i->eager_process_elem_submission_to_local(id_map, spec.fargs[i]), length(spec.fargs)) end -function DTaskMetadata(spec::DTaskSpec) - f = value(spec.fargs[1]) +DTaskMetadata(spec::DTaskSpec) = DTaskMetadata(eager_metadata(spec.fargs)) +function eager_metadata(fargs) + f = value(fargs[1]) f = f isa StreamingFunction ? f.f : f - arg_types = ntuple(i->chunktype(value(spec.fargs[i+1])), length(spec.fargs)-1) - return_type = Base.promote_op(f, arg_types...) - return DTaskMetadata(return_type) + arg_types = ntuple(i->chunktype(value(fargs[i+1])), length(fargs)-1) + return Base.promote_op(f, arg_types...) end function eager_spawn(spec::DTaskSpec) @@ -298,48 +303,64 @@ end chunktype(t::DTask) = t.metadata.return_type -function eager_launch!((spec, task)::Pair{DTaskSpec,DTask}) +function eager_launch!(pair::DTaskPair) + spec = pair.spec + task = pair.task + # Assign a name, if specified eager_assign_name!(spec, task) # Lookup DTask -> ThunkID - lock(Sch.EAGER_ID_MAP) do id_map - eager_process_args_submission_to_local!(id_map, spec=>task) + fargs = lock(Sch.EAGER_ID_MAP) do id_map + if is_typed(spec) + return Argument[map(Argument, eager_process_args_submission_to_local(id_map, spec))...] + else + eager_process_args_submission_to_local!(id_map, spec) + return spec.fargs + end end # Submit the task #=FIXME:REALLOC=# thunk_id = eager_submit!(PayloadOne(task.uid, task.future, - spec.fargs, spec.options, true)) + fargs, spec.options, true)) task.thunk_ref = thunk_id.ref end -function eager_launch!(specs::Vector{Pair{DTaskSpec,DTask}}) - ntasks = length(specs) +# FIXME: Don't convert Tuple to Vector{Argument} +function eager_launch!(pairs::Vector{DTaskPair}) + ntasks = length(pairs) # Assign a name, if specified - for (spec, task) in specs - eager_assign_name!(spec, task) + for pair in pairs + eager_assign_name!(pair.spec, pair.task) end #=FIXME:REALLOC_N=# - uids = [task.uid for (_, task) in specs] - futures = [task.future for (_, task) in specs] + uids = [pair.task.uid for pair in pairs] + futures = [pair.task.future for pair in pairs] # Get all functions, args/kwargs, and options #=FIXME:REALLOC_N=# all_fargs = lock(Sch.EAGER_ID_MAP) do id_map # Lookup DTask -> ThunkID - eager_process_args_submission_to_local!(id_map, specs) - [spec.fargs for (spec, _) in specs] + return map(pairs) do pair + spec = pair.spec + if is_typed(spec) + return Argument[map(Argument, eager_process_args_submission_to_local(id_map, spec))...] + else + eager_process_args_submission_to_local!(id_map, spec) + return spec.fargs + end + end end - all_options = Options[spec.options for (spec, _) in specs] + all_options = Options[pair.spec.options for pair in pairs] # Submit the tasks #=FIXME:REALLOC=# thunk_ids = eager_submit!(PayloadMulti(ntasks, uids, futures, all_fargs, all_options, true)) for i in 1:ntasks - task = specs[i][2] + task = pairs[i].task task.thunk_ref = thunk_ids[i].ref end end diff --git a/src/thunk.jl b/src/thunk.jl index 482d66209..d1701e3ef 100644 --- a/src/thunk.jl +++ b/src/thunk.jl @@ -186,21 +186,19 @@ function args_kwargs_to_arguments(f, args, kwargs) end return args_kwargs end -function args_kwargs_to_arguments(f, args) - @nospecialize f args - args_kwargs = Argument[] - push!(args_kwargs, Argument(ArgPosition(true, 0, :NULL), f)) - pos_ctr = 1 - for idx in 1:length(args) - pos, arg = args[idx]::Pair - if pos === nothing - push!(args_kwargs, Argument(pos_ctr, arg)) - pos_ctr += 1 +function args_kwargs_to_typedarguments(f, args, kwargs) + nargs = 1 + length(args) + length(kwargs) + return ntuple(nargs) do idx + if idx == 1 + return TypedArgument(ArgPosition(true, 0, :NULL), f) + elseif idx in 2:(1+length(args)) + arg = args[idx-1] + return TypedArgument(idx, arg) else - push!(args_kwargs, Argument(pos, arg)) + kw, value = kwargs[idx-length(args)-1] + return TypedArgument(kw, value) end end - return args_kwargs end """ @@ -491,7 +489,11 @@ function _par(mod, ex::Expr; lazy=true, recur=true, opts=()) @gensym result return quote let - $result = $spawn($f, $Options(;$(opts...)), $(args...); $(kwargs...)) + $result = if $get_task_typed() + $typed_spawn($f, $Options(;$(opts...)), $(args...); $(kwargs...)) + else + $spawn($f, $Options(;$(opts...)), $(args...); $(kwargs...)) + end if $(Expr(:islocal, sync_var)) put!($sync_var, schedule(Task(()->fetch($result; raw=true)))) end @@ -516,6 +518,9 @@ function _setindex!_return_value(A, value, idxs...) return value end +const TASK_TYPED = ScopedValue{Bool}(false) +get_task_typed() = TASK_TYPED[] + """ Dagger.spawn(f, args...; kwargs...) -> DTask @@ -526,6 +531,36 @@ Spawns a `DTask` that will call `f(args...; kwargs...)`. Also supports passing a function spawn(f, args...; kwargs...) @nospecialize f args kwargs + # Merge all passed options + if length(args) >= 1 && first(args) isa Options + # N.B. Make a defensive copy in case user aliases Options struct + task_options = copy(first(args)::Options) + args = args[2:end] + else + task_options = Options() + end + + # Process the args and kwargs into Argument form + args_kwargs = args_kwargs_to_arguments(f, args, kwargs) + + return _spawn(args_kwargs, task_options) +end +function typed_spawn(f, args...; kwargs...) + # Merge all passed options + if length(args) >= 1 && first(args) isa Options + # N.B. Make a defensive copy in case user aliases Options struct + task_options = copy(first(args)::Options) + args = args[2:end] + else + task_options = Options() + end + + # Process the args and kwargs into Tuple of TypedArgument form + args_kwargs = args_kwargs_to_typedarguments(f, args, kwargs) + + return _spawn(args_kwargs, task_options) +end +function _spawn(args_kwargs, task_options) # Get all scoped options and determine which propagate beyond this task scoped_options = get_options()::NamedTuple if haskey(scoped_options, :propagates) @@ -539,20 +574,9 @@ function spawn(f, args...; kwargs...) end append!(propagates, keys(scoped_options)::NTuple{N,Symbol} where N) - # Merge all passed options - if length(args) >= 1 && first(args) isa Options - # N.B. Make a defensive copy in case user aliases Options struct - task_options = copy(first(args)::Options) - args = args[2:end] - else - task_options = Options() - end # N.B. Merges into task_options options_merge!(task_options, scoped_options; override=false) - # Process the args and kwargs into Pair form - args_kwargs = args_kwargs_to_arguments(f, args, kwargs) - # Get task queue, and don't let it propagate task_queue = get(scoped_options, :task_queue, DefaultTaskQueue())::AbstractTaskQueue filter!(prop -> prop != :task_queue, propagates) @@ -568,7 +592,7 @@ function spawn(f, args...; kwargs...) task = eager_spawn(spec) # Enqueue the task into the task queue - enqueue!(task_queue, spec=>task) + enqueue!(task_queue, DTaskPair(spec, task)) return task end From 7c13c85c39ebee510ecafc40e9a8813e64d97c02 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Wed, 15 Oct 2025 19:50:51 -0700 Subject: [PATCH 03/28] datadeps: Optimize ainfo aliasing lookups --- src/Dagger.jl | 4 +- src/datadeps/aliasing.jl | 40 ++-- src/memory-spaces.jl | 271 +++++++++++++++++------ src/{datadeps => utils}/interval_tree.jl | 96 ++++---- src/utils/memory-span.jl | 98 ++++++++ 5 files changed, 381 insertions(+), 128 deletions(-) rename src/{datadeps => utils}/interval_tree.jl (81%) create mode 100644 src/utils/memory-span.jl diff --git a/src/Dagger.jl b/src/Dagger.jl index 987963b34..102a76149 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -73,6 +73,9 @@ include("utils/fetch.jl") include("utils/chunks.jl") include("utils/logging.jl") include("submission.jl") +abstract type MemorySpace end +include("utils/memory-span.jl") +include("utils/interval_tree.jl") include("memory-spaces.jl") # Task scheduling @@ -85,7 +88,6 @@ include("sch/Sch.jl"); using .Sch # Data dependency task queue include("datadeps/aliasing.jl") include("datadeps/chunkview.jl") -include("datadeps/interval_tree.jl") include("datadeps/remainders.jl") include("datadeps/queue.jl") diff --git a/src/datadeps/aliasing.jl b/src/datadeps/aliasing.jl index 164944f13..6482492a9 100644 --- a/src/datadeps/aliasing.jl +++ b/src/datadeps/aliasing.jl @@ -288,6 +288,10 @@ struct DataDepsState # N.B. This is a mapping for remote argument copies ainfo_cache::Dict{ArgumentWrapper,AliasingWrapper} + # The oracle for aliasing lookups + # Used to populate ainfos_overlaps efficiently + ainfos_lookup::AliasingLookup + # The overlapping ainfos for each ainfo # Incrementally updated as new ainfos are created # Used for fast will_alias lookups @@ -317,13 +321,14 @@ struct DataDepsState supports_inplace_cache = IdDict{Any,Bool}() ainfo_cache = Dict{ArgumentWrapper,AliasingWrapper}() + ainfos_lookup = AliasingLookup() ainfos_overlaps = Dict{AliasingWrapper,Set{AliasingWrapper}}() ainfos_owner = Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}}() ainfos_readers = Dict{AliasingWrapper,Vector{Pair{DTask,Int}}}() return new(arg_to_chunk, arg_origin, remote_args, remote_arg_to_original, ainfo_arg, arg_owner, arg_overlaps, ainfo_backing_chunk, arg_history, - supports_inplace_cache, ainfo_cache, ainfos_overlaps, ainfos_owner, ainfos_readers) + supports_inplace_cache, ainfo_cache, ainfos_lookup, ainfos_overlaps, ainfos_owner, ainfos_readers) end end @@ -452,27 +457,30 @@ function populate_argument_info!(state::DataDepsState, arg_w::ArgumentWrapper, o aliasing!(state, origin_space, arg_w) end function populate_ainfo!(state::DataDepsState, original_arg_w::ArgumentWrapper, target_ainfo::AliasingWrapper, target_space::MemorySpace) - # Initialize owner and readers if !haskey(state.ainfos_owner, target_ainfo) + # Add ourselves to the lookup oracle + ainfo_idx = push!(state.ainfos_lookup, target_ainfo) + + # Find overlapping ainfos overlaps = Set{AliasingWrapper}() push!(overlaps, target_ainfo) - for other_ainfo in keys(state.ainfos_owner) + for other_ainfo in intersect(state.ainfos_lookup, target_ainfo; ainfo_idx) target_ainfo == other_ainfo && continue - if will_alias(target_ainfo, other_ainfo) - # Mark us and them as overlapping - push!(overlaps, other_ainfo) - push!(state.ainfos_overlaps[other_ainfo], target_ainfo) - - # Add overlapping history to our own - other_remote_arg_w = state.ainfo_arg[other_ainfo] - other_arg = state.remote_arg_to_original[other_remote_arg_w.arg] - other_arg_w = ArgumentWrapper(other_arg, other_remote_arg_w.dep_mod) - push!(state.arg_overlaps[original_arg_w], other_arg_w) - push!(state.arg_overlaps[other_arg_w], original_arg_w) - merge_history!(state, original_arg_w, other_arg_w) - end + # Mark us and them as overlapping + push!(overlaps, other_ainfo) + push!(state.ainfos_overlaps[other_ainfo], target_ainfo) + + # Add overlapping history to our own + other_remote_arg_w = state.ainfo_arg[other_ainfo] + other_arg = state.remote_arg_to_original[other_remote_arg_w.arg] + other_arg_w = ArgumentWrapper(other_arg, other_remote_arg_w.dep_mod) + push!(state.arg_overlaps[original_arg_w], other_arg_w) + push!(state.arg_overlaps[other_arg_w], original_arg_w) + merge_history!(state, original_arg_w, other_arg_w) end state.ainfos_overlaps[target_ainfo] = overlaps + + # Initialize owner and readers state.ainfos_owner[target_ainfo] = nothing state.ainfos_readers[target_ainfo] = Pair{DTask,Int}[] end diff --git a/src/memory-spaces.jl b/src/memory-spaces.jl index b1ff40d8f..fcc3dbf0b 100644 --- a/src/memory-spaces.jl +++ b/src/memory-spaces.jl @@ -1,5 +1,3 @@ -abstract type MemorySpace end - struct CPURAMMemorySpace <: MemorySpace owner::Int end @@ -92,46 +90,16 @@ end may_alias(::MemorySpace, ::MemorySpace) = true may_alias(space1::CPURAMMemorySpace, space2::CPURAMMemorySpace) = space1.owner == space2.owner -struct RemotePtr{T,S<:MemorySpace} <: Ref{T} - addr::UInt - space::S -end -RemotePtr{T}(addr::UInt, space::S) where {T,S} = RemotePtr{T,S}(addr, space) -RemotePtr{T}(ptr::Ptr{V}, space::S) where {T,V,S} = RemotePtr{T,S}(UInt(ptr), space) -RemotePtr{T}(ptr::Ptr{V}) where {T,V} = RemotePtr{T}(UInt(ptr), CPURAMMemorySpace(myid())) -# FIXME: Don't hardcode CPURAMMemorySpace -RemotePtr(addr::UInt) = RemotePtr{Cvoid}(addr, CPURAMMemorySpace(myid())) -Base.convert(::Type{RemotePtr}, x::Ptr{T}) where T = - RemotePtr(UInt(x), CPURAMMemorySpace(myid())) -Base.convert(::Type{<:RemotePtr{V}}, x::Ptr{T}) where {V,T} = - RemotePtr{V}(UInt(x), CPURAMMemorySpace(myid())) -Base.convert(::Type{UInt}, ptr::RemotePtr) = ptr.addr -Base.:+(ptr::RemotePtr{T}, offset::Integer) where T = RemotePtr{T}(ptr.addr + offset, ptr.space) -Base.:-(ptr::RemotePtr{T}, offset::Integer) where T = RemotePtr{T}(ptr.addr - offset, ptr.space) -function Base.isless(ptr1::RemotePtr, ptr2::RemotePtr) - @assert ptr1.space == ptr2.space - return ptr1.addr < ptr2.addr -end - -struct MemorySpan{S} - ptr::RemotePtr{Cvoid,S} - len::UInt -end -MemorySpan(ptr::RemotePtr{Cvoid,S}, len::Integer) where S = - MemorySpan{S}(ptr, UInt(len)) -MemorySpan{S}(addr::UInt, len::Integer) where S = - MemorySpan{S}(RemotePtr{Cvoid,S}(addr), UInt(len)) -Base.isless(a::MemorySpan, b::MemorySpan) = a.ptr < b.ptr -Base.isempty(x::MemorySpan) = x.len == 0 abstract type AbstractAliasing end memory_spans(::T) where T<:AbstractAliasing = throw(ArgumentError("Must define `memory_spans` for `$T`")) memory_spans(x) = memory_spans(aliasing(x)) memory_spans(x, T) = memory_spans(aliasing(x, T)) -struct AliasingWrapper <: AbstractAliasing +### Type-generic aliasing info wrapper + +mutable struct AliasingWrapper <: AbstractAliasing inner::AbstractAliasing hash::UInt64 - AliasingWrapper(inner::AbstractAliasing) = new(inner, hash(inner)) end memory_spans(x::AliasingWrapper) = memory_spans(x.inner) @@ -140,8 +108,204 @@ equivalent_structure(x::AliasingWrapper, y::AliasingWrapper) = Base.hash(x::AliasingWrapper, h::UInt64) = hash(x.hash, h) Base.isequal(x::AliasingWrapper, y::AliasingWrapper) = x.hash == y.hash Base.:(==)(x::AliasingWrapper, y::AliasingWrapper) = x.hash == y.hash -will_alias(x::AliasingWrapper, y::AliasingWrapper) = - will_alias(x.inner, y.inner) +will_alias(x::AliasingWrapper, y::AliasingWrapper) = will_alias(x.inner, y.inner) + +### Small dictionary type + +struct SmallDict{K,V} <: AbstractDict{K,V} + keys::Vector{K} + vals::Vector{V} +end +SmallDict{K,V}() where {K,V} = SmallDict{K,V}(Vector{K}(), Vector{V}()) +function Base.getindex(d::SmallDict{K,V}, key) where {K,V} + key_idx = findfirst(==(convert(K, key)), d.keys) + if key_idx === nothing + throw(KeyError(key)) + end + return @inbounds d.vals[key_idx] +end +function Base.setindex!(d::SmallDict{K,V}, val, key) where {K,V} + key_conv = convert(K, key) + key_idx = findfirst(==(key_conv), d.keys) + if key_idx === nothing + push!(d.keys, key_conv) + push!(d.vals, convert(V, val)) + else + d.vals[key_idx] = convert(V, val) + end + return val +end +Base.haskey(d::SmallDict{K,V}, key) where {K,V} = in(convert(K, key), d.keys) +Base.keys(d::SmallDict) = d.keys +Base.length(d::SmallDict) = length(d.keys) +Base.iterate(d::SmallDict) = iterate(d, 1) +Base.iterate(d::SmallDict, state) = state > length(d.keys) ? nothing : (d.keys[state] => d.vals[state], state+1) + +### Type-stable lookup structure for AliasingWrappers + +struct AliasingLookup + # The set of memory spaces that are being tracked + spaces::Vector{MemorySpace} + # The set of AliasingWrappers that are being tracked + # One entry for each AliasingWrapper + ainfos::Vector{AliasingWrapper} + # The memory spaces for each AliasingWrapper + # One entry for each AliasingWrapper + ainfos_spaces::Vector{Vector{Int}} + # The spans for each AliasingWrapper in each memory space + # One entry for each AliasingWrapper + spans::Vector{SmallDict{Int,Vector{LocalMemorySpan}}} + # The set of AliasingWrappers that only exist in a single memory space + # One entry for each AliasingWrapper + ainfos_only_space::Vector{Int} + # The bounding span for each AliasingWrapper in each memory space + # One entry for each AliasingWrapper + bounding_spans::Vector{SmallDict{Int,LocalMemorySpan}} + # The interval tree of the bounding spans for each AliasingWrapper + # One entry for each MemorySpace + bounding_spans_tree::Vector{IntervalTree{LocatorMemorySpan{Int},UInt64}} + + AliasingLookup() = new(MemorySpace[], + AliasingWrapper[], + Vector{Int}[], + SmallDict{Int,Vector{LocalMemorySpan}}[], + Int[], + SmallDict{Int,LocalMemorySpan}[], + IntervalTree{LocatorMemorySpan{Int},UInt64}[]) +end +function Base.push!(lookup::AliasingLookup, ainfo::AliasingWrapper) + # Update the set of memory spaces and spans, + # and find the bounding spans for this AliasingWrapper + spaces_set = Set{MemorySpace}(lookup.spaces) + self_spaces_set = Set{Int}() + spans = SmallDict{Int,Vector{LocalMemorySpan}}() + for span in memory_spans(ainfo) + space = span.ptr.space + if !in(space, spaces_set) + push!(spaces_set, space) + push!(lookup.spaces, space) + push!(lookup.bounding_spans_tree, IntervalTree{LocatorMemorySpan{Int}}()) + end + space_idx = findfirst(==(space), lookup.spaces) + push!(self_spaces_set, space_idx) + spans_in_space = get!(Vector{LocalMemorySpan}, spans, space_idx) + push!(spans_in_space, LocalMemorySpan(span)) + end + push!(lookup.ainfos_spaces, collect(self_spaces_set)) + push!(lookup.spans, spans) + + # Update the set of AliasingWrappers + push!(lookup.ainfos, ainfo) + ainfo_idx = length(lookup.ainfos) + + # Check if the AliasingWrapper only exists in a single memory space + if length(self_spaces_set) == 1 + space_idx = only(self_spaces_set) + push!(lookup.ainfos_only_space, space_idx) + else + push!(lookup.ainfos_only_space, 0) + end + + # Add the bounding spans for this AliasingWrapper + bounding_spans = SmallDict{Int,LocalMemorySpan}() + for space_idx in keys(spans) + space_spans = spans[space_idx] + bound_start = minimum(span_start, space_spans) + bound_end = maximum(span_end, space_spans) + bounding_span = LocalMemorySpan(bound_start, bound_end - bound_start) + bounding_spans[space_idx] = bounding_span + insert!(lookup.bounding_spans_tree[space_idx], LocatorMemorySpan(bounding_span, ainfo_idx)) + end + push!(lookup.bounding_spans, bounding_spans) + + return ainfo_idx +end +struct AliasingLookupFinder + lookup::AliasingLookup + ainfo::AliasingWrapper + ainfo_idx::Int + spaces_idx::Vector{Int} + to_consider::Vector{Int} +end +Base.eltype(::AliasingLookupFinder) = AliasingWrapper +Base.IteratorSize(::AliasingLookupFinder) = Base.SizeUnknown() +# FIXME: We should use a Dict{UInt,Int} to find the ainfo_idx instead of linear search +function Base.intersect(lookup::AliasingLookup, ainfo::AliasingWrapper; ainfo_idx=nothing) + if ainfo_idx === nothing + ainfo_idx = something(findfirst(==(ainfo), lookup.ainfos)) + end + spaces_idx = lookup.ainfos_spaces[ainfo_idx] + to_consider_spans = LocatorMemorySpan{Int}[] + for space_idx in spaces_idx + bounding_spans_tree = lookup.bounding_spans_tree[space_idx] + self_bounding_span = LocatorMemorySpan(lookup.bounding_spans[ainfo_idx][space_idx], 0) + find_overlapping!(bounding_spans_tree, self_bounding_span, to_consider_spans; exact=false) + end + to_consider = Int[locator.owner for locator in to_consider_spans] + @assert all(to_consider .> 0) + return AliasingLookupFinder(lookup, ainfo, ainfo_idx, spaces_idx, to_consider) +end +Base.iterate(finder::AliasingLookupFinder) = iterate(finder, 1) +function Base.iterate(finder::AliasingLookupFinder, cursor_ainfo_idx) + ainfo_spaces = nothing + cursor_space_idx = 1 + + # New ainfos enter here + @label ainfo_restart + + # Check if we've exhausted all ainfos + if cursor_ainfo_idx > length(finder.to_consider) + return nothing + end + ainfo_idx = finder.to_consider[cursor_ainfo_idx] + + # Find the appropriate memory spaces for this ainfo + if ainfo_spaces === nothing + ainfo_spaces = finder.lookup.ainfos_spaces[ainfo_idx] + end + + # New memory spaces (for the same ainfo) enter here + @label space_restart + + # Check if we've exhausted all memory spaces for this ainfo, and need to move to the next ainfo + if cursor_space_idx > length(ainfo_spaces) + cursor_ainfo_idx += 1 + ainfo_spaces = nothing + cursor_space_idx = 1 + @goto ainfo_restart + end + + # Find the currently considered memory space for this ainfo + space_idx = ainfo_spaces[cursor_space_idx] + + # Check if this memory space is part of our target ainfo's spaces + if !(space_idx in finder.spaces_idx) + cursor_space_idx += 1 + @goto space_restart + end + + # Check if this ainfo's bounding span is part of our target ainfo's bounding span in this memory space + other_ainfo_bounding_span = finder.lookup.bounding_spans[ainfo_idx][space_idx] + self_bounding_span = finder.lookup.bounding_spans[finder.ainfo_idx][space_idx] + if !spans_overlap(other_ainfo_bounding_span, self_bounding_span) + cursor_space_idx += 1 + @goto space_restart + end + + # We have a overlapping bounds in the same memory space, so check if the ainfos are aliasing + # This is the slow path! + other_ainfo = finder.lookup.ainfos[ainfo_idx] + aliasing = will_alias(finder.ainfo, other_ainfo) + if !aliasing + cursor_ainfo_idx += 1 + ainfo_spaces = nothing + cursor_space_idx = 1 + @goto ainfo_restart + end + + # We overlap, so return the ainfo and the next ainfo index + return other_ainfo, cursor_ainfo_idx+1 +end struct NoAliasing <: AbstractAliasing end memory_spans(::NoAliasing) = MemorySpan{CPURAMMemorySpace}[] @@ -406,36 +570,3 @@ function will_alias(x_span::MemorySpan, y_span::MemorySpan) y_end = y_span.ptr + y_span.len - 1 return x_span.ptr <= y_end && y_span.ptr <= x_end end - -### More space-efficient memory spans - -struct LocalMemorySpan - ptr::UInt - len::UInt -end -LocalMemorySpan(span::MemorySpan) = LocalMemorySpan(span.ptr.addr, span.len) -Base.isempty(x::LocalMemorySpan) = x.len == 0 - -# FIXME: Store the length separately, since it's shared by all spans -struct ManyMemorySpan{N} - spans::NTuple{N,LocalMemorySpan} -end -Base.isempty(x::ManyMemorySpan) = all(isempty, x.spans) - -struct ManyPair{N} <: Unsigned - pairs::NTuple{N,UInt} -end -Base.promote_rule(::Type{ManyPair}, ::Type{T}) where {T<:Integer} = ManyPair -Base.convert(::Type{ManyPair{N}}, x::T) where {T<:Integer,N} = ManyPair(ntuple(i -> x, N)) -Base.convert(::Type{ManyPair}, x::ManyPair) = x -Base.:+(x::ManyPair{N}, y::ManyPair{N}) where N = ManyPair(ntuple(i -> x.pairs[i] + y.pairs[i], N)) -Base.:-(x::ManyPair{N}, y::ManyPair{N}) where N = ManyPair(ntuple(i -> x.pairs[i] - y.pairs[i], N)) -Base.:-(x::ManyPair) = error("Can't negate a ManyPair") -Base.:(==)(x::ManyPair, y::ManyPair) = x.pairs == y.pairs -Base.isless(x::ManyPair, y::ManyPair) = x.pairs[1] < y.pairs[1] -Base.:(<)(x::ManyPair, y::ManyPair) = x.pairs[1] < y.pairs[1] -Base.string(x::ManyPair) = "ManyPair($(x.pairs))" - -ManyMemorySpan{N}(start::ManyPair{N}, len::ManyPair{N}) where N = - ManyMemorySpan{N}(ntuple(i -> LocalMemorySpan(start.pairs[i], len.pairs[i]), N)) - diff --git a/src/datadeps/interval_tree.jl b/src/utils/interval_tree.jl similarity index 81% rename from src/datadeps/interval_tree.jl rename to src/utils/interval_tree.jl index 1075f5912..e67f66b24 100644 --- a/src/datadeps/interval_tree.jl +++ b/src/utils/interval_tree.jl @@ -1,16 +1,3 @@ -# Get the start address of a span -span_start(span::MemorySpan) = span.ptr.addr -span_start(span::LocalMemorySpan) = span.ptr -span_start(span::ManyMemorySpan{N}) where N = ManyPair(ntuple(i -> span_start(span.spans[i]), N)) -# Get the length of a span -span_len(span::MemorySpan) = span.len -span_len(span::LocalMemorySpan) = span.len -span_len(span::ManyMemorySpan{N}) where N = ManyPair(ntuple(i -> span_len(span.spans[i]), N)) - -# Get the end address of a span -span_end(span::MemorySpan) = span.ptr.addr + span.len -span_end(span::LocalMemorySpan) = span.ptr + span.len -span_end(span::ManyMemorySpan{N}) where N = ManyPair(ntuple(i -> span_end(span.spans[i]), N)) mutable struct IntervalNode{M,E} span::M max_end::E # Maximum end value in this subtree @@ -20,6 +7,7 @@ mutable struct IntervalNode{M,E} IntervalNode(span::M) where M <: MemorySpan = new{M,UInt64}(span, span_end(span), nothing, nothing) IntervalNode(span::LocalMemorySpan) = new{LocalMemorySpan,UInt64}(span, span_end(span), nothing, nothing) IntervalNode(span::ManyMemorySpan{N}) where N = new{ManyMemorySpan{N},ManyPair{N}}(span, span_end(span), nothing, nothing) + IntervalNode(span::LocatorMemorySpan{T}) where T = new{LocatorMemorySpan{T},UInt64}(span, span_end(span), nothing, nothing) end mutable struct IntervalTree{M,E} @@ -28,6 +16,7 @@ mutable struct IntervalTree{M,E} IntervalTree{M}() where M<:MemorySpan = new{M,UInt64}(nothing) IntervalTree{LocalMemorySpan}() = new{LocalMemorySpan,UInt64}(nothing) IntervalTree{ManyMemorySpan{N}}() where N = new{ManyMemorySpan{N},ManyPair{N}}(nothing) + IntervalTree{LocatorMemorySpan{T}}() where T = new{LocatorMemorySpan{T},UInt64}(nothing) end # Construct interval tree from unsorted set of spans @@ -94,19 +83,48 @@ end # Update max_end value for a node based on its children function update_max_end!(node::IntervalNode) - node.max_end = span_end(node.span) + max_end = span_end(node.span) if node.left !== nothing - node.max_end = max(node.max_end, node.left.max_end) + max_end = max(max_end, node.left.max_end) end if node.right !== nothing - node.max_end = max(node.max_end, node.right.max_end) + max_end = max(max_end, node.right.max_end) end + node.max_end = max_end end # Insert a span into the interval tree -function Base.insert!(tree::IntervalTree{M}, span::M) where M +function Base.insert!(tree::IntervalTree{M,E}, span::M) where {M,E} if !isempty(span) - tree.root = insert_node!(tree.root, span) + if tree.root === nothing + tree.root = IntervalNode(span) + update_max_end!(tree.root) + return span + end + #tree.root = insert_node!(tree.root, span) + to_update = Vector{IntervalNode{M,E}}() + prev_node = tree.root + cur_node = tree.root + while cur_node !== nothing + if span_start(span) <= span_start(cur_node.span) + cur_node = cur_node.left + else + cur_node = cur_node.right + end + if cur_node !== nothing + prev_node = cur_node + push!(to_update, cur_node) + end + end + if prev_node.left === nothing + prev_node.left = IntervalNode(span) + else + prev_node.right = IntervalNode(span) + end + for node_idx in eachindex(to_update) + node = to_update[node_idx] + update_max_end!(node) + end end return span end @@ -221,46 +239,42 @@ function find_min(node::IntervalNode) return node end -# Check if two spans overlap -function spans_overlap(span1::MemorySpan, span2::MemorySpan) - return span_start(span1) < span_end(span2) && span_start(span2) < span_end(span1) -end -function spans_overlap(span1::LocalMemorySpan, span2::LocalMemorySpan) - return span_start(span1) < span_end(span2) && span_start(span2) < span_end(span1) -end -function spans_overlap(span1::ManyMemorySpan{N}, span2::ManyMemorySpan{N}) where N - # N.B. The spans are assumed to be the same length and relative offset - return spans_overlap(span1.spans[1], span2.spans[1]) -end - # Find all spans that overlap with the given query span -function find_overlapping(tree::IntervalTree{M}, query::M) where M +function find_overlapping(tree::IntervalTree{M}, query::M; exact::Bool=true) where M result = M[] - find_overlapping!(tree.root, query, result) + find_overlapping!(tree.root, query, result; exact) + return result +end +function find_overlapping!(tree::IntervalTree{M}, query::M, result::Vector{M}; exact::Bool=true) where M + find_overlapping!(tree.root, query, result; exact) return result end -function find_overlapping!(::Nothing, query::M, result::Vector{M}) where M +function find_overlapping!(::Nothing, query::M, result::Vector{M}; exact::Bool=true) where M return end -function find_overlapping!(node::IntervalNode{M,E}, query::M, result::Vector{M}) where {M,E} +function find_overlapping!(node::IntervalNode{M,E}, query::M, result::Vector{M}; exact::Bool=true) where {M,E} # Check if current node overlaps with query if spans_overlap(node.span, query) - # Get the overlapping portion of the span - overlap_start = max(span_start(node.span), span_start(query)) - overlap_end = min(span_end(node.span), span_end(query)) - overlap = M(overlap_start, overlap_end - overlap_start) - push!(result, overlap) + if exact + # Get the overlapping portion of the span + overlap_start = max(span_start(node.span), span_start(query)) + overlap_end = min(span_end(node.span), span_end(query)) + overlap = M(overlap_start, overlap_end - overlap_start) + push!(result, overlap) + else + push!(result, node.span) + end end # Recursively search left subtree if it might contain overlapping intervals if node.left !== nothing && node.left.max_end > span_start(query) - find_overlapping!(node.left, query, result) + find_overlapping!(node.left, query, result; exact) end # Recursively search right subtree if query extends beyond current node's start if node.right !== nothing && span_end(query) > span_start(node.span) - find_overlapping!(node.right, query, result) + find_overlapping!(node.right, query, result; exact) end end diff --git a/src/utils/memory-span.jl b/src/utils/memory-span.jl new file mode 100644 index 000000000..91f291cbe --- /dev/null +++ b/src/utils/memory-span.jl @@ -0,0 +1,98 @@ +### Remote pointer type + +struct RemotePtr{T,S<:MemorySpace} <: Ref{T} + addr::UInt + space::S +end +RemotePtr{T}(addr::UInt, space::S) where {T,S} = RemotePtr{T,S}(addr, space) +RemotePtr{T}(ptr::Ptr{V}, space::S) where {T,V,S} = RemotePtr{T,S}(UInt(ptr), space) +RemotePtr{T}(ptr::Ptr{V}) where {T,V} = RemotePtr{T}(UInt(ptr), CPURAMMemorySpace(myid())) +# FIXME: Don't hardcode CPURAMMemorySpace +RemotePtr(addr::UInt) = RemotePtr{Cvoid}(addr, CPURAMMemorySpace(myid())) +Base.convert(::Type{RemotePtr}, x::Ptr{T}) where T = + RemotePtr(UInt(x), CPURAMMemorySpace(myid())) +Base.convert(::Type{<:RemotePtr{V}}, x::Ptr{T}) where {V,T} = + RemotePtr{V}(UInt(x), CPURAMMemorySpace(myid())) +Base.convert(::Type{UInt}, ptr::RemotePtr) = ptr.addr +Base.:+(ptr::RemotePtr{T}, offset::Integer) where T = RemotePtr{T}(ptr.addr + offset, ptr.space) +Base.:-(ptr::RemotePtr{T}, offset::Integer) where T = RemotePtr{T}(ptr.addr - offset, ptr.space) +function Base.isless(ptr1::RemotePtr, ptr2::RemotePtr) + @assert ptr1.space == ptr2.space + return ptr1.addr < ptr2.addr +end + +### Generic memory spans + +struct MemorySpan{S} + ptr::RemotePtr{Cvoid,S} + len::UInt +end +MemorySpan(ptr::RemotePtr{Cvoid,S}, len::Integer) where S = + MemorySpan{S}(ptr, UInt(len)) +MemorySpan{S}(addr::UInt, len::Integer) where S = + MemorySpan{S}(RemotePtr{Cvoid,S}(addr), UInt(len)) +Base.isless(a::MemorySpan, b::MemorySpan) = a.ptr < b.ptr +Base.isempty(x::MemorySpan) = x.len == 0 +span_start(span::MemorySpan) = span.ptr.addr +span_len(span::MemorySpan) = span.len +span_end(span::MemorySpan) = span.ptr.addr + span.len +spans_overlap(span1::MemorySpan, span2::MemorySpan) = + span_start(span1) < span_end(span2) && span_start(span2) < span_end(span1) + +### More space-efficient memory spans + +struct LocalMemorySpan + ptr::UInt + len::UInt +end +LocalMemorySpan(span::MemorySpan) = LocalMemorySpan(span.ptr.addr, span.len) +Base.isempty(x::LocalMemorySpan) = x.len == 0 +span_start(span::LocalMemorySpan) = span.ptr +span_len(span::LocalMemorySpan) = span.len +span_end(span::LocalMemorySpan) = span.ptr + span.len +spans_overlap(span1::LocalMemorySpan, span2::LocalMemorySpan) = + span_start(span1) < span_end(span2) && span_start(span2) < span_end(span1) + +# FIXME: Store the length separately, since it's shared by all spans +struct ManyMemorySpan{N} + spans::NTuple{N,LocalMemorySpan} +end +Base.isempty(x::ManyMemorySpan) = all(isempty, x.spans) +span_start(span::ManyMemorySpan{N}) where N = ManyPair(ntuple(i -> span_start(span.spans[i]), N)) +span_len(span::ManyMemorySpan{N}) where N = ManyPair(ntuple(i -> span_len(span.spans[i]), N)) +span_end(span::ManyMemorySpan{N}) where N = ManyPair(ntuple(i -> span_end(span.spans[i]), N)) +spans_overlap(span1::ManyMemorySpan{N}, span2::ManyMemorySpan{N}) where N = + # N.B. The spans are assumed to be the same length and relative offset + spans_overlap(span1.spans[1], span2.spans[1]) + +struct ManyPair{N} <: Unsigned + pairs::NTuple{N,UInt} +end +Base.promote_rule(::Type{ManyPair}, ::Type{T}) where {T<:Integer} = ManyPair +Base.convert(::Type{ManyPair{N}}, x::T) where {T<:Integer,N} = ManyPair(ntuple(i -> x, N)) +Base.convert(::Type{ManyPair}, x::ManyPair) = x +Base.:+(x::ManyPair{N}, y::ManyPair{N}) where N = ManyPair(ntuple(i -> x.pairs[i] + y.pairs[i], N)) +Base.:-(x::ManyPair{N}, y::ManyPair{N}) where N = ManyPair(ntuple(i -> x.pairs[i] - y.pairs[i], N)) +Base.:-(x::ManyPair) = error("Can't negate a ManyPair") +Base.:(==)(x::ManyPair, y::ManyPair) = x.pairs == y.pairs +Base.isless(x::ManyPair, y::ManyPair) = x.pairs[1] < y.pairs[1] +Base.:(<)(x::ManyPair, y::ManyPair) = x.pairs[1] < y.pairs[1] +Base.string(x::ManyPair) = "ManyPair($(x.pairs))" + +ManyMemorySpan{N}(start::ManyPair{N}, len::ManyPair{N}) where N = + ManyMemorySpan{N}(ntuple(i -> LocalMemorySpan(start.pairs[i], len.pairs[i]), N)) + +### Memory spans with ownership info + +struct LocatorMemorySpan{T} + span::LocalMemorySpan + owner::T +end +LocatorMemorySpan{T}(start::UInt64, len::UInt64) where T = # For interval tree + LocatorMemorySpan{T}(LocalMemorySpan(start, len), 0) +Base.isempty(x::LocatorMemorySpan) = span_len(x.span) == 0 +span_start(x::LocatorMemorySpan) = span_start(x.span) +span_end(x::LocatorMemorySpan) = span_end(x.span) +span_len(x::LocatorMemorySpan) = span_len(x.span) +spans_overlap(span1::LocatorMemorySpan{T}, span2::LocatorMemorySpan{T}) where T = + spans_overlap(span1.span, span2.span) \ No newline at end of file From 1333ecf7603d695cb60cced428b1fe5662085310 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Wed, 15 Oct 2025 17:51:50 -0700 Subject: [PATCH 04/28] datadeps: Optimize remote ArgumentWrapper lookup --- src/datadeps/aliasing.jl | 84 +++++++++++++++++++++++++--------------- src/memory-spaces.jl | 10 ++++- 2 files changed, 60 insertions(+), 34 deletions(-) diff --git a/src/datadeps/aliasing.jl b/src/datadeps/aliasing.jl index 6482492a9..39b07cb28 100644 --- a/src/datadeps/aliasing.jl +++ b/src/datadeps/aliasing.jl @@ -258,6 +258,9 @@ struct DataDepsState # The mapping of remote argument to original argument remote_arg_to_original::IdDict{Any,Any} + # The mapping of original argument wrapper to remote argument wrapper + remote_arg_w::Dict{ArgumentWrapper,Dict{MemorySpace,ArgumentWrapper}} + # The mapping of ainfo to argument and dep_mod # Used to lookup which argument and dep_mod a given ainfo is generated from # N.B. This is a mapping for remote argument copies @@ -312,6 +315,7 @@ struct DataDepsState arg_origin = IdDict{Any,MemorySpace}() remote_args = Dict{MemorySpace,IdDict{Any,Any}}() remote_arg_to_original = IdDict{Any,Any}() + remote_arg_w = Dict{ArgumentWrapper,Dict{MemorySpace,ArgumentWrapper}}() ainfo_arg = Dict{AliasingWrapper,ArgumentWrapper}() arg_owner = Dict{ArgumentWrapper,MemorySpace}() arg_overlaps = Dict{ArgumentWrapper,Set{ArgumentWrapper}}() @@ -327,37 +331,11 @@ struct DataDepsState ainfos_owner = Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}}() ainfos_readers = Dict{AliasingWrapper,Vector{Pair{DTask,Int}}}() - return new(arg_to_chunk, arg_origin, remote_args, remote_arg_to_original, ainfo_arg, arg_owner, arg_overlaps, ainfo_backing_chunk, arg_history, + return new(arg_to_chunk, arg_origin, remote_args, remote_arg_to_original, remote_arg_w, ainfo_arg, arg_owner, arg_overlaps, ainfo_backing_chunk, arg_history, supports_inplace_cache, ainfo_cache, ainfos_lookup, ainfos_overlaps, ainfos_owner, ainfos_readers) end end -# N.B. arg_w must be the original argument wrapper, not a remote copy -function aliasing!(state::DataDepsState, target_space::MemorySpace, arg_w::ArgumentWrapper) - # Grab the remote copy of the argument, and calculate the ainfo - remote_arg = get_or_generate_slot!(state, target_space, arg_w.arg) - remote_arg_w = ArgumentWrapper(remote_arg, arg_w.dep_mod) - - # Check if we already have the result cached - if haskey(state.ainfo_cache, remote_arg_w) - return state.ainfo_cache[remote_arg_w] - end - - # Calculate the ainfo - ainfo = AliasingWrapper(aliasing(remote_arg, arg_w.dep_mod)) - - # Cache the result - state.ainfo_cache[remote_arg_w] = ainfo - - # Update the mapping of ainfo to argument and dep_mod - state.ainfo_arg[ainfo] = remote_arg_w - - # Populate info for the new ainfo - populate_ainfo!(state, arg_w, ainfo, target_space) - - return ainfo -end - function supports_inplace_move(state::DataDepsState, arg) return get!(state.supports_inplace_cache, arg) do return supports_inplace_move(arg) @@ -456,6 +434,41 @@ function populate_argument_info!(state::DataDepsState, arg_w::ArgumentWrapper, o # Calculate the ainfo (which will populate ainfo structures and merge history) aliasing!(state, origin_space, arg_w) end +# N.B. arg_w must be the original argument wrapper, not a remote copy +function aliasing!(state::DataDepsState, target_space::MemorySpace, arg_w::ArgumentWrapper) + if haskey(state.remote_arg_w, arg_w) && haskey(state.remote_arg_w[arg_w], target_space) + remote_arg_w = @inbounds state.remote_arg_w[arg_w][target_space] + remote_arg = remote_arg_w.arg + else + # Grab the remote copy of the argument, and calculate the ainfo + remote_arg = get_or_generate_slot!(state, target_space, arg_w.arg) + remote_arg_w = ArgumentWrapper(remote_arg, arg_w.dep_mod) + get!(Dict{MemorySpace,ArgumentWrapper}, state.remote_arg_w, arg_w)[target_space] = remote_arg_w + end + + # Check if we already have the result cached + if haskey(state.ainfo_cache, remote_arg_w) + return state.ainfo_cache[remote_arg_w] + end + + # Calculate the ainfo + ainfo = AliasingWrapper(aliasing(remote_arg, arg_w.dep_mod)) + + # Cache the result + state.ainfo_cache[remote_arg_w] = ainfo + + # Update the mapping of ainfo to argument and dep_mod + if !haskey(state.ainfo_arg, ainfo) + state.ainfo_arg[ainfo] = remote_arg_w + else + @assert state.ainfo_arg[ainfo] == remote_arg_w + end + + # Populate info for the new ainfo + populate_ainfo!(state, arg_w, ainfo, target_space) + + return ainfo +end function populate_ainfo!(state::DataDepsState, original_arg_w::ArgumentWrapper, target_ainfo::AliasingWrapper, target_space::MemorySpace) if !haskey(state.ainfos_owner, target_ainfo) # Add ourselves to the lookup oracle @@ -662,11 +675,18 @@ function get_or_generate_slot!(state, dest_space, data) end function move_rewrap(from_proc::Processor, to_proc::Processor, data) return aliased_object!(data) do data - to_w = root_worker_id(to_proc) - return remotecall_fetch(to_w, from_proc, to_proc, data) do from_proc, to_proc, data - data_converted = move(from_proc, to_proc, data) - return tochunk(data_converted, to_proc) - end + return remotecall_endpoint(identity, from_proc, to_proc, from_space, to_space, data) + end +end +function remotecall_endpoint(f, from_proc, to_proc, orig_space, dest_space, data) + to_w = root_worker_id(to_proc) + if to_w == myid() + data_converted = f(move(from_proc, to_proc, data)) + return tochunk(data_converted, to_proc, dest_space) + end + return remotecall_fetch(to_w, from_proc, to_proc, dest_space, data) do from_proc, to_proc, dest_space, data + data_converted = f(move(from_proc, to_proc, data)) + return tochunk(data_converted, to_proc, dest_space) end end const ALIASED_OBJECT_CACHE = TaskLocalValue{Union{Dict{AbstractAliasing,Chunk}, Nothing}}(()->nothing) diff --git a/src/memory-spaces.jl b/src/memory-spaces.jl index fcc3dbf0b..4124bbba6 100644 --- a/src/memory-spaces.jl +++ b/src/memory-spaces.jl @@ -382,8 +382,14 @@ end aliasing(::String) = NoAliasing() # FIXME: Not necessarily true aliasing(::Symbol) = NoAliasing() aliasing(::Type) = NoAliasing() -aliasing(x::Chunk, T) = remotecall_fetch(root_worker_id(x.processor), x, T) do x, T - aliasing(unwrap(x), T) +function aliasing(x::Chunk, T) + @assert x.handle isa DRef + if root_worker_id(x.processor) == myid() + return aliasing(unwrap(x), T) + end + return remotecall_fetch(root_worker_id(x.processor), x, T) do x, T + aliasing(unwrap(x), T) + end end aliasing(x::Chunk) = remotecall_fetch(root_worker_id(x.processor), x) do x aliasing(unwrap(x)) From 6cfba3fff83facc387a35828f24ff88014532a84 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Wed, 15 Oct 2025 17:52:46 -0700 Subject: [PATCH 05/28] thunk: Remove unnecessary scope allocations --- src/thunk.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/thunk.jl b/src/thunk.jl index d1701e3ef..e13e299f0 100644 --- a/src/thunk.jl +++ b/src/thunk.jl @@ -17,8 +17,6 @@ function unset!(spec::ThunkSpec, _) spec.id = 0 spec.cache_ref = nothing spec.affinity = nothing - compute_scope = DefaultScope() - result_scope = AnyScope() spec.options = nothing end From 2f7ca29e2df1d52f772acfa58357aafa94fd57a4 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 11 Nov 2025 15:30:36 -0700 Subject: [PATCH 06/28] test/datadeps: Remove aliasing=false tests --- test/datadeps.jl | 411 +++++++++++++++++++++++------------------------ 1 file changed, 203 insertions(+), 208 deletions(-) diff --git a/test/datadeps.jl b/test/datadeps.jl index cd83be95f..4fb873454 100644 --- a/test/datadeps.jl +++ b/test/datadeps.jl @@ -177,16 +177,15 @@ end @everywhere mut_V!(V) = (V .= 1;) function test_datadeps(;args_chunks::Bool, args_thunks::Bool, - args_loc::Int, - aliasing::Bool) + args_loc::Int) # Returns last value - @test Dagger.spawn_datadeps(;aliasing) do + @test Dagger.spawn_datadeps() do 42 end == 42 # Tasks are started and finished as spawn_datadeps returns ts = [] - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do for i in 1:5 t = Dagger.@spawn sleep(0.1) @test !istaskstarted(t) @@ -195,7 +194,7 @@ function test_datadeps(;args_chunks::Bool, @test all(istaskdone, ts) # Rethrows any task exceptions - @test_throws Exception Dagger.spawn_datadeps(;aliasing) do + @test_throws Exception Dagger.spawn_datadeps() do Dagger.@spawn error("Test") end @@ -209,7 +208,7 @@ function test_datadeps(;args_chunks::Bool, # Task return values can be tracked ts = [] logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do t1 = Dagger.@spawn fill(42, 1) push!(ts, t1) push!(ts, Dagger.@spawn copyto!(Out(A), In(t1))) @@ -224,7 +223,7 @@ function test_datadeps(;args_chunks::Bool, # R->R Non-Aliasing ts = [] logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do push!(ts, Dagger.@spawn do_nothing(In(A))) push!(ts, Dagger.@spawn do_nothing(In(A))) end @@ -236,7 +235,7 @@ function test_datadeps(;args_chunks::Bool, # R->W Aliasing ts = [] logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do push!(ts, Dagger.@spawn do_nothing(In(A))) push!(ts, Dagger.@spawn do_nothing(Out(A))) end @@ -248,7 +247,7 @@ function test_datadeps(;args_chunks::Bool, # W->W Aliasing ts = [] logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do push!(ts, Dagger.@spawn do_nothing(Out(A))) push!(ts, Dagger.@spawn do_nothing(Out(A))) end @@ -260,7 +259,7 @@ function test_datadeps(;args_chunks::Bool, # R->R Non-Self-Aliasing ts = [] logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do push!(ts, Dagger.@spawn do_nothing(In(A), In(A))) push!(ts, Dagger.@spawn do_nothing(In(A), In(A))) end @@ -272,7 +271,7 @@ function test_datadeps(;args_chunks::Bool, # R->W Self-Aliasing ts = [] logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do push!(ts, Dagger.@spawn do_nothing(In(A), In(A))) push!(ts, Dagger.@spawn do_nothing(Out(A), Out(A))) end @@ -284,7 +283,7 @@ function test_datadeps(;args_chunks::Bool, # W->W Self-Aliasing ts = [] logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do push!(ts, Dagger.@spawn do_nothing(Out(A), Out(A))) push!(ts, Dagger.@spawn do_nothing(Out(A), Out(A))) end @@ -293,197 +292,195 @@ function test_datadeps(;args_chunks::Bool, test_task_dominators(logs, tid_1, []; all_tids=[tid_1, tid_2]) test_task_dominators(logs, tid_2, [tid_1]; all_tids=[tid_1, tid_2]) - if aliasing - function wrap_chunk_thunk(f, args...) - if args_thunks || args_chunks - result = Dagger.@spawn scope=Dagger.scope(worker=args_loc) f(args...) - if args_thunks - return result - elseif args_chunks - return fetch(result; raw=true) - end - else - # N.B. We don't allocate remotely for raw data - return f(args...) - end - end - B = wrap_chunk_thunk(rand, 4, 4) - - # Views - B_ul = wrap_chunk_thunk(view, B, 1:2, 1:2) - B_ur = wrap_chunk_thunk(view, B, 1:2, 3:4) - B_ll = wrap_chunk_thunk(view, B, 3:4, 1:2) - B_lr = wrap_chunk_thunk(view, B, 3:4, 3:4) - B_mid = wrap_chunk_thunk(view, B, 2:3, 2:3) - for (B_name, B_view) in ( - (:B_ul, B_ul), - (:B_ur, B_ur), - (:B_ll, B_ll), - (:B_lr, B_lr), - (:B_mid, B_mid)) - @test Dagger.will_alias(Dagger.aliasing(B), Dagger.aliasing(B_view)) - B_view === B_mid && continue - @test Dagger.will_alias(Dagger.aliasing(B_mid), Dagger.aliasing(B_view)) - end - local t_A, t_B, t_ul, t_ur, t_ll, t_lr, t_mid - local t_ul2, t_ur2, t_ll2, t_lr2 - logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do - t_A = Dagger.@spawn do_nothing(InOut(A)) - t_B = Dagger.@spawn do_nothing(InOut(B)) - t_ul = Dagger.@spawn do_nothing(InOut(B_ul)) - t_ur = Dagger.@spawn do_nothing(InOut(B_ur)) - t_ll = Dagger.@spawn do_nothing(InOut(B_ll)) - t_lr = Dagger.@spawn do_nothing(InOut(B_lr)) - t_mid = Dagger.@spawn do_nothing(InOut(B_mid)) - t_ul2 = Dagger.@spawn do_nothing(InOut(B_ul)) - t_ur2 = Dagger.@spawn do_nothing(InOut(B_ur)) - t_ll2 = Dagger.@spawn do_nothing(InOut(B_ll)) - t_lr2 = Dagger.@spawn do_nothing(InOut(B_lr)) + function wrap_chunk_thunk(f, args...) + if args_thunks || args_chunks + result = Dagger.@spawn scope=Dagger.scope(worker=args_loc) f(args...) + if args_thunks + return result + elseif args_chunks + return fetch(result; raw=true) end + else + # N.B. We don't allocate remotely for raw data + return f(args...) end - tid_A, tid_B, tid_ul, tid_ur, tid_ll, tid_lr, tid_mid = - task_id.([t_A, t_B, t_ul, t_ur, t_ll, t_lr, t_mid]) - tid_ul2, tid_ur2, tid_ll2, tid_lr2 = - task_id.([t_ul2, t_ur2, t_ll2, t_lr2]) - tids_all = [tid_A, tid_B, tid_ul, tid_ur, tid_ll, tid_lr, tid_mid, - tid_ul2, tid_ur2, tid_ll2, tid_lr2] - test_task_dominators(logs, tid_A, []; all_tids=tids_all) - test_task_dominators(logs, tid_B, []; all_tids=tids_all) - test_task_dominators(logs, tid_ul, [tid_B]; all_tids=tids_all) - test_task_dominators(logs, tid_ur, [tid_B]; all_tids=tids_all) - test_task_dominators(logs, tid_ll, [tid_B]; all_tids=tids_all) - test_task_dominators(logs, tid_lr, [tid_B]; all_tids=tids_all) - test_task_dominators(logs, tid_mid, [tid_B, tid_ul, tid_ur, tid_ll, tid_lr]; all_tids=tids_all) - test_task_dominators(logs, tid_ul2, [tid_B, tid_mid, tid_ul]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_ur2, [tid_B, tid_mid, tid_ur]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_ll2, [tid_B, tid_mid, tid_ll]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_lr2, [tid_B, tid_mid, tid_lr]; all_tids=tids_all, nondom_check=false) - - # (Unit)Upper/LowerTriangular and Diagonal - B_upper = wrap_chunk_thunk(UpperTriangular, B) - B_unitupper = wrap_chunk_thunk(UnitUpperTriangular, B) - B_lower = wrap_chunk_thunk(LowerTriangular, B) - B_unitlower = wrap_chunk_thunk(UnitLowerTriangular, B) - for (B_name, B_view) in ( - (:B_upper, B_upper), - (:B_unitupper, B_unitupper), - (:B_lower, B_lower), - (:B_unitlower, B_unitlower)) - @test Dagger.will_alias(Dagger.aliasing(B), Dagger.aliasing(B_view)) - end - @test Dagger.will_alias(Dagger.aliasing(B_upper), Dagger.aliasing(B_lower)) - @test !Dagger.will_alias(Dagger.aliasing(B_unitupper), Dagger.aliasing(B_unitlower)) - @test Dagger.will_alias(Dagger.aliasing(B_upper), Dagger.aliasing(B_unitupper)) - @test Dagger.will_alias(Dagger.aliasing(B_lower), Dagger.aliasing(B_unitlower)) - - @test Dagger.will_alias(Dagger.aliasing(B_upper), Dagger.aliasing(B, Diagonal)) - @test Dagger.will_alias(Dagger.aliasing(B_lower), Dagger.aliasing(B, Diagonal)) - @test !Dagger.will_alias(Dagger.aliasing(B_unitupper), Dagger.aliasing(B, Diagonal)) - @test !Dagger.will_alias(Dagger.aliasing(B_unitlower), Dagger.aliasing(B, Diagonal)) - - local t_A, t_B, t_upper, t_unitupper, t_lower, t_unitlower, t_diag - local t_upper2, t_unitupper2, t_lower2, t_unitlower2 - logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do - t_A = Dagger.@spawn do_nothing(InOut(A)) - t_B = Dagger.@spawn do_nothing(InOut(B)) - t_upper = Dagger.@spawn do_nothing(InOut(B_upper)) - t_unitupper = Dagger.@spawn do_nothing(InOut(B_unitupper)) - t_lower = Dagger.@spawn do_nothing(InOut(B_lower)) - t_unitlower = Dagger.@spawn do_nothing(InOut(B_unitlower)) - t_diag = Dagger.@spawn do_nothing(Deps(B, InOut(Diagonal))) - t_unitlower2 = Dagger.@spawn do_nothing(InOut(B_unitlower)) - t_lower2 = Dagger.@spawn do_nothing(InOut(B_lower)) - t_unitupper2 = Dagger.@spawn do_nothing(InOut(B_unitupper)) - t_upper2 = Dagger.@spawn do_nothing(InOut(B_upper)) - end + end + B = wrap_chunk_thunk(rand, 4, 4) + + # Views + B_ul = wrap_chunk_thunk(view, B, 1:2, 1:2) + B_ur = wrap_chunk_thunk(view, B, 1:2, 3:4) + B_ll = wrap_chunk_thunk(view, B, 3:4, 1:2) + B_lr = wrap_chunk_thunk(view, B, 3:4, 3:4) + B_mid = wrap_chunk_thunk(view, B, 2:3, 2:3) + for (B_name, B_view) in ( + (:B_ul, B_ul), + (:B_ur, B_ur), + (:B_ll, B_ll), + (:B_lr, B_lr), + (:B_mid, B_mid)) + @test Dagger.will_alias(Dagger.aliasing(B), Dagger.aliasing(B_view)) + B_view === B_mid && continue + @test Dagger.will_alias(Dagger.aliasing(B_mid), Dagger.aliasing(B_view)) + end + local t_A, t_B, t_ul, t_ur, t_ll, t_lr, t_mid + local t_ul2, t_ur2, t_ll2, t_lr2 + logs = with_logs() do + Dagger.spawn_datadeps() do + t_A = Dagger.@spawn do_nothing(InOut(A)) + t_B = Dagger.@spawn do_nothing(InOut(B)) + t_ul = Dagger.@spawn do_nothing(InOut(B_ul)) + t_ur = Dagger.@spawn do_nothing(InOut(B_ur)) + t_ll = Dagger.@spawn do_nothing(InOut(B_ll)) + t_lr = Dagger.@spawn do_nothing(InOut(B_lr)) + t_mid = Dagger.@spawn do_nothing(InOut(B_mid)) + t_ul2 = Dagger.@spawn do_nothing(InOut(B_ul)) + t_ur2 = Dagger.@spawn do_nothing(InOut(B_ur)) + t_ll2 = Dagger.@spawn do_nothing(InOut(B_ll)) + t_lr2 = Dagger.@spawn do_nothing(InOut(B_lr)) end - tid_A, tid_B, tid_upper, tid_unitupper, tid_lower, tid_unitlower, tid_diag = - task_id.([t_A, t_B, t_upper, t_unitupper, t_lower, t_unitlower, t_diag]) - tid_upper2, tid_unitupper2, tid_lower2, tid_unitlower2 = - task_id.([t_upper2, t_unitupper2, t_lower2, t_unitlower2]) - tids_all = [tid_A, tid_B, tid_upper, tid_unitupper, tid_lower, tid_unitlower, tid_diag, - tid_upper2, tid_unitupper2, tid_lower2, tid_unitlower2] - test_task_dominators(logs, tid_A, []; all_tids=tids_all) - test_task_dominators(logs, tid_B, []; all_tids=tids_all) - # FIXME: Proper non-dominance checks - test_task_dominators(logs, tid_upper, [tid_B]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_unitupper, [tid_B, tid_upper]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_lower, [tid_B, tid_upper]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_unitlower, [tid_B, tid_lower]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_diag, [tid_B, tid_upper, tid_lower]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_unitlower2, [tid_B, tid_lower, tid_unitlower]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_lower2, [tid_B, tid_lower, tid_unitlower, tid_diag, tid_unitlower2]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_unitupper2, [tid_B, tid_upper, tid_unitupper]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_upper2, [tid_B, tid_upper, tid_unitupper, tid_diag, tid_unitupper2]; all_tids=tids_all, nondom_check=false) - - # Additional aliasing tests - views_overlap(x, y) = Dagger.will_alias(Dagger.aliasing(x), Dagger.aliasing(y)) - - A = wrap_chunk_thunk(identity, B) - - A_r1 = wrap_chunk_thunk(view, A, 1:1, 1:4) - A_r2 = wrap_chunk_thunk(view, A, 2:2, 1:4) - B_r1 = wrap_chunk_thunk(view, B, 1:1, 1:4) - B_r2 = wrap_chunk_thunk(view, B, 2:2, 1:4) - - A_c1 = wrap_chunk_thunk(view, A, 1:4, 1:1) - A_c2 = wrap_chunk_thunk(view, A, 1:4, 2:2) - B_c1 = wrap_chunk_thunk(view, B, 1:4, 1:1) - B_c2 = wrap_chunk_thunk(view, B, 1:4, 2:2) - - A_mid = wrap_chunk_thunk(view, A, 2:3, 2:3) - B_mid = wrap_chunk_thunk(view, B, 2:3, 2:3) - - @test views_overlap(A_r1, A_r1) - @test views_overlap(B_r1, B_r1) - @test views_overlap(A_c1, A_c1) - @test views_overlap(B_c1, B_c1) - - @test views_overlap(A_r1, B_r1) - @test views_overlap(A_r2, B_r2) - @test views_overlap(A_c1, B_c1) - @test views_overlap(A_c2, B_c2) - - @test !views_overlap(A_r1, A_r2) - @test !views_overlap(B_r1, B_r2) - @test !views_overlap(A_c1, A_c2) - @test !views_overlap(B_c1, B_c2) - - @test views_overlap(A_r1, A_c1) - @test views_overlap(A_r1, B_c1) - @test views_overlap(A_r2, A_c2) - @test views_overlap(A_r2, B_c2) - - for (name, mid) in ((:A_mid, A_mid), (:B_mid, B_mid)) - @test !views_overlap(A_r1, mid) - @test !views_overlap(B_r1, mid) - @test !views_overlap(A_c1, mid) - @test !views_overlap(B_c1, mid) - - @test views_overlap(A_r2, mid) - @test views_overlap(B_r2, mid) - @test views_overlap(A_c2, mid) - @test views_overlap(B_c2, mid) + end + tid_A, tid_B, tid_ul, tid_ur, tid_ll, tid_lr, tid_mid = + task_id.([t_A, t_B, t_ul, t_ur, t_ll, t_lr, t_mid]) + tid_ul2, tid_ur2, tid_ll2, tid_lr2 = + task_id.([t_ul2, t_ur2, t_ll2, t_lr2]) + tids_all = [tid_A, tid_B, tid_ul, tid_ur, tid_ll, tid_lr, tid_mid, + tid_ul2, tid_ur2, tid_ll2, tid_lr2] + test_task_dominators(logs, tid_A, []; all_tids=tids_all) + test_task_dominators(logs, tid_B, []; all_tids=tids_all) + test_task_dominators(logs, tid_ul, [tid_B]; all_tids=tids_all) + test_task_dominators(logs, tid_ur, [tid_B]; all_tids=tids_all) + test_task_dominators(logs, tid_ll, [tid_B]; all_tids=tids_all) + test_task_dominators(logs, tid_lr, [tid_B]; all_tids=tids_all) + test_task_dominators(logs, tid_mid, [tid_B, tid_ul, tid_ur, tid_ll, tid_lr]; all_tids=tids_all) + test_task_dominators(logs, tid_ul2, [tid_B, tid_mid, tid_ul]; all_tids=tids_all, nondom_check=false) + test_task_dominators(logs, tid_ur2, [tid_B, tid_mid, tid_ur]; all_tids=tids_all, nondom_check=false) + test_task_dominators(logs, tid_ll2, [tid_B, tid_mid, tid_ll]; all_tids=tids_all, nondom_check=false) + test_task_dominators(logs, tid_lr2, [tid_B, tid_mid, tid_lr]; all_tids=tids_all, nondom_check=false) + + # (Unit)Upper/LowerTriangular and Diagonal + B_upper = wrap_chunk_thunk(UpperTriangular, B) + B_unitupper = wrap_chunk_thunk(UnitUpperTriangular, B) + B_lower = wrap_chunk_thunk(LowerTriangular, B) + B_unitlower = wrap_chunk_thunk(UnitLowerTriangular, B) + for (B_name, B_view) in ( + (:B_upper, B_upper), + (:B_unitupper, B_unitupper), + (:B_lower, B_lower), + (:B_unitlower, B_unitlower)) + @test Dagger.will_alias(Dagger.aliasing(B), Dagger.aliasing(B_view)) + end + @test Dagger.will_alias(Dagger.aliasing(B_upper), Dagger.aliasing(B_lower)) + @test !Dagger.will_alias(Dagger.aliasing(B_unitupper), Dagger.aliasing(B_unitlower)) + @test Dagger.will_alias(Dagger.aliasing(B_upper), Dagger.aliasing(B_unitupper)) + @test Dagger.will_alias(Dagger.aliasing(B_lower), Dagger.aliasing(B_unitlower)) + + @test Dagger.will_alias(Dagger.aliasing(B_upper), Dagger.aliasing(B, Diagonal)) + @test Dagger.will_alias(Dagger.aliasing(B_lower), Dagger.aliasing(B, Diagonal)) + @test !Dagger.will_alias(Dagger.aliasing(B_unitupper), Dagger.aliasing(B, Diagonal)) + @test !Dagger.will_alias(Dagger.aliasing(B_unitlower), Dagger.aliasing(B, Diagonal)) + + local t_A, t_B, t_upper, t_unitupper, t_lower, t_unitlower, t_diag + local t_upper2, t_unitupper2, t_lower2, t_unitlower2 + logs = with_logs() do + Dagger.spawn_datadeps() do + t_A = Dagger.@spawn do_nothing(InOut(A)) + t_B = Dagger.@spawn do_nothing(InOut(B)) + t_upper = Dagger.@spawn do_nothing(InOut(B_upper)) + t_unitupper = Dagger.@spawn do_nothing(InOut(B_unitupper)) + t_lower = Dagger.@spawn do_nothing(InOut(B_lower)) + t_unitlower = Dagger.@spawn do_nothing(InOut(B_unitlower)) + t_diag = Dagger.@spawn do_nothing(Deps(B, InOut(Diagonal))) + t_unitlower2 = Dagger.@spawn do_nothing(InOut(B_unitlower)) + t_lower2 = Dagger.@spawn do_nothing(InOut(B_lower)) + t_unitupper2 = Dagger.@spawn do_nothing(InOut(B_unitupper)) + t_upper2 = Dagger.@spawn do_nothing(InOut(B_upper)) end + end + tid_A, tid_B, tid_upper, tid_unitupper, tid_lower, tid_unitlower, tid_diag = + task_id.([t_A, t_B, t_upper, t_unitupper, t_lower, t_unitlower, t_diag]) + tid_upper2, tid_unitupper2, tid_lower2, tid_unitlower2 = + task_id.([t_upper2, t_unitupper2, t_lower2, t_unitlower2]) + tids_all = [tid_A, tid_B, tid_upper, tid_unitupper, tid_lower, tid_unitlower, tid_diag, + tid_upper2, tid_unitupper2, tid_lower2, tid_unitlower2] + test_task_dominators(logs, tid_A, []; all_tids=tids_all) + test_task_dominators(logs, tid_B, []; all_tids=tids_all) + # FIXME: Proper non-dominance checks + test_task_dominators(logs, tid_upper, [tid_B]; all_tids=tids_all, nondom_check=false) + test_task_dominators(logs, tid_unitupper, [tid_B, tid_upper]; all_tids=tids_all, nondom_check=false) + test_task_dominators(logs, tid_lower, [tid_B, tid_upper]; all_tids=tids_all, nondom_check=false) + test_task_dominators(logs, tid_unitlower, [tid_B, tid_lower]; all_tids=tids_all, nondom_check=false) + test_task_dominators(logs, tid_diag, [tid_B, tid_upper, tid_lower]; all_tids=tids_all, nondom_check=false) + test_task_dominators(logs, tid_unitlower2, [tid_B, tid_lower, tid_unitlower]; all_tids=tids_all, nondom_check=false) + test_task_dominators(logs, tid_lower2, [tid_B, tid_lower, tid_unitlower, tid_diag, tid_unitlower2]; all_tids=tids_all, nondom_check=false) + test_task_dominators(logs, tid_unitupper2, [tid_B, tid_upper, tid_unitupper]; all_tids=tids_all, nondom_check=false) + test_task_dominators(logs, tid_upper2, [tid_B, tid_upper, tid_unitupper, tid_diag, tid_unitupper2]; all_tids=tids_all, nondom_check=false) + + # Additional aliasing tests + views_overlap(x, y) = Dagger.will_alias(Dagger.aliasing(x), Dagger.aliasing(y)) + + A = wrap_chunk_thunk(identity, B) + + A_r1 = wrap_chunk_thunk(view, A, 1:1, 1:4) + A_r2 = wrap_chunk_thunk(view, A, 2:2, 1:4) + B_r1 = wrap_chunk_thunk(view, B, 1:1, 1:4) + B_r2 = wrap_chunk_thunk(view, B, 2:2, 1:4) + + A_c1 = wrap_chunk_thunk(view, A, 1:4, 1:1) + A_c2 = wrap_chunk_thunk(view, A, 1:4, 2:2) + B_c1 = wrap_chunk_thunk(view, B, 1:4, 1:1) + B_c2 = wrap_chunk_thunk(view, B, 1:4, 2:2) + + A_mid = wrap_chunk_thunk(view, A, 2:3, 2:3) + B_mid = wrap_chunk_thunk(view, B, 2:3, 2:3) + + @test views_overlap(A_r1, A_r1) + @test views_overlap(B_r1, B_r1) + @test views_overlap(A_c1, A_c1) + @test views_overlap(B_c1, B_c1) + + @test views_overlap(A_r1, B_r1) + @test views_overlap(A_r2, B_r2) + @test views_overlap(A_c1, B_c1) + @test views_overlap(A_c2, B_c2) + + @test !views_overlap(A_r1, A_r2) + @test !views_overlap(B_r1, B_r2) + @test !views_overlap(A_c1, A_c2) + @test !views_overlap(B_c1, B_c2) + + @test views_overlap(A_r1, A_c1) + @test views_overlap(A_r1, B_c1) + @test views_overlap(A_r2, A_c2) + @test views_overlap(A_r2, B_c2) + + for (name, mid) in ((:A_mid, A_mid), (:B_mid, B_mid)) + @test !views_overlap(A_r1, mid) + @test !views_overlap(B_r1, mid) + @test !views_overlap(A_c1, mid) + @test !views_overlap(B_c1, mid) + + @test views_overlap(A_r2, mid) + @test views_overlap(B_r2, mid) + @test views_overlap(A_c2, mid) + @test views_overlap(B_c2, mid) + end - @test views_overlap(A_mid, A_mid) - @test views_overlap(A_mid, B_mid) + @test views_overlap(A_mid, A_mid) + @test views_overlap(A_mid, B_mid) - # SubArray hashing - V = zeros(3) - Dagger.spawn_datadeps(;aliasing) do - Dagger.@spawn mut_V!(InOut(view(V, 1:2))) - Dagger.@spawn mut_V!(InOut(view(V, 2:3))) - end - @test fetch(V) == [1, 1, 1] + # SubArray hashing + V = zeros(3) + Dagger.spawn_datadeps() do + Dagger.@spawn mut_V!(InOut(view(V, 1:2))) + Dagger.@spawn mut_V!(InOut(view(V, 2:3))) end + @test fetch(V) == [1, 1, 1] # FIXME: Deps # Outer Scope - exec_procs = fetch.(Dagger.spawn_datadeps(;aliasing) do + exec_procs = fetch.(Dagger.spawn_datadeps() do [Dagger.@spawn Dagger.task_processor() for i in 1:10] end) unique!(exec_procs) @@ -499,7 +496,7 @@ function test_datadeps(;args_chunks::Bool, end # Inner Scope - @test_throws Dagger.Sch.SchedulingException Dagger.spawn_datadeps(;aliasing) do + @test_throws Dagger.Sch.SchedulingException Dagger.spawn_datadeps() do Dagger.@spawn scope=Dagger.ExactScope(Dagger.ThreadProc(1, 5000)) 1+1 end @@ -528,7 +525,7 @@ function test_datadeps(;args_chunks::Bool, C = Dagger.@spawn scope=Dagger.scope(worker=args_loc) copy(C) D = Dagger.@spawn scope=Dagger.scope(worker=args_loc) copy(D) end - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do Dagger.@spawn add!(InOut(B), In(A)) Dagger.@spawn add!(InOut(C), In(A)) Dagger.@spawn add!(InOut(C), In(B)) @@ -545,7 +542,7 @@ function test_datadeps(;args_chunks::Bool, elseif args_thunks As = map(A->(Dagger.@spawn scope=Dagger.scope(worker=args_loc) copy(A)), As) end - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do to_reduce = Vector[] push!(to_reduce, As) while !isempty(to_reduce) @@ -576,7 +573,7 @@ function test_datadeps(;args_chunks::Bool, elseif args_thunks M = map(m->(Dagger.@spawn scope=Dagger.scope(worker=args_loc) copy(m)), M) end - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do for k in range(1, mt) Dagger.@spawn LAPACK.potrf!('L', InOut(M[k, k])) for _m in range(k+1, mt) @@ -596,18 +593,16 @@ function test_datadeps(;args_chunks::Bool, @test isapprox(M_dense, expected) end -@testset "$(aliasing ? "With" : "Without") Aliasing Support" for aliasing in (true, false) - @testset "$args_mode Data" for args_mode in (:Raw, :Chunk, :Thunk) - args_chunks = args_mode == :Chunk - args_thunks = args_mode == :Thunk - for nw in (1, 2) - args_loc = nw == 2 ? 2 : 1 - for nt in (1, 2) - if nprocs() >= nw && Threads.nthreads() >= nt - @testset "$nw Workers, $nt Threads" begin - Dagger.with_options(;scope=Dagger.scope(workers=1:nw, threads=1:nt)) do - test_datadeps(;args_chunks, args_thunks, args_loc, aliasing) - end +@testset @testset "$args_mode Data" for args_mode in (:Raw, :Chunk, :Thunk) + args_chunks = args_mode == :Chunk + args_thunks = args_mode == :Thunk + for nw in (1, 2) + args_loc = nw == 2 ? 2 : 1 + for nt in (1, 2) + if nprocs() >= nw && Threads.nthreads() >= nt + @testset "$nw Workers, $nt Threads" begin + Dagger.with_options(;scope=Dagger.scope(workers=1:nw, threads=1:nt)) do + test_datadeps(;args_chunks, args_thunks, args_loc) end end end From f580d5ff39d7eb67f5823be821a503acc19ead86 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Sat, 15 Nov 2025 10:47:40 -0700 Subject: [PATCH 07/28] datadeps: ainfo_arg must track ainfo -> multiple arg_w --- src/datadeps/aliasing.jl | 22 +++++++++++----------- src/datadeps/remainders.jl | 2 +- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/datadeps/aliasing.jl b/src/datadeps/aliasing.jl index 39b07cb28..3760b3a26 100644 --- a/src/datadeps/aliasing.jl +++ b/src/datadeps/aliasing.jl @@ -264,7 +264,7 @@ struct DataDepsState # The mapping of ainfo to argument and dep_mod # Used to lookup which argument and dep_mod a given ainfo is generated from # N.B. This is a mapping for remote argument copies - ainfo_arg::Dict{AliasingWrapper,ArgumentWrapper} + ainfo_arg::Dict{AliasingWrapper,Set{ArgumentWrapper}} # The history of writes (direct or indirect) to each argument and dep_mod, in terms of ainfos directly written to, and the memory space they were written to # Updated when a new write happens on an overlapping ainfo @@ -316,7 +316,7 @@ struct DataDepsState remote_args = Dict{MemorySpace,IdDict{Any,Any}}() remote_arg_to_original = IdDict{Any,Any}() remote_arg_w = Dict{ArgumentWrapper,Dict{MemorySpace,ArgumentWrapper}}() - ainfo_arg = Dict{AliasingWrapper,ArgumentWrapper}() + ainfo_arg = Dict{AliasingWrapper,Set{ArgumentWrapper}}() arg_owner = Dict{ArgumentWrapper,MemorySpace}() arg_overlaps = Dict{ArgumentWrapper,Set{ArgumentWrapper}}() ainfo_backing_chunk = Dict{MemorySpace,Dict{AbstractAliasing,Chunk}}() @@ -459,10 +459,9 @@ function aliasing!(state::DataDepsState, target_space::MemorySpace, arg_w::Argum # Update the mapping of ainfo to argument and dep_mod if !haskey(state.ainfo_arg, ainfo) - state.ainfo_arg[ainfo] = remote_arg_w - else - @assert state.ainfo_arg[ainfo] == remote_arg_w + state.ainfo_arg[ainfo] = Set{ArgumentWrapper}([remote_arg_w]) end + push!(state.ainfo_arg[ainfo], remote_arg_w) # Populate info for the new ainfo populate_ainfo!(state, arg_w, ainfo, target_space) @@ -484,12 +483,13 @@ function populate_ainfo!(state::DataDepsState, original_arg_w::ArgumentWrapper, push!(state.ainfos_overlaps[other_ainfo], target_ainfo) # Add overlapping history to our own - other_remote_arg_w = state.ainfo_arg[other_ainfo] - other_arg = state.remote_arg_to_original[other_remote_arg_w.arg] - other_arg_w = ArgumentWrapper(other_arg, other_remote_arg_w.dep_mod) - push!(state.arg_overlaps[original_arg_w], other_arg_w) - push!(state.arg_overlaps[other_arg_w], original_arg_w) - merge_history!(state, original_arg_w, other_arg_w) + for other_remote_arg_w in state.ainfo_arg[other_ainfo] + other_arg = state.remote_arg_to_original[other_remote_arg_w.arg] + other_arg_w = ArgumentWrapper(other_arg, other_remote_arg_w.dep_mod) + push!(state.arg_overlaps[original_arg_w], other_arg_w) + push!(state.arg_overlaps[other_arg_w], original_arg_w) + merge_history!(state, original_arg_w, other_arg_w) + end end state.ainfos_overlaps[target_ainfo] = overlaps diff --git a/src/datadeps/remainders.jl b/src/datadeps/remainders.jl index b6f3d3c51..5312d0935 100644 --- a/src/datadeps/remainders.jl +++ b/src/datadeps/remainders.jl @@ -154,7 +154,7 @@ function compute_remainder_for_arg!(state::DataDepsState, end # Lookup all memory spans for arg_w in these spaces - other_remote_arg_w = state.ainfo_arg[other_ainfo] + other_remote_arg_w = first(collect(state.ainfo_arg[other_ainfo])) other_arg_w = ArgumentWrapper(state.remote_arg_to_original[other_remote_arg_w.arg], other_remote_arg_w.dep_mod) other_ainfos = Vector{Vector{LocalMemorySpan}}() for space in spaces From 894171b2a662c70a73be8565e722bdf3826165d1 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Sat, 15 Nov 2025 10:49:06 -0700 Subject: [PATCH 08/28] datadeps: Fix broken ChunkView unwrapping --- src/datadeps/aliasing.jl | 18 +++++------------- src/datadeps/chunkview.jl | 29 +++++++++++++++++++++-------- 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/src/datadeps/aliasing.jl b/src/datadeps/aliasing.jl index 3760b3a26..5a3913644 100644 --- a/src/datadeps/aliasing.jl +++ b/src/datadeps/aliasing.jl @@ -634,9 +634,6 @@ isremotehandle(x) = false isremotehandle(x::DTask) = true isremotehandle(x::Chunk) = true function generate_slot!(state::DataDepsState, dest_space, data) - if data isa DTask - data = fetch(data; raw=true) - end # N.B. We do not perform any sync/copy with the current owner of the data, # because all we want here is to make a copy of some version of the data, # even if the data is not up to date. @@ -645,16 +642,11 @@ function generate_slot!(state::DataDepsState, dest_space, data) from_proc = first(processors(orig_space)) dest_space_args = get!(IdDict{Any,Any}, state.remote_args, dest_space) ALIASED_OBJECT_CACHE[] = get!(Dict{AbstractAliasing,Chunk}, state.ainfo_backing_chunk, dest_space) - if orig_space == dest_space && (data isa Chunk || !isremotehandle(data)) - # Fast path for local data that's already in a Chunk or not a remote handle needing rewrapping - data_chunk = tochunk(data, from_proc) - else - ctx = Sch.eager_context() - id = rand(Int) - @maybelog ctx timespan_start(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data)) - data_chunk = move_rewrap(from_proc, to_proc, data) - @maybelog ctx timespan_finish(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data=data_chunk)) - end + ctx = Sch.eager_context() + id = rand(Int) + @maybelog ctx timespan_start(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data)) + data_chunk = move_rewrap(from_proc, to_proc, orig_space, dest_space, data) + @maybelog ctx timespan_finish(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data=data_chunk)) @assert memory_space(data_chunk) == dest_space "space mismatch! $dest_space (dest) != $(memory_space(data_chunk)) (actual) ($(typeof(data)) (data) vs. $(typeof(data_chunk)) (chunk)), spaces ($orig_space -> $dest_space)" dest_space_args[data] = data_chunk state.remote_arg_to_original[data_chunk] = data diff --git a/src/datadeps/chunkview.jl b/src/datadeps/chunkview.jl index 04b581c17..e6e1d4840 100644 --- a/src/datadeps/chunkview.jl +++ b/src/datadeps/chunkview.jl @@ -27,38 +27,51 @@ end Base.view(c::DTask, slices...) = view(fetch(c; raw=true), slices...) -aliasing(x::ChunkView) = - throw(ConcurrencyViolationError("Cannot query aliasing of a ChunkView directly")) +function aliasing(x::ChunkView{N}) where N + return remotecall_fetch(root_worker_id(x.chunk.processor), x.chunk, x.slices) do x, slices + x = unwrap(x) + v = view(x, slices...) + return aliasing(v) + end +end memory_space(x::ChunkView) = memory_space(x.chunk) isremotehandle(x::ChunkView) = true # This definition is here because it's so similar to ChunkView -function move_rewrap(from_proc::Processor, to_proc::Processor, v::SubArray) +function move_rewrap(from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::SubArray) to_w = root_worker_id(to_proc) p_chunk = aliased_object!(parent(v)) do p - return remotecall_fetch(to_w, from_proc, to_proc, p) do from_proc, to_proc, p + return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p) do from_proc, to_proc, from_space, to_space, p return tochunk(move(from_proc, to_proc, p), to_proc) end end inds = parentindices(v) - return remotecall_fetch(to_w, from_proc, to_proc, p_chunk, inds) do from_proc, to_proc, p_chunk, inds + return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk, inds) do from_proc, to_proc, from_space, to_space, p_chunk, inds p_new = move(from_proc, to_proc, p_chunk) v_new = view(p_new, inds...) return tochunk(v_new, to_proc) end end -function move_rewrap(from_proc::Processor, to_proc::Processor, slice::ChunkView) +function move_rewrap(from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, slice::ChunkView) to_w = root_worker_id(to_proc) p_chunk = aliased_object!(slice.chunk) do p_chunk - return remotecall_fetch(to_w, from_proc, to_proc, p_chunk) do from_proc, to_proc, p_chunk + return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk) do from_proc, to_proc, from_space, to_space, p_chunk return tochunk(move(from_proc, to_proc, p_chunk), to_proc) end end - return remotecall_fetch(to_w, from_proc, to_proc, p_chunk, slice.slices) do from_proc, to_proc, p_chunk, inds + return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk, slice.slices) do from_proc, to_proc, from_space, to_space, p_chunk, inds p_new = move(from_proc, to_proc, p_chunk) v_new = view(p_new, inds...) return tochunk(v_new, to_proc) end end +function move(from_proc::Processor, to_proc::Processor, slice::ChunkView) + to_w = root_worker_id(to_proc) + return remotecall_fetch(to_w, from_proc, to_proc, slice.chunk, slice.slices) do from_proc, to_proc, chunk, slices + chunk_new = move(from_proc, to_proc, chunk) + v_new = view(chunk_new, slices...) + return tochunk(v_new, to_proc) + end +end Base.fetch(slice::ChunkView) = view(fetch(slice.chunk), slice.slices...) \ No newline at end of file From 44e203226b0e6332c1e9fe2c34894201d78ce37a Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Sat, 15 Nov 2025 10:49:48 -0700 Subject: [PATCH 09/28] datadeps: Signature fixups and small cleanups --- src/datadeps/aliasing.jl | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/src/datadeps/aliasing.jl b/src/datadeps/aliasing.jl index 5a3913644..55363101d 100644 --- a/src/datadeps/aliasing.jl +++ b/src/datadeps/aliasing.jl @@ -502,7 +502,6 @@ function merge_history!(state::DataDepsState, arg_w::ArgumentWrapper, other_arg_ history = state.arg_history[arg_w] @opcounter :merge_history @opcounter :merge_history_complexity length(history) - largest_value_update!(length(history)) origin_space = state.arg_origin[other_arg_w.arg] for other_entry in state.arg_history[other_arg_w] write_num_tuple = HistoryEntry(AliasingWrapper(NoAliasing()), origin_space, other_entry.write_num) @@ -665,29 +664,25 @@ function get_or_generate_slot!(state, dest_space, data) end return state.remote_args[dest_space][data] end -function move_rewrap(from_proc::Processor, to_proc::Processor, data) +function move_rewrap(from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, data) return aliased_object!(data) do data return remotecall_endpoint(identity, from_proc, to_proc, from_space, to_space, data) end end -function remotecall_endpoint(f, from_proc, to_proc, orig_space, dest_space, data) +function remotecall_endpoint(f, from_proc, to_proc, from_space, to_space, data) to_w = root_worker_id(to_proc) if to_w == myid() data_converted = f(move(from_proc, to_proc, data)) - return tochunk(data_converted, to_proc, dest_space) + return tochunk(data_converted, to_proc) end - return remotecall_fetch(to_w, from_proc, to_proc, dest_space, data) do from_proc, to_proc, dest_space, data + return remotecall_fetch(to_w, from_proc, to_proc, to_space, data) do from_proc, to_proc, to_space, data data_converted = f(move(from_proc, to_proc, data)) - return tochunk(data_converted, to_proc, dest_space) + return tochunk(data_converted, to_proc) end end const ALIASED_OBJECT_CACHE = TaskLocalValue{Union{Dict{AbstractAliasing,Chunk}, Nothing}}(()->nothing) @warn "Document these public methods" maxlog=1 # TODO: Use state to cache aliasing() results -function declare_aliased_object!(x; ainfo=aliasing(x, identity)) - cache = ALIASED_OBJECT_CACHE[] - cache[ainfo] = x -end function aliased_object!(x; ainfo=aliasing(x, identity)) cache = ALIASED_OBJECT_CACHE[] if haskey(cache, ainfo) @@ -710,11 +705,6 @@ function aliased_object!(f, x; ainfo=aliasing(x, identity)) end return y end -function aliased_object_unwrap!(x::Chunk) - y = unwrap(x) - ainfo = aliasing(y, identity) - return unwrap(aliased_object!(x; ainfo)) -end struct DataDepsSchedulerState task_to_spec::Dict{DTask,DTaskSpec} From 3681444a08fe9c8645a847b672eb063026219a48 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 9 Dec 2025 16:22:21 -0700 Subject: [PATCH 10/28] datadeps: Fix aliased object detection around Chunks --- src/datadeps/aliasing.jl | 74 ++++++++++++++++++++++++++------------- src/datadeps/chunkview.jl | 53 +++++++++++++++++++++++++--- src/utils/chunks.jl | 3 ++ 3 files changed, 102 insertions(+), 28 deletions(-) diff --git a/src/datadeps/aliasing.jl b/src/datadeps/aliasing.jl index 55363101d..172b4c177 100644 --- a/src/datadeps/aliasing.jl +++ b/src/datadeps/aliasing.jl @@ -640,18 +640,20 @@ function generate_slot!(state::DataDepsState, dest_space, data) to_proc = first(processors(dest_space)) from_proc = first(processors(orig_space)) dest_space_args = get!(IdDict{Any,Any}, state.remote_args, dest_space) - ALIASED_OBJECT_CACHE[] = get!(Dict{AbstractAliasing,Chunk}, state.ainfo_backing_chunk, dest_space) + if !haskey(state.ainfo_backing_chunk, dest_space) + state.ainfo_backing_chunk[dest_space] = Dict{AbstractAliasing,Chunk}() + end + # FIXME: tochunk the cache just once per space + aliased_object_cache = AliasedObjectCache(tochunk(state.ainfo_backing_chunk[dest_space])) ctx = Sch.eager_context() id = rand(Int) @maybelog ctx timespan_start(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data)) - data_chunk = move_rewrap(from_proc, to_proc, orig_space, dest_space, data) + data_chunk = move_rewrap(aliased_object_cache, from_proc, to_proc, orig_space, dest_space, data) @maybelog ctx timespan_finish(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data=data_chunk)) @assert memory_space(data_chunk) == dest_space "space mismatch! $dest_space (dest) != $(memory_space(data_chunk)) (actual) ($(typeof(data)) (data) vs. $(typeof(data_chunk)) (chunk)), spaces ($orig_space -> $dest_space)" dest_space_args[data] = data_chunk state.remote_arg_to_original[data_chunk] = data - ALIASED_OBJECT_CACHE[] = nothing - return dest_space_args[data] end function get_or_generate_slot!(state, dest_space, data) @@ -664,8 +666,47 @@ function get_or_generate_slot!(state, dest_space, data) end return state.remote_args[dest_space][data] end -function move_rewrap(from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, data) - return aliased_object!(data) do data +struct AliasedObjectCache + chunk::Chunk +end +@warn "Document these public methods" maxlog=1 +function Base.haskey(cache::AliasedObjectCache, ainfo::AbstractAliasing) + wid = root_worker_id(cache.chunk) + if wid != myid() + return remotecall_fetch(haskey, wid, cache, ainfo) + end + cache_raw = unwrap(cache.chunk)::Dict{AbstractAliasing,Chunk} + return haskey(cache_raw, ainfo) +end +function Base.getindex(cache::AliasedObjectCache, ainfo::AbstractAliasing) + wid = root_worker_id(cache.chunk) + if wid != myid() + return remotecall_fetch(getindex, wid, cache, ainfo) + end + cache_raw = unwrap(cache.chunk)::Dict{AbstractAliasing,Chunk} + return getindex(cache_raw, ainfo) +end +function Base.setindex!(cache::AliasedObjectCache, value::Chunk, ainfo::AbstractAliasing) + wid = root_worker_id(cache.chunk) + if wid != myid() + return remotecall_fetch(setindex!, wid, cache, value, ainfo) + end + cache_raw = unwrap(cache.chunk)::Dict{AbstractAliasing,Chunk} + cache_raw[ainfo] = value + return +end +function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, data::Chunk) + # Unwrap so that we hit the right dispatch + wid = root_worker_id(data) + if wid != myid() + return remotecall_fetch(move_rewrap, wid, cache, from_proc, to_proc, from_space, to_space, data) + end + data_raw = unwrap(data) + return move_rewrap(cache, from_proc, to_proc, from_space, to_space, data_raw) +end +function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, data) + # For generic data + return aliased_object!(cache, data) do data return remotecall_endpoint(identity, from_proc, to_proc, from_space, to_space, data) end end @@ -680,30 +721,15 @@ function remotecall_endpoint(f, from_proc, to_proc, from_space, to_space, data) return tochunk(data_converted, to_proc) end end -const ALIASED_OBJECT_CACHE = TaskLocalValue{Union{Dict{AbstractAliasing,Chunk}, Nothing}}(()->nothing) -@warn "Document these public methods" maxlog=1 -# TODO: Use state to cache aliasing() results -function aliased_object!(x; ainfo=aliasing(x, identity)) - cache = ALIASED_OBJECT_CACHE[] - if haskey(cache, ainfo) - y = cache[ainfo] - else - @assert x isa Chunk "x must be a Chunk\nUse functor form of aliased_object!" - cache[ainfo] = x - y = x - end - return y -end -function aliased_object!(f, x; ainfo=aliasing(x, identity)) - cache = ALIASED_OBJECT_CACHE[] +function aliased_object!(f, cache::AliasedObjectCache, x; ainfo=aliasing(x, identity)) if haskey(cache, ainfo) - y = cache[ainfo] + return cache[ainfo] else y = f(x) @assert y isa Chunk "Didn't get a Chunk from functor" cache[ainfo] = y + return y end - return y end struct DataDepsSchedulerState diff --git a/src/datadeps/chunkview.jl b/src/datadeps/chunkview.jl index e6e1d4840..60ded6151 100644 --- a/src/datadeps/chunkview.jl +++ b/src/datadeps/chunkview.jl @@ -38,9 +38,9 @@ memory_space(x::ChunkView) = memory_space(x.chunk) isremotehandle(x::ChunkView) = true # This definition is here because it's so similar to ChunkView -function move_rewrap(from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::SubArray) +function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::SubArray) to_w = root_worker_id(to_proc) - p_chunk = aliased_object!(parent(v)) do p + p_chunk = aliased_object!(cache, parent(v)) do p return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p) do from_proc, to_proc, from_space, to_space, p return tochunk(move(from_proc, to_proc, p), to_proc) end @@ -52,9 +52,54 @@ function move_rewrap(from_proc::Processor, to_proc::Processor, from_space::Memor return tochunk(v_new, to_proc) end end -function move_rewrap(from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, slice::ChunkView) +# FIXME: Do this programmatically via recursive dispatch +for wrapper in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular) + @eval function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::$(wrapper)) + to_w = root_worker_id(to_proc) + p_chunk = aliased_object!(cache, parent(v)) do p + return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p) do from_proc, to_proc, from_space, to_space, p + return tochunk(move(from_proc, to_proc, p), to_proc) + end + end + return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk) do from_proc, to_proc, from_space, to_space, p_chunk + p_new = move(from_proc, to_proc, p_chunk) + v_new = $(wrapper)(p_new) + return tochunk(v_new, to_proc) + end + end +end +function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::Base.RefValue) + to_w = root_worker_id(to_proc) + return aliased_object!(cache, v[]) do p + return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p) do from_proc, to_proc, from_space, to_space, p + return tochunk(Ref(move(from_proc, to_proc, p)), to_proc) + end + end +end +#= +function move_rewrap_recursive(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::T) where T + if isstructtype(T) + # Check all object fields (recursive) + for field in fieldnames(T) + value = getfield(x, field) + new_value = aliased_object!(cache, value) do value + return move_rewrap_recursive(cache, from_proc, to_proc, from_space, to_space, value) + end + setfield!(x, field, new_value) + end + return x + else + @warn "Cannot move-rewrap object of type $T" + return x + end +end +move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::String) = x # FIXME: Not necessarily true +move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::Symbol) = x +move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::Type) = x +=# +function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, slice::ChunkView) to_w = root_worker_id(to_proc) - p_chunk = aliased_object!(slice.chunk) do p_chunk + p_chunk = aliased_object!(cache, slice.chunk) do p_chunk return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk) do from_proc, to_proc, from_space, to_space, p_chunk return tochunk(move(from_proc, to_proc, p_chunk), to_proc) end diff --git a/src/utils/chunks.jl b/src/utils/chunks.jl index 400b49332..9f0c3b487 100644 --- a/src/utils/chunks.jl +++ b/src/utils/chunks.jl @@ -174,6 +174,9 @@ function tochunk(x::Chunk, proc=nothing, scope=nothing; rewrap=false, kwargs...) end tochunk(x::Thunk, proc=nothing, scope=nothing; kwargs...) = x +root_worker_id(chunk::Chunk) = root_worker_id(chunk.handle) +root_worker_id(dref::DRef) = dref.owner # FIXME: Migration + function savechunk(data, dir, f) sz = open(joinpath(dir, f), "w") do io serialize(io, MemPool.MMWrap(data)) From 655313644aabafdb72be7cecba3eb27460ddfbb0 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 9 Dec 2025 16:27:03 -0700 Subject: [PATCH 11/28] datadeps: Validate ManyMemorySpan inner span lengths --- src/datadeps/remainders.jl | 10 +++++++- src/utils/interval_tree.jl | 19 +++++++++++---- src/utils/memory-span.jl | 47 +++++++++++++++++++++++++++++++++++++- 3 files changed, 69 insertions(+), 7 deletions(-) diff --git a/src/datadeps/remainders.jl b/src/datadeps/remainders.jl index 5312d0935..3d90b4423 100644 --- a/src/datadeps/remainders.jl +++ b/src/datadeps/remainders.jl @@ -127,7 +127,11 @@ function compute_remainder_for_arg!(state::DataDepsState, end # Create our remainder as an interval tree over all target ainfos + VERIFY_SPAN_CURRENT_OBJECT[] = arg_w.arg remainder = IntervalTree{ManyMemorySpan{N}}(ManyMemorySpan{N}(ntuple(i -> target_ainfos[i][j], N)) for j in 1:nspans) + for span in remainder + verify_span(span) + end # Create our tracker tracker = Dict{MemorySpace,Tuple{Vector{Tuple{LocalMemorySpan,LocalMemorySpan}},Set{ThunkSyncdep}}}() @@ -164,6 +168,9 @@ function compute_remainder_for_arg!(state::DataDepsState, end nspans = length(first(other_ainfos)) other_many_spans = [ManyMemorySpan{N}(ntuple(i -> other_ainfos[i][j], N)) for j in 1:nspans] + foreach(other_many_spans) do span + verify_span(span) + end if other_space == target_space # Only subtract, this data is already up-to-date in target_space @@ -187,6 +194,7 @@ function compute_remainder_for_arg!(state::DataDepsState, get_read_deps!(state, other_space, other_ainfo, write_num, tracker_other_space[2]) end end + VERIFY_SPAN_CURRENT_OBJECT[] = nothing if isempty(tracker) return NoAliasing(), 0 @@ -217,10 +225,10 @@ copy from `other_many_spans` to the subtraced portion of `remainder`. function schedule_remainder!(tracker::Vector, source_space_idx::Int, dest_space_idx::Int, remainder::IntervalTree, other_many_spans::Vector{ManyMemorySpan{N}}) where N diff = Vector{ManyMemorySpan{N}}() subtract_spans!(remainder, other_many_spans, diff) - for span in diff source_span = span.spans[source_space_idx] dest_span = span.spans[dest_space_idx] + @assert span_len(source_span) == span_len(dest_span) "Source and dest spans are not the same size: $(span_len(source_span)) != $(span_len(dest_span))" push!(tracker, (source_span, dest_span)) end end diff --git a/src/utils/interval_tree.jl b/src/utils/interval_tree.jl index e67f66b24..73c0ee4d4 100644 --- a/src/utils/interval_tree.jl +++ b/src/utils/interval_tree.jl @@ -25,6 +25,7 @@ function IntervalTree{M}(spans) where M for span in spans insert!(tree, span) end + verify_spans(tree) return tree end IntervalTree(spans::Vector{M}) where M = IntervalTree{M}(spans) @@ -44,6 +45,13 @@ function Base.collect(tree::IntervalTree{M}) where M return result end +# Useful for debugging when spans get misaligned +function verify_spans(tree::IntervalTree{ManyMemorySpan{N}}) where N + for span in tree + verify_span(span) + end +end + function Base.iterate(tree::IntervalTree{M}) where M state = Vector{M}() if tree.root === nothing @@ -196,6 +204,7 @@ function delete_node!(node::IntervalNode{M,E}, span::M) where {M,E} original_end = span_end(original_span) del_start = span_start(span) del_end = span_end(span) + verify_span(span) # Left portion: exists if original starts before deleted span if original_start < del_start @@ -258,10 +267,10 @@ function find_overlapping!(node::IntervalNode{M,E}, query::M, result::Vector{M}; if spans_overlap(node.span, query) if exact # Get the overlapping portion of the span - overlap_start = max(span_start(node.span), span_start(query)) - overlap_end = min(span_end(node.span), span_end(query)) - overlap = M(overlap_start, overlap_end - overlap_start) - push!(result, overlap) + overlap = span_diff(node.span, query) + if !isempty(overlap) + push!(result, overlap) + end else push!(result, node.span) end @@ -360,4 +369,4 @@ function add_remaining_portions!(tree::IntervalTree{M}, original::M, subtracted: end end end -end \ No newline at end of file +end diff --git a/src/utils/memory-span.jl b/src/utils/memory-span.jl index 91f291cbe..0664139d0 100644 --- a/src/utils/memory-span.jl +++ b/src/utils/memory-span.jl @@ -38,6 +38,18 @@ span_len(span::MemorySpan) = span.len span_end(span::MemorySpan) = span.ptr.addr + span.len spans_overlap(span1::MemorySpan, span2::MemorySpan) = span_start(span1) < span_end(span2) && span_start(span2) < span_end(span1) +function span_diff(span1::MemorySpan, span2::MemorySpan) + @assert span1.ptr.space == span2.ptr.space + start = max(span_start(span1), span_start(span2)) + stop = min(span_end(span1), span_end(span2)) + start_ptr = RemotePtr(start, span1.ptr.space) + if start < stop + len = stop - start + return MemorySpan(start_ptr, len) + else + return MemorySpan(start_ptr, 0) + end +end ### More space-efficient memory spans @@ -52,6 +64,16 @@ span_len(span::LocalMemorySpan) = span.len span_end(span::LocalMemorySpan) = span.ptr + span.len spans_overlap(span1::LocalMemorySpan, span2::LocalMemorySpan) = span_start(span1) < span_end(span2) && span_start(span2) < span_end(span1) +function span_diff(span1::LocalMemorySpan, span2::LocalMemorySpan) + start = max(span_start(span1), span_start(span2)) + stop = min(span_end(span1), span_end(span2)) + if start < stop + len = stop - start + return LocalMemorySpan(start, len) + else + return LocalMemorySpan(start, 0) + end +end # FIXME: Store the length separately, since it's shared by all spans struct ManyMemorySpan{N} @@ -64,6 +86,20 @@ span_end(span::ManyMemorySpan{N}) where N = ManyPair(ntuple(i -> span_end(span.s spans_overlap(span1::ManyMemorySpan{N}, span2::ManyMemorySpan{N}) where N = # N.B. The spans are assumed to be the same length and relative offset spans_overlap(span1.spans[1], span2.spans[1]) +function span_diff(span1::ManyMemorySpan{N}, span2::ManyMemorySpan{N}) where N + verify_span(span1) + verify_span(span2) + span = ManyMemorySpan(ntuple(i -> span_diff(span1.spans[i], span2.spans[i]), N)) + matches = ntuple(i->span1.spans[i].ptr == span2.spans[i].ptr, Val(N)) + @assert !(any(matches) && !all(matches)) "Spans only partially match:\n Span1: $span1\n Span2: $span2\n Result: $span" + @assert allequal(span_len, span.spans) "Uneven span_diff result:\n Span1: $span1\n Span2: $span2\n Result: $span" + verify_span(span) + return span +end +const VERIFY_SPAN_CURRENT_OBJECT = TaskLocalValue{Any}(()->nothing) +function verify_span(span::ManyMemorySpan{N}) where N + @assert allequal(span_len, span.spans) "All spans must be the same: $(map(span_len, span.spans))\nWhile processing $(typeof(VERIFY_SPAN_CURRENT_OBJECT[]))" +end struct ManyPair{N} <: Unsigned pairs::NTuple{N,UInt} @@ -78,6 +114,7 @@ Base.:(==)(x::ManyPair, y::ManyPair) = x.pairs == y.pairs Base.isless(x::ManyPair, y::ManyPair) = x.pairs[1] < y.pairs[1] Base.:(<)(x::ManyPair, y::ManyPair) = x.pairs[1] < y.pairs[1] Base.string(x::ManyPair) = "ManyPair($(x.pairs))" +Base.show(io::IO, x::ManyPair) = print(io, string(x)) ManyMemorySpan{N}(start::ManyPair{N}, len::ManyPair{N}) where N = ManyMemorySpan{N}(ntuple(i -> LocalMemorySpan(start.pairs[i], len.pairs[i]), N)) @@ -95,4 +132,12 @@ span_start(x::LocatorMemorySpan) = span_start(x.span) span_end(x::LocatorMemorySpan) = span_end(x.span) span_len(x::LocatorMemorySpan) = span_len(x.span) spans_overlap(span1::LocatorMemorySpan{T}, span2::LocatorMemorySpan{T}) where T = - spans_overlap(span1.span, span2.span) \ No newline at end of file + spans_overlap(span1.span, span2.span) +function span_diff(span1::LocatorMemorySpan{T}, span2::LocatorMemorySpan{T}) where T + span = LocatorMemorySpan(span_diff(span1.span, span2.span), 0) + verify_span(span) + return span +end +function verify_span(span::LocatorMemorySpan{T}) where T + verify_span(span.span) +end \ No newline at end of file From 1f85aa1d3acb654ed2c390d9eaf73edae523f40a Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 9 Dec 2025 16:28:00 -0700 Subject: [PATCH 12/28] datadeps: Optimize RemainderAliasing move! copies --- src/datadeps/remainders.jl | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/src/datadeps/remainders.jl b/src/datadeps/remainders.jl index 3d90b4423..de9d6ab73 100644 --- a/src/datadeps/remainders.jl +++ b/src/datadeps/remainders.jl @@ -384,28 +384,33 @@ end # Main copy function for RemainderAliasing function move!(dep_mod::RemainderAliasing{S}, to_space::MemorySpace, from_space::MemorySpace, to::Chunk, from::Chunk) where S - # Get the source data for each span + # Copy the data from the source object copies = remotecall_fetch(root_worker_id(from_space), dep_mod) do dep_mod - copies = Vector{UInt8}[] - for (from_span, _) in dep_mod.spans - copy = Vector{UInt8}(undef, from_span.len) - GC.@preserve copy begin + len = sum(span_tuple->span_len(span_tuple[1]), dep_mod.spans) + copies = Vector{UInt8}(undef, len) + offset = 1 + GC.@preserve copies begin + for (from_span, _) in dep_mod.spans from_ptr = Ptr{UInt8}(from_span.ptr) - to_ptr = Ptr{UInt8}(pointer(copy)) + to_ptr = Ptr{UInt8}(pointer(copies, offset)) unsafe_copyto!(to_ptr, from_ptr, from_span.len) + offset += from_span.len end - push!(copies, copy) end + @assert offset == len+1 return copies end # Copy the data into the destination object - for (copy, (_, to_span)) in zip(copies, dep_mod.spans) - GC.@preserve copy begin - from_ptr = Ptr{UInt8}(pointer(copy)) + offset = 1 + GC.@preserve copies begin + for (_, to_span) in dep_mod.spans + from_ptr = Ptr{UInt8}(pointer(copies, offset)) to_ptr = Ptr{UInt8}(to_span.ptr) unsafe_copyto!(to_ptr, from_ptr, to_span.len) + offset += to_span.len end + @assert offset == length(copies)+1 end # Ensure that the data is visible From 6ce4dcf7a79be3d96fae396e2c157b18f55cd96c Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 9 Dec 2025 16:29:34 -0700 Subject: [PATCH 13/28] datadeps: Overhaul Datadeps tests --- src/datadeps/queue.jl | 7 + src/datadeps/remainders.jl | 16 ++ test/datadeps.jl | 402 +++++++++++++++++++++++++++++-------- 3 files changed, 343 insertions(+), 82 deletions(-) diff --git a/src/datadeps/queue.jl b/src/datadeps/queue.jl index f8f907741..8b92f3087 100644 --- a/src/datadeps/queue.jl +++ b/src/datadeps/queue.jl @@ -206,6 +206,10 @@ function distribute_tasks!(queue::DataDepsTaskQueue) else @assert remainder isa NoAliasing "Expected NoAliasing, got $(typeof(remainder))" @dagdebug nothing :spawn_datadeps "Skipped copy-from (up-to-date): $origin_space" + ctx = Sch.eager_context() + id = rand(UInt) + @maybelog ctx timespan_start(ctx, :datadeps_copy_skip, (;id), (;)) + @maybelog ctx timespan_finish(ctx, :datadeps_copy_skip, (;id), (;thunk_id=0, from_space=origin_space, to_space=origin_space, arg_w, from_arg=arg, to_arg=arg)) end end end @@ -520,7 +524,10 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr new_spec.options.scope = our_scope new_spec.options.exec_scope = our_scope new_spec.options.occupancy = Dict(Any=>0) + ctx = Sch.eager_context() + @maybelog ctx timespan_start(ctx, :datadeps_execute, (;thunk_id=task.uid), (;)) enqueue!(queue.upper_queue, DTaskPair(new_spec, task)) + @maybelog ctx timespan_finish(ctx, :datadeps_execute, (;thunk_id=task.uid), (;space=our_space, deps=task_arg_ws, args=remote_args)) # Update read/write tracking for arguments map_or_ntuple(task_arg_ws) do idx diff --git a/src/datadeps/remainders.jl b/src/datadeps/remainders.jl index de9d6ab73..68354f55c 100644 --- a/src/datadeps/remainders.jl +++ b/src/datadeps/remainders.jl @@ -275,7 +275,11 @@ function enqueue_remainder_copy_to!(state::DataDepsState, dest_space::MemorySpac @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Remainder copy-to has $(length(remainder_syncdeps)) syncdeps" # Launch the remainder copy task + ctx = Sch.eager_context() + id = rand(UInt) + @maybelog ctx timespan_start(ctx, :datadeps_copy, (;id), (;)) copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=remainder_syncdeps meta=true Dagger.move!(remainder_aliasing, dest_space, source_space, arg_dest, arg_source) + @maybelog ctx timespan_finish(ctx, :datadeps_copy, (;id), (;thunk_id=copy_task.uid, from_space=source_space, to_space=dest_space, arg_w, from_arg=arg_source, to_arg=arg_dest)) # This copy task becomes a new writer for the target region add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) @@ -320,7 +324,11 @@ function enqueue_remainder_copy_from!(state::DataDepsState, dest_space::MemorySp @dagdebug nothing :spawn_datadeps "($(typeof(arg_w.arg)))[$dep_mod] Remainder copy-from has $(length(remainder_syncdeps)) syncdeps" # Launch the remainder copy task + ctx = Sch.eager_context() + id = rand(UInt) + @maybelog ctx timespan_start(ctx, :datadeps_copy, (;id), (;)) copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=remainder_syncdeps meta=true Dagger.move!(remainder_aliasing, dest_space, source_space, arg_dest, arg_source) + @maybelog ctx timespan_finish(ctx, :datadeps_copy, (;id), (;thunk_id=copy_task.uid, from_space=source_space, to_space=dest_space, arg_w, from_arg=arg_source, to_arg=arg_dest)) # This copy task becomes a new writer for the target region add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) @@ -349,7 +357,11 @@ function enqueue_copy_to!(state::DataDepsState, dest_space::MemorySpace, arg_w:: @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Full copy-to has $(length(copy_syncdeps)) syncdeps" # Launch the remainder copy task + ctx = Sch.eager_context() + id = rand(UInt) + @maybelog ctx timespan_start(ctx, :datadeps_copy, (;id), (;)) copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=copy_syncdeps meta=true Dagger.move!(dep_mod, dest_space, source_space, arg_dest, arg_source) + @maybelog ctx timespan_finish(ctx, :datadeps_copy, (;id), (;thunk_id=copy_task.uid, from_space=source_space, to_space=dest_space, arg_w, from_arg=arg_source, to_arg=arg_dest)) # This copy task becomes a new writer for the target region add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) @@ -376,7 +388,11 @@ function enqueue_copy_from!(state::DataDepsState, dest_space::MemorySpace, arg_w @dagdebug nothing :spawn_datadeps "($(typeof(arg_w.arg)))[$dep_mod] Full copy-from has $(length(copy_syncdeps)) syncdeps" # Launch the remainder copy task + ctx = Sch.eager_context() + id = rand(UInt) + @maybelog ctx timespan_start(ctx, :datadeps_copy, (;id), (;)) copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=copy_syncdeps meta=true Dagger.move!(dep_mod, dest_space, source_space, arg_dest, arg_source) + @maybelog ctx timespan_finish(ctx, :datadeps_copy, (;id), (;thunk_id=copy_task.uid, from_space=source_space, to_space=dest_space, arg_w, from_arg=arg_source, to_arg=arg_dest)) # This copy task becomes a new writer for the target region add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) diff --git a/test/datadeps.jl b/test/datadeps.jl index 4fb873454..faf33b9b9 100644 --- a/test/datadeps.jl +++ b/test/datadeps.jl @@ -1,4 +1,5 @@ -import Dagger: ChunkView, Chunk +import Dagger: ChunkView, Chunk, AbstractAliasing, MemorySpace, ArgumentWrapper +import Dagger: aliasing, memory_space using LinearAlgebra, Graphs @testset "Memory Aliasing" begin @@ -82,7 +83,7 @@ end end function with_logs(f) - Dagger.enable_logging!(;taskdeps=true, taskargs=true) + Dagger.enable_logging!(;taskdeps=true, taskargs=true, timeline=true) try f() return Dagger.fetch_logs!() @@ -108,68 +109,296 @@ function taskdeps_for_task(logs::Dict{Int,<:Dict}, tid::Int) end error("Task $tid not found in logs") end -function test_task_dominators(logs::Dict, tid::Int, doms::Vector; all_tids::Vector=[], nondom_check::Bool=false) - g = SimpleDiGraph() - tid_to_v = Dict{Int,Int}() +function all_tasks_in_logs(logs::Dict) + all_tids = Int[] + for w in keys(logs) + _logs = logs[w] + for idx in 1:length(_logs[:core]) + core_log = _logs[:core][idx] + id_log = _logs[:id][idx] + if core_log.category == :add_thunk && core_log.kind == :finish + tid = id_log.thunk_id::Int + push!(all_tids, tid) + end + end + end + return all_tids +end +mutable struct FlowEntry + kind::Symbol + tid::Int + ainfo::AbstractAliasing + to_ainfo::AbstractAliasing + from_space::MemorySpace + to_space::MemorySpace + read::Bool + write::Bool +end +struct FlowCheck + read::Bool + write::Bool + arg_w::ArgumentWrapper + orig_ainfo::AbstractAliasing + orig_space::MemorySpace + function FlowCheck(kind, arg, dep_mod=identity) + if kind == :read + read = true + write = false + elseif kind == :write + read = false + write = true + elseif kind == :readwrite + read = true + write = true + else + error("Invalid kind: $kind") + end + arg_w = maybe_rewrap_arg_w(ArgumentWrapper(arg, dep_mod)) + return new(read, write, arg_w, aliasing(arg, dep_mod), memory_space(arg)) + end +end +struct FlowGraph + g::SimpleDiGraph + tid_to_v::Dict{Int,Int} + FlowGraph() = new(SimpleDiGraph(), Dict{Int,Int}()) +end +struct FlowState + flows::Dict{ArgumentWrapper,Vector{FlowEntry}} + graph::FlowGraph + FlowState() = new(Dict{ArgumentWrapper,Vector{FlowEntry}}(), FlowGraph()) +end +function maybe_rewrap_arg_w(arg_w::ArgumentWrapper) + arg = arg_w.arg + if arg isa DTask + arg = fetch(arg; raw=true) + end + if arg isa Chunk && Dagger.root_worker_id(arg) == myid() + arg = Dagger.unwrap(arg) + end + return ArgumentWrapper(arg, arg_w.dep_mod) +end +function build_dataflow(logs::Dict; verbose::Bool=false) + state = FlowState() + orig_ainfos = Dict{AbstractAliasing,AbstractAliasing}() + ainfo_arg_w = Dict{AbstractAliasing,ArgumentWrapper}() + + function add_execute!(arg_w, orig_ainfo, ainfo, tid, space, read, write) + ainfo_flows = get!(Vector{FlowEntry}, state.flows, arg_w) + # Skip duplicates (same arg 2+ times to same task) + dup_idx = findfirst(flow->flow.tid == tid, ainfo_flows) + if dup_idx === nothing + if !haskey(orig_ainfos, ainfo) + orig_ainfos[ainfo] = orig_ainfo + end + if !haskey(ainfo_arg_w, ainfo) + ainfo_arg_w[ainfo] = arg_w + end + verbose && println("Adding execute flow (tid $tid, space $space, read $read, write $write):\n $orig_ainfo ->\n $ainfo") + verbose && println(" $(arg_w.dep_mod), $(arg_w.arg)") + push!(ainfo_flows, FlowEntry(:execute, tid, ainfo, ainfo, space, space, read, write)) + else + # Union read and write fields + ainfo_flows[dup_idx].read |= read + ainfo_flows[dup_idx].write |= write + end + end + function add_copy!(arg_w, from_arg, to_arg, tid, from_space, to_space) + dep_mod = arg_w.dep_mod + from_ainfo = aliasing(from_arg, dep_mod) + to_ainfo = aliasing(to_arg, dep_mod) + if !haskey(orig_ainfos, from_ainfo) + orig_ainfos[from_ainfo] = from_ainfo + end + if !haskey(ainfo_arg_w, from_ainfo) + ainfo_arg_w[from_ainfo] = arg_w + end + if !haskey(ainfo_arg_w, to_ainfo) + ainfo_arg_w[to_ainfo] = arg_w + end + orig_ainfo = orig_ainfos[from_ainfo] + orig_ainfos[to_ainfo] = orig_ainfo + arg_flows = get!(Vector{FlowEntry}, state.flows, arg_w) + verbose && println("Adding copy flow (tid $tid, from_space $from_space, to_space $to_space):\n $orig_ainfo ->\n $to_ainfo") + verbose && println(" $(arg_w.dep_mod), $(arg_w.arg)") + push!(arg_flows, FlowEntry(:copy, tid, from_ainfo, to_ainfo, from_space, to_space, true, true)) + end + + # Populate graph from syncdeps seen = Set{Int}() - to_visit = copy(all_tids) + to_visit = all_tasks_in_logs(logs) while !isempty(to_visit) this_tid = popfirst!(to_visit) this_tid in seen && continue push!(seen, this_tid) - if !(this_tid in keys(tid_to_v)) - add_vertex!(g); tid_to_v[this_tid] = nv(g) + if !(this_tid in keys(state.graph.tid_to_v)) + add_vertex!(state.graph.g); state.graph.tid_to_v[this_tid] = nv(state.graph.g) end # Add syncdeps deps = taskdeps_for_task(logs, this_tid) for dep in deps - if !(dep in keys(tid_to_v)) - add_vertex!(g); tid_to_v[dep] = nv(g) + if !(dep in keys(state.graph.tid_to_v)) + add_vertex!(state.graph.g); state.graph.tid_to_v[dep] = nv(state.graph.g) end - add_edge!(g, tid_to_v[this_tid], tid_to_v[dep]) + add_edge!(state.graph.g, state.graph.tid_to_v[this_tid], state.graph.tid_to_v[dep]) push!(to_visit, dep) end end - state = dijkstra_shortest_paths(g, tid_to_v[tid]) - any_failed = false - @test !has_edge(g, tid_to_v[tid], tid_to_v[tid]) - any_failed |= has_edge(g, tid_to_v[tid], tid_to_v[tid]) - for dom in doms - @test state.pathcounts[tid_to_v[dom]] > 0 - if state.pathcounts[tid_to_v[dom]] == 0 - println("Expected dominance for $dom of $tid") - any_failed = true - end - end - if nondom_check - for nondom in all_tids - nondom == tid && continue - nondom in doms && continue - @test state.pathcounts[tid_to_v[nondom]] == 0 - if state.pathcounts[tid_to_v[nondom]] > 0 - println("Expected non-dominance for $nondom of $tid") - any_failed = true + + # Populate flows and graphs from datadeps logs + for w in keys(logs) + _logs = logs[w] + for idx in 1:length(_logs[:core]) + core_log = _logs[:core][idx] + id_log = _logs[:id][idx] + tl_log = _logs[:timeline][idx] + if core_log.category == :datadeps_execute && core_log.kind == :finish + tid = id_log.thunk_id + for (remote_arg, depset) in zip(tl_log.args, tl_log.deps) + for dep in depset.deps + arg_w = maybe_rewrap_arg_w(dep.arg_w) + orig_ainfo = aliasing(arg_w.arg, arg_w.dep_mod) + remote_ainfo = aliasing(remote_arg, arg_w.dep_mod) + space = memory_space(remote_arg) + add_execute!(arg_w, orig_ainfo, remote_ainfo, tid, space, dep.readdep, dep.writedep) + end + end + elseif (core_log.category == :datadeps_copy || core_log.category == :datadeps_copy_skip) && core_log.kind == :finish + tid = tl_log.thunk_id + from_space = tl_log.from_space + to_space = tl_log.to_space + from_arg = tl_log.from_arg + to_arg = tl_log.to_arg + arg_w = maybe_rewrap_arg_w(tl_log.arg_w) + add_copy!(arg_w, from_arg, to_arg, tid, from_space, to_space) + end + end + end + + return state +end +function test_dataflow(state::FlowState, checks...; verbose::Bool=true) + # Check that each ainfo starts and ends in the same space + for arg_w in keys(state.flows) + ainfo = aliasing(arg_w.arg, arg_w.dep_mod) + arg_flows = state.flows[arg_w] + orig_space = memory_space(arg_w.arg) #arg_flows[1].from_space + #=if ainfo != arg_flows[1].ainfo + verbose && println("Ainfo key $(ainfo) is not the same as the first flow's ainfo $(ainfo_flows[1].ainfo)") + return false + end=# + final_space = arg_flows[end].to_space + # FIXME: will_alias doesn't check across spaces + any_writes = any(flows->Dagger.will_alias(flows[1], ainfo) && any(flow->flow.write, flows[2]), state.flows) + if orig_space != final_space + if verbose + println("Arg ($(arg_w.dep_mod), $(arg_w.arg)) starts in $(orig_space) but ends in $(final_space)") + for flow in arg_flows + println(" $(flow.kind) $(flow.tid) $(flow.from_space) -> $(flow.to_space)") + end + end + return false + end + end + + # Check each flow against the previous flow, ensuring that the previous flow is a dominator of the current flow + # FIXME: Validate non-dominance when unnecessary? + for arg_w in keys(state.flows) + arg_flows = state.flows[arg_w] + for (idx, flow) in enumerate(arg_flows) + if idx > 1 + prev_flow = arg_flows[idx-1] + if !prev_flow.write && !flow.write + # R->R don't depend on each other + continue + end + if !prev_flow.write && flow.write && prev_flow.kind == :execute && flow.kind == :copy && prev_flow.ainfo != flow.to_ainfo + # Copy only writes to a different ainfo, so don't depend on each other + continue + end + if flow.tid == 0 + # Ignore copy skip flows + continue + end + v = state.graph.tid_to_v[flow.tid] + prev_v = state.graph.tid_to_v[prev_flow.tid] + path_state = dijkstra_shortest_paths(state.graph.g, v; allpaths=true) + if path_state.pathcounts[prev_v] == 0 + if verbose + println("Flow $(idx-1) (tid $(prev_flow.tid), $(prev_flow.kind), R:$(prev_flow.read), W:$(prev_flow.write)) is not a dominator of flow $(idx) (tid $(flow.tid), $(flow.kind), R:$(flow.read), W:$(flow.write))") + @show length(state.flows[arg_w]) + for flow in state.flows[arg_w] + println(" $(flow.kind) $(flow.tid) $(flow.from_space) -> $(flow.to_space) (R:$(flow.read), W:$(flow.write))") + end + for flow in state.flows[arg_w] + println(" May write to: $(flow.to_ainfo)") + end + e_vs = collect(edges(state.graph.g)) + e_tids = map(e->Edge(only(filter(tv->tv[2]==src(e), state.graph.tid_to_v))[1], + only(filter(tv->tv[2]==dst(e), state.graph.tid_to_v))[1]), + e_vs) + sort!(e_tids) + for e in e_tids + s_tid, d_tid = src(e), dst(e) + println("Edge: $s_tid -(up)> $d_tid") + end + end + return false + end end end end - # For debugging purposes - if any_failed - println("Failure detected!") - println("Root: $tid") - println("Exp. doms: $doms") - println("All: $all_tids") - e_vs = collect(edges(g)) - e_tids = map(e->Edge(only(filter(tv->tv[2]==src(e), tid_to_v))[1], - only(filter(tv->tv[2]==dst(e), tid_to_v))[1]), - e_vs) - sort!(e_tids) - for e in e_tids - s_tid, d_tid = src(e), dst(e) - println("Edge: $s_tid -(up)> $d_tid") + # Walk through each check, ensuring that the current state of the flow matches the check + arg_locations = Dict{ArgumentWrapper,MemorySpace}() + flow_idxs = Dict{ArgumentWrapper,Int}(arg_w=>1 for arg_w in keys(state.flows)) + for (idx, check) in enumerate(checks) + # Record the original location of the ainfo + if !haskey(arg_locations, check.arg_w) + arg_locations[check.arg_w] = check.orig_space + end + + # Try to advance a flow + if !haskey(flow_idxs, check.arg_w) + if verbose + @warn "Didn't encounter argument ($(check.arg_w.dep_mod), $(check.arg_w.arg))" + println("Seen arguments:") + for arg_w in keys(state.flows) + println(" ($(arg_w.dep_mod), $(arg_w.arg))") + end + return false + end + end + flow_idx = flow_idxs[check.arg_w] + while true + if flow_idx > length(state.flows[check.arg_w]) + verbose && println("Exhausted all tasks while trying to find $(check.arg_w)") + return false + end + flow = state.flows[check.arg_w][flow_idx] + if flow.kind == :execute + # The current flow state must match the check + if flow.read == check.read && flow.write == check.write + # Match, move on to next check + flow_idx += 1 + break + else + verbose && println("Expected ($(check.read), $(check.write)), got ($(flow.read), $(flow.write))") + return false + end + elseif flow.kind == :copy + # We need to advance our ainfo location + # FIXME: Assert proper data progression (requires more complex tracking of other arguments) + #@assert flow.from_space == arg_locations[check.arg_w] + arg_locations[check.arg_w] = flow.to_space + flow_idx += 1 + end end + + flow_idxs[check.arg_w] = flow_idx end + + return true end @everywhere do_nothing(Xs...) = nothing @@ -205,8 +434,11 @@ function test_datadeps(;args_chunks::Bool, A = Dagger.@spawn scope=Dagger.scope(worker=args_loc) copy(A) end + @warn "Negative-test the test_dataflow helper" + # Task return values can be tracked ts = [] + local t1 logs = with_logs() do Dagger.spawn_datadeps() do t1 = Dagger.@spawn fill(42, 1) @@ -216,9 +448,12 @@ function test_datadeps(;args_chunks::Bool, end tid_1, tid_2 = task_id.(ts) @test fetch(A)[1] == 42.0 - test_task_dominators(logs, tid_1, []; all_tids=[tid_1, tid_2]) + state = build_dataflow(logs) + # FIXME: We don't record the task as a syncdep, but instead internally `fetch` the chunk - test_task_dominators(logs, tid_2, [#=tid_1=#]; all_tids=[tid_1, tid_2]) + # We don't see the :readwrite because we don't see the use of t1 + #@test test_dataflow(state, FlowCheck(:readwrite, t1)) + @test test_dataflow(state, FlowCheck(:read, t1), FlowCheck(:write, A)) # R->R Non-Aliasing ts = [] @@ -229,8 +464,8 @@ function test_datadeps(;args_chunks::Bool, end end tid_1, tid_2 = task_id.(ts) - test_task_dominators(logs, tid_1, []; all_tids=[tid_1, tid_2]) - test_task_dominators(logs, tid_2, []; all_tids=[tid_1, tid_2], nondom_check=false) + state = build_dataflow(logs) + test_dataflow(state, FlowCheck(:read, A), FlowCheck(:read, A)) # R->W Aliasing ts = [] @@ -241,8 +476,8 @@ function test_datadeps(;args_chunks::Bool, end end tid_1, tid_2 = task_id.(ts) - test_task_dominators(logs, tid_1, []; all_tids=[tid_1, tid_2]) - test_task_dominators(logs, tid_2, [tid_1]; all_tids=[tid_1, tid_2]) + state = build_dataflow(logs) + @test test_dataflow(state, FlowCheck(:read, A), FlowCheck(:write, A)) # W->W Aliasing ts = [] @@ -253,8 +488,8 @@ function test_datadeps(;args_chunks::Bool, end end tid_1, tid_2 = task_id.(ts) - test_task_dominators(logs, tid_1, []; all_tids=[tid_1, tid_2]) - test_task_dominators(logs, tid_2, [tid_1]; all_tids=[tid_1, tid_2]) + state = build_dataflow(logs) + @test test_dataflow(state, FlowCheck(:write, A), FlowCheck(:write, A)) # R->R Non-Self-Aliasing ts = [] @@ -265,8 +500,8 @@ function test_datadeps(;args_chunks::Bool, end end tid_1, tid_2 = task_id.(ts) - test_task_dominators(logs, tid_1, []; all_tids=[tid_1, tid_2]) - test_task_dominators(logs, tid_2, []; all_tids=[tid_1, tid_2]) + state = build_dataflow(logs) + @test test_dataflow(state, FlowCheck(:read, A), FlowCheck(:read, A)) # R->W Self-Aliasing ts = [] @@ -277,8 +512,8 @@ function test_datadeps(;args_chunks::Bool, end end tid_1, tid_2 = task_id.(ts) - test_task_dominators(logs, tid_1, []; all_tids=[tid_1, tid_2]) - test_task_dominators(logs, tid_2, [tid_1]; all_tids=[tid_1, tid_2]) + state = build_dataflow(logs) + @test test_dataflow(state, FlowCheck(:read, A), FlowCheck(:write, A)) # W->W Self-Aliasing ts = [] @@ -289,8 +524,8 @@ function test_datadeps(;args_chunks::Bool, end end tid_1, tid_2 = task_id.(ts) - test_task_dominators(logs, tid_1, []; all_tids=[tid_1, tid_2]) - test_task_dominators(logs, tid_2, [tid_1]; all_tids=[tid_1, tid_2]) + state = build_dataflow(logs) + @test test_dataflow(state, FlowCheck(:write, A), FlowCheck(:write, A)) function wrap_chunk_thunk(f, args...) if args_thunks || args_chunks @@ -346,17 +581,16 @@ function test_datadeps(;args_chunks::Bool, task_id.([t_ul2, t_ur2, t_ll2, t_lr2]) tids_all = [tid_A, tid_B, tid_ul, tid_ur, tid_ll, tid_lr, tid_mid, tid_ul2, tid_ur2, tid_ll2, tid_lr2] - test_task_dominators(logs, tid_A, []; all_tids=tids_all) - test_task_dominators(logs, tid_B, []; all_tids=tids_all) - test_task_dominators(logs, tid_ul, [tid_B]; all_tids=tids_all) - test_task_dominators(logs, tid_ur, [tid_B]; all_tids=tids_all) - test_task_dominators(logs, tid_ll, [tid_B]; all_tids=tids_all) - test_task_dominators(logs, tid_lr, [tid_B]; all_tids=tids_all) - test_task_dominators(logs, tid_mid, [tid_B, tid_ul, tid_ur, tid_ll, tid_lr]; all_tids=tids_all) - test_task_dominators(logs, tid_ul2, [tid_B, tid_mid, tid_ul]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_ur2, [tid_B, tid_mid, tid_ur]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_ll2, [tid_B, tid_mid, tid_ll]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_lr2, [tid_B, tid_mid, tid_lr]; all_tids=tids_all, nondom_check=false) + state = build_dataflow(logs) + @test test_dataflow(state, FlowCheck(:readwrite, A)) + @test test_dataflow(state, FlowCheck(:readwrite, B)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_ul)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_ur)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_ll)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_lr)) + for arg in [B_ul, B_ur, B_ll, B_lr] + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, arg), FlowCheck(:readwrite, B_mid), FlowCheck(:readwrite, arg)) + end # (Unit)Upper/LowerTriangular and Diagonal B_upper = wrap_chunk_thunk(UpperTriangular, B) @@ -403,18 +637,22 @@ function test_datadeps(;args_chunks::Bool, task_id.([t_upper2, t_unitupper2, t_lower2, t_unitlower2]) tids_all = [tid_A, tid_B, tid_upper, tid_unitupper, tid_lower, tid_unitlower, tid_diag, tid_upper2, tid_unitupper2, tid_lower2, tid_unitlower2] - test_task_dominators(logs, tid_A, []; all_tids=tids_all) - test_task_dominators(logs, tid_B, []; all_tids=tids_all) - # FIXME: Proper non-dominance checks - test_task_dominators(logs, tid_upper, [tid_B]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_unitupper, [tid_B, tid_upper]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_lower, [tid_B, tid_upper]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_unitlower, [tid_B, tid_lower]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_diag, [tid_B, tid_upper, tid_lower]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_unitlower2, [tid_B, tid_lower, tid_unitlower]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_lower2, [tid_B, tid_lower, tid_unitlower, tid_diag, tid_unitlower2]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_unitupper2, [tid_B, tid_upper, tid_unitupper]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_upper2, [tid_B, tid_upper, tid_unitupper, tid_diag, tid_unitupper2]; all_tids=tids_all, nondom_check=false) + state = build_dataflow(logs) + @test test_dataflow(state, FlowCheck(:readwrite, A)) + @test test_dataflow(state, FlowCheck(:readwrite, B)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_upper), FlowCheck(:readwrite, B_unitupper)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_upper), FlowCheck(:readwrite, B_lower)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_lower), FlowCheck(:readwrite, B_unitlower)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_upper), FlowCheck(:readwrite, B_lower), + FlowCheck(:readwrite, B, Diagonal)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_lower), FlowCheck(:readwrite, B_unitlower), + FlowCheck(:readwrite, B, Diagonal), FlowCheck(:readwrite, B_unitlower)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_lower), FlowCheck(:readwrite, B_unitlower), + FlowCheck(:readwrite, B, Diagonal), FlowCheck(:readwrite, B_unitlower), FlowCheck(:readwrite, B_lower)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_upper), FlowCheck(:readwrite, B_unitupper), + FlowCheck(:readwrite, B_unitupper)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_upper), FlowCheck(:readwrite, B_unitupper), + FlowCheck(:readwrite, B, Diagonal), FlowCheck(:readwrite, B_unitupper), FlowCheck(:readwrite, B_upper)) # Additional aliasing tests views_overlap(x, y) = Dagger.will_alias(Dagger.aliasing(x), Dagger.aliasing(y)) @@ -593,7 +831,7 @@ function test_datadeps(;args_chunks::Bool, @test isapprox(M_dense, expected) end -@testset @testset "$args_mode Data" for args_mode in (:Raw, :Chunk, :Thunk) +@testset "$args_mode Data" for args_mode in (:Raw, :Chunk, :Thunk) args_chunks = args_mode == :Chunk args_thunks = args_mode == :Thunk for nw in (1, 2) From 8559a87ef3d2872261056b9838314eb727ac2084 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Wed, 10 Dec 2025 14:19:05 -0700 Subject: [PATCH 14/28] datadeps: Validate further that RemainderAliasing is not empty --- src/datadeps/remainders.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/datadeps/remainders.jl b/src/datadeps/remainders.jl index 68354f55c..fd76f07a8 100644 --- a/src/datadeps/remainders.jl +++ b/src/datadeps/remainders.jl @@ -188,15 +188,15 @@ function compute_remainder_for_arg!(state::DataDepsState, (Vector{Tuple{LocalMemorySpan,LocalMemorySpan}}(), Set{ThunkSyncdep}()) end @opcounter :compute_remainder_for_arg_schedule - schedule_remainder!(tracker_other_space[1], other_space_idx, target_space_idx, remainder, other_many_spans) - if compute_syncdeps + has_overlap = schedule_remainder!(tracker_other_space[1], other_space_idx, target_space_idx, remainder, other_many_spans) + if compute_syncdeps && has_overlap @assert haskey(state.ainfos_owner, other_ainfo) "[idx $idx] ainfo $(typeof(other_ainfo)) has no owner" get_read_deps!(state, other_space, other_ainfo, write_num, tracker_other_space[2]) end end VERIFY_SPAN_CURRENT_OBJECT[] = nothing - if isempty(tracker) + if isempty(tracker) || all(tracked->isempty(tracked[1]), values(tracker)) return NoAliasing(), 0 end @@ -210,6 +210,7 @@ function compute_remainder_for_arg!(state::DataDepsState, end end end + @assert !isempty(mra.remainders) "Expected at least one remainder (spaces: $spaces, tracker spaces: $(collect(keys(tracker))))" return mra, last_idx end @@ -231,6 +232,7 @@ function schedule_remainder!(tracker::Vector, source_space_idx::Int, dest_space_ @assert span_len(source_span) == span_len(dest_span) "Source and dest spans are not the same size: $(span_len(source_span)) != $(span_len(dest_span))" push!(tracker, (source_span, dest_span)) end + return !isempty(diff) end ### Remainder copy functions From 5db12d78c8482a8c7c31c9d8f37e0ecfa020acff Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Fri, 12 Dec 2025 12:01:55 -0700 Subject: [PATCH 15/28] datadeps: Fix aliasing for degenerate views --- src/memory-spaces.jl | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/memory-spaces.jl b/src/memory-spaces.jl index 4124bbba6..fcce572c4 100644 --- a/src/memory-spaces.jl +++ b/src/memory-spaces.jl @@ -451,10 +451,16 @@ end function aliasing(x::SubArray{T,N,A}) where {T,N,A<:Array} if isbitstype(T) S = CPURAMMemorySpace - return StridedAliasing{T,ndims(x),S}(RemotePtr{Cvoid}(pointer(parent(x))), + p = parent(x) + NA = ndims(p) + raw_inds = parentindices(x) + inds = ntuple(i->raw_inds[i] isa Integer ? (raw_inds[i]:raw_inds[i]) : UnitRange(raw_inds[i]), NA) + sz = ntuple(i->length(inds[i]), NA) + return StridedAliasing{T,NA,S}(RemotePtr{Cvoid}(pointer(p)), RemotePtr{Cvoid}(pointer(x)), - parentindices(x), - size(x), strides(x)) + inds, + sz, + strides(p)) else # FIXME: Also ContiguousAliasing of container #return IteratedAliasing(x) From 3c9b54182216514b22aba0ff08880a59efecde0d Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 23 Sep 2025 22:01:54 +0000 Subject: [PATCH 16/28] datadeps: Fix GPU execution --- ext/CUDAExt.jl | 9 +++++++ ext/IntelExt.jl | 9 +++++++ ext/MetalExt.jl | 9 +++++++ ext/OpenCLExt.jl | 9 +++++++ ext/ROCExt.jl | 9 +++++++ src/datadeps/remainders.jl | 54 ++++++++++++++++++++++++++++++-------- src/gpu.jl | 5 +++- src/memory-spaces.jl | 23 +++++++++------- 8 files changed, 106 insertions(+), 21 deletions(-) diff --git a/ext/CUDAExt.jl b/ext/CUDAExt.jl index 6b8c61f9a..9f9b8df4d 100644 --- a/ext/CUDAExt.jl +++ b/ext/CUDAExt.jl @@ -48,6 +48,13 @@ function Dagger.memory_space(x::CuArray) device_uuid = CUDA.uuid(dev) return CUDAVRAMMemorySpace(myid(), device_id, device_uuid) end +function Dagger.aliasing(x::CuArray{T}) where T + space = Dagger.memory_space(x) + S = typeof(space) + cuptr = pointer(x) + rptr = Dagger.RemotePtr{Cvoid}(UInt64(cuptr), space) + return Dagger.ContiguousAliasing(Dagger.MemorySpan{S}(rptr, sizeof(T)*length(x))) +end Dagger.memory_spaces(proc::CuArrayDeviceProc) = Set([CUDAVRAMMemorySpace(proc.owner, proc.device, proc.device_uuid)]) Dagger.processors(space::CUDAVRAMMemorySpace) = Set([CuArrayDeviceProc(space.owner, space.device, space.device_uuid)]) @@ -75,6 +82,8 @@ function with_context!(space::CUDAVRAMMemorySpace) @assert Dagger.root_worker_id(space) == myid() with_context!(space.device) end +Dagger.with_context!(proc::CuArrayDeviceProc) = with_context!(proc) +Dagger.with_context!(space::CUDAVRAMMemorySpace) = with_context!(space) function with_context(f, x) old_ctx = context() old_stream = stream() diff --git a/ext/IntelExt.jl b/ext/IntelExt.jl index 74253007d..08d54ee81 100644 --- a/ext/IntelExt.jl +++ b/ext/IntelExt.jl @@ -46,6 +46,13 @@ function Dagger.memory_space(x::oneArray) return IntelVRAMMemorySpace(myid(), device_id) end _device_id(dev::ZeDevice) = findfirst(other_dev->other_dev === dev, collect(oneAPI.devices())) +function Dagger.aliasing(x::oneArray{T}) where T + space = Dagger.memory_space(x) + S = typeof(space) + gpu_ptr = pointer(x) + rptr = Dagger.RemotePtr{Cvoid}(UInt64(gpu_ptr), space) + return Dagger.ContiguousAliasing(Dagger.MemorySpan{S}(rptr, sizeof(T)*length(x))) +end Dagger.memory_spaces(proc::oneArrayDeviceProc) = Set([IntelVRAMMemorySpace(proc.owner, proc.device_id)]) Dagger.processors(space::IntelVRAMMemorySpace) = Set([oneArrayDeviceProc(space.owner, space.device_id)]) @@ -68,6 +75,8 @@ function with_context!(space::IntelVRAMMemorySpace) @assert Dagger.root_worker_id(space) == myid() with_context!(space.device_id) end +Dagger.with_context!(proc::oneArrayDeviceProc) = with_context!(proc) +Dagger.with_context!(space::IntelVRAMMemorySpace) = with_context!(space) function with_context(f, x) old_drv = driver() old_dev = device() diff --git a/ext/MetalExt.jl b/ext/MetalExt.jl index 50cfc8905..21cea360a 100644 --- a/ext/MetalExt.jl +++ b/ext/MetalExt.jl @@ -43,6 +43,13 @@ function Dagger.memory_space(x::MtlArray) return MetalVRAMMemorySpace(myid(), device_id) end _device_id(dev::MtlDevice) = findfirst(other_dev->other_dev === dev, Metal.devices()) +function Dagger.aliasing(x::MtlArray{T}) where T + space = Dagger.memory_space(x) + S = typeof(space) + gpu_ptr = pointer(x) + rptr = Dagger.RemotePtr{Cvoid}(UInt64(gpu_ptr), space) + return Dagger.ContiguousAliasing(Dagger.MemorySpan{S}(rptr, sizeof(T)*length(x))) +end Dagger.memory_spaces(proc::MtlArrayDeviceProc) = Set([MetalVRAMMemorySpace(proc.owner, proc.device_id)]) Dagger.processors(space::MetalVRAMMemorySpace) = Set([MtlArrayDeviceProc(space.owner, space.device_id)]) @@ -66,6 +73,8 @@ end function with_context!(space::MetalVRAMMemorySpace) @assert Dagger.root_worker_id(space) == myid() end +Dagger.with_context!(proc::MtlArrayDeviceProc) = with_context!(proc) +Dagger.with_context!(space::MetalVRAMMemorySpace) = with_context!(space) function with_context(f, x) with_context!(x) return f() diff --git a/ext/OpenCLExt.jl b/ext/OpenCLExt.jl index fbf73de72..f8eac930c 100644 --- a/ext/OpenCLExt.jl +++ b/ext/OpenCLExt.jl @@ -44,6 +44,13 @@ function Dagger.memory_space(x::CLArray) idx = findfirst(==(queue), QUEUES) return CLMemorySpace(myid(), idx) end +function Dagger.aliasing(x::CLArray{T}) where T + space = Dagger.memory_space(x) + S = typeof(space) + gpu_ptr = pointer(x) + rptr = Dagger.RemotePtr{Cvoid}(UInt64(gpu_ptr), space) + return Dagger.ContiguousAliasing(Dagger.MemorySpan{S}(rptr, sizeof(T)*length(x))) +end Dagger.memory_spaces(proc::CLArrayDeviceProc) = Set([CLMemorySpace(proc.owner, proc.device)]) Dagger.processors(space::CLMemorySpace) = Set([CLArrayDeviceProc(space.owner, space.device)]) @@ -71,6 +78,8 @@ function with_context!(space::CLMemorySpace) @assert Dagger.root_worker_id(space) == myid() with_context!(space.device) end +Dagger.with_context!(proc::CLArrayDeviceProc) = with_context!(proc) +Dagger.with_context!(space::CLMemorySpace) = with_context!(space) function with_context(f, x) old_ctx = cl.context() old_queue = cl.queue() diff --git a/ext/ROCExt.jl b/ext/ROCExt.jl index 288c4744f..773c2bb95 100644 --- a/ext/ROCExt.jl +++ b/ext/ROCExt.jl @@ -39,6 +39,13 @@ end Dagger.root_worker_id(space::ROCVRAMMemorySpace) = space.owner Dagger.memory_space(x::ROCArray) = ROCVRAMMemorySpace(myid(), AMDGPU.device(x).device_id) +function Dagger.aliasing(x::ROCArray{T}) where T + space = Dagger.memory_space(x) + S = typeof(space) + gpu_ptr = pointer(x) + rptr = Dagger.RemotePtr{Cvoid}(UInt64(gpu_ptr), space) + return Dagger.ContiguousAliasing(Dagger.MemorySpan{S}(rptr, sizeof(T)*length(x))) +end Dagger.memory_spaces(proc::ROCArrayDeviceProc) = Set([ROCVRAMMemorySpace(proc.owner, proc.device_id)]) Dagger.processors(space::ROCVRAMMemorySpace) = Set([ROCArrayDeviceProc(space.owner, space.device_id)]) @@ -67,6 +74,8 @@ function with_context!(space::ROCVRAMMemorySpace) @assert Dagger.root_worker_id(space) == myid() with_context!(space.device_id) end +Dagger.with_context!(proc::ROCArrayDeviceProc) = with_context!(proc) +Dagger.with_context!(space::ROCVRAMMemorySpace) = with_context!(space) function with_context(f, x) old_ctx = context() old_device = AMDGPU.device() diff --git a/src/datadeps/remainders.jl b/src/datadeps/remainders.jl index fd76f07a8..66c3b051a 100644 --- a/src/datadeps/remainders.jl +++ b/src/datadeps/remainders.jl @@ -107,6 +107,7 @@ function compute_remainder_for_arg!(state::DataDepsState, push!(target_ainfos, LocalMemorySpan.(spans)) end nspans = length(first(target_ainfos)) + @assert all(==(nspans), length.(target_ainfos)) "Aliasing info for $(typeof(arg_w.arg))[$(arg_w.dep_mod)] has different number of spans in different memory spaces" # FIXME: This is a hack to ensure that we don't miss any history generated by aliasing(...) for entry in state.arg_history[arg_w] @@ -402,33 +403,42 @@ end # Main copy function for RemainderAliasing function move!(dep_mod::RemainderAliasing{S}, to_space::MemorySpace, from_space::MemorySpace, to::Chunk, from::Chunk) where S + # TODO: Support direct copy between GPU memory spaces + + @assert sizeof(eltype(chunktype(from))) == sizeof(eltype(chunktype(to))) "Source and destination chunks have different element sizes: $(sizeof(eltype(chunktype(from)))) != $(sizeof(eltype(chunktype(to))))" + # Copy the data from the source object - copies = remotecall_fetch(root_worker_id(from_space), dep_mod) do dep_mod + copies = remotecall_fetch(root_worker_id(from_space), from_space, dep_mod, from) do from_space, dep_mod, from len = sum(span_tuple->span_len(span_tuple[1]), dep_mod.spans) copies = Vector{UInt8}(undef, len) - offset = 1 + from_raw = unwrap(from) + offset = UInt64(1) + with_context!(from_space) GC.@preserve copies begin for (from_span, _) in dep_mod.spans - from_ptr = Ptr{UInt8}(from_span.ptr) - to_ptr = Ptr{UInt8}(pointer(copies, offset)) - unsafe_copyto!(to_ptr, from_ptr, from_span.len) + elsize = sizeof(eltype(from_raw)) + offset_n = UInt64((offset-1) / elsize) + UInt64(1) + n = UInt64(from_span.len / elsize) + read_remainder!(copies, offset_n, from_raw, from_span.ptr, n) offset += from_span.len end end - @assert offset == len+1 + @assert offset == len+UInt64(1) return copies end # Copy the data into the destination object - offset = 1 + offset = UInt64(1) + to_raw = unwrap(to) GC.@preserve copies begin for (_, to_span) in dep_mod.spans - from_ptr = Ptr{UInt8}(pointer(copies, offset)) - to_ptr = Ptr{UInt8}(to_span.ptr) - unsafe_copyto!(to_ptr, from_ptr, to_span.len) + elsize = sizeof(eltype(to_raw)) + offset_n = UInt64((offset-1) / elsize) + UInt64(1) + n = UInt64(to_span.len / elsize) + write_remainder!(copies, offset_n, to_raw, to_span.ptr, n) offset += to_span.len end - @assert offset == length(copies)+1 + @assert offset == length(copies)+UInt64(1) end # Ensure that the data is visible @@ -436,3 +446,25 @@ function move!(dep_mod::RemainderAliasing{S}, to_space::MemorySpace, from_space: return end + +function read_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, from::DenseArray, from_ptr::UInt64, n::UInt64) + elsize = sizeof(eltype(from)) + from_offset = UInt64((from_ptr - UInt64(pointer(from))) / elsize) + UInt64(1) + from_vec = reshape(from, prod(size(from)))::DenseVector{eltype(from)} + copies_typed = unsafe_wrap(Vector{eltype(from)}, Ptr{eltype(from)}(pointer(copies, copies_offset)), n) + copyto!(copies_typed, 1, from_vec, Int(from_offset), Int(n)) +end +function read_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, from::SubArray, from_ptr::UInt64, n::UInt64) + read_remainder!(copies, copies_offset, parent(from), from_ptr, n) +end + +function write_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, to::DenseArray, to_ptr::UInt64, n::UInt64) + elsize = sizeof(eltype(to)) + to_offset = UInt64((to_ptr - UInt64(pointer(to))) / elsize) + UInt64(1) + to_vec = reshape(to, prod(size(to)))::DenseVector{eltype(to)} + copies_typed = unsafe_wrap(Vector{eltype(to)}, Ptr{eltype(to)}(pointer(copies, copies_offset)), n) + copyto!(to_vec, Int(to_offset), copies_typed, 1, Int(n)) +end +function write_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, to::SubArray, to_ptr::UInt64, n::UInt64) + write_remainder!(copies, copies_offset, parent(to), to_ptr, n) +end diff --git a/src/gpu.jl b/src/gpu.jl index 06d749543..fa93f8076 100644 --- a/src/gpu.jl +++ b/src/gpu.jl @@ -100,4 +100,7 @@ function gpu_synchronize(kind::Symbol) gpu_synchronize(Val(kind)) end end -gpu_synchronize(::Val{:CPU}) = nothing \ No newline at end of file +gpu_synchronize(::Val{:CPU}) = nothing + +with_context!(proc::Processor) = nothing +with_context!(space::MemorySpace) = nothing \ No newline at end of file diff --git a/src/memory-spaces.jl b/src/memory-spaces.jl index fcce572c4..bd980a81d 100644 --- a/src/memory-spaces.jl +++ b/src/memory-spaces.jl @@ -1,6 +1,7 @@ struct CPURAMMemorySpace <: MemorySpace owner::Int end +CPURAMMemorySpace() = CPURAMMemorySpace(myid()) root_worker_id(space::CPURAMMemorySpace) = space.owner memory_space(x) = CPURAMMemorySpace(myid()) @@ -87,7 +88,8 @@ function type_may_alias(::Type{T}) where T return false end -may_alias(::MemorySpace, ::MemorySpace) = true +may_alias(::MemorySpace, ::MemorySpace) = false +may_alias(space1::M, space2::M) where M<:MemorySpace = space1 == space2 may_alias(space1::CPURAMMemorySpace, space2::CPURAMMemorySpace) = space1.owner == space2.owner abstract type AbstractAliasing end @@ -448,19 +450,22 @@ function _memory_spans(a::StridedAliasing{T,N,S}, spans, ptr, dim) where {T,N,S} return spans end -function aliasing(x::SubArray{T,N,A}) where {T,N,A<:Array} +function aliasing(x::SubArray{T,N}) where {T,N} if isbitstype(T) - S = CPURAMMemorySpace p = parent(x) + space = memory_space(p) + S = typeof(space) + parent_ptr = RemotePtr{Cvoid}(UInt64(pointer(p)), space) + ptr = RemotePtr{Cvoid}(UInt64(pointer(x)), space) NA = ndims(p) raw_inds = parentindices(x) inds = ntuple(i->raw_inds[i] isa Integer ? (raw_inds[i]:raw_inds[i]) : UnitRange(raw_inds[i]), NA) sz = ntuple(i->length(inds[i]), NA) - return StridedAliasing{T,NA,S}(RemotePtr{Cvoid}(pointer(p)), - RemotePtr{Cvoid}(pointer(x)), - inds, - sz, - strides(p)) + return StridedAliasing{T,NA,S}(parent_ptr, + ptr, + inds, + sz, + strides(p)) else # FIXME: Also ContiguousAliasing of container #return IteratedAliasing(x) @@ -577,7 +582,7 @@ end function will_alias(x_span::MemorySpan, y_span::MemorySpan) may_alias(x_span.ptr.space, y_span.ptr.space) || return false # FIXME: Allow pointer conversion instead of just failing - @assert x_span.ptr.space == y_span.ptr.space + @assert x_span.ptr.space == y_span.ptr.space "Memory spans are in different spaces: $(x_span.ptr.space) vs. $(y_span.ptr.space)" x_end = x_span.ptr + x_span.len - 1 y_end = y_span.ptr + y_span.len - 1 return x_span.ptr <= y_end && y_span.ptr <= x_end From 18d2234054cd569f71a51ffc3b6b45ab80adb8ed Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Sun, 14 Dec 2025 13:08:32 -0500 Subject: [PATCH 17/28] Sch: Skip set_failed! store when result already set --- src/sch/util.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sch/util.jl b/src/sch/util.jl index 3f9d7b2f6..d3b7a4804 100644 --- a/src/sch/util.jl +++ b/src/sch/util.jl @@ -238,7 +238,7 @@ function set_failed!(state, origin, thunk=origin) @dagdebug thunk :finish "Setting as failed" filter!(x -> x !== thunk, state.ready) # N.B. If origin === thunk, we assume that the caller has already set the error - if origin !== thunk + if origin !== thunk && !has_result(state, thunk) origin_ex = load_result(state, origin) if origin_ex isa RemoteException origin_ex = origin_ex.captured From f147fa2f776a36f8df762f65aa6f8a54285fc9b1 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Sun, 14 Dec 2025 13:09:22 -0500 Subject: [PATCH 18/28] scopes: Disallow constructing empty UnionScope --- src/scopes.jl | 3 +++ test/scopes.jl | 6 +++--- test/task-affinity.jl | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/scopes.jl b/src/scopes.jl index ba291bc2b..79190c292 100644 --- a/src/scopes.jl +++ b/src/scopes.jl @@ -40,6 +40,9 @@ struct UnionScope <: AbstractScope push!(scope_set, scope) end end + if isempty(scope_set) + throw(ArgumentError("Cannot construct UnionScope with no inner scopes")) + end return new((collect(scope_set)...,)) end end diff --git a/test/scopes.jl b/test/scopes.jl index fa5bf1135..55d15b349 100644 --- a/test/scopes.jl +++ b/test/scopes.jl @@ -123,8 +123,8 @@ us_es1_multi_ch = Dagger.tochunk(nothing, OSProc(), UnionScope(es1, es1)) @test fetch(Dagger.@spawn exact_scope_test(us_es1_multi_ch)) == es1.processor - # No inner scopes - @test UnionScope() isa UnionScope + # No inner scopes (disallowed) + @test_throws ArgumentError UnionScope() # Same inner scope @test fetch(Dagger.@spawn exact_scope_test(us_es1_ch, us_es1_ch)) == es1.processor @@ -165,7 +165,7 @@ @test Dagger.scope(:any) isa AnyScope @test Dagger.scope(:default) == DefaultScope() @test_throws ArgumentError Dagger.scope(:blah) - @test Dagger.scope(()) == UnionScope() + @test_throws ArgumentError Dagger.scope(()) @test Dagger.scope(worker=wid1) == Dagger.scope(workers=[wid1]) diff --git a/test/task-affinity.jl b/test/task-affinity.jl index f1e26295a..ce898b476 100644 --- a/test/task-affinity.jl +++ b/test/task-affinity.jl @@ -135,7 +135,7 @@ @testset "Chunk function, scope, compute_scope and result_scope" begin @everywhere g(x, y) = x * 2 + y * 3 - n = cld(numscopes, 3) + n = fld(numscopes, 3) shuffle!(availscopes) scope_a = availscopes[1:n] From 7e750940ad222d703e702ff2989665c399e090c1 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Mon, 15 Dec 2025 13:52:07 -0500 Subject: [PATCH 19/28] datadeps: Consolidate aliasing rewrap code --- src/datadeps/aliasing.jl | 97 ++++++++++++++++++++++++++++++--------- src/datadeps/chunkview.jl | 66 +------------------------- 2 files changed, 77 insertions(+), 86 deletions(-) diff --git a/src/datadeps/aliasing.jl b/src/datadeps/aliasing.jl index 172b4c177..b475c3f9a 100644 --- a/src/datadeps/aliasing.jl +++ b/src/datadeps/aliasing.jl @@ -25,19 +25,19 @@ KEY CONCEPTS: 1. ALIASING ANALYSIS: - Every mutable argument is analyzed for its memory access pattern - Memory spans are computed to determine which bytes in memory are accessed - - Objects that access overlapping memory spans are considered "aliasing" + - Arguments that access overlapping memory spans are considered "aliasing" - Examples: An array A and view(A, 2:3, 2:3) alias each other 2. DATA LOCALITY TRACKING: - The system tracks where the "source of truth" for each piece of data lives - As tasks execute and modify data, the source of truth may move between workers - - Each aliasing region can have its own independent source of truth location + - Each argument can have its own independent source of truth location 3. ALIASED OBJECT MANAGEMENT: - When copying arguments between workers, the system tracks "aliased objects" - This ensures that if both an array and its view need to be copied to a worker, only one copy of the underlying array is made, with the view pointing to it - - The aliased_object!() functions manage this sharing + - The aliased_object!() and move_rewrap() functions manage this sharing ALIASING INFO: -------------- @@ -96,11 +96,9 @@ MULTITHREADED BEHAVIOR (WORKS): - Task dependencies ensure correct ordering (e.g., Task 1 then Task 2) DISTRIBUTED BEHAVIOR (THE PROBLEM): -- Tasks may be scheduled on different workers - Each argument must be copied to the destination worker -- Without special handling, we would copy A to worker1 and vA to worker2 -- This creates two separate arrays, breaking the aliasing relationship -- Updates to the view on worker2 don't affect the array on worker1 +- Without special handling, we would copy A and vA independently to another worker +- This creates two separate arrays, breaking the aliasing relationship between A and vA THE SOLUTION - PARTIAL DATA MOVEMENT: ------------------------------------- @@ -695,6 +693,32 @@ function Base.setindex!(cache::AliasedObjectCache, value::Chunk, ainfo::Abstract cache_raw[ainfo] = value return end +function aliased_object!(f, cache::AliasedObjectCache, x; ainfo=aliasing(x, identity)) + if haskey(cache, ainfo) + return cache[ainfo] + else + y = f(x) + @assert y isa Chunk "Didn't get a Chunk from functor" + cache[ainfo] = y + return y + end +end +function remotecall_endpoint(f, from_proc, to_proc, from_space, to_space, data) + to_w = root_worker_id(to_proc) + if to_w == myid() + data_converted = f(move(from_proc, to_proc, data)) + return tochunk(data_converted, to_proc) + end + return remotecall_fetch(to_w, from_proc, to_proc, to_space, data) do from_proc, to_proc, to_space, data + data_converted = f(move(from_proc, to_proc, data)) + return tochunk(data_converted, to_proc) + end +end +function rewrap_aliased_object!(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x) + return aliased_object!(cache, x) do x + return remotecall_endpoint(identity, from_proc, to_proc, from_space, to_space, x) + end +end function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, data::Chunk) # Unwrap so that we hit the right dispatch wid = root_worker_id(data) @@ -710,27 +734,58 @@ function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::P return remotecall_endpoint(identity, from_proc, to_proc, from_space, to_space, data) end end -function remotecall_endpoint(f, from_proc, to_proc, from_space, to_space, data) +function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::SubArray) to_w = root_worker_id(to_proc) - if to_w == myid() - data_converted = f(move(from_proc, to_proc, data)) - return tochunk(data_converted, to_proc) + p_chunk = rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, parent(v)) + inds = parentindices(v) + return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk, inds) do from_proc, to_proc, from_space, to_space, p_chunk, inds + p_new = move(from_proc, to_proc, p_chunk) + v_new = view(p_new, inds...) + return tochunk(v_new, to_proc) end - return remotecall_fetch(to_w, from_proc, to_proc, to_space, data) do from_proc, to_proc, to_space, data - data_converted = f(move(from_proc, to_proc, data)) - return tochunk(data_converted, to_proc) +end +# FIXME: Do this programmatically via recursive dispatch +for wrapper in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular) + @eval function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::$(wrapper)) + to_w = root_worker_id(to_proc) + p_chunk = rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, parent(v)) + return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk) do from_proc, to_proc, from_space, to_space, p_chunk + p_new = move(from_proc, to_proc, p_chunk) + v_new = $(wrapper)(p_new) + return tochunk(v_new, to_proc) + end end end -function aliased_object!(f, cache::AliasedObjectCache, x; ainfo=aliasing(x, identity)) - if haskey(cache, ainfo) - return cache[ainfo] +function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::Base.RefValue) + to_w = root_worker_id(to_proc) + p_chunk = rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, v[]) + return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk) do from_proc, to_proc, from_space, to_space, p_chunk + p_new = move(from_proc, to_proc, p_chunk) + v_new = Ref(p_new) + return tochunk(v_new, to_proc) + end +end +#= +function move_rewrap_recursive(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::T) where T + if isstructtype(T) + # Check all object fields (recursive) + for field in fieldnames(T) + value = getfield(x, field) + new_value = aliased_object!(cache, value) do value + return move_rewrap_recursive(cache, from_proc, to_proc, from_space, to_space, value) + end + setfield!(x, field, new_value) + end + return x else - y = f(x) - @assert y isa Chunk "Didn't get a Chunk from functor" - cache[ainfo] = y - return y + @warn "Cannot move-rewrap object of type $T" + return x end end +move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::String) = x # FIXME: Not necessarily true +move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::Symbol) = x +move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::Type) = x +=# struct DataDepsSchedulerState task_to_spec::Dict{DTask,DTaskSpec} diff --git a/src/datadeps/chunkview.jl b/src/datadeps/chunkview.jl index 60ded6151..42f32cca9 100644 --- a/src/datadeps/chunkview.jl +++ b/src/datadeps/chunkview.jl @@ -37,73 +37,9 @@ end memory_space(x::ChunkView) = memory_space(x.chunk) isremotehandle(x::ChunkView) = true -# This definition is here because it's so similar to ChunkView -function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::SubArray) - to_w = root_worker_id(to_proc) - p_chunk = aliased_object!(cache, parent(v)) do p - return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p) do from_proc, to_proc, from_space, to_space, p - return tochunk(move(from_proc, to_proc, p), to_proc) - end - end - inds = parentindices(v) - return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk, inds) do from_proc, to_proc, from_space, to_space, p_chunk, inds - p_new = move(from_proc, to_proc, p_chunk) - v_new = view(p_new, inds...) - return tochunk(v_new, to_proc) - end -end -# FIXME: Do this programmatically via recursive dispatch -for wrapper in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular) - @eval function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::$(wrapper)) - to_w = root_worker_id(to_proc) - p_chunk = aliased_object!(cache, parent(v)) do p - return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p) do from_proc, to_proc, from_space, to_space, p - return tochunk(move(from_proc, to_proc, p), to_proc) - end - end - return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk) do from_proc, to_proc, from_space, to_space, p_chunk - p_new = move(from_proc, to_proc, p_chunk) - v_new = $(wrapper)(p_new) - return tochunk(v_new, to_proc) - end - end -end -function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::Base.RefValue) - to_w = root_worker_id(to_proc) - return aliased_object!(cache, v[]) do p - return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p) do from_proc, to_proc, from_space, to_space, p - return tochunk(Ref(move(from_proc, to_proc, p)), to_proc) - end - end -end -#= -function move_rewrap_recursive(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::T) where T - if isstructtype(T) - # Check all object fields (recursive) - for field in fieldnames(T) - value = getfield(x, field) - new_value = aliased_object!(cache, value) do value - return move_rewrap_recursive(cache, from_proc, to_proc, from_space, to_space, value) - end - setfield!(x, field, new_value) - end - return x - else - @warn "Cannot move-rewrap object of type $T" - return x - end -end -move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::String) = x # FIXME: Not necessarily true -move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::Symbol) = x -move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::Type) = x -=# function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, slice::ChunkView) to_w = root_worker_id(to_proc) - p_chunk = aliased_object!(cache, slice.chunk) do p_chunk - return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk) do from_proc, to_proc, from_space, to_space, p_chunk - return tochunk(move(from_proc, to_proc, p_chunk), to_proc) - end - end + p_chunk = rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, slice.chunk) return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk, slice.slices) do from_proc, to_proc, from_space, to_space, p_chunk, inds p_new = move(from_proc, to_proc, p_chunk) v_new = view(p_new, inds...) From c6649160ecb6fea38c84d685e11ba36e18fe6490 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Mon, 15 Dec 2025 13:52:35 -0500 Subject: [PATCH 20/28] HaloArray: Add aliasing methods --- src/memory-spaces.jl | 7 +++++-- src/utils/haloarray.jl | 20 +++++++++++++++++++- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/memory-spaces.jl b/src/memory-spaces.jl index bd980a81d..d39e665cc 100644 --- a/src/memory-spaces.jl +++ b/src/memory-spaces.jl @@ -322,8 +322,11 @@ struct CombinedAliasing <: AbstractAliasing end function memory_spans(ca::CombinedAliasing) # FIXME: Don't hardcode CPURAMMemorySpace - all_spans = MemorySpan{CPURAMMemorySpace}[] - for sub_a in ca.sub_ainfos + if length(ca.sub_ainfos) == 0 + return MemorySpan{CPURAMMemorySpace}[] + end + all_spans = memory_spans(ca.sub_ainfos[1]) + for sub_a in ca.sub_ainfos[2:end] append!(all_spans, memory_spans(sub_a)) end return all_spans diff --git a/src/utils/haloarray.jl b/src/utils/haloarray.jl index 1fadbeeb6..e47b70c8b 100644 --- a/src/utils/haloarray.jl +++ b/src/utils/haloarray.jl @@ -99,4 +99,22 @@ Adapt.adapt_structure(to, H::Dagger.HaloArray) = HaloArray(Adapt.adapt(to, H.center), Adapt.adapt.(Ref(to), H.edges), Adapt.adapt.(Ref(to), H.corners), - H.halo_width) \ No newline at end of file + H.halo_width) + +function aliasing(A::HaloArray) + return CombinedAliasing([aliasing(A.center), map(aliasing, A.edges)..., map(aliasing, A.corners)...]) +end +memory_space(A::HaloArray) = memory_space(A.center) +function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, A::HaloArray) + center_chunk = rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, A.center) + edge_chunks = ntuple(i->rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, A.edges[i]), length(A.edges)) + corner_chunks = ntuple(i->rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, A.corners[i]), length(A.corners)) + halo_width = A.halo_width + to_w = root_worker_id(to_proc) + return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, center_chunk, edge_chunks, corner_chunks, halo_width) do from_proc, to_proc, from_space, to_space, center_chunk, edge_chunks, corner_chunks, halo_width + center_new = move(from_proc, to_proc, center_chunk) + edges_new = ntuple(i->move(from_proc, to_proc, edge_chunks[i]), length(edge_chunks)) + corners_new = ntuple(i->move(from_proc, to_proc, corner_chunks[i]), length(corner_chunks)) + return tochunk(HaloArray(center_new, edges_new, corners_new, halo_width), to_proc) + end +end \ No newline at end of file From 5f0e57676983d0d1b29918ef7bd4a5f5b6a814cc Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Mon, 15 Dec 2025 14:26:07 -0500 Subject: [PATCH 21/28] CI: Extend CUDA job time --- .buildkite/pipeline.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index e9177b9a9..ad94a7b6d 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -91,7 +91,7 @@ steps: codecov: true - label: Julia 1.11 (CUDA) - timeout_in_minutes: 20 + timeout_in_minutes: 30 <<: *gputest plugins: - JuliaCI/julia#v1: From 09873364761c1b67d629ff30d52ed0b41dc04570 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 16 Dec 2025 11:12:12 -0500 Subject: [PATCH 22/28] datadeps: Make IntervalTree algorithms non-recursive --- src/utils/interval_tree.jl | 212 +++++++++++++++++++++++++------------ 1 file changed, 146 insertions(+), 66 deletions(-) diff --git a/src/utils/interval_tree.jl b/src/utils/interval_tree.jl index 73c0ee4d4..8960dd69d 100644 --- a/src/utils/interval_tree.jl +++ b/src/utils/interval_tree.jl @@ -140,15 +140,36 @@ end function insert_node!(::Nothing, span::M) where M return IntervalNode(span) end -function insert_node!(node::IntervalNode{M,E}, span::M) where {M,E} - if span_start(span) <= span_start(node.span) - node.left = insert_node!(node.left, span) - else - node.right = insert_node!(node.right, span) +function insert_node!(root::IntervalNode{M,E}, span::M) where {M,E} + # Use a queue to track the path for updating max_end after insertion + path = Vector{IntervalNode{M,E}}() + current = root + + # Traverse to find the insertion point + while current !== nothing + push!(path, current) + if span_start(span) <= span_start(current.span) + if current.left === nothing + current.left = IntervalNode(span) + break + end + current = current.left + else + if current.right === nothing + current.right = IntervalNode(span) + break + end + current = current.right + end end - update_max_end!(node) - return node + # Update max_end for all ancestors (process in reverse order) + while !isempty(path) + node = pop!(path) + update_max_end!(node) + end + + return root end # Remove a specific span from the tree (split as needed) @@ -162,44 +183,78 @@ end function delete_node!(::Nothing, span::M) where M return nothing end -function delete_node!(node::IntervalNode{M,E}, span::M) where {M,E} - # Check for exact match first - if span_start(node.span) == span_start(span) && span_len(node.span) == span_len(span) - # Exact match, remove the node - if node.left === nothing && node.right === nothing - return nothing - elseif node.left === nothing - return node.right - elseif node.right === nothing - return node.left +function delete_node!(root::IntervalNode{M,E}, span::M) where {M,E} + # Track the path to the target node: (node, direction_to_child) + path = Vector{Tuple{IntervalNode{M,E}, Symbol}}() + current = root + target = nothing + target_type = :none # :exact or :overlap + + # Phase 1: Search for target node + while current !== nothing + is_exact = span_start(current.span) == span_start(span) && span_len(current.span) == span_len(span) + is_overlap = !is_exact && spans_overlap(current.span, span) + + if is_exact + target = current + target_type = :exact + break + elseif is_overlap + target = current + target_type = :overlap + break + elseif span_start(span) <= span_start(current.span) + push!(path, (current, :left)) + current = current.left else - # Node has two children - replace with inorder successor - successor = find_min(node.right) - node.span = successor.span - node.right = delete_node!(node.right, successor.span) + push!(path, (current, :right)) + current = current.right end - # Check for overlap - elseif spans_overlap(node.span, span) - # Handle overlapping spans by removing current node and adding remainders - original_span = node.span - - # Remove the current node first (same logic as exact match) - if node.left === nothing && node.right === nothing - # Leaf node - remove it and create a new subtree with remainders - remaining_node = nothing - elseif node.left === nothing - remaining_node = node.right - elseif node.right === nothing - remaining_node = node.left + end + + if target === nothing + return root + end + + # Phase 2: Compute replacement for target node + original_span = target.span + succ_path = Vector{IntervalNode{M,E}}() # Path to successor (for max_end updates) + local replacement::Union{IntervalNode{M,E}, Nothing} + + if target.left === nothing && target.right === nothing + # Leaf node + replacement = nothing + elseif target.left === nothing + # Only right child + replacement = target.right + elseif target.right === nothing + # Only left child + replacement = target.left + else + # Two children - find and remove inorder successor + successor = find_min(target.right) + + if target.right === successor + # Successor is direct right child + target.right = successor.right else - # Node has two children - replace with inorder successor - successor = find_min(node.right) - node.span = successor.span - node.right = delete_node!(node.right, successor.span) - remaining_node = node + # Track path to successor for max_end updates + succ_parent = target.right + push!(succ_path, succ_parent) + while succ_parent.left !== successor + succ_parent = succ_parent.left + push!(succ_path, succ_parent) + end + # Remove successor by replacing with its right child + succ_parent.left = successor.right end - # Calculate and insert the remaining portions + target.span = successor.span + replacement = target + end + + # Phase 3: Handle overlap case - add remaining portions + if target_type == :overlap original_start = span_start(original_span) original_end = span_end(original_span) del_start = span_start(span) @@ -212,7 +267,7 @@ function delete_node!(node::IntervalNode{M,E}, span::M) where {M,E} if left_end > original_start left_span = M(original_start, left_end - original_start) if !isempty(left_span) - remaining_node = insert_node!(remaining_node, left_span) + replacement = insert_node!(replacement, left_span) end end end @@ -223,22 +278,39 @@ function delete_node!(node::IntervalNode{M,E}, span::M) where {M,E} if original_end > right_start right_span = M(right_start, original_end - right_start) if !isempty(right_span) - remaining_node = insert_node!(remaining_node, right_span) + replacement = insert_node!(replacement, right_span) end end end + end - return remaining_node - elseif span_start(span) <= span_start(node.span) - node.left = delete_node!(node.left, span) + # Phase 4: Update parent's child pointer + if isempty(path) + root = replacement else - node.right = delete_node!(node.right, span) + parent, dir = path[end] + if dir == :left + parent.left = replacement + else + parent.right = replacement + end end - if node !== nothing - update_max_end!(node) + # Phase 5: Update max_end in correct order (bottom-up) + # First: successor path (if any) + for i in length(succ_path):-1:1 + update_max_end!(succ_path[i]) end - return node + # Second: target node (if it wasn't removed) + if replacement === target + update_max_end!(target) + end + # Third: main path (ancestors of target) + for i in length(path):-1:1 + update_max_end!(path[i][1]) + end + + return root end function find_min(node::IntervalNode) @@ -263,27 +335,35 @@ function find_overlapping!(::Nothing, query::M, result::Vector{M}; exact::Bool=t return end function find_overlapping!(node::IntervalNode{M,E}, query::M, result::Vector{M}; exact::Bool=true) where {M,E} - # Check if current node overlaps with query - if spans_overlap(node.span, query) - if exact - # Get the overlapping portion of the span - overlap = span_diff(node.span, query) - if !isempty(overlap) - push!(result, overlap) + # Use a queue for breadth-first traversal + queue = Vector{IntervalNode{M,E}}() + push!(queue, node) + + while !isempty(queue) + current = popfirst!(queue) + + # Check if current node overlaps with query + if spans_overlap(current.span, query) + if exact + # Get the overlapping portion of the span + overlap = span_diff(current.span, query) + if !isempty(overlap) + push!(result, overlap) + end + else + push!(result, current.span) end - else - push!(result, node.span) end - end - # Recursively search left subtree if it might contain overlapping intervals - if node.left !== nothing && node.left.max_end > span_start(query) - find_overlapping!(node.left, query, result; exact) - end + # Enqueue left subtree if it might contain overlapping intervals + if current.left !== nothing && current.left.max_end > span_start(query) + push!(queue, current.left) + end - # Recursively search right subtree if query extends beyond current node's start - if node.right !== nothing && span_end(query) > span_start(node.span) - find_overlapping!(node.right, query, result; exact) + # Enqueue right subtree if query extends beyond current node's start + if current.right !== nothing && span_end(query) > span_start(current.span) + push!(queue, current.right) + end end end From dfed258f9f719f986722b3300c2659e36188fe92 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 16 Dec 2025 12:37:51 -0500 Subject: [PATCH 23/28] datadeps: Add TID to dagdebug statements --- src/datadeps/queue.jl | 17 +++++++++-------- src/datadeps/remainders.jl | 8 ++++---- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/datadeps/queue.jl b/src/datadeps/queue.jl index 8b92f3087..c7b5e2bc1 100644 --- a/src/datadeps/queue.jl +++ b/src/datadeps/queue.jl @@ -405,9 +405,10 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr end f = spec.fargs[1] + tid = task.uid # FIXME: May not be correct to move this under uniformity #f.value = move(default_processor(), our_proc, value(f)) - @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) Scheduling: $our_proc ($our_space)" + @dagdebug tid :spawn_datadeps "($(repr(value(f)))) Scheduling: $our_proc ($our_space)" # Copy raw task arguments for analysis # N.B. Used later for checking dependencies @@ -434,13 +435,13 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr # Is the data written previously or now? if !arg_ws.may_alias - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (immutable)" + @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (immutable)" return arg end # Is the data writeable? if !arg_ws.inplace_move - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (non-writeable)" + @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (non-writeable)" return arg end @@ -457,7 +458,7 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr enqueue_copy_to!(state, our_space, arg_w, value(f), idx, our_scope, task, write_num) else @assert remainder isa NoAliasing "Expected NoAliasing, got $(typeof(remainder))" - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Skipped copy-to (up-to-date): $our_space" + @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Skipped copy-to (up-to-date): $our_space" end end return arg_remote @@ -501,16 +502,16 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr ainfo = aliasing!(state, our_space, arg_w) dep_mod = arg_w.dep_mod if dep.writedep - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as writer" + @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as writer" get_write_deps!(state, our_space, ainfo, write_num, syncdeps) else - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as reader" + @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as reader" get_read_deps!(state, our_space, ainfo, write_num, syncdeps) end end return end - @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) Task has $(length(syncdeps)) syncdeps" + @dagdebug tid :spawn_datadeps "($(repr(value(f)))) Task has $(length(syncdeps)) syncdeps" # Launch user's task new_fargs = map_or_ntuple(task_arg_ws) do idx @@ -540,7 +541,7 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr ainfo = aliasing!(state, our_space, arg_w) dep_mod = arg_w.dep_mod if dep.writedep - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Task set as writer" + @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Task set as writer" add_writer!(state, arg_w, our_space, ainfo, task, write_num) else add_reader!(state, arg_w, our_space, ainfo, task, write_num) diff --git a/src/datadeps/remainders.jl b/src/datadeps/remainders.jl index 66c3b051a..93a1e3222 100644 --- a/src/datadeps/remainders.jl +++ b/src/datadeps/remainders.jl @@ -260,7 +260,7 @@ function enqueue_remainder_copy_to!(state::DataDepsState, dest_space::MemorySpac # overwritten by more recent partial updates source_space = remainder_aliasing.space - @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Enqueueing remainder copy-to for $(typeof(arg_w.arg))[$(arg_w.dep_mod)]: $source_space => $dest_space" + @dagdebug task.uid :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Enqueueing remainder copy-to for $(typeof(arg_w.arg))[$(arg_w.dep_mod)]: $source_space => $dest_space" # Get the source and destination arguments arg_dest = state.remote_args[dest_space][arg_w.arg] @@ -275,7 +275,7 @@ function enqueue_remainder_copy_to!(state::DataDepsState, dest_space::MemorySpac empty!(remainder_aliasing.syncdeps) # We can't bring these to move! get_write_deps!(state, dest_space, target_ainfo, write_num, remainder_syncdeps) - @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Remainder copy-to has $(length(remainder_syncdeps)) syncdeps" + @dagdebug task.uid :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Remainder copy-to has $(length(remainder_syncdeps)) syncdeps" # Launch the remainder copy task ctx = Sch.eager_context() @@ -344,7 +344,7 @@ function enqueue_copy_to!(state::DataDepsState, dest_space::MemorySpace, arg_w:: source_space = state.arg_owner[arg_w] target_ainfo = aliasing!(state, dest_space, arg_w) - @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Enqueueing full copy-to for $(typeof(arg_w.arg))[$(arg_w.dep_mod)]: $source_space => $dest_space" + @dagdebug task.uid :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Enqueueing full copy-to for $(typeof(arg_w.arg))[$(arg_w.dep_mod)]: $source_space => $dest_space" # Get the source and destination arguments arg_dest = state.remote_args[dest_space][arg_w.arg] @@ -357,7 +357,7 @@ function enqueue_copy_to!(state::DataDepsState, dest_space::MemorySpace, arg_w:: get_read_deps!(state, source_space, source_ainfo, write_num, copy_syncdeps) get_write_deps!(state, dest_space, target_ainfo, write_num, copy_syncdeps) - @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Full copy-to has $(length(copy_syncdeps)) syncdeps" + @dagdebug task.uid :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Full copy-to has $(length(copy_syncdeps)) syncdeps" # Launch the remainder copy task ctx = Sch.eager_context() From 9a832e513a62176c17938752b297fc69303ec39e Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Wed, 31 Dec 2025 12:04:30 -0500 Subject: [PATCH 24/28] datadeps: Fix split-brain in aliasing object cache --- src/datadeps/aliasing.jl | 144 ++++++++++++++++++++++++++------------- 1 file changed, 95 insertions(+), 49 deletions(-) diff --git a/src/datadeps/aliasing.jl b/src/datadeps/aliasing.jl index b475c3f9a..260af39ce 100644 --- a/src/datadeps/aliasing.jl +++ b/src/datadeps/aliasing.jl @@ -241,6 +241,95 @@ struct HistoryEntry write_num::Int end +struct AliasedObjectCacheStore + keys::Vector{AbstractAliasing} + derived::Dict{AbstractAliasing,AbstractAliasing} + stored::Dict{MemorySpace,Set{AbstractAliasing}} + values::Dict{MemorySpace,Dict{AbstractAliasing,Chunk}} +end +AliasedObjectCacheStore() = + AliasedObjectCacheStore(Vector{AbstractAliasing}(), + Dict{AbstractAliasing,AbstractAliasing}(), + Dict{MemorySpace,Set{AbstractAliasing}}(), + Dict{MemorySpace,Dict{AbstractAliasing,Chunk}}()) + +function is_stored(cache::AliasedObjectCacheStore, space::MemorySpace, ainfo::AbstractAliasing) + if !haskey(cache.stored, space) + return false + end + if !haskey(cache.derived, ainfo) + return false + end + key = cache.derived[ainfo] + return key in cache.stored[space] +end +function get_stored(cache::AliasedObjectCacheStore, space::MemorySpace, ainfo::AbstractAliasing) + @assert is_stored(cache, space, ainfo) "Cache does not have key $ainfo" + key = cache.derived[ainfo] + return cache.values[space][key] +end +function set_stored!(cache::AliasedObjectCacheStore, dest_space::MemorySpace, value::Chunk, ainfo::AbstractAliasing, orig_space::MemorySpace) + @assert !is_stored(cache, dest_space, ainfo) "Cache already has key $ainfo" + if !haskey(cache.derived, ainfo) + push!(cache.keys, ainfo) + cache.derived[ainfo] = ainfo + push!(get!(Set{AbstractAliasing}, cache.stored, orig_space), ainfo) + key = ainfo + else + key = cache.derived[ainfo] + end + value_ainfo = aliasing(value, identity) + cache.derived[value_ainfo] = key + push!(get!(Set{AbstractAliasing}, cache.stored, dest_space), key) + values_dict = get!(Dict{AbstractAliasing,Chunk}, cache.values, dest_space) + values_dict[key] = value + return +end + +struct AliasedObjectCache + space::MemorySpace + chunk::Chunk +end +function is_stored(cache::AliasedObjectCache, ainfo::AbstractAliasing) + wid = root_worker_id(cache.chunk) + if wid != myid() + return remotecall_fetch(is_stored, wid, cache, ainfo) + end + cache_raw = unwrap(cache.chunk)::AliasedObjectCacheStore + return is_stored(cache_raw, cache.space, ainfo) +end +function get_stored(cache::AliasedObjectCache, ainfo::AbstractAliasing) + wid = root_worker_id(cache.chunk) + if wid != myid() + return remotecall_fetch(get_stored, wid, cache, ainfo) + end + cache_raw = unwrap(cache.chunk)::AliasedObjectCacheStore + return get_stored(cache_raw, cache.space, ainfo) +end +function set_stored!(cache::AliasedObjectCache, value::Chunk, ainfo::AbstractAliasing, orig_space::MemorySpace) + wid = root_worker_id(cache.chunk) + if wid != myid() + return remotecall_fetch(set_stored!, wid, cache, value, ainfo, orig_space) + end + cache_raw = unwrap(cache.chunk)::AliasedObjectCacheStore + set_stored!(cache_raw, cache.space, value, ainfo, orig_space) + return +end +function aliased_object!(f, cache::AliasedObjectCache, x; ainfo=aliasing(x, identity)) + if is_stored(cache, ainfo) + return get_stored(cache, ainfo) + else + y = f(x) + @assert y isa Chunk "Didn't get a Chunk from functor" + @assert memory_space(y) == cache.space "Space mismatch! $(memory_space(y)) != $(cache.space)" + if memory_space(x) != cache.space + @assert ainfo != aliasing(y, identity) "Aliasing mismatch! $ainfo == $(aliasing(y, identity))" + end + set_stored!(cache, y, ainfo, memory_space(x)) + return y + end +end + struct DataDepsState # The mapping of original raw argument to its Chunk raw_arg_to_chunk::IdDict{Any,Chunk} @@ -280,7 +369,7 @@ struct DataDepsState # The mapping of, for a given memory space, the backing Chunks that an ainfo references # Used by slot generation to replace the backing Chunks during move - ainfo_backing_chunk::Dict{MemorySpace,Dict{AbstractAliasing,Chunk}} + ainfo_backing_chunk::Chunk{AliasedObjectCacheStore} # Cache of argument's supports_inplace_move query result supports_inplace_cache::IdDict{Any,Bool} @@ -315,10 +404,10 @@ struct DataDepsState remote_arg_to_original = IdDict{Any,Any}() remote_arg_w = Dict{ArgumentWrapper,Dict{MemorySpace,ArgumentWrapper}}() ainfo_arg = Dict{AliasingWrapper,Set{ArgumentWrapper}}() + arg_history = Dict{ArgumentWrapper,Vector{HistoryEntry}}() arg_owner = Dict{ArgumentWrapper,MemorySpace}() arg_overlaps = Dict{ArgumentWrapper,Set{ArgumentWrapper}}() - ainfo_backing_chunk = Dict{MemorySpace,Dict{AbstractAliasing,Chunk}}() - arg_history = Dict{ArgumentWrapper,Vector{HistoryEntry}}() + ainfo_backing_chunk = tochunk(AliasedObjectCacheStore()) supports_inplace_cache = IdDict{Any,Bool}() ainfo_cache = Dict{ArgumentWrapper,AliasingWrapper}() @@ -329,7 +418,7 @@ struct DataDepsState ainfos_owner = Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}}() ainfos_readers = Dict{AliasingWrapper,Vector{Pair{DTask,Int}}}() - return new(arg_to_chunk, arg_origin, remote_args, remote_arg_to_original, remote_arg_w, ainfo_arg, arg_owner, arg_overlaps, ainfo_backing_chunk, arg_history, + return new(arg_to_chunk, arg_origin, remote_args, remote_arg_to_original, remote_arg_w, ainfo_arg, arg_history, arg_owner, arg_overlaps, ainfo_backing_chunk, supports_inplace_cache, ainfo_cache, ainfos_lookup, ainfos_overlaps, ainfos_owner, ainfos_readers) end end @@ -638,11 +727,7 @@ function generate_slot!(state::DataDepsState, dest_space, data) to_proc = first(processors(dest_space)) from_proc = first(processors(orig_space)) dest_space_args = get!(IdDict{Any,Any}, state.remote_args, dest_space) - if !haskey(state.ainfo_backing_chunk, dest_space) - state.ainfo_backing_chunk[dest_space] = Dict{AbstractAliasing,Chunk}() - end - # FIXME: tochunk the cache just once per space - aliased_object_cache = AliasedObjectCache(tochunk(state.ainfo_backing_chunk[dest_space])) + aliased_object_cache = AliasedObjectCache(dest_space, state.ainfo_backing_chunk) ctx = Sch.eager_context() id = rand(Int) @maybelog ctx timespan_start(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data)) @@ -664,45 +749,6 @@ function get_or_generate_slot!(state, dest_space, data) end return state.remote_args[dest_space][data] end -struct AliasedObjectCache - chunk::Chunk -end -@warn "Document these public methods" maxlog=1 -function Base.haskey(cache::AliasedObjectCache, ainfo::AbstractAliasing) - wid = root_worker_id(cache.chunk) - if wid != myid() - return remotecall_fetch(haskey, wid, cache, ainfo) - end - cache_raw = unwrap(cache.chunk)::Dict{AbstractAliasing,Chunk} - return haskey(cache_raw, ainfo) -end -function Base.getindex(cache::AliasedObjectCache, ainfo::AbstractAliasing) - wid = root_worker_id(cache.chunk) - if wid != myid() - return remotecall_fetch(getindex, wid, cache, ainfo) - end - cache_raw = unwrap(cache.chunk)::Dict{AbstractAliasing,Chunk} - return getindex(cache_raw, ainfo) -end -function Base.setindex!(cache::AliasedObjectCache, value::Chunk, ainfo::AbstractAliasing) - wid = root_worker_id(cache.chunk) - if wid != myid() - return remotecall_fetch(setindex!, wid, cache, value, ainfo) - end - cache_raw = unwrap(cache.chunk)::Dict{AbstractAliasing,Chunk} - cache_raw[ainfo] = value - return -end -function aliased_object!(f, cache::AliasedObjectCache, x; ainfo=aliasing(x, identity)) - if haskey(cache, ainfo) - return cache[ainfo] - else - y = f(x) - @assert y isa Chunk "Didn't get a Chunk from functor" - cache[ainfo] = y - return y - end -end function remotecall_endpoint(f, from_proc, to_proc, from_space, to_space, data) to_w = root_worker_id(to_proc) if to_w == myid() @@ -765,7 +811,7 @@ function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::P return tochunk(v_new, to_proc) end end -#= +#= FIXME: Make this work so we can automatically move-rewrap recursive objects function move_rewrap_recursive(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::T) where T if isstructtype(T) # Check all object fields (recursive) From 3be30908b84153d5de297b2be483dde5071cdf14 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Wed, 31 Dec 2025 13:20:28 -0500 Subject: [PATCH 25/28] datadeps: Reduce remainder restart distance --- src/datadeps/remainders.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/datadeps/remainders.jl b/src/datadeps/remainders.jl index 93a1e3222..bb42b513f 100644 --- a/src/datadeps/remainders.jl +++ b/src/datadeps/remainders.jl @@ -86,13 +86,14 @@ function compute_remainder_for_arg!(state::DataDepsState, target_space::MemorySpace, arg_w::ArgumentWrapper, write_num::Int; compute_syncdeps::Bool=true) - @label restart - - # Determine all memory spaces of the history spaces_set = Set{MemorySpace}() push!(spaces_set, target_space) owner_space = state.arg_owner[arg_w] push!(spaces_set, owner_space) + + @label restart + + # Determine all memory spaces of the history for entry in state.arg_history[arg_w] push!(spaces_set, entry.space) end From cefbc2150ac43a92a254bbf91defe878c86cb50e Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Wed, 31 Dec 2025 13:21:50 -0500 Subject: [PATCH 26/28] datadeps: Properly handle nested structures for remainder copies --- src/datadeps/remainders.jl | 87 +++++++++++++++++++++++++++++--------- src/utils/haloarray.jl | 20 +++++++++ 2 files changed, 87 insertions(+), 20 deletions(-) diff --git a/src/datadeps/remainders.jl b/src/datadeps/remainders.jl index bb42b513f..ed0a9fc1c 100644 --- a/src/datadeps/remainders.jl +++ b/src/datadeps/remainders.jl @@ -406,8 +406,6 @@ end function move!(dep_mod::RemainderAliasing{S}, to_space::MemorySpace, from_space::MemorySpace, to::Chunk, from::Chunk) where S # TODO: Support direct copy between GPU memory spaces - @assert sizeof(eltype(chunktype(from))) == sizeof(eltype(chunktype(to))) "Source and destination chunks have different element sizes: $(sizeof(eltype(chunktype(from)))) != $(sizeof(eltype(chunktype(to))))" - # Copy the data from the source object copies = remotecall_fetch(root_worker_id(from_space), from_space, dep_mod, from) do from_space, dep_mod, from len = sum(span_tuple->span_len(span_tuple[1]), dep_mod.spans) @@ -417,10 +415,7 @@ function move!(dep_mod::RemainderAliasing{S}, to_space::MemorySpace, from_space: with_context!(from_space) GC.@preserve copies begin for (from_span, _) in dep_mod.spans - elsize = sizeof(eltype(from_raw)) - offset_n = UInt64((offset-1) / elsize) + UInt64(1) - n = UInt64(from_span.len / elsize) - read_remainder!(copies, offset_n, from_raw, from_span.ptr, n) + read_remainder!(copies, offset, from_raw, from_span.ptr, from_span.len) offset += from_span.len end end @@ -433,10 +428,7 @@ function move!(dep_mod::RemainderAliasing{S}, to_space::MemorySpace, from_space: to_raw = unwrap(to) GC.@preserve copies begin for (_, to_span) in dep_mod.spans - elsize = sizeof(eltype(to_raw)) - offset_n = UInt64((offset-1) / elsize) + UInt64(1) - n = UInt64(to_span.len / elsize) - write_remainder!(copies, offset_n, to_raw, to_span.ptr, n) + write_remainder!(copies, offset, to_raw, to_span.ptr, to_span.len) offset += to_span.len end @assert offset == length(copies)+UInt64(1) @@ -448,24 +440,79 @@ function move!(dep_mod::RemainderAliasing{S}, to_space::MemorySpace, from_space: return end -function read_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, from::DenseArray, from_ptr::UInt64, n::UInt64) +function read_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, from::Array, from_ptr::UInt64, len::UInt64) + elsize = sizeof(eltype(from)) + @assert len / elsize == round(UInt64, len / elsize) "Span length is not an integer multiple of the element size: $(len) / $(elsize) = $(len / elsize) (elsize: $elsize)" + n = UInt64(len / elsize) + from_offset_n = UInt64((from_ptr - UInt64(pointer(from))) / elsize) + UInt64(1) + from_vec = reshape(from, prod(size(from)))::DenseVector{eltype(from)} + # unsafe_wrap(Array, ...) doesn't like unaligned memory + unsafe_copyto!(Ptr{eltype(from)}(pointer(copies, copies_offset)), pointer(from_vec, from_offset_n), n) +end +function read_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, from::DenseArray, from_ptr::UInt64, len::UInt64) elsize = sizeof(eltype(from)) - from_offset = UInt64((from_ptr - UInt64(pointer(from))) / elsize) + UInt64(1) + @assert len / elsize == round(UInt64, len / elsize) "Span length is not an integer multiple of the element size: $(len) / $(elsize) = $(len / elsize) (elsize: $elsize)" + n = UInt64(len / elsize) + from_offset_n = UInt64((from_ptr - UInt64(pointer(from))) / elsize) + UInt64(1) from_vec = reshape(from, prod(size(from)))::DenseVector{eltype(from)} copies_typed = unsafe_wrap(Vector{eltype(from)}, Ptr{eltype(from)}(pointer(copies, copies_offset)), n) - copyto!(copies_typed, 1, from_vec, Int(from_offset), Int(n)) + copyto!(copies_typed, 1, from_vec, Int(from_offset_n), Int(n)) end -function read_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, from::SubArray, from_ptr::UInt64, n::UInt64) - read_remainder!(copies, copies_offset, parent(from), from_ptr, n) +function read_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, from, from_ptr::UInt64, n::UInt64) + real_from = find_object_holding_ptr(from, from_ptr) + return read_remainder!(copies, copies_offset, real_from, from_ptr, n) end -function write_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, to::DenseArray, to_ptr::UInt64, n::UInt64) +function write_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, to::Array, to_ptr::UInt64, len::UInt64) elsize = sizeof(eltype(to)) - to_offset = UInt64((to_ptr - UInt64(pointer(to))) / elsize) + UInt64(1) + @assert len / elsize == round(UInt64, len / elsize) "Span length is not an integer multiple of the element size: $(len) / $(elsize) = $(len / elsize) (elsize: $elsize)" + n = UInt64(len / elsize) + to_offset_n = UInt64((to_ptr - UInt64(pointer(to))) / elsize) + UInt64(1) + to_vec = reshape(to, prod(size(to)))::DenseVector{eltype(to)} + # unsafe_wrap(Array, ...) doesn't like unaligned memory + unsafe_copyto!(pointer(to_vec, to_offset_n), Ptr{eltype(to)}(pointer(copies, copies_offset)), n) +end +function write_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, to::DenseArray, to_ptr::UInt64, len::UInt64) + elsize = sizeof(eltype(to)) + @assert len / elsize == round(UInt64, len / elsize) "Span length is not an integer multiple of the element size: $(len) / $(elsize) = $(len / elsize) (elsize: $elsize)" + n = UInt64(len / elsize) + to_offset_n = UInt64((to_ptr - UInt64(pointer(to))) / elsize) + UInt64(1) to_vec = reshape(to, prod(size(to)))::DenseVector{eltype(to)} copies_typed = unsafe_wrap(Vector{eltype(to)}, Ptr{eltype(to)}(pointer(copies, copies_offset)), n) - copyto!(to_vec, Int(to_offset), copies_typed, 1, Int(n)) + copyto!(to_vec, Int(to_offset_n), copies_typed, 1, Int(n)) +end +function write_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, to, to_ptr::UInt64, n::UInt64) + real_to = find_object_holding_ptr(to, to_ptr) + return write_remainder!(copies, copies_offset, real_to, to_ptr, n) end -function write_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, to::SubArray, to_ptr::UInt64, n::UInt64) - write_remainder!(copies, copies_offset, parent(to), to_ptr, n) + +# Remainder copies for common objects +for wrapper in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular, SubArray) + @eval function read_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, from::$wrapper, from_ptr::UInt64, n::UInt64) + read_remainder!(copies, copies_offset, parent(from), from_ptr, n) + end + @eval function write_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, to::$wrapper, to_ptr::UInt64, n::UInt64) + write_remainder!(copies, copies_offset, parent(to), to_ptr, n) + end +end +# N.B. We don't handle pointer aliasing in remainder copies +function read_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, from::Base.RefValue, from_ptr::UInt64, n::UInt64) + read_remainder!(copies, copies_offset, from[], from_ptr, n) end +function write_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, to::Base.RefValue, to_ptr::UInt64, n::UInt64) + write_remainder!(copies, copies_offset, to[], to_ptr, n) +end + +function find_object_holding_ptr(A::SparseMatrixCSC, ptr::UInt64) + span = LocalMemorySpan(pointer(A.nzval), length(A.nzval)*sizeof(eltype(A.nzval))) + if span_start(span) <= ptr <= span_end(span) + return A.nzval + end + span = LocalMemorySpan(pointer(A.colptr), length(A.colptr)*sizeof(eltype(A.colptr))) + if span_start(span) <= ptr <= span_end(span) + return A.colptr + end + span = LocalMemorySpan(pointer(A.rowval), length(A.rowval)*sizeof(eltype(A.rowval))) + @assert span_start(span) <= ptr <= span_end(span) "Pointer $ptr not found in SparseMatrixCSC" + return A.rowval +end \ No newline at end of file diff --git a/src/utils/haloarray.jl b/src/utils/haloarray.jl index e47b70c8b..e34cf2b05 100644 --- a/src/utils/haloarray.jl +++ b/src/utils/haloarray.jl @@ -117,4 +117,24 @@ function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::P corners_new = ntuple(i->move(from_proc, to_proc, corner_chunks[i]), length(corner_chunks)) return tochunk(HaloArray(center_new, edges_new, corners_new, halo_width), to_proc) end +end +function find_object_holding_ptr(object::HaloArray, ptr::UInt64) + for i in 1:length(object.edges) + edge = object.edges[i] + span = LocalMemorySpan(pointer(edge), length(edge)*sizeof(eltype(edge))) + if span_start(span) <= ptr <= span_end(span) + return edge + end + end + for i in 1:length(object.corners) + corner = object.corners[i] + span = LocalMemorySpan(pointer(corner), length(corner)*sizeof(eltype(corner))) + if span_start(span) <= ptr <= span_end(span) + return corner + end + end + center = object.center + span = LocalMemorySpan(pointer(center), length(center)*sizeof(eltype(center))) + @assert span_start(span) <= ptr <= span_end(span) "Pointer $ptr not found in HaloArray" + return center end \ No newline at end of file From 13bc0a39985919d5c44af9332a03ac5e7cb9a23e Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Thu, 1 Jan 2026 13:12:09 -0500 Subject: [PATCH 27/28] ManyPair: Add missing convert rule --- src/utils/memory-span.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/utils/memory-span.jl b/src/utils/memory-span.jl index 0664139d0..e6140c726 100644 --- a/src/utils/memory-span.jl +++ b/src/utils/memory-span.jl @@ -105,6 +105,7 @@ struct ManyPair{N} <: Unsigned pairs::NTuple{N,UInt} end Base.promote_rule(::Type{ManyPair}, ::Type{T}) where {T<:Integer} = ManyPair +Base.convert(::Type{ManyPair{N}}, pair::ManyPair{N}) where N = pair Base.convert(::Type{ManyPair{N}}, x::T) where {T<:Integer,N} = ManyPair(ntuple(i -> x, N)) Base.convert(::Type{ManyPair}, x::ManyPair) = x Base.:+(x::ManyPair{N}, y::ManyPair{N}) where N = ManyPair(ntuple(i -> x.pairs[i] + y.pairs[i], N)) From 188edec8ff397669614ac0a2f2b5343b459255af Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Thu, 8 Jan 2026 22:20:02 -0500 Subject: [PATCH 28/28] datadeps: Disallow aliasing=false --- src/datadeps/queue.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/datadeps/queue.jl b/src/datadeps/queue.jl index c7b5e2bc1..12e773d7a 100644 --- a/src/datadeps/queue.jl +++ b/src/datadeps/queue.jl @@ -79,6 +79,9 @@ function spawn_datadeps(f::Base.Callable; static::Bool=true, if !static throw(ArgumentError("Dynamic scheduling is no longer available")) end + if !aliasing + throw(ArgumentError("Aliasing analysis is no longer optional")) + end wait_all(; check_errors=true) do scheduler = something(scheduler, DATADEPS_SCHEDULER[], :roundrobin)::Symbol launch_wait = something(launch_wait, DATADEPS_LAUNCH_WAIT[], false)::Bool