From 3f01be84e03d8ae0339713f3f9cd98308e629a50 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sat, 29 Nov 2025 11:34:12 -0500 Subject: [PATCH 1/2] add rrule for transpose --- ext/TensorKitChainRulesCoreExt/linalg.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/ext/TensorKitChainRulesCoreExt/linalg.jl b/ext/TensorKitChainRulesCoreExt/linalg.jl index 9db147a92..d4510c2f1 100644 --- a/ext/TensorKitChainRulesCoreExt/linalg.jl +++ b/ext/TensorKitChainRulesCoreExt/linalg.jl @@ -69,6 +69,16 @@ function ChainRulesCore.rrule( return permute(tsrc, p; copy = true), permute_pullback end +function ChainRulesCore.rrule( + ::typeof(transpose), tsrc::AbstractTensorMap, p::Index2Tuple; copy::Bool = false + ) + function transpose_pullback(Δtdst) + invp = TensorKit._canonicalize(TupleTools.invperm(linearize(p)), tsrc) + return NoTangent(), transpose(unthunk(Δtdst), invp; copy = true), NoTangent() + end + return transpose(tsrc, p; copy = true), transpose_pullback +end + function ChainRulesCore.rrule(::typeof(tr), A::AbstractTensorMap) tr_pullback(Δtr) = NoTangent(), Δtr * id(domain(A)) return tr(A), tr_pullback From af4c611fa6de1ec737b8c6c14dcc3936d2efbdd1 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sat, 29 Nov 2025 11:36:38 -0500 Subject: [PATCH 2/2] add test rrule transpose --- test/autodiff/ad.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/autodiff/ad.jl b/test/autodiff/ad.jl index 0230936f7..15c8c387e 100644 --- a/test/autodiff/ad.jl +++ b/test/autodiff/ad.jl @@ -293,6 +293,7 @@ for V in spacelist C = randn(T, domain(A), codomain(A)) test_rrule(*, A, C) + test_rrule(transpose, A, ((2, 5, 4), (1, 3))) symmetricbraiding && test_rrule(permute, A, ((1, 3, 2), (5, 4))) test_rrule(twist, A, 1) test_rrule(twist, A, [1, 3])