Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
Expand All @@ -29,6 +30,7 @@ cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
TensorKitAdaptExt = "Adapt"
TensorKitCUDAExt = ["CUDA", "cuTENSOR"]
TensorKitChainRulesCoreExt = "ChainRulesCore"
TensorKitEnzymeExt = "Enzyme"
TensorKitFiniteDifferencesExt = "FiniteDifferences"
TensorKitMooncakeExt = "Mooncake"

Expand All @@ -41,6 +43,7 @@ CUDA = "5.9"
ChainRulesCore = "1"
ChainRulesTestUtils = "1"
Combinatorics = "1"
Enzyme = "0.13.118"
FiniteDifferences = "0.12"
GPUArrays = "11.3.1"
JET = "0.9, 0.10, 0.11"
Expand Down Expand Up @@ -73,6 +76,8 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Expand All @@ -86,4 +91,4 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"

[targets]
test = ["ArgParse", "Adapt", "Aqua", "AllocCheck", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote", "Mooncake", "JET"]
test = ["ArgParse", "Adapt", "Aqua", "AllocCheck", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote", "Mooncake", "Enzyme", "EnzymeTestUtils", "JET"]
21 changes: 21 additions & 0 deletions ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
module TensorKitEnzymeExt

using Enzyme
using TensorKit
import TensorKit as TK
using VectorInterface
using TensorOperations: TensorOperations, IndexTuple, Index2Tuple, linearize
import TensorOperations as TO
using MatrixAlgebraKit
using TupleTools
using Random: AbstractRNG

include("utility.jl")
include("linalg.jl")
include("vectorinterface.jl")
include("tensoroperations.jl")
include("factorizations.jl")
include("indexmanipulations.jl")
#include("planaroperations.jl")

end
134 changes: 134 additions & 0 deletions ext/TensorKitEnzymeExt/factorizations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(MatrixAlgebraKit.copy_input)},
::Type{RT},
cache,
f::Annotation,
A::Annotation{<:AbstractTensorMap}
) where {RT}
copy_shadow = cache
if !isa(A, Const) && !isnothing(copy_shadow)
add!(A.dval, copy_shadow)
end
return (nothing, nothing)
end

for (f, pb) in (
(:eig_full, :(MatrixAlgebraKit.eig_pullback!)),
(:eigh_full, :(MatrixAlgebraKit.eigh_pullback!)),
(:lq_compact, :(MatrixAlgebraKit.lq_pullback!)),
(:qr_compact, :(MatrixAlgebraKit.qr_pullback!)),
)
@eval begin
function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof($f)},
::Type{RT},
A::Annotation{<:AbstractTensorMap},
alg::Const,
) where {RT}
ret = $f(A.val, alg.val)
dret = make_zero(ret)
cache = (ret, dret)
return EnzymeRules.AugmentedReturn(ret, dret, cache)
end
function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof($f)},
::Type{RT},
cache,
A::Annotation{<:AbstractTensorMap},
alg::Const,
) where {RT}
ret, dret = cache
$pb(A.dval, A.val, ret, dret)
return (nothing, nothing)
end
end
end

for f in (:svd_compact, :svd_full)
@eval begin
function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof($f)},
::Type{RT},
A::Annotation{<:AbstractTensorMap},
alg::Const,
) where {RT}
USVᴴ = $f(A.val, alg.val)
dUSVᴴ = make_zero(USVᴴ)
cache = (USVᴴ, dUSVᴴ)
return EnzymeRules.AugmentedReturn(USVᴴ, dUSVᴴ, cache)
end
function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof($f)},
::Type{RT},
cache,
A::Annotation{<:AbstractTensorMap},
alg::Const,
) where {RT}
USVᴴ, dUSVᴴ = cache
MatrixAlgebraKit.svd_pullback!(A.dval, A.val, USVᴴ, dUSVᴴ)
return (nothing, nothing)
end
end

# mutating version is not guaranteed to actually mutate
# so we can simply use the non-mutating version instead
f! = Symbol(f, :!)
#=@eval begin
function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof($f!)},
::Type{RT},
A::Annotation{<:AbstractTensorMap},
USVᴴ::Annotation,
alg::Const,
) where {RT}
EnzymeRules.augmented_primal(func, RT, A, alg)
end
function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof($f!)},
::Type{RT},
cache,
A::Annotation{<:AbstractTensorMap},
USVᴴ::Annotation,
alg::Const,
) where {RT}
EnzymeRules.reverse(func, RT, A, alg)
end
end=# #hmmmm
end

# TODO
#=
function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(svd_trunc)},
::Type{RT},
A::Annotation{<:AbstractTensorMap},
alg::Const,
) where {RT}

USVᴴ = svd_compact(A.val, alg.val.alg)
USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.val.trunc)
ϵ = MatrixAlgebraKit.truncation_error(diagview(USVᴴ[2]), ind)
dUSVᴴtrunc = make_zero(USVᴴtrunc)
cache = (USVᴴtrunc, dUSVᴴtrunc)
return EnzymeRules.AugmentedReturn(USVᴴtrunc, dUSVᴴtrunc, cache)
end
function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(svd_trunc)},
::Type{RT},
cache,
A::Annotation{<:AbstractTensorMap},
alg::Const,
) where {RT}
USVᴴ, dUSVᴴ = cache
MatrixAlgebraKit.svd_pullback!(A.dval, A.val, USVᴴ, dUSVᴴ)
return (nothing, nothing)
end=#
Loading
Loading