From ccbb454b7552f6fca97f2cc607926b62cd9b4966 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 9 Jan 2026 17:09:40 -0500 Subject: [PATCH 1/2] Add permuteddims definitions for StridedViews and more FillArrays --- Project.toml | 5 +++- .../FunctionImplementationsFillArraysExt.jl | 26 ++++++++++++++++--- .../FunctionImplementationsStridedViewsExt.jl | 11 ++++++++ test/Project.toml | 2 ++ test/test_fillarraysext.jl | 26 +++++++++++++++++++ test/test_linearalgebraext.jl | 9 +++++++ test/test_permuteddims.jl | 25 ++++-------------- test/test_stridedviewsext.jl | 11 ++++++++ 8 files changed, 90 insertions(+), 25 deletions(-) create mode 100644 ext/FunctionImplementationsStridedViewsExt/FunctionImplementationsStridedViewsExt.jl create mode 100644 test/test_fillarraysext.jl create mode 100644 test/test_linearalgebraext.jl create mode 100644 test/test_stridedviewsext.jl diff --git a/Project.toml b/Project.toml index debf4cb..4f9cd4d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,22 +1,25 @@ name = "FunctionImplementations" uuid = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c" authors = ["ITensor developers and contributors"] -version = "0.3.1" +version = "0.3.2" [weakdeps] BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +StridedViews = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143" [extensions] FunctionImplementationsBlockArraysExt = "BlockArrays" FunctionImplementationsFillArraysExt = "FillArrays" FunctionImplementationsLinearAlgebraExt = "LinearAlgebra" +FunctionImplementationsStridedViewsExt = "StridedViews" [compat] BlockArrays = "1.4" FillArrays = "1.15" LinearAlgebra = "1.10" +StridedViews = "0.4.1" julia = "1.10" [workspace] diff --git a/ext/FunctionImplementationsFillArraysExt/FunctionImplementationsFillArraysExt.jl b/ext/FunctionImplementationsFillArraysExt/FunctionImplementationsFillArraysExt.jl index 0a9fc1a..7ea8b8a 100644 --- a/ext/FunctionImplementationsFillArraysExt/FunctionImplementationsFillArraysExt.jl +++ b/ext/FunctionImplementationsFillArraysExt/FunctionImplementationsFillArraysExt.jl @@ -1,12 +1,30 @@ module FunctionImplementationsFillArraysExt -using FillArrays: RectDiagonal -using FunctionImplementations: FunctionImplementations +using FillArrays: FillArrays as FA, AbstractFill, RectDiagonal +import FunctionImplementations as FI -function FunctionImplementations.permuteddims(a::RectDiagonal, perm) +function check_perm(a::AbstractArray, perm) (ndims(a) == length(perm) && isperm(perm)) || throw(ArgumentError("no valid permutation of dimensions")) - return RectDiagonal(parent(a), ntuple(d -> axes(a)[perm[d]], ndims(a))) + return nothing +end + +function perm_axes(a::AbstractArray, perm) + return ntuple(d -> axes(a)[perm[d]], ndims(a)) +end + +# This could call `permutedims` directly after +# https://github.com/JuliaArrays/FillArrays.jl/pull/319 is merged. +function FI.permuteddims(a::AbstractFill, perm) + check_perm(a, perm) + return FA.fillsimilar(parent(a), perm_axes(a, perm)) +end + +# This could call `permutedims` directly after +# https://github.com/JuliaArrays/FillArrays.jl/issues/413 is fixed. +function FI.permuteddims(a::RectDiagonal, perm) + check_perm(a, perm) + return RectDiagonal(parent(a), perm_axes(a, perm)) end end diff --git a/ext/FunctionImplementationsStridedViewsExt/FunctionImplementationsStridedViewsExt.jl b/ext/FunctionImplementationsStridedViewsExt/FunctionImplementationsStridedViewsExt.jl new file mode 100644 index 0000000..251ba38 --- /dev/null +++ b/ext/FunctionImplementationsStridedViewsExt/FunctionImplementationsStridedViewsExt.jl @@ -0,0 +1,11 @@ +module FunctionImplementationsStridedViewsExt + +using FunctionImplementations: FunctionImplementations +using StridedViews: StridedView + +# `permutedims` is lazy for `StridedView` so we can just call it directly. +function FunctionImplementations.permuteddims(a::StridedView, perm) + return permutedims(a, perm) +end + +end diff --git a/test/Project.toml b/test/Project.toml index c49bbae..b74165f 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -7,6 +7,7 @@ FunctionImplementations = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +StridedViews = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -22,5 +23,6 @@ FunctionImplementations = "0.3" JLArrays = "0.3" LinearAlgebra = "1.10" SafeTestsets = "0.1" +StridedViews = "0.4" Suppressor = "0.2" Test = "1.10" diff --git a/test/test_fillarraysext.jl b/test/test_fillarraysext.jl new file mode 100644 index 0000000..213d342 --- /dev/null +++ b/test/test_fillarraysext.jl @@ -0,0 +1,26 @@ +import FillArrays as FA +import FunctionImplementations as FI +using Test: @test, @testset + +@testset "FillArraysExt" begin + @testset "Fill" begin + a = FA.Fill(42, (2, 3)) + @test FI.permuteddims(a, (1, 2)) ≡ a + @test FI.permuteddims(a, (2, 1)) ≡ FA.Fill(42, (3, 2)) + end + @testset "Zeros" begin + a = FA.Zeros((2, 3)) + @test FI.permuteddims(a, (1, 2)) ≡ a + @test FI.permuteddims(a, (2, 1)) ≡ FA.Zeros((3, 2)) + end + @testset "Ones" begin + a = FA.Ones((2, 3)) + @test FI.permuteddims(a, (1, 2)) ≡ a + @test FI.permuteddims(a, (2, 1)) ≡ FA.Ones((3, 2)) + end + @testset "RectDiagonal" begin + a = FA.RectDiagonal(randn(3), (3, 4)) + @test FI.permuteddims(a, (1, 2)) ≡ a + @test FI.permuteddims(a, (2, 1)) ≡ FA.RectDiagonal(parent(a), (4, 3)) + end +end diff --git a/test/test_linearalgebraext.jl b/test/test_linearalgebraext.jl new file mode 100644 index 0000000..8fca0fc --- /dev/null +++ b/test/test_linearalgebraext.jl @@ -0,0 +1,9 @@ +import FunctionImplementations as FI +import LinearAlgebra as LA +using Test: @test, @testset + +@testset "LinearAlgebraExt" begin + a = LA.Diagonal(randn(3)) + b = FI.permuteddims(a, (2, 1)) + @test b ≡ a +end diff --git a/test/test_permuteddims.jl b/test/test_permuteddims.jl index 07e4a4d..fe1d777 100644 --- a/test/test_permuteddims.jl +++ b/test/test_permuteddims.jl @@ -1,25 +1,10 @@ -import FillArrays as FA import FunctionImplementations as FI -import LinearAlgebra as LA using Test: @test, @testset @testset "permuteddims" begin - @testset "Array" begin - a = randn(2, 3) - b = FI.permuteddims(a, (2, 1)) - @test b ≡ PermutedDimsArray(a, (2, 1)) - @test size(b) == (3, 2) - @test b == permutedims(a, (2, 1)) - end - @testset "LinearAlgebra.Diagonal" begin - a = LA.Diagonal(randn(3)) - b = FI.permuteddims(a, (2, 1)) - @test b ≡ a - end - - @testset "FillArrays.RectDiagonal" begin - a = FA.RectDiagonal(randn(3), (3, 4)) - @test FI.permuteddims(a, (1, 2)) ≡ a - @test FI.permuteddims(a, (2, 1)) ≡ FA.RectDiagonal(parent(a), (4, 3)) - end + a = randn(2, 3) + b = FI.permuteddims(a, (2, 1)) + @test b ≡ PermutedDimsArray(a, (2, 1)) + @test size(b) == (3, 2) + @test b == permutedims(a, (2, 1)) end diff --git a/test/test_stridedviewsext.jl b/test/test_stridedviewsext.jl new file mode 100644 index 0000000..2814c6f --- /dev/null +++ b/test/test_stridedviewsext.jl @@ -0,0 +1,11 @@ +import FunctionImplementations as FI +import StridedViews as SV +using Test: @test, @testset + +@testset "StridedViewsExt" begin + a = SV.StridedView(randn(2, 3, 4)) + b = FI.permuteddims(a, (3, 2, 1)) + @test b isa SV.StridedView + @test size(b) == (4, 3, 2) + @test b ≡ permutedims(a, (3, 2, 1)) +end \ No newline at end of file From 2618ff500aaf0d492eb2652eaa624e6086dd00a9 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 9 Jan 2026 17:14:15 -0500 Subject: [PATCH 2/2] Format --- test/test_stridedviewsext.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_stridedviewsext.jl b/test/test_stridedviewsext.jl index 2814c6f..b40d6df 100644 --- a/test/test_stridedviewsext.jl +++ b/test/test_stridedviewsext.jl @@ -8,4 +8,4 @@ using Test: @test, @testset @test b isa SV.StridedView @test size(b) == (4, 3, 2) @test b ≡ permutedims(a, (3, 2, 1)) -end \ No newline at end of file +end