Skip to content

Commit cee5fed

Browse files
authored
Add graph capture (#895)
- Add graph capture. - Update profiling docs to rocprofv3.
1 parent c02ef8d commit cee5fed

10 files changed

Lines changed: 322 additions & 10 deletions

File tree

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
33
DocumenterVitepress = "4710194d-e776-4893-9690-8d956a29c365"
44
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
5+
LiveServer = "16fef848-5104-11e9-1b77-fb7a48bbb589"
56
SIMD = "fdea26ae-647d-5447-a871-4b548cad5224"
67

78
[compat]

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ function main()
4141
"Devices" => "api/devices.md",
4242
"Streams" => "api/streams.md",
4343
"Kernel Programming" => "api/kernel_programming.md",
44+
"Graphs" => "api/graphs.md",
4445
"Exceptions" => "api/exceptions.md",
4546
"Memory" => "api/memory.md",
4647
"Host-Call" => "api/hostcall.md",

docs/src/api/graphs.md

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Graphs
2+
3+
[Graphs](https://rocm.docs.amd.com/projects/HIP/en/latest/how-to/hip_runtime_api/hipgraph.html)
4+
allow capturing GPU kernels and executing them as one unit, reducing host overhead.
5+
6+
Simple operations can be captured as is:
7+
8+
```@example graph-1
9+
using AMDGPU
10+
11+
f!(o) = o .+= one(eltype(o))
12+
13+
z = AMDGPU.zeros(Int, 4, 4)
14+
graph = AMDGPU.@captured f!(z)
15+
@assert sum(z) == 16
16+
17+
AMDGPU.launch(graph)
18+
@assert sum(z) == 16 * 2
19+
```
20+
21+
However, if your code contains more complex flow, it requires more preparations:
22+
- code **must not** result in hostcall invokation.
23+
- if code contains malloc and respective frees, then it can be captured and relaunched as is.
24+
- if code contains **only** allocations (without freeing), allocations must be cached with `GPUArrays.@cached` beforehand (see example below).
25+
- other unsupported operations (e.g. RNG init) must be done beforehand as well.
26+
- updating graph, does not update allocated pointers, only instantiation is supported in such cases.
27+
28+
```@example graph-2
29+
using AMDGPU, GPUArrays
30+
31+
function f(o)
32+
x = AMDGPU.rand(Float32, size(o))
33+
y = AMDGPU.rand(Float32, size(o))
34+
o .+= sin.(x) * cos.(y) .+ 1f0
35+
return
36+
end
37+
38+
cache = GPUArrays.AllocCache()
39+
z = AMDGPU.zeros(Float32, 256, 256)
40+
N = 10
41+
42+
# Execute function normally and cache all allocations.
43+
GPUArrays.@cached cache f(z)
44+
45+
# Capture graph using AllocCache to avoid capturing malloc/free calls.
46+
graph = GPUArrays.@cached cache AMDGPU.@captured f(z)
47+
48+
# Allocations cache must be kept alive while executing graph.
49+
for i in 1:N
50+
AMDGPU.launch(graph)
51+
end
52+
AMDGPU.synchronize()
53+
```
54+
55+
```@docs
56+
AMDGPU.capture
57+
AMDGPU.@captured
58+
AMDGPU.instantiate
59+
AMDGPU.update
60+
AMDGPU.is_capturing
61+
AMDGPU.launch
62+
```

docs/src/tutorials/profiling.md

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
## rocprof
44

5-
[rocprofv2](https://github.com/ROCm/rocprofiler?tab=readme-ov-file#rocprofiler-v2)
6-
allows profiling both HSA & HIP API calls (rocprof being deprecated).
5+
[rocprofv3](https://rocm.docs.amd.com/projects/rocprofiler-sdk/en/latest/how-to/using-rocprofv3.html)
6+
allows profiling both HSA & HIP API calls.
77

88
Let's profile simple copying kernel saved in `profile.jl` file:
99
```julia
@@ -39,11 +39,10 @@ main(2^24)
3939
### Profiling problematic code
4040

4141
```bash
42-
ENABLE_JITPROFILING=1 rocprofv2 --plugin perfetto --hip-trace --hsa-trace --kernel-trace -o prof julia ./profile.jl
42+
ENABLE_JITPROFILING=1 rocprofv3 --output-directory ./profiling --output-format pftrace --hip-trace --hsa-trace --kernel-trace -- julia ./profile.jl
4343
```
4444

45-
This will produce `prof_output.pftrace` file which can be visualized
46-
using [Perfetto UI](https://ui.perfetto.dev/).
45+
This will produce `.pftrace` file which can be visualized using [Perfetto UI](https://ui.perfetto.dev/).
4746

4847
![image](../assets/profile_1.png)
4948

src/hip/HIP.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module HIP
22
export HIPError, devices, device_synchronize, default_stream
3+
export HIPGraph, HIPGraphExec, @captured, capture, instantiate, update, is_capturing, launch
34

45
using CEnum
56

@@ -90,6 +91,7 @@ include("stream.jl")
9091
include("event.jl")
9192
include("pool.jl")
9293
include("module.jl")
94+
include("graph.jl")
9395

9496
"""
9597
Blocks until all kernels on all streams have completed.

src/hip/graph.jl

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
"""
2+
instantiate(graph::HIPGraph)::HIPGraphExec
3+
4+
Instantiate captured graph making it executable with [`launch`](@ref).
5+
"""
6+
instantiate
7+
8+
"""
9+
capture(f::Function; flags = hipStreamCaptureModeGlobal, throw_error::Bool = true)::Union{Nothing, HIPGraph}
10+
11+
Capture fiven function `f` to a graph.
12+
If successful, returns a captured graph that needs to be [`instantiate`](@ref)'d to obtain executable graph.
13+
"""
14+
capture
15+
16+
function unchecked_hipStreamEndCapture(stream, pGraph)
17+
AMDGPU.prepare_state()
18+
@gcsafe_ccall(libhip.hipStreamEndCapture(stream::hipStream_t, pGraph::Ptr{hipGraph_t})::hipError_t)
19+
end
20+
21+
mutable struct HIPGraph
22+
handle::hipGraph_t
23+
24+
function HIPGraph(flags = hipStreamCaptureModeGlobal)
25+
handle_ref = Ref{hipGraph_t}()
26+
hipGraphCreate(handle_ref, flags)
27+
28+
obj = new(handle_ref[])
29+
finalizer(obj) do obj
30+
hipGraphDestroy(obj)
31+
end
32+
return obj
33+
end
34+
35+
global function capture(f::Function; flags = hipStreamCaptureModeGlobal, throw_error::Bool = true)::Union{Nothing, HIPGraph}
36+
gc_state = GC.enable(false)
37+
stream = AMDGPU.stream()
38+
try
39+
hipStreamBeginCapture(stream, flags)
40+
f()
41+
finally
42+
handle_ref = Ref{hipGraph_t}()
43+
st = unchecked_hipStreamEndCapture(stream, handle_ref)
44+
GC.enable(gc_state)
45+
46+
if st == hipErrorStreamCaptureInvalidated && !throw_error
47+
return nothing
48+
elseif st != hipSuccess
49+
throw(HIPError(st))
50+
end
51+
52+
obj = new(handle_ref[])
53+
finalizer(hipGraphDestroy, obj)
54+
return obj
55+
end
56+
return nothing
57+
end
58+
end
59+
60+
Base.unsafe_convert(::Type{hipGraph_t}, graph::HIPGraph) = graph.handle
61+
62+
mutable struct HIPGraphExec
63+
handle::hipGraphExec_t
64+
65+
global function instantiate(graph::HIPGraph)
66+
handle_ref = Ref{hipGraphExec_t}()
67+
hipGraphInstantiateWithFlags(handle_ref, graph, 0)
68+
obj = new(handle_ref[])
69+
70+
finalizer(obj) do obj
71+
hipGraphExecDestroy(obj)
72+
end
73+
return obj
74+
end
75+
end
76+
77+
Base.unsafe_convert(::Type{hipGraphExec_t}, exec::HIPGraphExec) = exec.handle
78+
79+
"""
80+
launch(exec::HIPGraphExec, stream::HIPStream = AMDGPU.stream())
81+
82+
Launch executable graph on a given stream.
83+
"""
84+
function launch(exec::HIPGraphExec, stream::HIPStream = AMDGPU.stream())
85+
hipGraphLaunch(exec, stream)
86+
end
87+
88+
"""
89+
update(exec::HIPGraphExec, graph::HIPGraph; throw_error::Bool = true)::Bool
90+
91+
Given executable graph, perform update with graph.
92+
Return `true` if successful, `false` otherwise.
93+
94+
If `throw_error=false` allows avoiding throwing an exception if update was not successful.
95+
"""
96+
function update(exec::HIPGraphExec, graph::HIPGraph; throw_error::Bool = true)::Bool
97+
error_node = Ref{hipGraphNode_t}()
98+
update_res_ref = Ref{hipGraphExecUpdateResult}()
99+
hipGraphExecUpdate(exec, graph, error_node, update_res_ref)
100+
101+
update_res = update_res_ref[]
102+
if update_res != hipGraphExecUpdateSuccess
103+
throw_error && error("Failed to update HIPGraphExec: `$(update_res)`.")
104+
return false
105+
end
106+
return true
107+
end
108+
109+
function capture_status(stream::HIPStream)
110+
status_ref = Ref{hipStreamCaptureStatus}()
111+
id_ref = Ref{Culonglong}()
112+
hipStreamGetCaptureInfo(stream, status_ref, id_ref)
113+
status = status_ref[]
114+
return (; status, id=(status == hipStreamCaptureStatusActive) ? id_ref[] : nothing)
115+
end
116+
117+
"""
118+
is_capturing(stream::HIPStream = AMDGPU.stream())::Bool
119+
120+
For a given `stream` check if capturing for a graph is performed.
121+
"""
122+
function is_capturing(stream::HIPStream = AMDGPU.stream())::Bool
123+
capture_status(stream).status == hipStreamCaptureStatusActive
124+
end
125+
126+
"""
127+
graph = AMDGPU.@captured begin
128+
# code to capture in a graph.
129+
end
130+
131+
Macro to capture a given expression in a graph & execute it.
132+
Returns captured graph, that can be relaunched with [`launch`](@ref) or updated with [`update`](@ref).
133+
134+
If capture fails (e.g. due to JIT), attempts recovery, compilation and re-capture.
135+
"""
136+
macro captured(ex)
137+
quote
138+
executed = false
139+
GC.enable(false)
140+
graph = try
141+
capture(; throw_error=false) do
142+
$(esc(ex))
143+
end
144+
finally
145+
GC.enable(true)
146+
end
147+
148+
if graph === nothing
149+
# If the capture failed, this may have been due to JIT compilation.
150+
# execute the body out of capture, and try capturing again.
151+
$(esc(ex))
152+
153+
# Don't tolerate capture failures now so that the user will be informed.
154+
GC.enable(false)
155+
graph = try
156+
capture() do
157+
$(esc(ex))
158+
end
159+
catch
160+
rethrow()
161+
finally
162+
GC.enable(true)
163+
end
164+
executed = true
165+
end
166+
167+
exec = instantiate(graph)
168+
executed || launch(exec)
169+
exec
170+
end
171+
end

src/hip/module.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@ mutable struct HIPModule
22
handle::hipModule_t
33

44
function HIPModule(data)
5-
device_synchronize()
5+
# During stream capture no GPU work is actually executing, so syncing
6+
# would call hipStreamQuery on a capturing stream, which returns
7+
# hipErrorStreamCaptureUnsupported and invalidates the capture.
8+
is_capturing() || device_synchronize()
69

710
mod_ref = Ref{hipModule_t}()
811
hipModuleLoadData(mod_ref, data)

src/memory.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -409,9 +409,10 @@ mutable struct Managed{M}
409409
const mem::M
410410
stream::HIPStream
411411
dirty::Bool
412+
captured::Bool
412413

413-
function Managed(mem; stream=AMDGPU.stream(), dirty=true)
414-
new{typeof(mem)}(mem, stream, dirty)
414+
function Managed(mem; stream=AMDGPU.stream(), dirty=true, captured=false)
415+
new{typeof(mem)}(mem, stream, dirty, captured)
415416
end
416417
end
417418

@@ -472,7 +473,7 @@ function pool_alloc(::Type{B}, bytesize) where B
472473
maybe_collect()
473474
time = Base.@elapsed begin
474475
s = AMDGPU.stream()
475-
managed = Managed(B(bytesize; stream=s); stream=s)
476+
managed = Managed(B(bytesize; stream=s); stream=s, captured=AMDGPU.is_capturing())
476477
end
477478

478479
Base.@atomic alloc_stats.alloc_count += 1

test/core/graph_tests.jl

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
using Test
2+
using AMDGPU
3+
using GPUArrays
4+
5+
@testset "HIP Graphs" begin
6+
@testset "+1" begin
7+
f!(o) = o .+= one(eltype(o))
8+
9+
z = AMDGPU.zeros(Int, 4, 4)
10+
graph = AMDGPU.@captured f!(z)
11+
@test sum(z) == 16
12+
13+
AMDGPU.launch(graph)
14+
@test sum(z) == 16 * 2
15+
AMDGPU.launch(graph)
16+
@test sum(z) == 16 * 3
17+
end
18+
19+
@testset "malloc/free" begin
20+
function f!(o)
21+
x = AMDGPU.ones(eltype(o), size(o))
22+
o .+= x .+ one(eltype(o))
23+
AMDGPU.unsafe_free!(x)
24+
end
25+
26+
z = AMDGPU.zeros(Int, 4, 4)
27+
graph = AMDGPU.@captured f!(z)
28+
@test sum(z) == 32
29+
30+
AMDGPU.launch(graph)
31+
@test sum(z) == 32 * 2
32+
AMDGPU.launch(graph)
33+
@test sum(z) == 32 * 3
34+
end
35+
36+
@testset "only malloc + alloc cache" begin
37+
function f!(o)
38+
x = AMDGPU.ones(eltype(o), size(o))
39+
y = AMDGPU.ones(eltype(o), size(o))
40+
o .+= (x * y) .+ one(eltype(o))
41+
end
42+
43+
z = AMDGPU.zeros(Int, 4, 4)
44+
cache = GPUArrays.AllocCache()
45+
# Pre-populate alloc cache, to avoid malloc calls during capture.
46+
GPUArrays.@cached cache f!(z)
47+
# Capture with alloc cache.
48+
graph = GPUArrays.@cached cache AMDGPU.@captured f!(z)
49+
@test sum(z) == length(z) * 5 * 2
50+
51+
AMDGPU.launch(graph)
52+
@test sum(z) == length(z) * 5 * 3
53+
AMDGPU.launch(graph)
54+
@test sum(z) == length(z) * 5 * 4
55+
end
56+
57+
@testset "Update graph" begin
58+
f1!(o) = o .+= one(eltype(o))
59+
f2!(o) = o .+= eltype(o)(2)
60+
61+
z = AMDGPU.zeros(Int, 4, 4)
62+
graph = AMDGPU.@captured f1!(z)
63+
@test sum(z) == 16
64+
65+
g_new = AMDGPU.capture() do
66+
f2!(z)
67+
end
68+
@test AMDGPU.update(graph, g_new)
69+
AMDGPU.launch(graph)
70+
@test sum(z) == 16 * 3
71+
end
72+
end

0 commit comments

Comments
 (0)