Skip to content

Commit 677f4af

Browse files
committed
Start on CUDA extension
1 parent d60855e commit 677f4af

File tree

9 files changed

+902
-28
lines changed

9 files changed

+902
-28
lines changed

.buildkite/pipeline.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@ steps:
1515
queue: "juliagpu"
1616
cuda: "*"
1717
if: build.message !~ /\[skip tests\]/
18-
timeout_in_minutes: 30
18+
timeout_in_minutes: 60
1919
matrix:
2020
setup:
2121
julia:
2222
- "1.10"
23-
- "1.11"
23+
- "1.12"
2424

2525
- label: "Julia {{matrix.julia}} -- AMDGPU"
2626
plugins:
@@ -36,9 +36,9 @@ steps:
3636
rocm: "*"
3737
rocmgpu: "*"
3838
if: build.message !~ /\[skip tests\]/
39-
timeout_in_minutes: 30
39+
timeout_in_minutes: 60
4040
matrix:
4141
setup:
4242
julia:
4343
- "1.10"
44-
- "1.11"
44+
- "1.12"

Project.toml

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,30 @@ TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
1818
VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
1919

2020
[weakdeps]
21+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2122
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2223
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
24+
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
25+
26+
[sources]
27+
GPUArrays = {rev = "master", url = "https://github.com/JuliaGPU/GPUArrays.jl"}
28+
MatrixAlgebraKit = {rev = "main", url = "https://github.com/QuantumKitHub/MatrixAlgebraKit.jl"}
2329

2430
[extensions]
31+
TensorKitCUDAExt = ["CUDA", "cuTENSOR"]
2532
TensorKitChainRulesCoreExt = "ChainRulesCore"
2633
TensorKitFiniteDifferencesExt = "FiniteDifferences"
2734

2835
[compat]
36+
Adapt = "4"
2937
Aqua = "0.6, 0.7, 0.8"
3038
ArgParse = "1.2.0"
39+
CUDA = "5.9"
3140
ChainRulesCore = "1"
3241
ChainRulesTestUtils = "1"
3342
Combinatorics = "1"
3443
FiniteDifferences = "0.12"
44+
GPUArrays = "11.3.1"
3545
LRUCache = "1.0.2"
3646
LinearAlgebra = "1"
3747
MatrixAlgebraKit = "0.6.0"
@@ -48,21 +58,26 @@ TestExtras = "0.2,0.3"
4858
TupleTools = "1.1"
4959
VectorInterface = "0.4.8, 0.5"
5060
Zygote = "0.7"
61+
cuTENSOR = "2"
5162
julia = "1.10"
5263

5364
[extras]
54-
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
65+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
5566
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
67+
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
68+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
5669
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
5770
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
5871
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
5972
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
73+
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
6074
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
6175
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
6276
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
6377
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6478
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
6579
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
80+
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
6681

