From d6667df5cdca13b85b8b9bcd657079864e1217ac Mon Sep 17 00:00:00 2001 From: mtfishman Date: Sun, 28 Dec 2025 18:05:55 -0500 Subject: [PATCH 1/3] permuteddims --- Project.toml | 15 +++++++++++---- .../FunctionImplementationsLinearAlgebraExt.jl | 15 +++++++++++++++ src/FunctionImplementations.jl | 1 + src/permuteddims.jl | 11 +++++++++++ src/style.jl | 2 ++ test/Project.toml | 2 ++ test/test_permuteddims.jl | 18 ++++++++++++++++++ 7 files changed, 60 insertions(+), 4 deletions(-) create mode 100644 ext/FunctionImplementationsLinearAlgebraExt/FunctionImplementationsLinearAlgebraExt.jl create mode 100644 src/permuteddims.jl create mode 100644 test/test_permuteddims.jl diff --git a/Project.toml b/Project.toml index 17a1cbe..d64a42a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,10 +1,17 @@ name = "FunctionImplementations" uuid = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c" authors = ["ITensor developers and contributors"] -version = "0.2.0" - -[compat] -julia = "1.10" +version = "0.2.1" [workspace] projects = ["benchmark", "dev", "docs", "examples", "test"] + +[weakdeps] +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[extensions] +FunctionImplementationsLinearAlgebraExt = "LinearAlgebra" + +[compat] +LinearAlgebra = "1.10" +julia = "1.10" diff --git a/ext/FunctionImplementationsLinearAlgebraExt/FunctionImplementationsLinearAlgebraExt.jl b/ext/FunctionImplementationsLinearAlgebraExt/FunctionImplementationsLinearAlgebraExt.jl new file mode 100644 index 0000000..249bc7f --- /dev/null +++ b/ext/FunctionImplementationsLinearAlgebraExt/FunctionImplementationsLinearAlgebraExt.jl @@ -0,0 +1,15 @@ +module FunctionImplementationsLinearAlgebraExt + +import FunctionImplementations as FI +import LinearAlgebra as LA + +struct DiagonalStyle <: FI.AbstractMatrixStyle end +FI.Style(::Type{<:LA.Diagonal}) = DiagonalStyle() +const permuteddims_diag = FI.Implementation(FI.permuteddims, DiagonalStyle()) +function permuteddims_diag(a::AbstractArray, perm) + (ndims(a) == length(perm) && isperm(perm)) || + throw(ArgumentError("no valid permutation of dimensions")) + return a +end + +end diff --git a/src/FunctionImplementations.jl b/src/FunctionImplementations.jl index 186c693..5581f9e 100644 --- a/src/FunctionImplementations.jl +++ b/src/FunctionImplementations.jl @@ -2,5 +2,6 @@ module FunctionImplementations include("implementation.jl") include("style.jl") +include("permuteddims.jl") end diff --git a/src/permuteddims.jl b/src/permuteddims.jl new file mode 100644 index 0000000..95758a9 --- /dev/null +++ b/src/permuteddims.jl @@ -0,0 +1,11 @@ +# See: https://github.com/JuliaLang/julia/issues/53188 +""" + permuteddims(a::AbstractArray, perm) + +Lazy version of `permutedims`. Defaults to constructing a `Base.PermutedDimsArray` +but can be customized to output a different type of array. +""" +permuteddims(a::AbstractArray, perm) = style(a)(permuteddims)(a, perm) +function (::Implementation{typeof(permuteddims)})(a::AbstractArray, perm) + return PermutedDimsArray(a, perm) +end diff --git a/src/style.jl b/src/style.jl index 8324030..c05eefd 100644 --- a/src/style.jl +++ b/src/style.jl @@ -54,6 +54,8 @@ define binary [`Style`](@ref) rules to control the output type. See also [`FunctionImplementations.DefaultArrayStyle`](@ref). """ abstract type AbstractArrayStyle{N} <: Style end +abstract type AbstractVectorStyle <: AbstractArrayStyle{1} end +abstract type AbstractMatrixStyle <: AbstractArrayStyle{2} end """ `FunctionImplementations.DefaultArrayStyle{N}()` is a [`FunctionImplementations.Style`](@ref) indicating that an object diff --git a/test/Project.toml b/test/Project.toml index 8f9d369..00f25b1 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,7 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" FunctionImplementations = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -11,6 +12,7 @@ FunctionImplementations = {path = ".."} [compat] Aqua = "0.8" FunctionImplementations = "0.2" +LinearAlgebra = "1.10" SafeTestsets = "0.1" Suppressor = "0.2" Test = "1.10" diff --git a/test/test_permuteddims.jl b/test/test_permuteddims.jl new file mode 100644 index 0000000..14b0480 --- /dev/null +++ b/test/test_permuteddims.jl @@ -0,0 +1,18 @@ +import FunctionImplementations as FI +import LinearAlgebra as LA +using Test: @test, @testset + +@testset "permuteddims" begin + @testset "Array" begin + a = randn(2, 3) + b = FI.permuteddims(a, (2, 1)) + @test b ≡ PermutedDimsArray(a, (2, 1)) + @test size(b) == (3, 2) + @test b == permutedims(a, (2, 1)) + end + @testset "Diagonal" begin + a = LA.Diagonal(randn(3)) + b = FI.permuteddims(a, (2, 1)) + @test b ≡ a + end +end From b009797abbdf5d84b9ccfaea89050675dea0c924 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Sun, 28 Dec 2025 18:14:40 -0500 Subject: [PATCH 2/3] Reorganize Project.toml --- Project.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index d64a42a..be0ea81 100644 --- a/Project.toml +++ b/Project.toml @@ -3,9 +3,6 @@ uuid = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c" authors = ["ITensor developers and contributors"] version = "0.2.1" -[workspace] -projects = ["benchmark", "dev", "docs", "examples", "test"] - [weakdeps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -15,3 +12,6 @@ FunctionImplementationsLinearAlgebraExt = "LinearAlgebra" [compat] LinearAlgebra = "1.10" julia = "1.10" + +[workspace] +projects = ["benchmark", "dev", "docs", "examples", "test"] From 5621b488d203265237ca81b85233d86a95add226 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Sun, 28 Dec 2025 18:23:08 -0500 Subject: [PATCH 3/3] Code style --- .../FunctionImplementationsLinearAlgebraExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/FunctionImplementationsLinearAlgebraExt/FunctionImplementationsLinearAlgebraExt.jl b/ext/FunctionImplementationsLinearAlgebraExt/FunctionImplementationsLinearAlgebraExt.jl index 249bc7f..6182553 100644 --- a/ext/FunctionImplementationsLinearAlgebraExt/FunctionImplementationsLinearAlgebraExt.jl +++ b/ext/FunctionImplementationsLinearAlgebraExt/FunctionImplementationsLinearAlgebraExt.jl @@ -5,7 +5,7 @@ import LinearAlgebra as LA struct DiagonalStyle <: FI.AbstractMatrixStyle end FI.Style(::Type{<:LA.Diagonal}) = DiagonalStyle() -const permuteddims_diag = FI.Implementation(FI.permuteddims, DiagonalStyle()) +const permuteddims_diag = DiagonalStyle()(FI.permuteddims) function permuteddims_diag(a::AbstractArray, perm) (ndims(a) == length(perm) && isperm(perm)) || throw(ArgumentError("no valid permutation of dimensions"))