diff --git a/ext/TensorKitChainRulesCoreExt/linalg.jl b/ext/TensorKitChainRulesCoreExt/linalg.jl index 9db147a92..d4510c2f1 100644 --- a/ext/TensorKitChainRulesCoreExt/linalg.jl +++ b/ext/TensorKitChainRulesCoreExt/linalg.jl @@ -69,6 +69,16 @@ function ChainRulesCore.rrule( return permute(tsrc, p; copy = true), permute_pullback end +function ChainRulesCore.rrule( + ::typeof(transpose), tsrc::AbstractTensorMap, p::Index2Tuple; copy::Bool = false + ) + function transpose_pullback(Δtdst) + invp = TensorKit._canonicalize(TupleTools.invperm(linearize(p)), tsrc) + return NoTangent(), transpose(unthunk(Δtdst), invp; copy = true), NoTangent() + end + return transpose(tsrc, p; copy = true), transpose_pullback +end + function ChainRulesCore.rrule(::typeof(tr), A::AbstractTensorMap) tr_pullback(Δtr) = NoTangent(), Δtr * id(domain(A)) return tr(A), tr_pullback diff --git a/test/autodiff/ad.jl b/test/autodiff/ad.jl index 0230936f7..15c8c387e 100644 --- a/test/autodiff/ad.jl +++ b/test/autodiff/ad.jl @@ -293,6 +293,7 @@ for V in spacelist C = randn(T, domain(A), codomain(A)) test_rrule(*, A, C) + test_rrule(transpose, A, ((2, 5, 4), (1, 3))) symmetricbraiding && test_rrule(permute, A, ((1, 3, 2), (5, 4))) test_rrule(twist, A, 1) test_rrule(twist, A, [1, 3])