6782
[targets]
68-
test = ["ArgParse", "Aqua", "Combinatorics", "LinearAlgebra", "TensorOperations", "Test", "TestExtras", "SafeTestsets", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote"]
83+
test = ["ArgParse", "Adapt", "Aqua", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote"]
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
module TensorKitCUDAExt
2+
3+
using CUDA, CUDA.CUBLAS, CUDA.CUSOLVER, LinearAlgebra
4+
using CUDA: @allowscalar
5+
using cuTENSOR: cuTENSOR
6+
import CUDA: rand as curand, rand! as curand!, randn as curandn, randn! as curandn!
7+
8+
using TensorKit
9+
using TensorKit.Factorizations
10+
using TensorKit.Strided
11+
using TensorKit.Factorizations: AbstractAlgorithm
12+
using TensorKit: SectorDict, tensormaptype, scalar, similarstoragetype, AdjointTensorMap, scalartype, _project_symmetric_and_check
13+
import TensorKit: randisometry
14+
15+
using TensorKit.MatrixAlgebraKit
16+
17+
using Random
18+
19+
include("cutensormap.jl")
20+
21+
# TODO
22+
# add VectorInterface extensions for proper CUDA promotion
23+
function TensorKit.VectorInterface.promote_add(TA::Type{<:CUDA.StridedCuMatrix{Tx}}, TB::Type{<:CUDA.StridedCuMatrix{Ty}}, α::Tα = TensorKit.VectorInterface.One(), β::Tβ = TensorKit.VectorInterface.One()) where {Tx, Ty, Tα, Tβ}
24+
return Base.promote_op(add, Tx, Ty, Tα, Tβ)
25+
end
26+
27+
end
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
const CuTensorMap{T, S, N₁, N₂} = TensorMap{T, S, N₁, N₂, CuVector{T, CUDA.DeviceMemory}}
2+
const CuTensor{T, S, N} = CuTensorMap{T, S, N, 0}
3+
4+
const AdjointCuTensorMap{T, S, N₁, N₂} = AdjointTensorMap{T, S, N₁, N₂, CuTensorMap{T, S, N₁, N₂}}
5+
6+
function CuTensorMap{T, S, N₁, N₂}(t::TensorMap{T, S, N₁, N₂, A}) where {T, S, N₁, N₂, A}
7+
return CuTensorMap{T, S, N₁, N₂}(CuArray(t.data), t.space)
8+
end
9+
10+
# project_symmetric! doesn't yet work for GPU types, so do this on the host, then copy
11+
function TensorKit._project_symmetric_and_check(::Type{T}, ::Type{A}, data::AbstractArray, V::TensorMapSpace; tol = sqrt(eps(real(float(eltype(data)))))) where {T, A <: CuVector{T}}
12+
h_t = TensorKit.TensorMapWithStorage{T, Vector{T}}(undef, V)
13+
h_t = TensorKit.project_symmetric!(h_t, Array(data))
14+
# verify result
15+
isapprox(Array(reshape(data, dims(h_t))), convert(Array, h_t); atol = tol) ||
16+
throw(ArgumentError("Data has non-zero elements at incompatible positions"))
17+
return TensorKit.TensorMapWithStorage{T, A}(A(h_t.data), V)
18+
end
19+
20+
for (fname, felt) in ((:zeros, :zero), (:ones, :one))
21+
@eval begin
22+
function CUDA.$fname(
23+
codomain::TensorSpace{S},
24+
domain::TensorSpace{S} = one(codomain)
25+
) where {S <: IndexSpace}
26+
return CUDA.$fname(codomain domain)
27+
end
28+
function CUDA.$fname(
29+
::Type{T}, codomain::TensorSpace{S},
30+
domain::TensorSpace{S} = one(codomain)
31+
) where {T, S <: IndexSpace}
32+
return CUDA.$fname(T, codomain domain)
33+
end
34+
CUDA.$fname(V::TensorMapSpace) = CUDA.$fname(Float64, V)
35+
function CUDA.$fname(::Type{T}, V::TensorMapSpace) where {T}
36+
t = CuTensorMap{T}(undef, V)
37+
fill!(t, $felt(T))
38+
return t
39+
end
40+
end
41+
end
42+
43+
for randfun in (:curand, :curandn)
44+
randfun! = Symbol(randfun, :!)
45+
@eval begin
46+
# converting `codomain` and `domain` into `HomSpace`
47+
function $randfun(
48+
codomain::TensorSpace{S},
49+
domain::TensorSpace{S} = one(codomain),
50+
) where {S <: IndexSpace}
51+
return $randfun(codomain domain)
52+
end
53+
function $randfun(
54+
::Type{T}, codomain::TensorSpace{S},
55+
domain::TensorSpace{S} = one(codomain),
56+
) where {T, S <: IndexSpace}
57+
return $randfun(T, codomain domain)
58+
end
59+
function $randfun(
60+
rng::Random.AbstractRNG, ::Type{T},
61+
codomain::TensorSpace{S},
62+
domain::TensorSpace{S} = one(codomain),
63+
) where {T, S <: IndexSpace}
64+
return $randfun(rng, T, codomain domain)
65+
end
66+
67+
# filling in default eltype
68+
$randfun(V::TensorMapSpace) = $randfun(Float64, V)
69+
function $randfun(rng::Random.AbstractRNG, V::TensorMapSpace)
70+
return $randfun(rng, Float64, V)
71+
end
72+
73+
# filling in default rng
74+
function $randfun(::Type{T}, V::TensorMapSpace) where {T}
75+
return $randfun(Random.default_rng(), T, V)
76+
end
77+
78+
# implementation
79+
function $randfun(
80+
rng::Random.AbstractRNG, ::Type{T},
81+
V::TensorMapSpace
82+
) where {T}
83+
t = CuTensorMap{T}(undef, V)
84+
$randfun!(rng, t)
85+
return t
86+
end
87+
end
88+
end
89+
90+
for randfun in (:rand, :randn, :randisometry)
91+
randfun! = Symbol(randfun, :!)
92+
@eval begin
93+
# converting `codomain` and `domain` into `HomSpace`
94+
function $randfun(
95+
::Type{A}, codomain::TensorSpace{S},
96+
domain::TensorSpace{S}
97+
) where {A <: CuArray, S <: IndexSpace}
98+
return $randfun(A, codomain domain)
99+
end
100+
function $randfun(
101+
::Type{T}, ::Type{A}, codomain::TensorSpace{S},
102+
domain::TensorSpace{S}
103+
) where {T, S <: IndexSpace, A <: CuArray{T}}
104+
return $randfun(T, A, codomain domain)
105+
end
106+
function $randfun(
107+
rng::Random.AbstractRNG, ::Type{T}, ::Type{A},
108+
codomain::TensorSpace{S},
109+
domain::TensorSpace{S}
110+
) where {T, S <: IndexSpace, A <: CuArray{T}}
111+
return $randfun(rng, T, A, codomain domain)
112+
end
113+
114+
# accepting single `TensorSpace`
115+
$randfun(::Type{A}, codomain::TensorSpace) where {A <: CuArray} = $randfun(A, codomain one(codomain))
116+
function $randfun(::Type{T}, ::Type{A}, codomain::TensorSpace) where {T, A <: CuArray{T}}
117+
return $randfun(T, A, codomain one(codomain))
118+
end
119+
function $randfun(
120+
rng::Random.AbstractRNG, ::Type{T},
121+
::Type{A}, codomain::TensorSpace
122+
) where {T, A <: CuArray{T}}
123+
return $randfun(rng, T, A, codomain one(domain))
124+
end
125+
126+
# filling in default eltype
127+
$randfun(::Type{A}, V::TensorMapSpace) where {A <: CuArray} = $randfun(eltype(A), A, V)
128+
function $randfun(rng::Random.AbstractRNG, ::Type{A}, V::TensorMapSpace) where {A <: CuArray}
129+
return $randfun(rng, eltype(A), A, V)
130+
end
131+
132+
# filling in default rng
133+
function $randfun(::Type{T}, ::Type{A}, V::TensorMapSpace) where {T, A <: CuArray{T}}
134+
return $randfun(Random.default_rng(), T, A, V)
135+
end
136+
137+
# implementation
138+
function $randfun(
139+
rng::Random.AbstractRNG, ::Type{T},
140+
::Type{A}, V::TensorMapSpace
141+
) where {T, A <: CuArray{T}}
142+
t = CuTensorMap{T}(undef, V)
143+
$randfun!(rng, t)
144+
return t
145+
end
146+
end
147+
end
148+
149+
function Base.convert(::Type{CuTensorMap}, t::AbstractTensorMap)
150+
return copy!(CuTensorMap{scalartype(t)}(undef, space(t)), t)
151+
end
152+
153+
# Scalar implementation
154+
#-----------------------
155+
function TensorKit.scalar(t::CuTensorMap)
156+
# TODO: should scalar only work if N₁ == N₂ == 0?
157+
return @allowscalar dim(codomain(t)) == dim(domain(t)) == 1 ?
158+
first(blocks(t))[2][1, 1] : throw(DimensionMismatch())
159+
end
160+
161+
TensorKit.scalartype(A::StridedCuArray{T}) where {T} = T
162+
TensorKit.scalartype(::Type{<:CuTensorMap{T}}) where {T} = T
163+
TensorKit.scalartype(::Type{<:CuArray{T}}) where {T} = T
164+
165+
function Base.convert(
166+
TT::Type{CuTensorMap{T, S, N₁, N₂}},
167+
t::AbstractTensorMap{<:Any, S, N₁, N₂}
168+
) where {T, S, N₁, N₂}
169+
if typeof(t) === TT
170+
return t
171+
else
172+
tnew = TT(undef, space(t))
173+
return copy!(tnew, t)
174+
end
175+
end
176+
177+
function LinearAlgebra.isposdef(t::CuTensorMap)
178+
domain(t) == codomain(t) ||
179+
throw(SpaceMismatch("`isposdef` requires domain and codomain to be the same"))
180+
InnerProductStyle(spacetype(t)) === EuclideanInnerProduct() || return false
181+
for (c, b) in blocks(t)
182+
# do our own hermitian check
183+
isherm = TensorKit.MatrixAlgebraKit.ishermitian(b; atol = eps(real(eltype(b))), rtol = eps(real(eltype(b))))
184+
isherm || return false
185+
isposdef(Hermitian(b)) || return false
186+
end
187+
return true
188+
end
189+
190+
function Base.promote_rule(
191+
::Type{<:TT₁},
192+
::Type{<:TT₂}
193+
) where {
194+
S, N₁, N₂, TTT₁, TTT₂,
195+
TT₁ <: CuTensorMap{TTT₁, S, N₁, N₂},
196+
TT₂ <: CuTensorMap{TTT₂, S, N₁, N₂},
197+
}
198+
T = TensorKit.VectorInterface.promote_add(TTT₁, TTT₂)
199+
return CuTensorMap{T, S, N₁, N₂}
200+
end
201+
202+
# CuTensorMap exponentation:
203+
function TensorKit.exp!(t::CuTensorMap)
204+
domain(t) == codomain(t) ||
205+
error("Exponential of a tensor only exist when domain == codomain.")
206+
for (c, b) in blocks(t)
207+
copy!(b, parent(Base.exp(Hermitian(b))))
208+
end
209+
return t
210+
end
211+
212+
# functions that don't map ℝ to (a subset of) ℝ
213+
for f in (:sqrt, :log, :asin, :acos, :acosh, :atanh, :acoth)
214+
sf = string(f)
215+
@eval function Base.$f(t::CuTensorMap)
216+
domain(t) == codomain(t) ||
217+
throw(SpaceMismatch("`$($sf)` of a tensor only exist when domain == codomain"))
218+
T = complex(float(scalartype(t)))
219+
tf = similar(t, T)
220+
for (c, b) in blocks(t)
221+
copy!(block(tf, c), parent($f(Hermitian(b))))
222+
end
223+
return tf
224+
end
225+
end

src/tensors/diagonal.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ function LinearAlgebra.mul!(
273273
dC::DiagonalTensorMap, dA::DiagonalTensorMap, dB::DiagonalTensorMap, α::Number, β::Number
274274
)
275275
dC.domain == dA.domain == dB.domain || throw(SpaceMismatch())
276-
mul!(Diagonal(dC.data), Diagonal(dA.data), Diagonal(dB.data), α, β)
276+
@. dC.data =* dA.data * dB.data) + β * dC.data
277277
return dC
278278
end
279279

src/tensors/linalg.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ function _norm(blockiter, p::Real, init::Real)
272272
end
273273
elseif p == 2
274274
= mapreduce(+, blockiter; init = init) do (c, b)
275-
return isempty(b) ? init : oftype(init, dim(c) * LinearAlgebra.norm2(b)^2)
275+
return isempty(b) ? init : oftype(init, dim(c) * LinearAlgebra.norm(b, 2)^2)
276276
end
277277
return sqrt(n²)
278278
elseif p == 1
@@ -281,7 +281,7 @@ function _norm(blockiter, p::Real, init::Real)
281281
end
282282
elseif p > 0
283283
nᵖ = mapreduce(+, blockiter; init = init) do (c, b)
284-
return isempty(b) ? init : oftype(init, dim(c) * LinearAlgebra.normp(b, p)^p)
284+
return isempty(b) ? init : oftype(init, dim(c) * LinearAlgebra.norm(b, p)^p)
285285
end
286286
return (nᵖ)^inv(oftype(nᵖ, p))
287287
else
@@ -299,7 +299,7 @@ function LinearAlgebra.rank(
299299
r = 0 * dim(first(allunits(sectortype(t))))
300300
dim(t) == 0 && return r
301301
S = LinearAlgebra.svdvals(t)
302-
tol = max(atol, rtol * maximum(first, values(S)))
302+
tol = max(atol, rtol * mapreduce(maximum, max, values(S)))
303303
for (c, b) in pairs(S)
304304
if !isempty(b)
305305
r += dim(c) * count(>(tol), b)
@@ -317,8 +317,8 @@ function LinearAlgebra.cond(t::AbstractTensorMap, p::Real = 2)
317317
return zero(real(float(scalartype(t))))
318318
end
319319
S = LinearAlgebra.svdvals(t)
320-
maxS = maximum(first, values(S))
321-
minS = minimum(last, values(S))
320+
maxS = mapreduce(maximum, max, values(S))
321+
minS = mapreduce(minimum, min, values(S))
322322
return iszero(maxS) ? oftype(maxS, Inf) : (maxS / minS)
323323
else
324324
throw(ArgumentError("cond currently only defined for p=2"))

0 commit comments

Comments
 (0)