Skip to content

Conversation

@lkdvos
Copy link
Member

@lkdvos lkdvos commented Dec 30, 2025

This is an attempt to work around a performance issue as reported in #235.
The main point is that the @thunk construction does not seem to be taken into account, and the tensorscalar calls lead to latency issues when computing derivatives in the context of GPUs because of synchronization issues.
However, in quite a lot of cases, these scalar parameters are fixed to 0 or 1, in which case we know that the contributions to the derivatives should simply be ZeroTangent, rather than some @thunk.
Here I simply try to catch some of these cases (Integer, (therefore also Bool), One and Zero) which end up in these calls quite often since they get generated by @tensor expressions.

I think this should alleviate the synchronization issues, but also might just overall speed up the computations since it avoids computing anything to begin with.

@XingyuZhang2018, I don't have access to a GPU to run the benchmarks right now, would you mind verifying if this actually does resolve the issue?

@codecov
Copy link

codecov bot commented Dec 30, 2025

Codecov Report

❌ Patch coverage is 0% with 43 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
ext/TensorOperationsChainRulesCoreExt.jl 0.00% 43 Missing ⚠️
Files with missing lines Coverage Δ
ext/TensorOperationsChainRulesCoreExt.jl 0.00% <0.00%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@lkdvos lkdvos changed the title Avoid computing derivatives with respect to non-derivable α, β Avoid computing derivatives with respect to non-differentiable α, β Dec 31, 2025
@XingyuZhang2018
Copy link

I reran the test from #235, and the results are much more reasonable now:

foo1 (@tensor):      108.371 ms (19059 allocations: 534.84 KiB)
foo2 (OMEinsum):      73.656 ms (24499 allocations: 736.59 KiB)
foo3 (reshape+*):    107.698 ms (15417 allocations: 408.08 KiB)

Thanks!

@XingyuZhang2018
Copy link

I found a side effect of the current implementation: it can introduce a CPU Array into an otherwise pure CuArray computation.

A minimal example is shown below. In foo2, the contraction produces a 0-dimensional result, and during AD this ends up materializing as an Array{ComplexF64,0}. This CPU scalar is then fed back into tensorcontract! together with CuArrays, which causes a backend mismatch and ultimately fails:

using TensorOperations
using CUDA
using cuTENSOR
using OMEinsum
using Zygote
using Test
using LinearAlgebra
using BenchmarkTools

@testset "ad" begin
    D = 2^5
    A = [CUDA.rand(ComplexF64, D,D,D) for _ in 1:10]
    B = [CUDA.rand(ComplexF64, D,D) for _ in 1:10]

    function foo1(A)
        C = Zygote.Buffer(A)
        for i in 1:length(A)
            @tensor C[i][1,2,4] := A[i][1,2,3] * B[i][3,4]
        end
        return real(dot(C, C))
    end

    function foo2(A)
        s = 0.0
        for i in 1:length(A)
            @tensor C[1,2,4] := A[i][1,2,3] * B[i][3,4]
            s += @tensor conj(C[1,2,3]) * C[1,2,3]
        end
        
        return real(s)
    end

    g1 = Zygote.gradient(foo1, A)[1]
    g2 = Zygote.gradient(foo2, A)[1]
end

