diff --git a/Project.toml b/Project.toml index 2b008d4..debf4cb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,15 +1,21 @@ name = "FunctionImplementations" uuid = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c" authors = ["ITensor developers and contributors"] -version = "0.3.0" +version = "0.3.1" [weakdeps] +BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [extensions] +FunctionImplementationsBlockArraysExt = "BlockArrays" +FunctionImplementationsFillArraysExt = "FillArrays" FunctionImplementationsLinearAlgebraExt = "LinearAlgebra" [compat] +BlockArrays = "1.4" +FillArrays = "1.15" LinearAlgebra = "1.10" julia = "1.10" diff --git a/docs/src/reference.md b/docs/src/reference.md index c377d72..e1ff9b6 100644 --- a/docs/src/reference.md +++ b/docs/src/reference.md @@ -1,5 +1,5 @@ # Reference ```@autodocs -Modules = [FunctionImplementations] +Modules = [FunctionImplementations, FunctionImplementations.Concatenate] ``` diff --git a/ext/FunctionImplementationsBlockArraysExt/FunctionImplementationsBlockArraysExt.jl b/ext/FunctionImplementationsBlockArraysExt/FunctionImplementationsBlockArraysExt.jl new file mode 100644 index 0000000..2be20bf --- /dev/null +++ b/ext/FunctionImplementationsBlockArraysExt/FunctionImplementationsBlockArraysExt.jl @@ -0,0 +1,11 @@ +module FunctionImplementationsBlockArraysExt + +using BlockArrays: AbstractBlockedUnitRange, blockedrange, blocklengths +using FunctionImplementations.Concatenate: Concatenate + +function Concatenate.cat_axis(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange) + first(a1) == first(a2) == 1 || throw(ArgumentError("Concatenated axes must start at 1")) + return blockedrange([blocklengths(a1); blocklengths(a2)]) +end + +end diff --git a/ext/FunctionImplementationsFillArraysExt/FunctionImplementationsFillArraysExt.jl b/ext/FunctionImplementationsFillArraysExt/FunctionImplementationsFillArraysExt.jl new file mode 100644 index 0000000..0a9fc1a --- /dev/null +++ b/ext/FunctionImplementationsFillArraysExt/FunctionImplementationsFillArraysExt.jl @@ -0,0 +1,12 @@ +module FunctionImplementationsFillArraysExt + +using FillArrays: RectDiagonal +using FunctionImplementations: FunctionImplementations + +function FunctionImplementations.permuteddims(a::RectDiagonal, perm) + (ndims(a) == length(perm) && isperm(perm)) || + throw(ArgumentError("no valid permutation of dimensions")) + return RectDiagonal(parent(a), ntuple(d -> axes(a)[perm[d]], ndims(a))) +end + +end diff --git a/src/FunctionImplementations.jl b/src/FunctionImplementations.jl index 5581f9e..178b727 100644 --- a/src/FunctionImplementations.jl +++ b/src/FunctionImplementations.jl @@ -3,5 +3,7 @@ module FunctionImplementations include("implementation.jl") include("style.jl") include("permuteddims.jl") +include("zero.jl") +include("concatenate.jl") end diff --git a/src/concatenate.jl b/src/concatenate.jl new file mode 100644 index 0000000..0b1e644 --- /dev/null +++ b/src/concatenate.jl @@ -0,0 +1,225 @@ +""" + module Concatenate + +Alternative implementation for `Base.cat` through `Concatenate.cat(!)`. + +This is mostly a copy of the Base implementation, with the main difference being +that the destination is chosen based on all inputs instead of just the first. + +Additionally, we have an intermediate representation in terms of a Concatenated object, +reminiscent of how Broadcast works. + +The various entry points for specializing behavior are: + +* Destination selection can be achieved through: + +```julia +Base.similar(concat::Concatenated{Style}, ::Type{T}, axes) where {Style} +``` + +* Custom implementations: + +```julia +Base.copy(concat::Concatenated{Style}) # custom implementation of cat +Base.copyto!(dest, concat::Concatenated{Style}) # custom implementation of cat! based on style +Base.copyto!(dest, concat::Concatenated{Nothing}) # custom implementation of cat! based on typeof(dest) +``` +""" +module Concatenate + +export concatenate +VERSION >= v"1.11.0-DEV.469" && eval(Meta.parse("public Concatenated, cat, cat!, concatenated")) + +using Base: promote_eltypeof +import Base.Broadcast as BC +using ..FunctionImplementations: zero! + +unval(::Val{x}) where {x} = x + +function _Concatenated end + +""" + Concatenated{Style, Dims, Args <: Tuple} + +Lazy representation of the concatenation of various `Args` along `Dims`, in order to provide +hooks to customize the implementation. +""" +struct Concatenated{Style, Dims, Args <: Tuple} + style::Style + dims::Val{Dims} + args::Args + global @inline function _Concatenated( + style::Style, dims::Val{Dims}, args::Args + ) where {Style, Dims, Args <: Tuple} + return new{Style, Dims, Args}(style, dims, args) + end +end + +function Concatenated( + style::Union{BC.AbstractArrayStyle, Nothing}, dims::Val, args::Tuple + ) + return _Concatenated(style, dims, args) +end +function Concatenated(dims::Val, args::Tuple) + return Concatenated(cat_style(dims, args...), dims, args) +end +function Concatenated{Style}( + dims::Val, args::Tuple + ) where {Style <: Union{BC.AbstractArrayStyle, Nothing}} + return Concatenated(Style(), dims, args) +end + +dims(::Concatenated{<:Any, D}) where {D} = D +style(concat::Concatenated) = getfield(concat, :style) + +concatenated(dims, args...) = concatenated(Val(dims), args...) +concatenated(dims::Val, args...) = Concatenated(dims, args) + +function Base.convert( + ::Type{Concatenated{NewStyle}}, concat::Concatenated{<:Any, Dims, Args} + ) where {NewStyle, Dims, Args} + return Concatenated{NewStyle}( + concat.dims, concat.args + )::Concatenated{NewStyle, Dims, Args} +end + +# allocating the destination container +# ------------------------------------ +Base.similar(concat::Concatenated) = similar(concat, eltype(concat)) +Base.similar(concat::Concatenated, ::Type{T}) where {T} = similar(concat, T, axes(concat)) +function Base.similar(concat::Concatenated, ax) + return similar(concat, eltype(concat), ax) +end + +function Base.similar(concat::Concatenated, ::Type{T}, ax) where {T} + # Convert to a broadcasted to leverage its similar implementation. + bc = BC.Broadcasted(style(concat), identity, concat.args, ax) + return similar(bc, T) +end + +function cat_axis( + a1::AbstractUnitRange, a2::AbstractUnitRange, a_rest::AbstractUnitRange... + ) + return cat_axis(cat_axis(a1, a2), a_rest...) +end +function cat_axis(a1::AbstractUnitRange, a2::AbstractUnitRange) + first(a1) == first(a2) == 1 || throw(ArgumentError("Concatenated axes must start at 1")) + return Base.OneTo(length(a1) + length(a2)) +end + +function cat_ndims(dims, as::AbstractArray...) + return max(maximum(dims), maximum(ndims, as)) +end +function cat_ndims(dims::Val, as::AbstractArray...) + return cat_ndims(unval(dims), as...) +end + +function cat_axes(dims, a::AbstractArray, as::AbstractArray...) + return ntuple(cat_ndims(dims, a, as...)) do dim + return dim in dims ? cat_axis(map(Base.Fix2(axes, dim), (a, as...))...) : axes(a, dim) + end +end +function cat_axes(dims::Val, as::AbstractArray...) + return cat_axes(unval(dims), as...) +end + +function cat_style(dims, as::AbstractArray...) + N = cat_ndims(dims, as...) + return typeof(BC.combine_styles(as...))(Val(N)) +end + +Base.eltype(concat::Concatenated) = promote_eltypeof(concat.args...) +Base.axes(concat::Concatenated) = cat_axes(dims(concat), concat.args...) +Base.size(concat::Concatenated) = length.(axes(concat)) +Base.ndims(concat::Concatenated) = cat_ndims(dims(concat), concat.args...) + +# Main logic +# ---------- +""" + concatenate(dims, args...) + +Concatenate the supplied `args` along dimensions `dims`. + +See also [`cat`](@ref) and [`cat!`](@ref). +""" +concatenate(dims, args...) = Base.materialize(concatenated(dims, args...)) + +""" + Concatenate.cat(args...; dims) + +Concatenate the supplied `args` along dimensions `dims`. + +See also [`concatenate`](@ref) and [`cat!`](@ref). +""" +cat(args...; dims) = concatenate(dims, args...) +Base.materialize(concat::Concatenated) = copy(concat) + +""" + Concatenate.cat!(dest, args...; dims) + +Concatenate the supplied `args` along dimensions `dims`, placing the result into `dest`. +""" +function cat!(dest, args...; dims) + Base.materialize!(dest, concatenated(dims, args...)) + return dest +end +Base.materialize!(dest, concat::Concatenated) = copyto!(dest, concat) + +Base.copy(concat::Concatenated) = copyto!(similar(concat), concat) + +# The following is largely copied from the Base implementation of `Base.cat`, see: +# https://github.com/JuliaLang/julia/blob/885b1cd875f101f227b345f681cc36879124d80d/base/abstractarray.jl#L1778-L1887 +_copy_or_fill!(A, inds, x) = fill!(view(A, inds...), x) +_copy_or_fill!(A, inds, x::AbstractArray) = (A[inds...] = x) + +cat_size(A) = (1,) +cat_size(A::AbstractArray) = size(A) +cat_size(A, d) = 1 +cat_size(A::AbstractArray, d) = size(A, d) + +cat_indices(A, d) = Base.OneTo(1) +cat_indices(A::AbstractArray, d) = axes(A, d) + +function __cat!(A, shape, catdims, X...) + return __cat_offset!(A, shape, catdims, ntuple(zero, length(shape)), X...) +end +function __cat_offset!(A, shape, catdims, offsets, x, X...) + # splitting the "work" on x from X... may reduce latency (fewer costly specializations) + newoffsets = __cat_offset1!(A, shape, catdims, offsets, x) + return __cat_offset!(A, shape, catdims, newoffsets, X...) +end +__cat_offset!(A, shape, catdims, offsets) = A +function __cat_offset1!(A, shape, catdims, offsets, x) + inds = ntuple(length(offsets)) do i + (i <= length(catdims) && catdims[i]) ? offsets[i] .+ cat_indices(x, i) : 1:shape[i] + end + _copy_or_fill!(A, inds, x) + newoffsets = ntuple(length(offsets)) do i + (i <= length(catdims) && catdims[i]) ? offsets[i] + cat_size(x, i) : offsets[i] + end + return newoffsets +end + +dims2cat(dims::Val) = dims2cat(unval(dims)) +function dims2cat(dims) + if any(≤(0), dims) + throw(ArgumentError("All cat dimensions must be positive integers, but got $dims")) + end + return ntuple(in(dims), maximum(dims)) +end + +# default falls back to replacing style with Nothing +# this permits specializing on typeof(dest) without ambiguities +# Note: this needs to be defined for AbstractArray specifically to avoid ambiguities with Base. +@inline function Base.copyto!(dest::AbstractArray, concat::Concatenated) + return copyto!(dest, convert(Concatenated{Nothing}, concat)) +end + +function Base.copyto!(dest::AbstractArray, concat::Concatenated{Nothing}) + catdims = dims2cat(dims(concat)) + shape = size(concat) + count(!iszero, catdims)::Int > 1 && zero!(dest) + return __cat!(dest, shape, catdims, concat.args...) +end + +end diff --git a/src/zero.jl b/src/zero.jl new file mode 100644 index 0000000..a0790db --- /dev/null +++ b/src/zero.jl @@ -0,0 +1,10 @@ +""" + zero!(a::AbstractArray) + +In-place version of `zero(a)`, sets all entries of `a` to zero. +""" +zero!(a::AbstractArray) = style(a)(zero!)(a) +function (::Implementation{typeof(zero!)})(a::AbstractArray) + fill!(a, zero(eltype(a))) + return a +end diff --git a/test/Project.toml b/test/Project.toml index 191438c..c49bbae 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,10 @@ [deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" FunctionImplementations = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c" +JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" @@ -10,8 +14,12 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" FunctionImplementations = {path = ".."} [compat] +Adapt = "4" Aqua = "0.8" +BlockArrays = "1.4" +FillArrays = "1.15" FunctionImplementations = "0.3" +JLArrays = "0.3" LinearAlgebra = "1.10" SafeTestsets = "0.1" Suppressor = "0.2" diff --git a/test/test_blockarraysext.jl b/test/test_blockarraysext.jl new file mode 100644 index 0000000..2a4256e --- /dev/null +++ b/test/test_blockarraysext.jl @@ -0,0 +1,17 @@ +using BlockArrays: BlockArray, blockedrange, blockisequal +using FunctionImplementations.Concatenate: concatenate +using Test: @test, @testset + +@testset "BlockArraysExt" begin + a = BlockArray(randn(4, 4), [2, 2], [2, 2]) + b = BlockArray(randn(4, 4), [2, 2], [2, 2]) + + concat = concatenate(1, a, b) + @test axes(concat) == (Base.OneTo(8), Base.OneTo(4)) + @test blockisequal(axes(concat, 1), blockedrange([2, 2, 2, 2])) + @test blockisequal(axes(concat, 2), blockedrange([2, 2])) + @test size(concat) == (8, 4) + @test eltype(concat) ≡ Float64 + @test copy(concat) == cat(a, b; dims = 1) + @test copy(concat) isa BlockArray{Float64, 2} +end diff --git a/test/test_concatenate.jl b/test/test_concatenate.jl new file mode 100644 index 0000000..89aa0cd --- /dev/null +++ b/test/test_concatenate.jl @@ -0,0 +1,38 @@ +using Adapt: adapt +using FunctionImplementations.Concatenate: concatenated +using JLArrays: JLArray +using Test: @test, @testset + +@testset "Concatenated" for arrayt in (Array, JLArray) + dev = adapt(arrayt) + a = dev(randn(Float32, 2, 2)) + b = dev(randn(Float64, 2, 2)) + + concat = concatenated((1, 2), a, b) + @test axes(concat) == Base.OneTo.((4, 4)) + @test size(concat) == (4, 4) + @test eltype(concat) === Float64 + @test copy(concat) == cat(a, b; dims = (1, 2)) + @test copy(concat) isa arrayt{promote_type(eltype(a), eltype(b)), 2} + + concat = concatenated(1, a, b) + @test axes(concat) == Base.OneTo.((4, 2)) + @test size(concat) == (4, 2) + @test eltype(concat) === Float64 + @test copy(concat) == cat(a, b; dims = 1) + @test copy(concat) isa arrayt{promote_type(eltype(a), eltype(b)), 2} + + concat = concatenated(3, a, b) + @test axes(concat) == Base.OneTo.((2, 2, 2)) + @test size(concat) == (2, 2, 2) + @test eltype(concat) === Float64 + @test copy(concat) == cat(a, b; dims = 3) + @test copy(concat) isa arrayt{promote_type(eltype(a), eltype(b)), 3} + + concat = concatenated(4, a, b) + @test axes(concat) == Base.OneTo.((2, 2, 1, 2)) + @test size(concat) == (2, 2, 1, 2) + @test eltype(concat) === Float64 + @test copy(concat) == cat(a, b; dims = 4) + @test copy(concat) isa arrayt{promote_type(eltype(a), eltype(b)), 4} +end diff --git a/test/test_permuteddims.jl b/test/test_permuteddims.jl index 14b0480..07e4a4d 100644 --- a/test/test_permuteddims.jl +++ b/test/test_permuteddims.jl @@ -1,3 +1,4 @@ +import FillArrays as FA import FunctionImplementations as FI import LinearAlgebra as LA using Test: @test, @testset @@ -10,9 +11,15 @@ using Test: @test, @testset @test size(b) == (3, 2) @test b == permutedims(a, (2, 1)) end - @testset "Diagonal" begin + @testset "LinearAlgebra.Diagonal" begin a = LA.Diagonal(randn(3)) b = FI.permuteddims(a, (2, 1)) @test b ≡ a end + + @testset "FillArrays.RectDiagonal" begin + a = FA.RectDiagonal(randn(3), (3, 4)) + @test FI.permuteddims(a, (1, 2)) ≡ a + @test FI.permuteddims(a, (2, 1)) ≡ FA.RectDiagonal(parent(a), (4, 3)) + end end diff --git a/test/test_zero.jl b/test/test_zero.jl new file mode 100644 index 0000000..f57a08e --- /dev/null +++ b/test/test_zero.jl @@ -0,0 +1,8 @@ +using FunctionImplementations: zero! +using Test: @test, @testset + +@testset "zero!" begin + a = randn(2, 2) + zero!(a) + @test iszero(a) +end