diff --git a/Project.toml b/Project.toml index 17a1cbe..be0ea81 100644 --- a/Project.toml +++ b/Project.toml @@ -1,9 +1,16 @@ name = "FunctionImplementations" uuid = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c" authors = ["ITensor developers and contributors"] -version = "0.2.0" +version = "0.2.1" + +[weakdeps] +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[extensions] +FunctionImplementationsLinearAlgebraExt = "LinearAlgebra" [compat] +LinearAlgebra = "1.10" julia = "1.10" [workspace] diff --git a/ext/FunctionImplementationsLinearAlgebraExt/FunctionImplementationsLinearAlgebraExt.jl b/ext/FunctionImplementationsLinearAlgebraExt/FunctionImplementationsLinearAlgebraExt.jl new file mode 100644 index 0000000..6182553 --- /dev/null +++ b/ext/FunctionImplementationsLinearAlgebraExt/FunctionImplementationsLinearAlgebraExt.jl @@ -0,0 +1,15 @@ +module FunctionImplementationsLinearAlgebraExt + +import FunctionImplementations as FI +import LinearAlgebra as LA + +struct DiagonalStyle <: FI.AbstractMatrixStyle end +FI.Style(::Type{<:LA.Diagonal}) = DiagonalStyle() +const permuteddims_diag = DiagonalStyle()(FI.permuteddims) +function permuteddims_diag(a::AbstractArray, perm) + (ndims(a) == length(perm) && isperm(perm)) || + throw(ArgumentError("no valid permutation of dimensions")) + return a +end + +end diff --git a/src/FunctionImplementations.jl b/src/FunctionImplementations.jl index 186c693..5581f9e 100644 --- a/src/FunctionImplementations.jl +++ b/src/FunctionImplementations.jl @@ -2,5 +2,6 @@ module FunctionImplementations include("implementation.jl") include("style.jl") +include("permuteddims.jl") end diff --git a/src/permuteddims.jl b/src/permuteddims.jl new file mode 100644 index 0000000..95758a9 --- /dev/null +++ b/src/permuteddims.jl @@ -0,0 +1,11 @@ +# See: https://github.com/JuliaLang/julia/issues/53188 +""" + permuteddims(a::AbstractArray, perm) + +Lazy version of `permutedims`. Defaults to constructing a `Base.PermutedDimsArray` +but can be customized to output a different type of array. +""" +permuteddims(a::AbstractArray, perm) = style(a)(permuteddims)(a, perm) +function (::Implementation{typeof(permuteddims)})(a::AbstractArray, perm) + return PermutedDimsArray(a, perm) +end diff --git a/src/style.jl b/src/style.jl index 8324030..c05eefd 100644 --- a/src/style.jl +++ b/src/style.jl @@ -54,6 +54,8 @@ define binary [`Style`](@ref) rules to control the output type. See also [`FunctionImplementations.DefaultArrayStyle`](@ref). """ abstract type AbstractArrayStyle{N} <: Style end +abstract type AbstractVectorStyle <: AbstractArrayStyle{1} end +abstract type AbstractMatrixStyle <: AbstractArrayStyle{2} end """ `FunctionImplementations.DefaultArrayStyle{N}()` is a [`FunctionImplementations.Style`](@ref) indicating that an object diff --git a/test/Project.toml b/test/Project.toml index 8f9d369..00f25b1 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,7 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" FunctionImplementations = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -11,6 +12,7 @@ FunctionImplementations = {path = ".."} [compat] Aqua = "0.8" FunctionImplementations = "0.2" +LinearAlgebra = "1.10" SafeTestsets = "0.1" Suppressor = "0.2" Test = "1.10" diff --git a/test/test_permuteddims.jl b/test/test_permuteddims.jl new file mode 100644 index 0000000..14b0480 --- /dev/null +++ b/test/test_permuteddims.jl @@ -0,0 +1,18 @@ +import FunctionImplementations as FI +import LinearAlgebra as LA +using Test: @test, @testset + +@testset "permuteddims" begin + @testset "Array" begin + a = randn(2, 3) + b = FI.permuteddims(a, (2, 1)) + @test b ≡ PermutedDimsArray(a, (2, 1)) + @test size(b) == (3, 2) + @test b == permutedims(a, (2, 1)) + end + @testset "Diagonal" begin + a = LA.Diagonal(randn(3)) + b = FI.permuteddims(a, (2, 1)) + @test b ≡ a + end +end