Skip to content

Commit 41e9512

Browse files
authored
permuteddims (#4)
1 parent 41fbf4b commit 41e9512

File tree

7 files changed

+57
-1
lines changed

7 files changed

+57
-1
lines changed

Project.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
name = "FunctionImplementations"
22
uuid = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c"
33
authors = ["ITensor developers <support@itensor.org> and contributors"]
4-
version = "0.2.0"
4+
version = "0.2.1"
5+
6+
[weakdeps]
7+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8+
9+
[extensions]
10+
FunctionImplementationsLinearAlgebraExt = "LinearAlgebra"
511

612
[compat]
13+
LinearAlgebra = "1.10"
714
julia = "1.10"
815

916
[workspace]
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
module FunctionImplementationsLinearAlgebraExt
2+
3+
import FunctionImplementations as FI
4+
import LinearAlgebra as LA
5+
6+
struct DiagonalStyle <: FI.AbstractMatrixStyle end
7+
FI.Style(::Type{<:LA.Diagonal}) = DiagonalStyle()
8+
const permuteddims_diag = DiagonalStyle()(FI.permuteddims)
9+
function permuteddims_diag(a::AbstractArray, perm)
10+
(ndims(a) == length(perm) && isperm(perm)) ||
11+
throw(ArgumentError("no valid permutation of dimensions"))
12+
return a
13+
end
14+
15+
end

src/FunctionImplementations.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@ module FunctionImplementations
22

33
include("implementation.jl")
44
include("style.jl")
5+
include("permuteddims.jl")
56

67
end

src/permuteddims.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# See: https://github.com/JuliaLang/julia/issues/53188
2+
"""
3+
permuteddims(a::AbstractArray, perm)
4+
5+
Lazy version of `permutedims`. Defaults to constructing a `Base.PermutedDimsArray`
6+
but can be customized to output a different type of array.
7+
"""
8+
permuteddims(a::AbstractArray, perm) = style(a)(permuteddims)(a, perm)
9+
function (::Implementation{typeof(permuteddims)})(a::AbstractArray, perm)
10+
return PermutedDimsArray(a, perm)
11+
end

src/style.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ define binary [`Style`](@ref) rules to control the output type.
5454
See also [`FunctionImplementations.DefaultArrayStyle`](@ref).
5555
"""
5656
abstract type AbstractArrayStyle{N} <: Style end
57+
abstract type AbstractVectorStyle <: AbstractArrayStyle{1} end
58+
abstract type AbstractMatrixStyle <: AbstractArrayStyle{2} end
5759

5860
"""
5961
`FunctionImplementations.DefaultArrayStyle{N}()` is a [`FunctionImplementations.Style`](@ref) indicating that an object

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[deps]
22
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
33
FunctionImplementations = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c"
4+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
45
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
56
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
67
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
@@ -11,6 +12,7 @@ FunctionImplementations = {path = ".."}
1112
[compat]
1213
Aqua = "0.8"
1314
FunctionImplementations = "0.2"
15+
LinearAlgebra = "1.10"
1416
SafeTestsets = "0.1"
1517
Suppressor = "0.2"
1618
Test = "1.10"

test/test_permuteddims.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import FunctionImplementations as FI
2+
import LinearAlgebra as LA
3+
using Test: @test, @testset
4+
5+
@testset "permuteddims" begin
6+
@testset "Array" begin
7+
a = randn(2, 3)
8+
b = FI.permuteddims(a, (2, 1))
9+
@test b PermutedDimsArray(a, (2, 1))
10+
@test size(b) == (3, 2)
11+
@test b == permutedims(a, (2, 1))
12+
end
13+
@testset "Diagonal" begin
14+
a = LA.Diagonal(randn(3))
15+
b = FI.permuteddims(a, (2, 1))
16+
@test b a
17+
end
18+
end

0 commit comments

Comments
 (0)