diff --git a/Project.toml b/Project.toml index f56b35ba7..8ca217db8 100644 --- a/Project.toml +++ b/Project.toml @@ -18,12 +18,14 @@ TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8" [weakdeps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" [extensions] +TensorKitAdaptExt = "Adapt" TensorKitCUDAExt = ["CUDA", "cuTENSOR"] TensorKitChainRulesCoreExt = "ChainRulesCore" TensorKitFiniteDifferencesExt = "FiniteDifferences" diff --git a/ext/TensorKitAdaptExt.jl b/ext/TensorKitAdaptExt.jl new file mode 100644 index 000000000..734f174ad --- /dev/null +++ b/ext/TensorKitAdaptExt.jl @@ -0,0 +1,19 @@ +module TensorKitAdaptExt + +using TensorKit +using TensorKit: AdjointTensorMap +using Adapt + +function Adapt.adapt_structure(to, x::TensorMap) + data′ = adapt(to, x.data) + return TensorMap{eltype(data′)}(data′, space(x)) +end +function Adapt.adapt_structure(to, x::AdjointTensorMap) + return adjoint(adapt(to, parent(x))) +end +function Adapt.adapt_structure(to, x::DiagonalTensorMap) + data′ = adapt(to, x.data) + return DiagonalTensorMap(data′, x.domain) +end + +end diff --git a/src/TensorKit.jl b/src/TensorKit.jl index c35101235..688e7363c 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -142,7 +142,6 @@ using LinearAlgebra: norm, dot, normalize, normalize!, tr, eigen, eigen!, svd, svd!, isposdef, isposdef!, rank, cond, Diagonal, Hermitian -using MatrixAlgebraKit import Base.Meta diff --git a/test/cuda/tensors.jl b/test/cuda/tensors.jl index 8a89a8ad8..e495eef08 100644 --- a/test/cuda/tensors.jl +++ b/test/cuda/tensors.jl @@ -115,6 +115,20 @@ for V in spacelist @test domain(t2) == one(W) end end + @timedtestset "Adapt" begin + W = V1 ⊗ V2 ⊗ V3 ⊗ V4 ⊗ V5 + for T in (Int, Float32, ComplexF64) + t = rand(T, W) + t_gpu = @constinferred adapt(CuArray, t) + @test storagetype(t_gpu) <: CuArray{T} + @test scalartype(t_gpu) === scalartype(t) + @test collect(t_gpu.data) == t.data + + t_cpu = @constinferred adapt(Array, t_gpu) + @test t_cpu == t + @test storagetype(t_cpu) <: Array{T} + end + end @timedtestset "Tensor Dict conversion" begin W = V1 ⊗ V2 ⊗ V3 ← V4 ⊗ V5 for T in (Int, Float32, ComplexF64)