Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,11 @@ if(ARCH STREQUAL "x86_64")
if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL "12.8")
list(APPEND CMAKE_CUDA_ARCHITECTURES 120a-real) # 5090
endif ()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL "12.8")
list(APPEND CMAKE_CUDA_ARCHITECTURES 100a-real) # B200
endif()
if (MSVC)
list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES 80-real 90a-real)
list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES 80-real 90a-real 100a-real)
endif ()
Comment thread
windreamer marked this conversation as resolved.
endif ()
elseif(ARCH STREQUAL "aarch64")
Expand Down
8 changes: 8 additions & 0 deletions src/turbomind/core/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ const auto& GetCopyAPI()
void* fpn{};
TM_CHECK_EQ(cudaGetDriverEntryPoint(symbol, &fpn, cudaEnableDefault, &status), 0);
if (fpn && status == cudaDriverEntryPointSuccess) {
// cuMemcpyBatchAsync crashes on sm_100 (Blackwell); force monostate -> serialized path.
int device = 0;
(void)cudaGetDevice(&device);
int major = 0;
(void)cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device);
if (major >= 10) {
return {};
}
return (PFN_cuMemcpyBatchAsync_v12080)fpn;
}
else {
Expand Down
87 changes: 71 additions & 16 deletions src/turbomind/kernels/gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,45 @@
# Copyright (c) OpenMMLab. All rights reserved.

set(GEMM2_KERNELS_SM70
kernel/sm70_884_4.cu
kernel/sm70_884_8.cu
kernel/sm70_884_16.cu
)
set(GEMM2_KERNELS_SM75
kernel/sm75_16816_4.cu
kernel/sm75_16816_8.cu
kernel/sm75_16816_16.cu
)
set(GEMM2_KERNELS_SM80
kernel/sm80_16816_4.cu
kernel/sm80_16816_8.cu
kernel/sm80_16816_16.cu
)
set(GEMM2_KERNELS_SM90
tma.cu
kernel/sm90_16816_4.cu
kernel/sm90_16816_8.cu
kernel/sm90_16816_16.cu
kernel/sm90_64n32_8.cu
)

set(GEMM2_ARCH_90_ENABLED FALSE)
set(_sm90_archs "${CMAKE_CUDA_ARCHITECTURES}")
list(FILTER _sm90_archs INCLUDE REGEX "^90")
if(_sm90_archs)
set(GEMM2_ARCH_90_ENABLED TRUE)
else()
# When building for SM100+ without explicit SM90, still compile SM90 CUTLASS
# kernels so the fat binary can run MoE models on H100 (CUTLASS fused path).
set(_sm100_archs "${CMAKE_CUDA_ARCHITECTURES}")
list(FILTER _sm100_archs INCLUDE REGEX "^100")
if(_sm100_archs)
set(GEMM2_ARCH_90_ENABLED TRUE)
set(_sm90_archs "90")
message(STATUS "GEMM: auto-enabling SM90 CUTLASS kernels for H100 backward compatibility")
endif()
endif()

add_library(gemm2
gemm.cu
kernel.cu
Expand All @@ -10,34 +50,30 @@ add_library(gemm2
cast.cu
unpack.cu
context.cu
tma.cu
tuner/cache_utils.cu
tuner/measurer.cu
tuner/sampler.cu
tuner/stopping_criterion.cc
tuner/params.cc
kernel/sm90_16816_4.cu
kernel/sm90_16816_8.cu
kernel/sm90_16816_16.cu
kernel/sm80_16816_4.cu
kernel/sm80_16816_8.cu
kernel/sm80_16816_16.cu
kernel/sm75_16816_4.cu
kernel/sm75_16816_8.cu
kernel/sm75_16816_16.cu
kernel/sm70_884_4.cu
kernel/sm70_884_8.cu
kernel/sm70_884_16.cu
kernel/sm90_64n32_8.cu
${GEMM2_KERNELS_SM70}
${GEMM2_KERNELS_SM75}
${GEMM2_KERNELS_SM80}
cublas.cu
moe_utils_v2.cu
test/test_utils.cu
)

target_link_libraries(gemm2 PRIVATE parser nvidia::cutlass::cutlass CUDA::cuda_driver)


target_compile_definitions(gemm2 PRIVATE -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
# cublasGemmGroupedBatchedEx (CUDA 12.5+): grouped batched GEMM for MoE on SM100
set(_has_sm100 FALSE)
set(_archs_100 "${CMAKE_CUDA_ARCHITECTURES}")
list(FILTER _archs_100 INCLUDE REGEX "^100")
if(_archs_100 AND CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "12.5")
set(_has_sm100 TRUE)
Comment on lines +69 to +73
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_has_sm100 is set but never used. If it’s not needed, remove it; if it is meant to drive later logic, wire it up so the intent is clear (unused variables in CMake can hide configuration bugs).

Suggested change
set(_has_sm100 FALSE)
set(_archs_100 "${CMAKE_CUDA_ARCHITECTURES}")
list(FILTER _archs_100 INCLUDE REGEX "^100")
if(_archs_100 AND CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "12.5")
set(_has_sm100 TRUE)
set(_archs_100 "${CMAKE_CUDA_ARCHITECTURES}")
list(FILTER _archs_100 INCLUDE REGEX "^100")
if(_archs_100 AND CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "12.5")

Copilot uses AI. Check for mistakes.
target_compile_definitions(gemm2 PRIVATE ENABLE_CUBLAS_GROUPED=1)
message(STATUS "GEMM: ENABLE_CUBLAS_GROUPED=1 (cublasGemmGroupedBatchedEx for MoE on SM100)")
endif()

target_compile_options(gemm2 PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:
Expand All @@ -48,7 +84,26 @@ target_compile_options(gemm2 PRIVATE
set_property(TARGET gemm2 PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET gemm2 PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)

if(GEMM2_ARCH_90_ENABLED)
# SM90 kernels only compile for 90/90a; avoid building them for sm_100.
add_library(gemm2_sm90 STATIC ${GEMM2_KERNELS_SM90})
set_target_properties(gemm2_sm90 PROPERTIES
CUDA_ARCHITECTURES "${_sm90_archs}"
POSITION_INDEPENDENT_CODE ON
CUDA_RESOLVE_DEVICE_SYMBOLS ON
)
target_compile_definitions(gemm2_sm90 PRIVATE -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
target_compile_options(gemm2_sm90 PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:
-Xptxas=-v
--generate-line-info
--threads 16>
)
target_link_libraries(gemm2_sm90 PRIVATE parser nvidia::cutlass::cutlass CUDA::cuda_driver)
target_link_libraries(gemm2 PRIVATE gemm2_sm90)

target_compile_definitions(gemm2 PRIVATE GEMM2_ARCH_90_ENABLED)
endif()

if (BUILD_TEST)
add_executable(test_gemm_v2
Expand Down
18 changes: 16 additions & 2 deletions src/turbomind/kernels/gemm/arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,20 @@ struct Sm80: Arch<800, 900> {
static constexpr int value = 800;
};

struct Sm90: Arch<900> {
struct Sm90: Arch<900, 1000> {
static constexpr int value = 900;
};
Comment thread
windreamer marked this conversation as resolved.

// B200 (Blackwell) SM 100
struct Sm100: Arch<1000, 1200> {
static constexpr int value = 1000;
};

// SM12.x (e.g. sm_120): use same CUTLASS SM90 kernel family as pre-PR Sm90+ range
struct Sm120: Arch<1200, 1300> {
static constexpr int value = 1200;
};

inline bool is_arch_compatible(int karch, int darch)
{
switch (karch) {
Expand All @@ -42,7 +52,11 @@ inline bool is_arch_compatible(int karch, int darch)
case 800:
return Sm80::is_compatible(darch);
case 900:
return Sm90::is_compatible(darch);
return Sm90::is_compatible(darch) || Sm120::is_compatible(darch);
case 1000:
return Sm100::is_compatible(darch);
case 1200:
return Sm120::is_compatible(darch);
default:
return false;
}
Expand Down
3 changes: 3 additions & 0 deletions src/turbomind/kernels/gemm/convert_v3.cu
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ std::array<const LayoutConverter*, 2> GetConverters(DataType data_type,
if (weight_type == kHalf || weight_type == kBfloat16) {
constexpr Cvt<uint16_t, uint16_t> W;
if (grouped) {
// SM10.x only: CublasGroupedKernel (cublasGemmGroupedBatchedEx) expects standard (K,N)
if (sm >= 100 && sm < 120)
return {};
// clang-format off
if (sm >= 80) return {W(sm8_, kRow, s16816h | B | _1), {}};
if (sm == 75) return {W(sm75, kRow, s16816h | B | _1), {}};
Expand Down
Loading
Loading