ad: Error During Test at d:\1 - research\1.18 - precondition_SU_AD\TeneT_demo\test\testTensorop.jl:10
  Got exception outside of a @test
  ArgumentError: No suitable backend found for tensorcontract! and tensor types CuArray{ComplexF64, 3, CUDA.DeviceMemory}, Array{ComplexF64, 0} and CuArray{ComplexF64, 3, CUDA.DeviceMemory}
  Stacktrace:
    [1] tensorcontract!(C::CuArray{ComplexF64, 3, CUDA.DeviceMemory}, A::Array{ComplexF64, 0}, pA::Tuple{Tuple{}, Tuple{}}, conjA::Bool, B::CuArray{ComplexF64, 3, CUDA.DeviceMemory}, pB::Tuple{Tuple{}, Tuple{Int64, Int64, Int64}}, conjB::Bool, pAB::Tuple{Tuple{Int64, Int64, Int64}, Tuple{}}, α::VectorInterface.One, β::VectorInterface.Zero, backend::TensorOperations.NoBackend, allocator::TensorOperations.DefaultAllocator)
      @ TensorOperations d:\1 - research\1.18 - precondition_SU_AD\TensorOperations.jl\src\interface.jl:187
    [2] tensorcontract!(C::CuArray{ComplexF64, 3, CUDA.DeviceMemory}, A::Array{ComplexF64, 0}, pA::Tuple{Tuple{}, Tuple{}}, conjA::Bool, B::CuArray{ComplexF64, 3, CUDA.DeviceMemory}, pB::Tuple{Tuple{}, Tuple{Int64, Int64, Int64}}, conjB::Bool, pAB::Tuple{Tuple{Int64, Int64, Int64}, Tuple{}}, α::VectorInterface.One, β::VectorInterface.Zero, backend::TensorOperations.DefaultBackend, allocator::TensorOperations.DefaultAllocator)
      @ TensorOperations d:\1 - research\1.18 - precondition_SU_AD\TensorOperations.jl\src\interface.jl:179
    [3] tensorcontract!(C::CuArray{ComplexF64, 3, CUDA.DeviceMemory}, A::Array{ComplexF64, 0}, pA::Tuple{Tuple{}, Tuple{}}, conjA::Bool, B::CuArray{ComplexF64, 3, CUDA.DeviceMemory}, pB::Tuple{Tuple{}, Tuple{Int64, Int64, Int64}}, conjB::Bool, pAB::Tuple{Tuple{Int64, Int64, Int64}, Tuple{}}, α::VectorInterface.One, β::VectorInterface.Zero, backend::TensorOperations.DefaultBackend)    
      @ TensorOperations d:\1 - research\1.18 - precondition_SU_AD\TensorOperations.jl\src\interface.jl:166
    [4] tensorcontract!(C::CuArray{ComplexF64, 3, CUDA.DeviceMemory}, A::Array{ComplexF64, 0}, pA::Tuple{Tuple{}, Tuple{}}, conjA::Bool, B::CuArray{ComplexF64, 3, CUDA.DeviceMemory}, pB::Tuple{Tuple{}, Tuple{Int64, Int64, Int64}}, conjB::Bool, pAB::Tuple{Tuple{Int64, Int64, Int64}, Tuple{}}, α::VectorInterface.One, β::VectorInterface.Zero)
      @ TensorOperations d:\1 - research\1.18 - precondition_SU_AD\TensorOperations.jl\src\interface.jl:155
    [5] (::TensorOperationsChainRulesCoreExt.var"#52#59"{Tuple{Tuple{}, Tuple{}}, Array{ComplexF64, 0}, CuArray{ComplexF64, 3, CUDA.DeviceMemory}, Tuple{Tuple{}, Tuple{Int64, Int64, Int64}}, Bool, CuArray{ComplexF64, 3, CUDA.DeviceMemory}, Tuple{Tuple{Int64, Int64, Int64}, Tuple{}}, Bool, VectorInterface.One, Tuple{}, ChainRulesCore.ProjectTo{AbstractArray, @NamedTuple{element::ChainRulesCore.ProjectTo{ComplexF64, @NamedTuple{}}, axes::Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}}}})()        
      @ TensorOperationsChainRulesCoreExt d:\1 - research\1.18 - precondition_SU_AD\TensorOperations.jl\ext\TensorOperationsChainRulesCoreExt.jl:207
    [6] unthunk
      @ C:\Users\xingzhan\.julia\packages\ChainRulesCore\Vsbj9\src\tangent_types\thunks.jl:213 [inlined]
    [7] wrap_chainrules_output
      @ C:\Users\xingzhan\.julia\packages\Zygote\zowwZ\src\compiler\chainrules.jl:110 [inlined]
    [8] map (repeats 3 times)
      @ .\tuple.jl:358 [inlined]
    [9] wrap_chainrules_output
      @ C:\Users\xingzhan\.julia\packages\Zygote\zowwZ\src\compiler\chainrules.jl:111 [inlined]
   [10] (::Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#pullback#57"{CuArray{ComplexF64, 0, CUDA.DeviceMemory}, CuArray{ComplexF64, 3, CUDA.DeviceMemory}, Tuple{Tuple{}, Tuple{Int64, Int64, Int64}}, Bool, CuArray{ComplexF64, 3, CUDA.DeviceMemory}, Tuple{Tuple{Int64, Int64, Int64}, Tuple{}}, Bool, Tuple{Tuple{}, Tuple{}}, VectorInterface.One, VectorInterface.Zero, Tuple{}, ChainRulesCore.ProjectTo{Number, @NamedTuple{}}, ChainRulesCore.ProjectTo{Number, @NamedTuple{}}, ChainRulesCore.ProjectTo{AbstractArray, @NamedTuple{element::ChainRulesCore.ProjectTo{ComplexF64, @NamedTuple{}}, axes::Tuple{}}}, ChainRulesCore.ProjectTo{AbstractArray, @NamedTuple{element::ChainRulesCore.ProjectTo{ComplexF64, @NamedTuple{}}, axes::Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}}}, ChainRulesCore.ProjectTo{AbstractArray, @NamedTuple{element::ChainRulesCore.ProjectTo{ComplexF64, @NamedTuple{}}, axes::Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}}}}})(dy::Array{ComplexF64, 0})
      @ Zygote C:\Users\xingzhan\.julia\packages\Zygote\zowwZ\src\compiler\chainrules.jl:212
   [11] foo2
      @ d:\1 - research\1.18 - precondition_SU_AD\TeneT_demo\test\testTensorop.jl:27 [inlined]
   [12] (::Zygote.Pullback{Tuple{var"#foo2#107"{Vector{CuArray{ComplexF64, 2, CUDA.DeviceMemory}}}, Vector{CuArray{ComplexF64, 3, CUDA.DeviceMemory}}}, Any})(Δ::Float64)
      @ Zygote C:\Users\xingzhan\.julia\packages\Zygote\zowwZ\src\compiler\interface2.jl:0
   [13] (::Zygote.var"#78#79"{Zygote.Pullback{Tuple{var"#foo2#107"{Vector{CuArray{ComplexF64, 2, CUDA.DeviceMemory}}}, Vector{CuArray{ComplexF64, 3, CUDA.DeviceMemory}}}, Any}})(Δ::Float64)
      @ Zygote C:\Users\xingzhan\.julia\packages\Zygote\zowwZ\src\compiler\interface.jl:91
   [14] gradient(f::Function, args::Vector{CuArray{ComplexF64, 3, CUDA.DeviceMemory}})
      @ Zygote C:\Users\xingzhan\.julia\packages\Zygote\zowwZ\src\compiler\interface.jl:148
   [15] macro expansion
      @ d:\1 - research\1.18 - precondition_SU_AD\TeneT_demo\test\testTensorop.jl:34 [inlined]
   [16] macro expansion
      @ C:\Users\xingzhan\AppData\Local\Programs\Julia-1.11.1\share\julia\stdlib\v1.11\Test\src\Test.jl:1700 [inlined]
   [17] top-level scope
      @ d:\1 - research\1.18 - precondition_SU_AD\TeneT_demo\test\testTensorop.jl:11
   [18] eval
      @ .\boot.jl:430 [inlined]
   [19] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
      @ Base .\loading.jl:2643
   [20] invokelatest(::Any, ::Any, ::Vararg{Any}; kwargs::@Kwargs{})
      @ Base .\essentials.jl:1055
   [21] invokelatest(::Any, ::Any, ::Vararg{Any})
      @ Base .\essentials.jl:1052
   [22] inlineeval(m::Module, code::String, code_line::Int64, code_column::Int64, file::String; softscope::Bool)
      @ VSCodeServer c:\Users\xingzhan\.vscode\extensions\julialang.language-julia-1.79.2\scripts\packages\VSCodeServer\src\eval.jl:271
   [23] (::VSCodeServer.var"#69#74"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})()
      @ VSCodeServer c:\Users\xingzhan\.vscode\extensions\julialang.language-julia-1.79.2\scripts\packages\VSCodeServer\src\eval.jl:181
   [24] withpath(f::VSCodeServer.var"#69#74"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams}, path::String)
      @ VSCodeServer c:\Users\xingzhan\.vscode\extensions\julialang.language-julia-1.79.2\scripts\packages\VSCodeServer\src\repl.jl:276
   [25] (::VSCodeServer.var"#68#73"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})()
      @ VSCodeServer c:\Users\xingzhan\.vscode\extensions\julialang.language-julia-1.79.2\scripts\packages\VSCodeServer\src\eval.jl:179
   [26] hideprompt(f::VSCodeServer.var"#68#73"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})
      @ VSCodeServer c:\Users\xingzhan\.vscode\extensions\julialang.language-julia-1.79.2\scripts\packages\VSCodeServer\src\repl.jl:38
   [27] #67
      @ c:\Users\xingzhan\.vscode\extensions\julialang.language-julia-1.79.2\scripts\packages\VSCodeServer\src\eval.jl:150 [inlined]
   [28] with_logstate(f::VSCodeServer.var"#67#72"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams}, logstate::Base.CoreLogging.LogState)
      @ Base.CoreLogging .\logging\logging.jl:522
   [29] with_logger
      @ .\logging\logging.jl:632 [inlined]
   [30] (::VSCodeServer.var"#66#71"{VSCodeServer.ReplRunCodeRequestParams})()
      @ VSCodeServer c:\Users\xingzhan\.vscode\extensions\julialang.language-julia-1.79.2\scripts\packages\VSCodeServer\src\eval.jl:263
   [31] #invokelatest#2
      @ .\essentials.jl:1055 [inlined]
   [32] invokelatest(::Any)
      @ Base .\essentials.jl:1052
   [33] (::VSCodeServer.var"#64#65")()
      @ VSCodeServer c:\Users\xingzhan\.vscode\extensions\julialang.language-julia-1.79.2\scripts\packages\VSCodeServer\src\eval.jl:34
      ```

@lkdvos
Copy link
Member Author

lkdvos commented Jan 1, 2026

Thanks for the report!

I think I traced this back to the change made in #233, specifically in the implementation of the tensorscalar reverse rule.
I had to hack around a bit to make sure the changes can be reverted while still allowing for higher-order derivatives, which I think should now work.
I think this PR would be ready now, unless further issues still pop up

return fill!(x′, y)
end
function ChainRulesCore.rrule(::typeof(similar_and_fill), x, y)
similar_and_fill_pullback(Δx) = NoTangent(), ZeroTangent(), tensorscalar(unthunk(Δx))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused by this rule, in particular the adjoint of y: I think I can reinterpret the output of similar_and_fill(x, y) as just y * similar_and_fill(x, 1).

To avoid confusion, let's say x = y * similar_and_fill(some_other_x, 1). Then clearly forward derivatives satisfy ẋ = ẏ * similar_and_fill(some_other_x, 1), where the last factor is completely constant.

So then I obtain from equation dot(Δx, ẋ) = ẏ * dot(Δx, similar_and_fill(some_other_x, 1)) to Δy' * ẏ that

Δy = dot(similar_and_fill(some_other_x, 1), Δx)

Maybe I have to first read further, and similar_and_fill is only ever called on tensor arguments x that are equivalent to scalars, and thus have only a single entry. But in principle, the definition makes sense for general tensors, but then the reverse rule can clearly not be correct since tensorscalar(Δx) would fail.

_Δc = unthunk(Δc)
return NoTangent(), projectC(_Δc)
end
tensorscalar_pullback(Δc) = NoTangent(), similar_and_fill(C, unthunk(Δc))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I see, so similar_and_fill is indeed only called on tensors C for which tensorscalar makes sense.

Copy link
Member

@Jutho Jutho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this is a very clean fix.

@lkdvos lkdvos merged commit f1fd025 into master Jan 1, 2026
15 of 16 checks passed
@lkdvos lkdvos deleted the ld-sync branch January 1, 2026 21:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Performance issue in Zygote AD for TensorOperations.@tensor on CUDA: costly CPU copies in rrule

4 participants