diff --git a/Project.toml b/Project.toml index 7e82876c5..35c8b21af 100644 --- a/Project.toml +++ b/Project.toml @@ -21,6 +21,7 @@ VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" @@ -29,6 +30,7 @@ cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" TensorKitAdaptExt = "Adapt" TensorKitCUDAExt = ["CUDA", "cuTENSOR"] TensorKitChainRulesCoreExt = "ChainRulesCore" +TensorKitEnzymeExt = "Enzyme" TensorKitFiniteDifferencesExt = "FiniteDifferences" TensorKitMooncakeExt = "Mooncake" @@ -41,6 +43,7 @@ CUDA = "5.9" ChainRulesCore = "1" ChainRulesTestUtils = "1" Combinatorics = "1" +Enzyme = "0.13.118" FiniteDifferences = "0.12" GPUArrays = "11.3.1" JET = "0.9, 0.10, 0.11" @@ -73,6 +76,8 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" @@ -86,4 +91,4 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" [targets] -test = ["ArgParse", "Adapt", "Aqua", "AllocCheck", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote", "Mooncake", "JET"] +test = ["ArgParse", "Adapt", "Aqua", "AllocCheck", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote", "Mooncake", "Enzyme", "EnzymeTestUtils", "JET"] diff --git a/ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl b/ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl new file mode 100644 index 000000000..ab3061795 --- /dev/null +++ b/ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl @@ -0,0 +1,21 @@ +module TensorKitEnzymeExt + +using Enzyme +using TensorKit +import TensorKit as TK +using VectorInterface +using TensorOperations: TensorOperations, IndexTuple, Index2Tuple, linearize +import TensorOperations as TO +using MatrixAlgebraKit +using TupleTools +using Random: AbstractRNG + +include("utility.jl") +include("linalg.jl") +include("vectorinterface.jl") +include("tensoroperations.jl") +include("factorizations.jl") +include("indexmanipulations.jl") +#include("planaroperations.jl") + +end diff --git a/ext/TensorKitEnzymeExt/factorizations.jl b/ext/TensorKitEnzymeExt/factorizations.jl new file mode 100644 index 000000000..4314f7900 --- /dev/null +++ b/ext/TensorKitEnzymeExt/factorizations.jl @@ -0,0 +1,134 @@ +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(MatrixAlgebraKit.copy_input)}, + ::Type{RT}, + cache, + f::Annotation, + A::Annotation{<:AbstractTensorMap} + ) where {RT} + copy_shadow = cache + if !isa(A, Const) && !isnothing(copy_shadow) + add!(A.dval, copy_shadow) + end + return (nothing, nothing) +end + +for (f, pb) in ( + (:eig_full, :(MatrixAlgebraKit.eig_pullback!)), + (:eigh_full, :(MatrixAlgebraKit.eigh_pullback!)), + (:lq_compact, :(MatrixAlgebraKit.lq_pullback!)), + (:qr_compact, :(MatrixAlgebraKit.qr_pullback!)), + ) + @eval begin + function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + alg::Const, + ) where {RT} + ret = $f(A.val, alg.val) + dret = make_zero(ret) + cache = (ret, dret) + return EnzymeRules.AugmentedReturn(ret, dret, cache) + end + function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + cache, + A::Annotation{<:AbstractTensorMap}, + alg::Const, + ) where {RT} + ret, dret = cache + $pb(A.dval, A.val, ret, dret) + return (nothing, nothing) + end + end +end + +for f in (:svd_compact, :svd_full) + @eval begin + function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + alg::Const, + ) where {RT} + USVᴴ = $f(A.val, alg.val) + dUSVᴴ = make_zero(USVᴴ) + cache = (USVᴴ, dUSVᴴ) + return EnzymeRules.AugmentedReturn(USVᴴ, dUSVᴴ, cache) + end + function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + cache, + A::Annotation{<:AbstractTensorMap}, + alg::Const, + ) where {RT} + USVᴴ, dUSVᴴ = cache + MatrixAlgebraKit.svd_pullback!(A.dval, A.val, USVᴴ, dUSVᴴ) + return (nothing, nothing) + end + end + + # mutating version is not guaranteed to actually mutate + # so we can simply use the non-mutating version instead + f! = Symbol(f, :!) + #=@eval begin + function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f!)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + USVᴴ::Annotation, + alg::Const, + ) where {RT} + EnzymeRules.augmented_primal(func, RT, A, alg) + end + function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f!)}, + ::Type{RT}, + cache, + A::Annotation{<:AbstractTensorMap}, + USVᴴ::Annotation, + alg::Const, + ) where {RT} + EnzymeRules.reverse(func, RT, A, alg) + end + end=# #hmmmm +end + +# TODO +#= +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(svd_trunc)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + alg::Const, + ) where {RT} + + USVᴴ = svd_compact(A.val, alg.val.alg) + USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.val.trunc) + ϵ = MatrixAlgebraKit.truncation_error(diagview(USVᴴ[2]), ind) + dUSVᴴtrunc = make_zero(USVᴴtrunc) + cache = (USVᴴtrunc, dUSVᴴtrunc) + return EnzymeRules.AugmentedReturn(USVᴴtrunc, dUSVᴴtrunc, cache) +end +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(svd_trunc)}, + ::Type{RT}, + cache, + A::Annotation{<:AbstractTensorMap}, + alg::Const, +) where {RT} + USVᴴ, dUSVᴴ = cache + MatrixAlgebraKit.svd_pullback!(A.dval, A.val, USVᴴ, dUSVᴴ) + return (nothing, nothing) +end=# diff --git a/ext/TensorKitEnzymeExt/indexmanipulations.jl b/ext/TensorKitEnzymeExt/indexmanipulations.jl new file mode 100644 index 000000000..02f52ef79 --- /dev/null +++ b/ext/TensorKitEnzymeExt/indexmanipulations.jl @@ -0,0 +1,435 @@ +for transform in (:permute, :transpose) + add_transform! = Symbol(:add_, transform, :!) + @eval function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TK.$add_transform!)}, + ::Type{RT}, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + p::Const{<:Index2Tuple}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + ba::Const... + ) where {RT} + C_cache = !isa(β, Const) && EnzymeRules.overwritten(config)[2] ? copy(C.val) : nothing + # if we need to compute Δa, it is faster to allocate an intermediate permuted A + # and store that instead of repeating the permutation in the pullback each time. + # effectively, we replace `add_permute` by `add ∘ permute`. + Ap = if !isa(α, Const) + Ap = $transform(A.val, p.val) + add!(C.val, Ap.val, α.val, β.val) + Ap + else + TK.$add_transform!(C.val, A.val, p.val, α.val, β.val, ba.val...) + nothing + end + cache = (C_cache, Ap) + primal = EnzymeRules.needs_primal(config) ? C.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? C.dval : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, cache) + end + @eval function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TK.$add_transform!)}, + ::Type{RT}, + cache, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + p::Const{<:Index2Tuple}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + ba::Const... + ) where {RT} + C_cache, Ap = cache + Cval = something(C_cache, C.val) + + # ΔA + ip = invperm(linearize(p.val)) + pΔA = _repartition(ip, A.val) + TC = VectorInterface.promote_scale(C.dval, α.val) + if scalartype(A.dval) <: Real && !(TC <: Real) + ΔAc = TO.tensoralloc_add(TC, C.dval, pΔA, false, Val(false)) + TK.$add_transform!(ΔAc, C.dval, pΔA, conj(α.val), Zero(), ba.val...) + add!(A.dval, real(ΔAc)) + else + TK.$add_transform!(A.dval, C.dval, pΔA, conj(α.val), One(), ba.val...) + end + Δαr = isnothing(Ap) ? nothing : project_scalar(α.val, inner(Ap, C.dval)) + Δβr = pullback_dβ(C.dval, Cval, β) + pullback_dC!(C.dval, β.val) # this typically returns nothing + + return nothing, nothing, nothing, nothing, Δαr, Δβr, map(Returns(nothing), ba)... + end +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TK.add_braid!)}, + ::Type{RT}, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + p::Const{<:Index2Tuple}, + levels::Const{<:Index2Tuple}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + ba::Const... + ) where {RT} + C_cache = !isa(β, Const) && EnzymeRules.overwritten(config)[2] ? copy(C.val) : nothing + # if we need to compute Δa, it is faster to allocate an intermediate braided A + # and store that instead of repeating the permutation in the pullback each time. + # effectively, we replace `add_permute` by `add ∘ permute`. + Ap = if !isa(α, Const) + Ap = braid(A.val, p.val, levels.val) + add!(C.val, Ap, α.val, β.val) + Ap + else + TK.add_braid!(C.val, A.val, p.val, levels.val, α.val, β.val, ba.val...) + nothing + end + cache = (C_cache, Ap) + primal = EnzymeRules.needs_primal(config) ? C.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? C.dval : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TK.add_braid!)}, + ::Type{RT}, + cache, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + p::Const{<:Index2Tuple}, + levels::Const{<:Index2Tuple}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + ba::Const... + ) where {RT} + C_cache, Ap = cache + Cval = something(C_cache, C.val) + # ΔA + ip = invperm(linearize(p.val)) + pΔA = _repartition(ip, A.val) + ilevels = TupleTools.permute(levels.val, linearize(p.val)) + TC = VectorInterface.promote_scale(C.dval, α.val) + if scalartype(A.dval) <: Real && !(TC <: Real) + ΔAc = TO.tensoralloc_add(TC, C.dval, pΔA, false, Val(false)) + TK.add_braid!(ΔAc, C.dval, pΔA, ilevels, conj(α.val), Zero(), ba.val...) + add!(A.dval, real(ΔAc)) + else + TK.add_braid!(A.dval, C.dval, pΔA, ilevels, conj(α.val), One(), ba.val...) + end + Δαr = isnothing(Ap) ? nothing : project_scalar(α.val, inner(Ap, C.dval)) + Δβr = pullback_dβ(C.dval, C.val, β) + pullback_dC!(C.dval, β.val) # this typically returns nothing + return nothing, nothing, nothing, nothing, nothing, Δαr, Δβr, map(Returns(nothing), ba)... +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(twist!)}, + ::Type{RT}, + t::Annotation{<:AbstractTensorMap}, + inds::Const, + ) where {RT} + twist!(t.val, inds.val; inv = false) + primal = EnzymeRules.needs_primal(config) ? t.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? t.dval : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(twist!)}, + ::Type{RT}, + cache, + t::Annotation{<:AbstractTensorMap}, + inds::Const, + ) where {RT} + twist!(t.dval, inds.val; inv = true) + return (nothing, nothing) +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(Core.kwcall)}, + ::Type{RT}, + kwargs::Const{@NamedTuple{inv::Bool}}, + ::Const{typeof(twist!)}, + t::Annotation{<:AbstractTensorMap}, + inds::Const, + ) where {RT} + inv = kwargs.val.inv + twist!(t.val, inds.val; inv) + primal = EnzymeRules.needs_primal(config) ? t.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? t.dval : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(twist!)}, + ::Type{RT}, + cache, + kwargs::Const{@NamedTuple{inv::Bool}}, + ::Const{typeof(twist!)}, + t::Annotation{<:AbstractTensorMap}, + inds::Const, + ) where {RT} + inv = kwargs.val.inv + twist!(t.dval, inds.val; inv = !inv) + return (nothing, nothing, nothing, nothing) +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(flip)}, + ::Type{RT}, + t::Annotation{<:AbstractTensorMap}, + inds::Const, + ) where {RT} + t′ = flip(t.val, inds.val; inv = false) + dt′ = make_zero(t′) + primal = EnzymeRules.needs_primal(config) ? t′ : nothing + shadow = EnzymeRules.needs_shadow(config) ? dt′ : nothing + cache = dt′ + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(flip)}, + ::Type{RT}, + cache, + t::Annotation{<:AbstractTensorMap}, + inds::Const, + ) where {RT} + dt′ = cache + dt′′ = flip(dt′, inds.val; inv = true) + add!(t.dval, scalartype(t.dval) <: Real ? real(dt′′) : dt′′) + return (nothing, nothing) +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(Core.kwcall)}, + ::Type{RT}, + kwargs::Const{@NamedTuple{inv::Bool}}, + ::Const{typeof(flip)}, + t::Annotation{<:AbstractTensorMap}, + inds::Const, + ) where {RT} + inv = kwargs.val.inv + t′ = flip(t.val, inds.val; inv = inv) + dt′ = make_zero(t′) + primal = EnzymeRules.needs_primal(config) ? t′ : nothing + shadow = EnzymeRules.needs_shadow(config) ? dt′ : nothing + cache = dt′ + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(Core.kwcall)}, + ::Type{RT}, + cache, + kwargs::Const{@NamedTuple{inv::Bool}}, + ::Const{typeof(flip)}, + t::Annotation{<:AbstractTensorMap}, + inds::Const, + ) where {RT} + inv = kwargs.val.inv + dt′ = cache + dt′′ = flip(dt′, inds.val; inv = !inv) + add!(t.dval, scalartype(t.dval) <: Real ? real(dt′′) : dt′′) + return (nothing, nothing, nothing, nothing) +end + +for insertunit in (:insertleftunit, :insertrightunit) + @eval begin + function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($insertunit)}, + ::Type{RT}, + tsrc::Annotation{<:AbstractTensorMap}, + ival::Const{<:Val}, + ) where {RT} + tdst = $insertunit(tsrc.val, ival.val) + Δtdst = make_zero(tdst) + primal = EnzymeRules.needs_primal(config) ? tdst : nothing + shadow = EnzymeRules.needs_shadow(config) ? Δtdst : nothing + cache = (nothing, tdst, Δtdst) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) + end + # tdst shares data with tsrc if <:TensorMap, in this case we have to deal with correctly + # sharing address spaces + function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($insertunit)}, + ::Type{RT}, + tsrc::Annotation{<:TensorMap}, + ival::Const{<:Val}, + ) where {RT} + tsrc_cache = copy(tsrc.val) + tdst = $insertunit(tsrc.val, ival.val) + Δtdst = $insertunit(tsrc.dval, ival.val) + primal = EnzymeRules.needs_primal(config) ? tdst : nothing + shadow = EnzymeRules.needs_shadow(config) ? Δtdst : nothing + return cache = (tsrc_cache, tdst, Δtdst) + end + function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($insertunit)}, + ::Type{RT}, + cache, + tsrc::Annotation{<:AbstractTensorMap}, + ival::Const{<:Val}, + ) where {RT} + tsrc_cache, tdst, Δtdst = cache + # note: since data is already shared for <:TensorMap, don't have to do anything here! + if isnothing(tsrc_cache) + for (c, b) in blocks(Δtdst) + add!(block(tsrc.dval, c), b) + end + end + return (nothing, nothing) + end + function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(Core.kwcall)}, + ::Type{RT}, + kwargs::Const{<:NamedTuple}, + ::Const{typeof($insertunit)}, + tsrc::Annotation{<:AbstractTensorMap}, + ival::Const{<:Val}, + ) where {RT} + if tsrc.val isa TensorMap && !get(kwargs.val, :copy, false) + tsrc_cache = copy(tsrc.val) + tdst = $insertunit(tsrc.val, ival.val; kwargs.val...) + Δtdst = $insertunit(tsrc.dval, ival.val; kwargs.val...) + else + tsrc_cache = nothing + tdst = $insertunit(tsrc.val, ival.val; kwargs.val...) + Δtdst = make_zero(tdst) + end + primal = EnzymeRules.needs_primal(config) ? tdst : nothing + shadow = EnzymeRules.needs_shadow(config) ? Δtdst : nothing + cache = (tsrc_cache, tdst, Δtdst) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) + end + function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(Core.kwcall)}, + ::Type{RT}, + cache, + kwargs::Const{<:NamedTuple}, + ::Const{typeof($insertunit)}, + tsrc::Annotation{<:AbstractTensorMap}, + ival::Const{<:Val}, + ) where {RT} + tsrc_cache, tdst, Δtdst = cache + if isnothing(tsrc_cache) + for (c, b) in blocks(Δtdst) + add!(block(Δtsrc, c), b) + end + end + return (nothing, nothing, nothing, nothing) + end + end +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(removeunit)}, + ::Type{RT}, + tsrc::Annotation{<:AbstractTensorMap}, + ival::Const{<:Val}, + ) where {RT} + tsrc_cache = nothing + tdst = removeunit(tsrc.val, ival.val) + Δtdst = make_zero(tdst) + primal = EnzymeRules.needs_primal(config) ? tdst : nothing + shadow = EnzymeRules.needs_shadow(config) ? Δtdst : nothing + cache = (tsrc_cache, tdst, Δtdst) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(removeunit)}, + ::Type{RT}, + tsrc::Annotation{<:TensorMap}, + ival::Const{<:Val}, + ) where {RT} + # tdst shares data with tsrc if <:TensorMap, in this case we have to deal with correctly + # sharing address spaces + tsrc_cache = copy(tsrc.val) + tdst = removeunit(tsrc.val, ival.val) + Δtdst = removeunit(tsrc.dval, ival.val) + primal = EnzymeRules.needs_primal(config) ? tdst : nothing + shadow = EnzymeRules.needs_shadow(config) ? Δtdst : nothing + cache = (tsrc_cache, tdst, Δtdst) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(removeunit)}, + ::Type{RT}, + cache, + tsrc::Annotation{<:AbstractTensorMap}, + ival::Const{<:Val}, + ) where {RT} + tsrc_cache, tdst, Δtdst = cache + # note: since data for <: TensorMap is already shared, don't have to do anything here! + if isnothing(tsrc_cache) + for (c, b) in blocks(Δtdst) + add!(block(tsrc.dval, c), b) + end + end + return (nothing, nothing) +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + ::Const{typeof(Core.kwcall)}, + ::Type{RT}, + kwargs::Const{<:NamedTuple}, + ::Const{typeof(removeunit)}, + tsrc::Annotation{<:AbstractTensorMap}, + ival::Const{<:Val}, + ) where {RT} + # tdst shares data with tsrc if <:TensorMap & copy=false, in this case we have to deal with correctly + # sharing address spaces + if tsrc.val isa TensorMap && !get(kwargs.val, :copy, false) + tsrc_cache = copy(tsrc.val) + Δtdst = removeunit(tsrc.dval, ival.val) + else + tsrc_cache = nothing + Δtdst = make_zero(tdst) + end + tdst = removeunit(tsrc.val, ival.val; kwargs.val...) + primal = EnzymeRules.needs_primal(config) ? tdst : nothing + shadow = EnzymeRules.needs_shadow(config) ? Δtdst : nothing + cache = (tsrc_cache, tdst, Δtdst) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + ::Const{typeof(Core.kwcall)}, + ::Type{RT}, + cache, + kwargs::Const{<:NamedTuple}, + ::Const{typeof(removeunit)}, + tsrc::Annotation{<:AbstractTensorMap}, + ival::Const{<:Val}, + ) where {RT} + tsrc_cache, tdst, Δtdst = cache + # note: since data for <: TensorMap is already shared, don't have to do anything here! + if isnothing(tsrc_cache) + for (c, b) in blocks(Δtdst) + add!(block(tsrc.dval, c), b) + end + end + return (nothing, nothing, nothing, nothing) +end diff --git a/ext/TensorKitEnzymeExt/linalg.jl b/ext/TensorKitEnzymeExt/linalg.jl new file mode 100644 index 000000000..c08f5ba03 --- /dev/null +++ b/ext/TensorKitEnzymeExt/linalg.jl @@ -0,0 +1,140 @@ +# Shared +# ------ +pullback_dC!(ΔC, β) = scale!(ΔC, conj(β)) +pullback_dβ(ΔC, C, β) = !isa(β, Const) ? project_scalar(β.val, inner(C, ΔC)) : nothing + +# Can Enzyme do this itself? Apparently not... +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(mul!)}, + ::Type{RT}, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + B::Annotation{<:AbstractTensorMap}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + ) where {RT} + cacheC = copy(C.val) + AB = if !isa(α, Const) + AB = A.val * B.val + add!(C.val, AB, α.val, β.val) + AB + else + mul!(C.val, A.val, B.val, α.val, β.val) + nothing + end + primal = EnzymeRules.needs_primal(config) ? C.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? C.dval : nothing + cacheA = EnzymeRules.overwritten(config)[3] ? copy(A.val) : nothing + cacheB = EnzymeRules.overwritten(config)[4] ? copy(B.val) : nothing + + cache = (cacheC, cacheA, cacheB, AB) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(mul!)}, + ::Type{<:Const}, + C::Const{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + B::Annotation{<:AbstractTensorMap}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + ) + mul!(C.val, A.val, B.val, α.val, β.val) + primal = EnzymeRules.needs_primal(config) ? C.val : nothing + return EnzymeRules.AugmentedReturn(primal, nothing, nothing) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(mul!)}, + ::Type{RT}, + cache, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + B::Annotation{<:AbstractTensorMap}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + ) where {RT} + if RT <: Const + Δα = isa(α, Const) ? nothing : zero(α.val) + Δβ = isa(β, Const) ? nothing : zero(β.val) + return (nothing, nothing, nothing, Δα, Δβ) + end + cacheC, cacheA, cacheB, AB = cache + Cval = cacheC + Aval = something(cacheA, A.val) + Bval = something(cacheB, B.val) + + !isa(A, Const) && project_mul!(A.dval, C.dval, Bval', conj(α.val)) + !isa(B, Const) && project_mul!(B.dval, Aval', C.dval, conj(α.val)) + Δαr = isnothing(AB) ? nothing : project_scalar(α.val, inner(AB, C.dval)) + Δβr = pullback_dβ(C.dval, Cval, β) + pullback_dC!(C.dval, β.val) + + return (nothing, nothing, nothing, Δαr, Δβr) +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(tr)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + ) where {RT} + ret = func.val(A.val) + primal = EnzymeRules.needs_primal(config) ? ret : nothing + cache = EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing + return EnzymeRules.AugmentedReturn(primal, nothing, cache) +end +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(tr)}, + dret::Active, + cache, + A::Annotation{<:AbstractTensorMap}, + ) + Aval = something(cache, A.val) + Δtrace = dret.val + if !isa(A, Const) + for (_, b) in blocks(A.dval) + TensorKit.diagview(b) .+= Δtrace + end + end + return (nothing,) +end +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(tr)}, + ::Type{<:Const}, + cache, + A::Annotation{<:AbstractTensorMap}, + ) + return (nothing,) +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(inv)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + ) where {RT} + ret = func.val(A.val) + primal = EnzymeRules.needs_primal(config) ? ret : nothing + shadow = EnzymeRules.needs_shadow(config) ? zero(ret) : nothing + cache = (copy(ret), shadow) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(inv)}, + ::Type{RT}, + cache, + A::Annotation{<:AbstractTensorMap}, + ) where {RT} + Ainv, ΔAinv = cache + !isa(A, Const) && mul!(A.dval, Ainv' * ΔAinv, Ainv', -1, One()) + return (nothing,) +end diff --git a/ext/TensorKitEnzymeExt/planaroperations.jl b/ext/TensorKitEnzymeExt/planaroperations.jl new file mode 100644 index 000000000..effb4683c --- /dev/null +++ b/ext/TensorKitEnzymeExt/planaroperations.jl @@ -0,0 +1,80 @@ +# planartrace! +# ------------ +# TODO: Fix planartrace pullback +# This implementation is slightly more involved than its non-planar counterpart +# this is because we lack a general `pAB` argument in `planarcontract`, and need +# to keep things planar along the way. +# In particular, we can't simply tensor product with multiple identities in one go +# if they aren't "contiguous", e.g. p = ((1, 4, 5), ()), q = ((2, 6), (3, 7)) + +# function Mooncake.rrule!!( +# ::CoDual{typeof(TensorKit.planartrace!)}, +# C_ΔC::CoDual{<:AbstractTensorMap}, +# A_ΔA::CoDual{<:AbstractTensorMap}, p_Δp::CoDual{<:Index2Tuple}, q_Δq::CoDual{<:Index2Tuple}, +# α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}, +# backend_Δbackend::CoDual, allocator_Δallocator::CoDual +# ) +# # prepare arguments +# C, ΔC = arrayify(C_ΔC) +# A, ΔA = arrayify(A_ΔA) +# p = primal(p_Δp) +# q = primal(q_Δq) +# α, β = primal.((α_Δα, β_Δβ)) +# backend, allocator = primal.((backend_Δbackend, allocator_Δallocator)) +# +# # primal call +# C_cache = copy(C) +# TensorKit.planartrace!(C, A, p, q, α, β, backend, allocator) +# +# function planartrace_pullback(::NoRData) +# copy!(C, C_cache) +# +# ΔAr = planartrace_pullback_ΔA!(ΔA, ΔC, A, p, q, α, backend, allocator) # this typically returns NoRData() +# Δαr = planartrace_pullback_Δα(ΔC, A, p, q, α, backend, allocator) +# Δβr = pullback_dβ(ΔC, C, β) +# ΔCr = pullback_dC!(ΔC, β) # this typically returns NoRData() +# +# return NoRData(), +# ΔCr, ΔAr, NoRData(), NoRData(), +# Δαr, Δβr, NoRData(), NoRData() +# end +# +# return C_ΔC, planartrace_pullback +# end + +# function planartrace_pullback_dA!( +# ΔA, ΔC, A, p, q, α, backend, allocator +# ) +# if length(q[1]) == 0 +# ip = invperm(linearize(p)) +# pΔA = _repartition(ip, A) +# TK.add_transpose!(ΔA, ΔC, pΔA, conj(α), One(), backend, allocator) +# return NoRData() +# end +# # if length(q[1]) == 1 +# # ip = invperm((p[1]..., q[2]..., p[2]..., q[1]...)) +# # pdA = _repartition(ip, A) +# # E = one!(TO.tensoralloc_add(scalartype(A), A, q, false)) +# # twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E))) +# # # pE = ((), trivtuple(TO.numind(q))) +# # # pΔC = (trivtuple(TO.numind(p)), ()) +# # TensorKit.planaradd!(ΔA, ΔC ⊗ E, pdA, conj(α), One(), backend, allocator) +# # return NoRData() +# # end +# error("The reverse rule for `planartrace` is not yet implemented") +# end +# +# function planartrace_pullback_dα( +# ΔC, A, p, q, α, backend, allocator +# ) +# Tdα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α))) +# Tdα === NoRData && return NoRData() +# +# # TODO: this result might be easier to compute as: +# # C′ = βC + α * trace(A) ⟹ At = (C′ - βC) / α +# At = TO.tensoralloc_add(scalartype(A), A, p, false, Val(true), allocator) +# TensorKit.planartrace!(At, A, p, q, One(), Zero(), backend, allocator) +# Δα = project_scalar(α, inner(At, ΔC)) +# TO.tensorfree!(At, allocator) +# return Δα +# end diff --git a/ext/TensorKitEnzymeExt/tensoroperations.jl b/ext/TensorKitEnzymeExt/tensoroperations.jl new file mode 100644 index 000000000..661fc661e --- /dev/null +++ b/ext/TensorKitEnzymeExt/tensoroperations.jl @@ -0,0 +1,213 @@ +# tensorcontract! +# --------------- +# TODO: it might be beneficial to compare here if it would make sense to simply compute the +# rrule of permute-permute-gemm-permute, rather than using the contractions directly. +# This could possibly out save some permutations being carried out twice, at the cost of having +# to store some more intermediate objects. +# For example, the combination `ΔC, pΔC, false` appears in the pullback for ΔA and ΔB, so effectively +# this permutation is done multiple times. + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TensorKit.blas_contract!)}, + ::Type{RT}, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + pA::Const{<:Index2Tuple}, + B::Annotation{<:AbstractTensorMap}, + pB::Const{<:Index2Tuple}, + pAB::Const{<:Index2Tuple}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + backend::Const, + allocator::Const + ) where {RT} + Ccache = isa(β, Const) ? nothing : copy(C.val) + A_needs_cache = EnzymeRules.overwritten(config)[3] && !(typeof(B) <: Const) && !(typeof(C) <: Const) + Acache = A_needs_cache ? copy(A.val) : nothing + B_needs_cache = EnzymeRules.overwritten(config)[5] && !(typeof(A) <: Const) && !(typeof(C) <: Const) + Bcache = B_needs_cache ? copy(B.val) : nothing + AB = if !isa(α, Const) + AB = TO.tensorcontract(A.val, pA.val, false, B.val, pB.val, false, pAB.val, One(), backend.val, allocator.val) + add!(C.val, AB, α.val, β.val) + AB + else + TensorKit.blas_contract!(C.val, A.val, pA.val, B.val, pB.val, pAB.val, α.val, β.val, backend.val, allocator.val) + nothing + end + primal = EnzymeRules.needs_primal(config) ? C.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? C.dval : nothing + cache = (Ccache, Acache, Bcache, AB) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TensorKit.blas_contract!)}, + ::Type{RT}, + cache, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + pA::Const{<:Index2Tuple}, + B::Annotation{<:AbstractTensorMap}, + pB::Const{<:Index2Tuple}, + pAB::Const{<:Index2Tuple}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + backend::Const, + allocator::Const + ) where {RT} + cacheC, cacheA, cacheB, AB = cache + Cval = cacheC + Aval = something(cacheA, A.val) + Bval = something(cacheB, B.val) + + Δα = isnothing(AB) ? nothing : project_scalar(α.val, inner(AB, C.dval)) + Δβ = isa(β, Const) ? nothing : pullback_dβ(C.dval, Cval, β) + + if !isa(A, Const) + blas_contract_pullback_ΔA!( + A.dval, C.dval, Aval, pA.val, Bval, pB.val, pAB.val, α.val, backend.val, allocator.val + ) # this typically returns nothing + end + if !isa(B, Const) + blas_contract_pullback_ΔB!( + B.dval, C.dval, Aval, pA.val, Bval, pB.val, pAB.val, α.val, backend.val, allocator.val + ) # this typically returns nothing + end + pullback_dC!(C.dval, β.val) # this typically returns nothing + return nothing, nothing, nothing, nothing, nothing, nothing, Δα, Δβ, nothing, nothing +end + +function blas_contract_pullback_ΔA!( + ΔA, ΔC, A, pA, B, pB, pAB, α, backend, allocator + ) + ipAB = invperm(linearize(pAB)) + pΔC = _repartition(ipAB, TO.numout(pA)) + ipA = _repartition(invperm(linearize(pA)), A) + + tB = twist( + B, + TupleTools.vcat( + filter(x -> !isdual(space(B, x)), pB[1]), + filter(x -> isdual(space(B, x)), pB[2]) + ); copy = false + ) + + project_contract!( + ΔA, + ΔC, pΔC, false, + tB, reverse(pB), true, + ipA, conj(α), backend, allocator + ) + + return nothing +end + +function blas_contract_pullback_ΔB!( + ΔB, ΔC, A, pA, B, pB, pAB, α, backend, allocator + ) + ipAB = invperm(linearize(pAB)) + pΔC = _repartition(ipAB, TO.numout(pA)) + ipB = _repartition(invperm(linearize(pB)), B) + + tA = twist( + A, + TupleTools.vcat( + filter(x -> isdual(space(A, x)), pA[1]), + filter(x -> !isdual(space(A, x)), pA[2]) + ); copy = false + ) + + project_contract!( + ΔB, + tA, reverse(pA), true, + ΔC, pΔC, false, + ipB, conj(α), backend, allocator + ) + + return nothing +end + + +# tensortrace! +# ------------ + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TensorKit.trace_permute!)}, + ::Type{RT}, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + p::Const{<:Index2Tuple}, + q::Const{<:Index2Tuple}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + backend::Const, + ) where {RT} + C_cache = !isa(β, Const) ? copy(C.val) : nothing + A_cache = EnzymeRules.overwritten(config)[3] ? copy(A.val) : nothing + At = if !isa(α, Const) + At = TO.tensortrace(A.val, p.val, q.val, false, One(), backend.val) + add!(C.val, At, α.val, β.val) + At + else + TensorKit.trace_permute!(C.val, A.val, p.val, q.val, α.val, β.val, backend.val) + nothing + end + primal = EnzymeRules.needs_primal(config) ? C.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? C.dval : nothing + cache = (C_cache, A_cache, At) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TensorKit.trace_permute!)}, + ::Type{RT}, + cache, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + p::Const{<:Index2Tuple}, + q::Const{<:Index2Tuple}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + backend::Const, + ) where {RT} + C_cache, A_cache, At = cache + Aval = something(A_cache, A.val) + Cval = something(C_cache, C.val) + !isa(A, Const) && !isa(C, Const) && trace_permute_pullback_ΔA!(A.dval, C.dval, Aval, p.val, q.val, α.val, backend.val) + Δαr = if !isa(C, Const) && !isnothing(At) + project_scalar(α.val, inner(At, C.dval)) + elseif !isnothing(At) + zero(α.val) + else + nothing + end + Δβr = if !isa(β, Const) && !isa(C, Const) + pullback_dβ(C.dval, Cval, β) + elseif !isa(β, Const) + zero(β.val) + else + nothing + end + !isa(C, Const) && pullback_dC!(C.dval, β.val) + return nothing, nothing, nothing, nothing, Δαr, Δβr, nothing +end + +function trace_permute_pullback_ΔA!( + ΔA, ΔC, A, p, q, α, backend + ) + ip = invperm((linearize(p)..., q[1]..., q[2]...)) + pdA = _repartition(ip, A) + E = one!(TO.tensoralloc_add(scalartype(A), A, q, false)) + twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E))) + pE = ((), trivtuple(TO.numind(q))) + pΔC = (trivtuple(TO.numind(p)), ()) + TO.tensorproduct!( + ΔA, ΔC, pΔC, false, E, pE, false, pdA, conj(α), One(), backend + ) + return nothing +end diff --git a/ext/TensorKitEnzymeExt/utility.jl b/ext/TensorKitEnzymeExt/utility.jl new file mode 100644 index 000000000..803432d2b --- /dev/null +++ b/ext/TensorKitEnzymeExt/utility.jl @@ -0,0 +1,67 @@ +#_needs_tangent(x) = _needs_tangent(typeof(x)) +#_needs_tangent(::Type{T}) where {T <: Number} = +# Mooncake.rdata_type(Mooncake.tangent_type(T)) !== NoRData + +# Projection +# ---------- +""" + project_scalar(x::Number, dx::Number) + +Project a computed tangent `dx` onto the correct tangent type for `x`. +For example, we might compute a complex `dx` but only require the real part. +""" +project_scalar(x::Number, dx::Number) = oftype(x, dx) +project_scalar(x::Real, dx::Complex) = project_scalar(x, real(dx)) + +# in-place multiplication and accumulation which might project to (real) +# TODO: this could probably be done without allocating +function project_mul!(C, A, B, α) + TC = TO.promote_contract(scalartype(A), scalartype(B), scalartype(α)) + return if !(TC <: Real) && scalartype(C) <: Real + add!(C, real(mul!(zerovector(C, TC), A, B, α))) + else + mul!(C, A, B, α, One()) + end +end +function project_contract!(C, A, pA, conjA, B, pB, conjB, pAB, α, backend, allocator) + TA = TensorKit.promote_permute(A) + TB = TensorKit.promote_permute(B) + TC = TO.promote_contract(TA, TB, scalartype(α)) + + return if scalartype(C) <: Real && !(TC <: Real) + add!(C, real(TO.tensorcontract!(zerovector(C, TC), A, pA, conjA, B, pB, conjB, pAB, α, Zero(), backend, allocator))) + else + TO.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, One(), backend, allocator) + end +end + +# IndexTuple utility +# ------------------ +trivtuple(N) = ntuple(identity, N) + +Base.@constprop :aggressive function _repartition(p::IndexTuple, N₁::Int) + length(p) >= N₁ || + throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)")) + return TupleTools.getindices(p, trivtuple(N₁)), + TupleTools.getindices(p, trivtuple(length(p) - N₁) .+ N₁) +end +Base.@constprop :aggressive function _repartition(p::Index2Tuple, N₁::Int) + return _repartition(linearize(p), N₁) +end +function _repartition(p::Union{IndexTuple, Index2Tuple}, ::Index2Tuple{N₁}) where {N₁} + return _repartition(p, N₁) +end +function _repartition(p::Union{IndexTuple, Index2Tuple}, t::AbstractTensorMap) + return _repartition(p, TensorKit.numout(t)) +end + +# Ignore derivatives +# ------------------ + +EnzymeRules.inactive(::typeof(TensorKit.fusionblockstructure), arg::Any) = nothing +EnzymeRules.inactive(::typeof(TensorKit.select), s::HomSpace, i::Index2Tuple) = nothing +EnzymeRules.inactive(::typeof(TensorKit.flip), s::HomSpace, i::Any) = nothing +EnzymeRules.inactive(::typeof(TensorKit.permute), s::HomSpace, i::Index2Tuple) = nothing +EnzymeRules.inactive(::typeof(TensorKit.braid), s::HomSpace, i::Index2Tuple, ::IndexTuple) = nothing +EnzymeRules.inactive(::typeof(TensorKit.compose), s1::HomSpace, s2::HomSpace) = nothing +EnzymeRules.inactive(::typeof(TensorOperations.tensorcontract), c::HomSpace, p::Index2Tuple, α::Bool, b::HomSpace, q::Index2Tuple, β::Bool, pq::Index2Tuple) = nothing diff --git a/ext/TensorKitEnzymeExt/vectorinterface.jl b/ext/TensorKitEnzymeExt/vectorinterface.jl new file mode 100644 index 000000000..c300e1209 --- /dev/null +++ b/ext/TensorKitEnzymeExt/vectorinterface.jl @@ -0,0 +1,165 @@ +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(scale!)}, + ::Type{RT}, + C::Annotation{<:AbstractTensorMap}, + α::Annotation{<:Number}, + ) where {RT} + C_cache = !isa(α, Const) ? copy(C.val) : nothing + scale!(C.val, α.val) + primal = EnzymeRules.needs_primal(config) ? C.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? C.dval : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, C_cache) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(scale!)}, + ::Type{RT}, + cache, + C::Annotation{<:AbstractTensorMap}, + α::Annotation{<:Number}, + ) where {RT} + Cval = something(cache, C.val) + Δα = if !isa(α, Const) && !isa(C, Const) + project_scalar(α.val, inner(Cval, C.dval)) + elseif !isa(α, Const) + zero(α.val) + else + nothing + end + !isa(C, Const) && scale!(C.dval, conj(α.val)) + return (nothing, Δα) +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(scale!)}, + ::Type{RT}, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + α::Annotation{<:Number}, + ) where {RT} + A_cache = !isa(α, Const) && EnzymeRules.overwritten(config)[3] ? copy(A.val) : nothing + scale!(C.val, A.val, α.val) + primal = EnzymeRules.needs_primal(config) ? C.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? C.dval : nothing + cache = A_cache + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(scale!)}, + ::Type{RT}, + cache, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + α::Annotation{<:Number}, + ) where {RT} + Aval = something(cache, A.val) + Δα = if !isa(α, Const) && !isa(C, Const) + project_scalar(α.val, inner(Aval, C.dval)) + elseif !isa(α, Const) + zero(α.val) + else + nothing + end + !isa(A, Const) && !isa(C, Const) && add!(A.dval, C.dval, conj(α.val)) + !isa(C, Const) && zerovector!(C.dval) + return (nothing, nothing, Δα) +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(add!)}, + ::Type{RT}, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + ) where {RT} + A_cache = !isa(α, Const) && EnzymeRules.overwritten(config)[3] ? copy(A.val) : nothing + C_cache = !isa(β, Const) ? copy(C.val) : nothing + add!(C.val, A.val, α.val, β.val) + primal = EnzymeRules.needs_primal(config) ? C.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? C.dval : nothing + cache = (A_cache, C_cache) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(add!)}, + ::Type{RT}, + cache, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + ) where {RT} + A_cache, C_cache = cache + Aval = something(A_cache, A.val) + Cval = something(C_cache, C.val) + Δα = if !isa(α, Const) && !isa(C, Const) + project_scalar(α.val, inner(Aval, C.dval)) + elseif !isa(α, Const) + zero(α.val) + else + nothing + end + Δβ = if !isa(β, Const) && !isa(C, Const) + project_scalar(β.val, inner(Cval, C.dval)) + elseif !isa(β, Const) + zero(β.val) + else + nothing + end + !isa(A, Const) && !isa(C, Const) && add!(A.dval, C.dval, conj(α.val)) + !isa(C, Const) && scale!(C.dval, conj(β.val)) + return (nothing, nothing, Δα, Δβ) +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(inner)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + B::Annotation{<:AbstractTensorMap}, + ) where {RT} + A_cache = !isa(B, Const) && EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing + B_cache = !isa(A, Const) && EnzymeRules.overwritten(config)[3] ? copy(B.val) : nothing + ret = inner(A.val, B.val) + primal = EnzymeRules.needs_primal(config) ? ret : nothing + shadow = EnzymeRules.needs_shadow(config) ? zero(ret) : nothing + cache = (A_cache, B_cache) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(inner)}, + dret::Active, + cache, + A::Annotation{<:AbstractTensorMap}, + B::Annotation{<:AbstractTensorMap}, + ) + A_cache, B_cache = cache + Aval = something(A_cache, A.val) + Bval = something(B_cache, B.val) + Δs = dret.val + !isa(A, Const) && add!(A.dval, Bval, conj(Δs)) + !isa(B, Const) && add!(B.dval, Aval, Δs) + return (nothing, nothing) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(inner)}, + ::Type{<:Const}, + cache, + A::Annotation{<:AbstractTensorMap}, + B::Annotation{<:AbstractTensorMap}, + ) + return (nothing, nothing) +end diff --git a/src/factorizations/matrixalgebrakit.jl b/src/factorizations/matrixalgebrakit.jl index b44f20653..d742860da 100644 --- a/src/factorizations/matrixalgebrakit.jl +++ b/src/factorizations/matrixalgebrakit.jl @@ -43,6 +43,20 @@ for f! in ( end end +for f! in (:qr_compact!, :qr_full!, :lq_compact!, :lq_full!) + @eval function MAK.$f!(t::AbstractTensorMap, F, alg::MAK.Algorithm{:Householder}) + foreachblock(t, F...) do _, (tblock, Fblocks...) + Fblocks′ = $f!(tblock, Fblocks, alg) + # deal with the case where the output is not in-place + for (b′, b) in zip(Fblocks′, Fblocks) + b === b′ || copy!(b, b′) + end + return nothing + end + return F + end +end + # Handle these separately because single output instead of tuple for f! in ( :qr_null!, :lq_null!, diff --git a/test/enzyme/factorizations.jl b/test/enzyme/factorizations.jl new file mode 100644 index 000000000..a20eb8bb0 --- /dev/null +++ b/test/enzyme/factorizations.jl @@ -0,0 +1,200 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using VectorInterface: Zero, One +using MatrixAlgebraKit +using Enzyme, EnzymeTestUtils +using Random +Enzyme.Compiler.VERBOSE_ERRORS[] = true +@isdefined(TestSetup) || include("../setup.jl") +using .TestSetup + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), # TODO problem in svdgauge_dep + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +eltypes = (Float64, ComplexF64) + +function remove_qrgauge_dependence!(ΔQ, t, Q) + for (c, b) in blocks(ΔQ) + m, n = size(block(t, c)) + minmn = min(m, n) + Qc = block(Q, c) + Q1 = view(Qc, 1:m, 1:minmn) + ΔQ2 = view(b, :, (minmn + 1):m) + mul!(ΔQ2, Q1, Q1' * ΔQ2) + end + return ΔQ +end +function remove_lqgauge_dependence!(ΔQ, t, Q) + for (c, b) in blocks(ΔQ) + m, n = size(block(t, c)) + minmn = min(m, n) + Qc = block(Q, c) + Q1 = view(Qc, 1:minmn, 1:n) + ΔQ2 = view(b, (minmn + 1):n, :) + mul!(ΔQ2, ΔQ2 * Q1', Q1) + end + return ΔQ +end +function remove_eiggauge_dependence!( + ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_degeneracy_atol(D) + ) + gaugepart = V' * ΔV + for (c, b) in blocks(gaugepart) + Dc = diagview(block(D, c)) + # for some reason this fails only on tests, and I cannot reproduce it in an + # interactive session. + # b[abs.(transpose(diagview(Dc)) .- diagview(Dc)) .>= degeneracy_atol] .= 0 + for j in axes(b, 2), i in axes(b, 1) + abs(Dc[i] - Dc[j]) >= degeneracy_atol && (b[i, j] = 0) + end + end + mul!(ΔV, V / (V' * V), gaugepart, -1, 1) + return ΔV +end +function remove_eighgauge_dependence!( + ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_degeneracy_atol(D) + ) + gaugepart = project_antihermitian!(V' * ΔV) + for (c, b) in blocks(gaugepart) + Dc = diagview(block(D, c)) + # for some reason this fails only on tests, and I cannot reproduce it in an + # interactive session. + # b[abs.(transpose(diagview(Dc)) .- diagview(Dc)) .>= degeneracy_atol] .= 0 + for j in axes(b, 2), i in axes(b, 1) + abs(Dc[i] - Dc[j]) >= degeneracy_atol && (b[i, j] = 0) + end + end + mul!(ΔV, V, gaugepart, -1, 1) + return ΔV +end +function remove_svdgauge_dependence!( + ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = MatrixAlgebraKit.default_pullback_degeneracy_atol(S) + ) + UdU = U' * ΔU + VdV = Vᴴ * ΔVᴴ' + gaugepart = project_antihermitian!(UdU + VdV) + for (c, b) in blocks(gaugepart) + Sd = diagview(block(S, c)) + # for some reason this fails only on tests, and I cannot reproduce it in an + # interactive session. + # b[abs.(transpose(diagview(Sc)) .- diagview(Sc)) .>= degeneracy_atol] .= 0 + for j in axes(b, 2), i in axes(b, 1) + abs(Sd[i] - Sd[j]) >= degeneracy_atol && (b[i, j] = 0) + end + end + mul!(ΔU, U, gaugepart, -1, 1) + return ΔU, ΔVᴴ +end + +@timedtestset "Enzyme - Factorizations: $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + + @timedtestset "QR" begin + A = randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]) + + EnzymeTestUtils.test_reverse(qr_compact, Duplicated, (A, Duplicated); atol, rtol) + + # qr_full/qr_null requires being careful with gauges + QR = qr_full(A) + ΔQR = randn!.(similar.(QR)) + remove_qrgauge_dependence!(ΔQR[1], A, QR[1]) + EnzymeTestUtils.test_reverse(qr_full, Duplicated, (A, Duplicated); output_tangent = ΔQR, atol, rtol) + #EnzymeTestUtils.test_reverse(qr_null, Duplicated, (A, Duplicated); atol, rtol) + + A = randn(T, V[1] ⊗ V[2] ← V[1]) + + EnzymeTestUtils.test_reverse(qr_compact, Duplicated, (A, Duplicated); atol, rtol) + + # qr_full/qr_null requires being careful with gauges + QR = qr_full(A) + ΔQR = randn!.(similar.(QR)) + remove_qrgauge_dependence!(ΔQR[1], A, QR[1]) + EnzymeTestUtils.test_reverse(qr_full, Duplicated, (A, Duplicated); output_tangent = ΔQR, atol, rtol) + #EnzymeTestUtils.test_reverse(qr_null, Duplicated, (A, Duplicated); atol, rtol) + end + + @timedtestset "LQ" begin + A = randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]) + + EnzymeTestUtils.test_reverse(lq_compact, Duplicated, (A, Duplicated); atol, rtol) + + # lq_full/lq_null requires being careful with gauges + LQ = lq_full(A) + ΔLQ = randn!.(similar.(LQ)) + remove_lqgauge_dependence!(ΔLQ[2], A, LQ[2]) + EnzymeTestUtils.test_reverse(lq_full, Duplicated, (A, Duplicated); output_tangent = ΔLQ, atol, rtol) + #EnzymeTestUtils.test_reverse(lq_null, Duplicated, (A, Duplicated); atol, rtol) + + A = randn(T, V[1] ⊗ V[2] ← V[1]) + + EnzymeTestUtils.test_reverse(lq_compact, Duplicated, (A, Duplicated); atol, rtol) + + # lq_full/lq_null requires being careful with gauges + LQ = lq_full(A) + ΔLQ = randn!.(similar.(LQ)) + remove_lqgauge_dependence!(ΔLQ[2], A, LQ[2]) + EnzymeTestUtils.test_reverse(lq_full, Duplicated, (A, Duplicated); output_tangent = ΔLQ, atol, rtol) + #EnzymeTestUtils.test_reverse(lq_null, Duplicated, (A, Duplicated); atol, rtol) + end + + @timedtestset "Eigenvalue decomposition" begin + for t in (randn(T, V[1] ← V[1]), rand(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2])) + DV = eig_full(t) + ΔDV = (DiagonalTensorMap(randn!(similar(DV[1].data)), space(DV[1])), randn!(similar(DV[2]))) + remove_eiggauge_dependence!(ΔDV[2], DV...) + EnzymeTestUtils.test_reverse(eig_full, Duplicated, (t, Duplicated); output_tangent = ΔDV, atol, rtol) + + th = project_hermitian(t) + DV = eigh_full(th) + ΔDV = (DiagonalTensorMap(randn!(similar(DV[1].data)), space(DV[1])), randn!(similar(DV[2]))) + remove_eighgauge_dependence!(ΔDV[2], DV...) + EnzymeTestUtils.test_reverse(eigh_full ∘ project_hermitian, Duplicated, (t, Duplicated); output_tangent = ΔDV, atol, rtol) + end + end + + @timedtestset "Singular value decomposition" begin + for t in (randn(T, V[1] ← V[1]), randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4])) + USVᴴ = svd_compact(t) + ΔUSVᴴ = (TensorMap(randn!(similar(USVᴴ[1].data)), space(USVᴴ[1])), DiagonalTensorMap(randn!(similar(USVᴴ[2].data)), space(USVᴴ[2], 1)), TensorMap(randn!(similar(USVᴴ[3].data)), space(USVᴴ[3]))) + remove_svdgauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...) + EnzymeTestUtils.test_reverse(svd_compact, Duplicated, (t, Duplicated); output_tangent = ΔUSVᴴ, atol, rtol) + + USVᴴ = svd_full(t) + ΔUSVᴴ = (TensorMap(randn!(similar(USVᴴ[1].data)), space(USVᴴ[1])), TensorMap(randn!(similar(USVᴴ[2].data)), space(USVᴴ[2])), TensorMap(randn!(similar(USVᴴ[3].data)), space(USVᴴ[3]))) + remove_svdgauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...) + EnzymeTestUtils.test_reverse(svd_full, Duplicated, (t, Duplicated); output_tangent = ΔUSVᴴ, atol, rtol) + + #=V_trunc = spacetype(t)(c => min(size(b)...) ÷ 2 for (c, b) in blocks(t)) + trunc = truncspace(V_trunc) + alg = MatrixAlgebraKit.select_algorithm(svd_trunc, t, nothing; trunc) + USVᴴtrunc = svd_trunc(t, alg) + ΔUSVᴴtrunc = (Enzyme.randn_tangent(Base.front(USVᴴtrunc))..., zero(last(USVᴴtrunc))) + remove_svdgauge_dependence!(ΔUSVᴴtrunc[1], ΔUSVᴴtrunc[3], Base.front(USVᴴtrunc)...) + EnzymeTestUtils.test_reverse(svd_trunc, Duplicated, (t, Duplicated), (alg, Const); output_tangent = ΔUSVᴴtrunc, atol, rtol)=# + # TODO + end + end +end diff --git a/test/enzyme/indexmanipulations.jl b/test/enzyme/indexmanipulations.jl new file mode 100644 index 000000000..75160da16 --- /dev/null +++ b/test/enzyme/indexmanipulations.jl @@ -0,0 +1,153 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using VectorInterface: Zero, One +using Enzyme, EnzymeTestUtils +using Random + +@isdefined(TestSetup) || include("../setup.jl") +using .TestSetup + +mode = Enzyme.ReverseMode +rng = Random.default_rng() + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 2, -1 => 1)', + Vect[U1Irrep](0 => 1, 1 => 1, -1 => 2), + Vect[U1Irrep](0 => 1, 1 => 2, -1 => 1)', + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +eltypes = (Float64, ComplexF64) + +@timedtestset "Enzyme - Index Manipulations: $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + symmetricbraiding = BraidingStyle(sectortype(eltype(V))) isa SymmetricBraiding + + symmetricbraiding && @timedtestset "add_permute!" begin + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + α = randn(T) + β = randn(T) + + # repeat a couple times to get some distribution of arrows + for _ in 1:5 + for Tα in (Active, Const), Tβ in (Active, Const) + p = randindextuple(numind(A)) + C = randn!(permute(A, p)) + EnzymeTestUtils.test_reverse(TensorKit.add_permute!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (α, Tα), (β, Tβ); atol, rtol) + A = C + end + end + end + + @timedtestset "add_transpose!" begin + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + α = randn(T) + β = randn(T) + + # repeat a couple times to get some distribution of arrows + for _ in 1:2 + p = randcircshift(numout(A), numin(A)) + C = randn!(transpose(A, p)) + EnzymeTestUtils.test_reverse(TensorKit.add_transpose!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (One(), Const), (Zero(), Const); atol, rtol) + for Tα in (Const, Active), Tβ in (Const, Active) + EnzymeTestUtils.test_reverse(TensorKit.add_transpose!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (α, Tα), (β, Tβ); atol, rtol) + if !(T <: Real) + EnzymeTestUtils.test_reverse(TensorKit.add_transpose!, Duplicated, (C, Duplicated), (real(A), Duplicated), (p, Const), (α, Tα), (β, Tβ); atol, rtol) + EnzymeTestUtils.test_reverse(TensorKit.add_transpose!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (real(α), Tα), (β, Tβ); atol, rtol) + EnzymeTestUtils.test_reverse(TensorKit.add_transpose!, Duplicated, (C, Duplicated), (real(A), Duplicated), (p, Const), (real(α), Tα), (β, Tβ); atol, rtol) + end + A = C + end + end + end + + @timedtestset "add_braid!" begin + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + α = randn(T) + β = randn(T) + + # repeat a couple times to get some distribution of arrows + for _ in 1:2 + p = randcircshift(numout(A), numin(A)) + levels = Tuple(randperm(numind(A))) + C = randn!(transpose(A, p)) + for Tα in (Active, Const), Tβ in (Active, Const) + EnzymeTestUtils.test_reverse(TensorKit.add_braid!, Duplicated, (C, Duplicated), (A, DuplicateD), (p, Const), (levels, Const), (α, Tα), (β, Tβ); atol, rtol) + if !(T <: Real) + EnzymeTestUtils.test_reverse(TensorKit.add_braid!, Duplicated, (C, Duplicated), (real(A), DuplicateD), (p, Const), (levels, Const), (α, Tα), (β, Tβ); atol, rtol) + EnzymeTestUtils.test_reverse(TensorKit.add_braid!, Duplicated, (C, Duplicated), (A, DuplicateD), (p, Const), (levels, Const), (real(α), Tα), (β, Tβ); atol, rtol) + EnzymeTestUtils.test_reverse(TensorKit.add_braid!, Duplicated, (C, Duplicated), (A, DuplicateD), (p, Const), (levels, Const), (real(α), Tα), (real(β), Tβ); atol, rtol) + end + A = C + end + end + end + + @timedtestset "flip_n_twist!" begin + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + for TA in (Const, Duplicated) + if !(T <: Real && !(sectorscalartype(sectortype(A)) <: Real)) + EnzymeTestUtils.test_reverse(twist!, TA, (A, TA), (1, Const); atol, rtol, fkwargs = (inv = false,)) + EnzymeTestUtils.test_reverse(twist!, TA, (A, TA), ([1, 3], Const); atol, rtol, fkwargs = (inv = true,)) + EnzymeTestUtils.test_reverse(twist!, TA, (A, TA), (1, Const); atol, rtol) + EnzymeTestUtils.test_reverse(twist!, TA, (A, TA), ([1, 3], Const); atol, rtol) + end + + EnzymeTestUtils.test_reverse(flip, TA, (A, TA), (1, Const); atol, rtol, fkwargs = (inv = false,)) + EnzymeTestUtils.test_reverse(flip, TA, (A, TA), [1, 3]; atol, rtol, fkwargs = (inv = true,)) + EnzymeTestUtils.test_reverse(flip, TA, (A, TA), (1, Const); atol, rtol) + EnzymeTestUtils.test_reverse(flip, TA, (A, TA), ([1, 3], Const); atol, rtol) + end + end + + @timedtestset "insert and remove units" begin + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + + for TA in (Const, Duplicated) + for insertunit in (insertleftunit, insertrightunit) + EnzymeTestUtils.test_reverse(insertunit, TA, (A, TA), (Val(1), Const); atol, rtol) + EnzymeTestUtils.test_reverse(insertunit, TA, (A, TA), (Val(4), Const); atol, rtol) + EnzymeTestUtils.test_reverse(insertunit, TA, (A', TA), (Val(2), Const); atol, rtol) + EnzymeTestUtils.test_reverse(insertunit, TA, (A, TA), (Val(1), Const); atol, rtol, fkwargs = (copy = false,)) + EnzymeTestUtils.test_reverse(insertunit, TA, (A, TA), (Val(2), Const); atol, rtol, fkwargs = (copy = true,)) + EnzymeTestUtils.test_reverse(insertunit, TA, (A, TA), (Val(3), Const); atol, rtol, fkwargs = (copy = false, dual = true, conj = true)) + EnzymeTestUtils.test_reverse(insertunit, TA, (A', TA), (Val(3), Const); atol, rtol, fkwargs = (copy = false, dual = true, conj = true)) + end + end + + for TB in (Const, Duplicated) + for i in 1:2 + B = insertleftunit(A, i; dual = rand(Bool)) + EnzymeTestUtils.test_reverse(removeunit, TB, (B, TB), (Val(i), Const); atol, rtol) + EnzymeTestUtils.test_reverse(removeunit, TB, (B, TB), (Val(i), Const); atol, rtol, fkwargs = (copy = false,)) + EnzymeTestUtils.test_reverse(removeunit, TB, (B, TB), (Val(i), Const); atol, rtol, fkwargs = (copy = true,)) + end + end + end +end diff --git a/test/enzyme/linalg.jl b/test/enzyme/linalg.jl new file mode 100644 index 000000000..7781653fd --- /dev/null +++ b/test/enzyme/linalg.jl @@ -0,0 +1,75 @@ +using Test, TestExtras +using TensorKit +using Enzyme, EnzymeTestUtils +using Random + +@isdefined(TestSetup) || include("../setup.jl") +using .TestSetup + +mode = Enzyme.ReverseMode +rng = Random.default_rng() + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), +) +eltypes = (Float64, ComplexF64) + +@timedtestset "Enzyme - LinearAlgebra: $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + + C = randn(T, V[1] ⊗ V[2] ← V[5]) + A = randn(T, codomain(C) ← V[3] ⊗ V[4]) + B = randn(T, domain(A) ← domain(C)) + α = randn(T) + β = randn(T) + + for TC in (Const, Duplicated), TA in (Const, Duplicated), TB in (Const, Duplicated) + for Tα in (Active, Const), Tβ in (Active, Const) + EnzymeTestUtils.test_reverse(mul!, TC, (C, TC), (A, TA), (B, TB), (α, Tα), (β, Tβ); atol, rtol) + end + EnzymeTestUtils.test_reverse(mul!, TC, (C, TC), (A, TA), (B, TB); atol, rtol) + end + + for RT in (Const, Active), TC in (Const, Duplicated) + EnzymeTestUtils.test_reverse(norm, RT, (C, TC), (2, Const); atol, rtol) + EnzymeTestUtils.test_reverse(norm, RT, (C', TC), (2, Const); atol, rtol) + end + + D1 = randn(T, V[1] ← V[1]) + D2 = randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]) + D3 = randn(T, V[1] ⊗ V[2] ⊗ V[3] ← V[1] ⊗ V[2] ⊗ V[3]) + + for RT in (Const, Active), TD in (Const, Duplicated) + EnzymeTestUtils.test_reverse(tr, RT, (D1, TD); atol, rtol) + EnzymeTestUtils.test_reverse(tr, RT, (D2, TD); atol, rtol) + #EnzymeTestUtils.test_reverse(tr, RT, (D3, TD); atol, rtol) + end + + for TD in (Const, Duplicated) + EnzymeTestUtils.test_reverse(inv, TD, (D1, TD); atol, rtol) + EnzymeTestUtils.test_reverse(inv, TD, (D2, TD); atol, rtol) + #EnzymeTestUtils.test_reverse(inv, TD, (D3, TD); atol, rtol) # TODO + end +end diff --git a/test/enzyme/planaroperations.jl b/test/enzyme/planaroperations.jl new file mode 100644 index 000000000..fa67ea7e5 --- /dev/null +++ b/test/enzyme/planaroperations.jl @@ -0,0 +1,119 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using VectorInterface: Zero, One +using Mooncake +using Random + +@isdefined(TestSetup) || include("../setup.jl") +using .TestSetup +using .TestSetup: _repartition + +mode = Mooncake.ReverseMode +rng = Random.default_rng() + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +eltypes = (Float64, ComplexF64) + +@timedtestset "Mooncake - PlanarOperations: $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + + @timedtestset "planarcontract!" begin + for _ in 1:5 + d = 0 + local V1, V2, V3, k1, k2, k3 + # retry a couple times to make sure there are at least some nonzero elements + for _ in 1:10 + k1 = rand(0:3) + k2 = rand(0:2) + k3 = rand(0:2) + V1 = prod(v -> rand(Bool) ? v' : v, rand(V, k1); init = one(V[1])) + V2 = prod(v -> rand(Bool) ? v' : v, rand(V, k2); init = one(V[1])) + V3 = prod(v -> rand(Bool) ? v' : v, rand(V, k3); init = one(V[1])) + d = min(dim(V1 ← V2), dim(V1' ← V2), dim(V2 ← V3), dim(V2' ← V3)) + d > 1 && break + end + k′ = rand(0:(k1 + k2)) + pA = randcircshift(k′, k1 + k2 - k′, k1) + ipA = _repartition(invperm(linearize(pA)), k′) + k′ = rand(0:(k2 + k3)) + pB = randcircshift(k′, k2 + k3 - k′, k2) + ipB = _repartition(invperm(linearize(pB)), k′) + # TODO: primal value already is broken for this? + # pAB = randcircshift(k1, k3) + pAB = _repartition(tuple((1:(k1 + k3))...), k1) + + α = randn(T) + β = randn(T) + + A = randn(T, permute(V1 ← V2, ipA)) + B = randn(T, permute(V2 ← V3, ipB)) + C = randn!( + TensorOperations.tensoralloc_contract( + T, A, pA, false, B, pB, false, pAB, Val(false) + ) + ) + Mooncake.TestUtils.test_rule( + rng, TensorKit.planarcontract!, C, A, pA, B, pB, pAB, One(), Zero(); + atol, rtol, mode, is_primitive = false + ) + Mooncake.TestUtils.test_rule( + rng, TensorKit.planarcontract!, C, A, pA, B, pB, pAB, α, β; + atol, rtol, mode, is_primitive = false + ) + end + end + + # TODO: currently broken + # @timedtestset "planartrace!" begin + # for _ in 1:5 + # k1 = rand(0:2) + # k2 = rand(0:1) + # V1 = map(v -> rand(Bool) ? v' : v, rand(V, k1)) + # V2 = map(v -> rand(Bool) ? v' : v, rand(V, k2)) + # V3 = prod(x -> x ⊗ x', V2[1:k2]; init = one(V[1])) + # V4 = prod(x -> x ⊗ x', V2[(k2 + 1):end]; init = one(V[1])) + # + # k′ = rand(0:(k1 + 2k2)) + # (_p, _q) = randcircshift(k′, k1 + 2k2 - k′, k1) + # p = _repartition(_p, rand(0:k1)) + # q = (tuple(_q[1:2:end]...), tuple(_q[2:2:end]...)) + # ip = _repartition(invperm(linearize((_p, _q))), k′) + # A = randn(T, permute(prod(V1) ⊗ V3 ← V4, ip)) + # + # α = randn(T) + # β = randn(T) + # C = randn!(TensorOperations.tensoralloc_add(T, A, p, false, Val(false))) + # Mooncake.TestUtils.test_rule( + # rng, TensorKit.planartrace!, + # C, A, p, q, α, β, + # TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator(); + # atol, rtol, mode + # ) + # end + # end +end diff --git a/test/enzyme/tensoroperations.jl b/test/enzyme/tensoroperations.jl new file mode 100644 index 000000000..0efe4d9c8 --- /dev/null +++ b/test/enzyme/tensoroperations.jl @@ -0,0 +1,157 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using VectorInterface: One, Zero +using Enzyme, EnzymeTestUtils +using Random + +Enzyme.Compiler.VERBOSE_ERRORS[] = true + +@isdefined(TestSetup) || include("../setup.jl") +using .TestSetup + +mode = Enzyme.ReverseMode +rng = Random.default_rng() + +spacelist = ( + #(ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 2, -1 => 1)', + Vect[U1Irrep](0 => 1, 1 => 1, -1 => 2), + Vect[U1Irrep](0 => 1, 1 => 2, -1 => 1)', + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +#eltypes = (Float64, ComplexF64) +eltypes = (ComplexF64,) + +@timedtestset verbose = true "Enzyme - TensorOperations: $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + symmetricbraiding = BraidingStyle(sectortype(eltype(V))) isa SymmetricBraiding + @show TensorKit.type_repr(sectortype(eltype(V))), T + flush(stdout) + symmetricbraiding && @timedtestset "tensorcontract!" begin + for _ in 1:2 + d = 0 + local V1, V2, V3 + # retry a couple times to make sure there are at least some nonzero elements + for _ in 1:10 + k1 = rand(0:3) + k2 = rand(0:2) + k3 = rand(0:2) + V1 = prod(v -> rand(Bool) ? v' : v, rand(V, k1); init = one(V[1])) + V2 = prod(v -> rand(Bool) ? v' : v, rand(V, k2); init = one(V[1])) + V3 = prod(v -> rand(Bool) ? v' : v, rand(V, k3); init = one(V[1])) + d = min(dim(V1 ← V2), dim(V1' ← V2), dim(V2 ← V3), dim(V2' ← V3)) + d > 0 && break + end + ipA = randindextuple(length(V1) + length(V2)) + pA = _repartition(invperm(linearize(ipA)), length(V1)) + ipB = randindextuple(length(V2) + length(V3)) + pB = _repartition(invperm(linearize(ipB)), length(V2)) + pAB = randindextuple(length(V1) + length(V3)) + + α = randn(T) + β = randn(T) + V2_conj = prod(conj, V2; init = one(V[1])) + A = randn(T, permute(V1 ← V2, ipA)) + B = randn(T, permute(V2 ← V3, ipB)) + C = randn!( + TensorOperations.tensoralloc_contract( + T, A, pA, false, B, pB, false, pAB, Val(false) + ) + ) + + @testset for α_ in ((One(), Const), (α, Const), (α, Active)), + β_ in ((Zero(), Const), (β, Const), (β, Active)) + EnzymeTestUtils.test_reverse( + TensorKit.blas_contract!, Duplicated, + (copy(C), Duplicated), (A, Duplicated), (pA, Const), + (B, Duplicated), (pB, Const), (pAB, Const), + α_, β_, + (TensorOperations.DefaultBackend(), Const), + (TensorOperations.DefaultAllocator(), Const); + atol, rtol + ) + end + if !(T <: Real) + EnzymeTestUtils.test_reverse( + TensorKit.blas_contract!, Duplicated, + (copy(C), Duplicated), (A, Duplicated), (pA, Const), + (B, Duplicated), (pB, Const), (pAB, Const), + (real(α), Active), (real(β), Active), + (TensorOperations.DefaultBackend(), Const), + (TensorOperations.DefaultAllocator(), Const); + atol, rtol + ) + #=EnzymeTestUtils.test_reverse( + TensorKit.blas_contract!, Duplicated, + (copy(C), Duplicated), (real(A), Duplicated), (pA, Const), + (B, Duplicated), (pB, Const), (pAB, Const), + (real(α), Active), (real(β), Active), + (TensorOperations.DefaultBackend(), Const), + (TensorOperations.DefaultAllocator(), Const); + atol, rtol + )=# # TODO + EnzymeTestUtils.test_reverse( + TensorKit.blas_contract!, Duplicated, + (copy(C), Duplicated), (A, Duplicated), (pA, Const), + (real(B), Duplicated), (pB, Const), (pAB, Const), + (real(α), Active), (real(β), Active), + (TensorOperations.DefaultBackend(), Const), + (TensorOperations.DefaultAllocator(), Const); + atol, rtol + ) + end + end + end + symmetricbraiding && @timedtestset "trace_permute!" begin + for _ in 1:2 + k1 = rand(0:2) + k2 = rand(1:2) + V1 = map(v -> rand(Bool) ? v' : v, rand(V, k1)) + V2 = map(v -> rand(Bool) ? v' : v, rand(V, k2)) + + (_p, _q) = randindextuple(k1 + 2 * k2, k1) + p = _repartition(_p, rand(0:k1)) + q = _repartition(_q, k2) + ip = _repartition(invperm(linearize((_p, _q))), rand(0:(k1 + 2 * k2))) + A = randn(T, permute(prod(V1) ⊗ prod(V2) ← prod(V2), ip)) + + α = randn(T) + β = randn(T) + C = randn!(TensorOperations.tensoralloc_add(T, A, p, false, Val(false))) + for TC in (Const, Duplicated), TA in (Const, Duplicated), Tα in (Const, Active), Tβ in (Const, Active) + EnzymeTestUtils.test_reverse( + TensorKit.trace_permute!, TC, + (copy(C), TC), (A, TA), (p, Const), (q, Const), + (α, Tα), (β, Tβ), (TensorOperations.DefaultBackend(), Const); + atol, rtol + ) + end + end + end +end diff --git a/test/enzyme/vectorinterface.jl b/test/enzyme/vectorinterface.jl new file mode 100644 index 000000000..62a0dd576 --- /dev/null +++ b/test/enzyme/vectorinterface.jl @@ -0,0 +1,75 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using Enzyme, EnzymeTestUtils +using Random + +@isdefined(TestSetup) || include("../setup.jl") +using .TestSetup + +mode = Enzyme.ReverseMode +rng = Random.default_rng() + +# TODO adjoints are broken! + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +eltypes = (Float64, ComplexF64) + +@timedtestset "Enzyme - VectorInterface: $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + + C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + α = randn(T) + β = randn(T) + + for TC in (Duplicated, Const), Tα in (Active, Const) + EnzymeTestUtils.test_reverse(scale!, TC, (C, TC), (α, Tα); atol, rtol) + #EnzymeTestUtils.test_reverse(scale!, TC, (C', TC), (α, Tα); atol, rtol) + for TA in (Duplicated, Const) + EnzymeTestUtils.test_reverse(scale!, TC, (C, TC), (A, TA), (α, Tα); atol, rtol) + #EnzymeTestUtils.test_reverse(scale!, TC, (C', TC), (A', TA), (α, Tα); atol, rtol) + #EnzymeTestUtils.test_reverse(scale!, TC, (copy(C'), TC), (A', TA), (α, Tα); atol, rtol) + #EnzymeTestUtils.test_reverse(scale!, TC, (C', TC), (copy(A'), TA), (α, Tα); atol, rtol) + end + end + + for TC in (Duplicated, Const), TA in (Duplicated, Const) + EnzymeTestUtils.test_reverse(add!, TC, (C, TC), (A, TA); atol, rtol) + for Tα in (Active, Const) + EnzymeTestUtils.test_reverse(add!, TC, (C, TC), (A, TA), (α, Tα); atol, rtol) + for Tβ in (Active, Const) + EnzymeTestUtils.test_reverse(add!, TC, (C, TC), (A, TA), (α, Tα), (β, Tβ); atol, rtol) + end + end + end + + for RT in (Active,), TC in (Duplicated, Const), TA in (Duplicated, Const) + EnzymeTestUtils.test_reverse(inner, RT, (C, TC), (A, TA); atol, rtol) + #EnzymeTestUtils.test_reverse(inner, RT, (C', TC), (A', TA); atol, rtol) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index ad7b4006e..9ad1f2480 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -57,11 +57,19 @@ istestfile(fn) = endswith(fn, ".jl") && !contains(fn, "setup") # somehow AD tests are unreasonably slow on Apple CI # and ChainRulesTestUtils doesn't like prereleases - if group == "chainrules" || group == "mooncake" + if group == "chainrules" || group == "mooncake" || group == "enzyme" Sys.isapple() && get(ENV, "CI", "false") == "true" && continue isempty(VERSION.prerelease) || continue end + if group == "enzyme" + include("enzyme/factorizations.jl") + include("enzyme/tensoroperations.jl") + include("enzyme/vectorinterface.jl") + include("enzyme/linalg.jl") + include("enzyme/indexmanipulations.jl") + end + grouppath = joinpath(@__DIR__, group) @time for file in filter(istestfile, readdir(grouppath)) @info "Running test file: $file"