Skip to content

Commit 9727de3

Browse files
committed
Revert "test: remove prefill-skip to match Phase 6 capture behavior"
This reverts commit 9c11a3b.
1 parent 9c11a3b commit 9727de3

1 file changed

Lines changed: 11 additions & 0 deletions

File tree

graph/cuda_graph.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,18 @@ func (g *CUDAGraphExecutor[T]) Run(ctx context.Context, inputs ...*tensor.Tensor
219219
}
220220

221221
// Phase 2: Capture on first post-warmup call.
222+
// Only capture during decode (seqLen=1). Prefill inputs have larger
223+
// sequence lengths and take different code paths inside composite nodes
224+
// (e.g. GQA uses SDPA instead of FlashAttentionDecode), which may
225+
// trigger allocations incompatible with CUDA stream capture.
222226
if g.graphExec == nil {
227+
if len(inputs) > 0 && inputs[0] != nil {
228+
shape := inputs[0].Shape()
229+
if len(shape) >= 2 && shape[len(shape)-1] > 1 {
230+
// Prefill input (seqLen > 1): skip capture, run normally.
231+
return g.plan.RunInstructions(ctx, inputs...)
232+
}
233+
}
223234
return g.captureAndRun(ctx, inputs...)
224235
}
225236

0 commit comments

Comments
 (0)