Skip to content

Commit 2f150d3

Browse files
authored
[Refactor] Support different CUDA versions in one single cuda_compat.h (#6770)
* Support different CUDA versions in one single cuda_compat.h * Remove useless nvtx header
1 parent be03870 commit 2f150d3

File tree

3 files changed

+42
-18
lines changed

3 files changed

+42
-18
lines changed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/**
2+
* @file cuda_compat.h
3+
* @brief Compatibility layer for CUDA and NVTX headers across different CUDA Toolkit versions.
4+
*
5+
* This header abstracts the differences in NVTX (NVIDIA Tools Extension) header locations
6+
* between CUDA Toolkit versions.
7+
*
8+
* @note Depends on the CUDA_VERSION macro defined in <cuda.h>.
9+
*
10+
*/
11+
12+
#ifndef CUDA_COMPAT_H_
13+
#define CUDA_COMPAT_H_
14+
15+
#include <cuda.h> // defines CUDA_VERSION
16+
17+
// NVTX header for CUDA versions prior to 12.9 vs. 12.9+
18+
// This block ensures the correct NVTX header path is used based on CUDA_VERSION.
19+
// - For CUDA Toolkit < 12.9, the legacy header "nvToolsExt.h" is included.
20+
// - For CUDA Toolkit >= 12.9, the modern header "nvtx3/nvToolsExt.h" is included,
21+
// and NVTX v2 is removed from 12.9.
22+
// This allows NVTX profiling APIs (e.g. nvtxRangePush) to be used consistently
23+
// across different CUDA versions.
24+
// See:
25+
// https://docs.nvidia.com/cuda/archive/12.9.0/cuda-toolkit-release-notes/index.html#id4
26+
#if defined(__CUDA) && defined(__USE_NVTX)
27+
#if CUDA_VERSION < 12090
28+
#include "nvToolsExt.h"
29+
#else
30+
#include "nvtx3/nvToolsExt.h"
31+
#endif
32+
#endif
33+
34+
#endif // CUDA_COMPAT_H_

source/source_base/timer.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,7 @@
1515
#include "source_base/formatter.h"
1616

1717
#if defined(__CUDA) && defined(__USE_NVTX)
18-
#if CUDA_VERSION < 12090
19-
#include "nvToolsExt.h"
20-
#else
21-
#include "nvtx3/nvToolsExt.h"
22-
#endif
18+
#include "source_base/module_device/cuda_compat.h"
2319
#include "source_io/module_parameter/parameter.h"
2420
#endif
2521

source/source_hsolver/kernels/cuda/diag_cusolver.cuh

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,6 @@
33
#include <cuda.h>
44
#include <complex>
55

6-
#if CUDA_VERSION < 12090
7-
#include "nvToolsExt.h"
8-
#else
9-
#include "nvtx3/nvToolsExt.h"
10-
#endif
11-
126
#include <cuda_runtime.h>
137
#include <cusolverDn.h>
148

@@ -39,7 +33,7 @@ class Diag_Cusolver_gvd{
3933
double *d_A = nullptr;
4034
double *d_B = nullptr;
4135
double *d_work = nullptr;
42-
36+
4337
cuDoubleComplex *d_A2 = nullptr;
4438
cuDoubleComplex *d_B2 = nullptr;
4539
cuDoubleComplex *d_work2 = nullptr;
@@ -54,7 +48,7 @@ class Diag_Cusolver_gvd{
5448
// - init_double : initializing relevant double type data structures and gpu apis' handle and memory
5549
// - init_complex : initializing relevant complex type data structures and gpu apis' handle and memory
5650
// Input Parameters
57-
// N: the dimension of the matrix
51+
// N: the dimension of the matrix
5852
void init_double(int N);
5953
void init_complex(int N);
6054

@@ -70,17 +64,17 @@ public:
7064
// - Dngvd_double : dense double type matrix
7165
// - Dngvd_complex : dense complex type matrix
7266
// Input Parameters
73-
// N: the number of rows of the matrix
74-
// M: the number of cols of the matrix
75-
// A: the hermitian matrix A in A x=lambda B (column major)
76-
// B: the SPD matrix B in A x=lambda B (column major)
67+
// N: the number of rows of the matrix
68+
// M: the number of cols of the matrix
69+
// A: the hermitian matrix A in A x=lambda B (column major)
70+
// B: the SPD matrix B in A x=lambda B (column major)
7771
// Output Parameter
7872
// W: generalized eigenvalues
7973
// V: generalized eigenvectors (column major)
8074

8175
void Dngvd_double(int N, int M, double *A, double *B, double *W, double *V);
8276
void Dngvd_complex(int N, int M, std::complex<double> *A, std::complex<double> *B, double *W, std::complex<double> *V);
83-
77+
8478
void Dngvd(int N, int M, double *A, double *B, double *W, double *V)
8579
{
8680
return Dngvd_double(N, M, A, B, W, V);

0 commit comments

Comments
 (0)