From c1103e6ac98c99842e5b26998028bf4417e3aba8 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 5 Jan 2026 21:45:29 +0100 Subject: [PATCH 1/5] add Adapt --- Project.toml | 1 + src/TensorKit.jl | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index f56b35ba7..5a525e9ff 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Jutho Haegeman, Lukas Devos"] version = "0.16.0" [deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4" diff --git a/src/TensorKit.jl b/src/TensorKit.jl index c35101235..34b1c3d94 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -121,6 +121,7 @@ using MatrixAlgebraKit using LRUCache using OhMyThreads using ScopedValues +using Adapt using TensorKitSectors import TensorKitSectors: dim, BraidingStyle, FusionStyle, ⊠, ⊗ @@ -142,7 +143,6 @@ using LinearAlgebra: norm, dot, normalize, normalize!, tr, eigen, eigen!, svd, svd!, isposdef, isposdef!, rank, cond, Diagonal, Hermitian -using MatrixAlgebraKit import Base.Meta From 38b5b3e440a96b5134f6a06b9a20cd149b6e9607 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 5 Jan 2026 21:50:23 +0100 Subject: [PATCH 2/5] convert to extension --- Project.toml | 3 ++- ext/TensorKitAdaptExt.jl | 0 src/TensorKit.jl | 1 - 3 files changed, 2 insertions(+), 2 deletions(-) create mode 100644 ext/TensorKitAdaptExt.jl diff --git a/Project.toml b/Project.toml index 5a525e9ff..8ca217db8 100644 --- a/Project.toml +++ b/Project.toml @@ -4,7 +4,6 @@ authors = ["Jutho Haegeman, Lukas Devos"] version = "0.16.0" [deps] -Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4" @@ -19,12 +18,14 @@ TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8" [weakdeps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" [extensions] +TensorKitAdaptExt = "Adapt" TensorKitCUDAExt = ["CUDA", "cuTENSOR"] TensorKitChainRulesCoreExt = "ChainRulesCore" TensorKitFiniteDifferencesExt = "FiniteDifferences" diff --git a/ext/TensorKitAdaptExt.jl b/ext/TensorKitAdaptExt.jl new file mode 100644 index 000000000..e69de29bb diff --git a/src/TensorKit.jl b/src/TensorKit.jl index 34b1c3d94..688e7363c 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -121,7 +121,6 @@ using MatrixAlgebraKit using LRUCache using OhMyThreads using ScopedValues -using Adapt using TensorKitSectors import TensorKitSectors: dim, BraidingStyle, FusionStyle, ⊠, ⊗ From bcebc6d713935139d159839522683d7252231e67 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 5 Jan 2026 22:55:48 +0100 Subject: [PATCH 3/5] add Adapt implementations --- ext/TensorKitAdaptExt.jl | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/ext/TensorKitAdaptExt.jl b/ext/TensorKitAdaptExt.jl index e69de29bb..734f174ad 100644 --- a/ext/TensorKitAdaptExt.jl +++ b/ext/TensorKitAdaptExt.jl @@ -0,0 +1,19 @@ +module TensorKitAdaptExt + +using TensorKit +using TensorKit: AdjointTensorMap +using Adapt + +function Adapt.adapt_structure(to, x::TensorMap) + data′ = adapt(to, x.data) + return TensorMap{eltype(data′)}(data′, space(x)) +end +function Adapt.adapt_structure(to, x::AdjointTensorMap) + return adjoint(adapt(to, parent(x))) +end +function Adapt.adapt_structure(to, x::DiagonalTensorMap) + data′ = adapt(to, x.data) + return DiagonalTensorMap(data′, x.domain) +end + +end From 7fc983a787ff7018672f998e21bcc5f47493efe0 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 5 Jan 2026 23:07:00 +0100 Subject: [PATCH 4/5] add Adapt tests --- test/cuda/tensors.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/cuda/tensors.jl b/test/cuda/tensors.jl index 8a89a8ad8..87a4f0236 100644 --- a/test/cuda/tensors.jl +++ b/test/cuda/tensors.jl @@ -115,6 +115,18 @@ for V in spacelist @test domain(t2) == one(W) end end + @timedtestset "Adapt" begin + W = V1 ⊗ V2 ⊗ V3 ⊗ V4 ⊗ V5 + t = rand(Float64, W) + t_gpu = @constinferred adapt(CuArray, t) + @test storagetype(t_gpu) <: CuArray + @test scalartype(t_gpu) === scalartype(t) + @test collect(t_gpu.data) == t.data + + t_cpu = @constinferred adapt(Array, t_cpu) + @test t_cpu == t + @test storagetype(t_cpu) isa Array + end @timedtestset "Tensor Dict conversion" begin W = V1 ⊗ V2 ⊗ V3 ← V4 ⊗ V5 for T in (Int, Float32, ComplexF64) From ae5ea379254a3170918e5dc14dbb61f9d95ac34c Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 6 Jan 2026 10:27:43 +0100 Subject: [PATCH 5/5] fix tests --- test/cuda/tensors.jl | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/test/cuda/tensors.jl b/test/cuda/tensors.jl index 87a4f0236..e495eef08 100644 --- a/test/cuda/tensors.jl +++ b/test/cuda/tensors.jl @@ -117,15 +117,17 @@ for V in spacelist end @timedtestset "Adapt" begin W = V1 ⊗ V2 ⊗ V3 ⊗ V4 ⊗ V5 - t = rand(Float64, W) - t_gpu = @constinferred adapt(CuArray, t) - @test storagetype(t_gpu) <: CuArray - @test scalartype(t_gpu) === scalartype(t) - @test collect(t_gpu.data) == t.data + for T in (Int, Float32, ComplexF64) + t = rand(T, W) + t_gpu = @constinferred adapt(CuArray, t) + @test storagetype(t_gpu) <: CuArray{T} + @test scalartype(t_gpu) === scalartype(t) + @test collect(t_gpu.data) == t.data - t_cpu = @constinferred adapt(Array, t_cpu) - @test t_cpu == t - @test storagetype(t_cpu) isa Array + t_cpu = @constinferred adapt(Array, t_gpu) + @test t_cpu == t + @test storagetype(t_cpu) <: Array{T} + end end @timedtestset "Tensor Dict conversion" begin W = V1 ⊗ V2 ⊗ V3 ← V4 ⊗ V5