Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 43 additions & 9 deletions cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ struct Causal_conv1d_fwd_kernel_traits
static_assert(kWidth <= kNElts);
static constexpr bool kIsVecLoad = kIsVecLoad_;
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
static_assert(kNThreads_ % 32 == 0, "kNThreads must be a multiple of 32 for warp shuffle");
static_assert(sizeof(vec_t) == 16, "vec_t must be 16 bytes for warp shuffle optimization");
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
Expand Down Expand Up @@ -123,7 +125,7 @@ __global__ __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_fwd_kernel(C
#pragma unroll
for (int i = 0; i < kWidth; ++i)
{
weight_vals[i] = float(weight[i * params.weight_width_stride]);
weight_vals[i] = float(__ldg(&weight[i * params.weight_width_stride]));
}

constexpr int kChunkSize = kNThreads * kNElts;
Expand All @@ -144,20 +146,41 @@ __global__ __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_fwd_kernel(C
x, *reinterpret_cast<input_t(*)[kNElts]>(&x_vals_load[kNElts]), seqlen - chunk * kChunkSize);
}
x += kChunkSize;

int const lane_id = tidx & 31;
vec_t high_val = reinterpret_cast<vec_t*>(x_vals_load)[1];

__syncthreads();
// Thread kNThreads - 1 don't write yet, so that thread 0 can read
// the last elements of the previous chunk.
if (tidx < kNThreads - 1)
{
smem_exchange[tidx] = reinterpret_cast<vec_t*>(x_vals_load)[1];
smem_exchange[tidx] = high_val;
}
__syncthreads();
reinterpret_cast<vec_t*>(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1];

// Get neighbor data: use warp shuffle for most threads, shared memory for warp boundaries
vec_t neighbor;
uint32_t* high_val_p = reinterpret_cast<uint32_t*>(&high_val);
uint32_t* nbr_p = reinterpret_cast<uint32_t*>(&neighbor);
nbr_p[0] = __shfl_up_sync(0xFFFFFFFF, high_val_p[0], 1);
nbr_p[1] = __shfl_up_sync(0xFFFFFFFF, high_val_p[1], 1);
nbr_p[2] = __shfl_up_sync(0xFFFFFFFF, high_val_p[2], 1);
nbr_p[3] = __shfl_up_sync(0xFFFFFFFF, high_val_p[3], 1);

// Lane 0 must use shared memory to handle the cross-warp boundary.
// thread 0 uses the last element of the previous chunk.
if (lane_id == 0)
{
neighbor = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1];
}
reinterpret_cast<vec_t*>(x_vals_load)[0] = neighbor;

__syncthreads();
// Now thread kNThreads - 1 can write the last elements of the current chunk.
if (tidx == kNThreads - 1)
{
smem_exchange[tidx] = reinterpret_cast<vec_t*>(x_vals_load)[1];
smem_exchange[tidx] = high_val;
}

float x_vals[2 * kNElts];
Expand All @@ -169,22 +192,33 @@ __global__ __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_fwd_kernel(C

float out_vals[kNElts];
#pragma unroll
for (int i = 0; i < kNElts; ++i)
// Process 2 outputs at a time for better ILP (instruction level parallelism).
for (int i = 0; i < kNElts; i += 2)
{
out_vals[i] = bias_val;
float acc0 = bias_val;
float acc1 = bias_val;
#pragma unroll
for (int w = 0; w < kWidth; ++w)
{
out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
float wt = weight_vals[w];
acc0 = __fmaf_rn(wt, x_vals[kNElts + i - (kWidth - w - 1)], acc0);
acc1 = __fmaf_rn(wt, x_vals[kNElts + i + 1 - (kWidth - w - 1)], acc1);
}
out_vals[i] = acc0;
out_vals[i + 1] = acc1;
}

if (params.silu_activation)
{
#pragma unroll
for (int i = 0; i < kNElts; ++i)
for (int i = 0; i < kNElts; i += 2)
{
out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i]));
// SiLU: x * sigmoid(x) = x / (1 + exp(-x))
// Using fast math: __expf and __frcp_rn
float v0 = out_vals[i];
float v1 = out_vals[i + 1];
out_vals[i] = v0 * __frcp_rn(1.0f + __expf(-v0));
out_vals[i + 1] = v1 * __frcp_rn(1.0f + __expf(-v1));
}
}

Expand Down
187 changes: 187 additions & 0 deletions cpp/tensorrt_llm/kernels/fusedActivationQuant.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
/*
* Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/fusedActivationQuant.h"
#include "tensorrt_llm/kernels/quantization.cuh"
#include "tensorrt_llm/kernels/quantization.h"

#include <cstdint>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>

TRTLLM_NAMESPACE_BEGIN

namespace kernels
{

constexpr int kEltsPerThread = 8;

__device__ __forceinline__ float relu2_f32(float x)
{
float r = fmaxf(0.0f, x);
return r * r;
}

// Fused relu2 + NVFP4 quantization kernel.
//
// To match the unfused path (PyTorch relu2 -> cvt_warp_fp16_to_fp4), relu2 is
// computed in f32 then rounded back to native precision (bf16/fp16) before
// quantization. Absmax and scale-factor math follow cvt_warp_fp16_to_fp4 exactly.
// Column padding to a multiple of (4 * kSfVecSize) matches quantize_with_block_size
// for the swizzled SF layout.
template <typename T>
__global__ void fusedRelu2QuantizeKernel(T const* __restrict__ input, float const* __restrict__ sfScale,
uint32_t* __restrict__ outputFp4, uint32_t* __restrict__ outputSf, int m, int n)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr int kSfVecSize = 16;
constexpr int kNumThreadsPerSf = kSfVecSize / kEltsPerThread;
constexpr int kPackedPerThread = kEltsPerThread / 2;

using PackedType = std::conditional_t<std::is_same_v<T, half>, __half2, __nv_bfloat162>;

float const SFScaleVal = sfScale[0];
int const numColThreads = n / kEltsPerThread;
int const numColVecs = n / kSfVecSize;
int const numColThreadsPadded = ((n + 4 * kSfVecSize - 1) / (4 * kSfVecSize)) * (4 * kSfVecSize) / kEltsPerThread;
int const rowIdx = blockIdx.x;

if (rowIdx >= m)
return;

for (int colIdx = threadIdx.x; colIdx < numColThreadsPadded; colIdx += blockDim.x)
{
bool const isValidCol = colIdx < numColThreads;
PackedType packedVals[kPackedPerThread];

if (isValidCol)
{
int const inputOffset = rowIdx * n + colIdx * kEltsPerThread;
#pragma unroll
for (int i = 0; i < kPackedPerThread; i++)
{
float f0 = relu2_f32(static_cast<float>(input[inputOffset + i * 2]));
float f1 = relu2_f32(static_cast<float>(input[inputOffset + i * 2 + 1]));
if constexpr (std::is_same_v<T, half>)
{
packedVals[i] = __floats2half2_rn(f0, f1);
}
else
{
packedVals[i] = __floats2bfloat162_rn(f0, f1);
}
}
}
else
{
#pragma unroll
for (int i = 0; i < kPackedPerThread; i++)
{
if constexpr (std::is_same_v<T, half>)
{
packedVals[i] = __float2half2_rn(0.0f);
}
else
{
packedVals[i] = __float2bfloat162_rn(0.0f);
}
}
}

// Absmax in native precision, then reduce across the SF group (2 threads).
auto localMax = cuda_abs(packedVals[0]);
#pragma unroll
for (int i = 1; i < kPackedPerThread; i++)
{
localMax = cuda_max(localMax, cuda_abs(packedVals[i]));
}
localMax = cuda_max(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
float vecMax = float(cuda_max(localMax.x, localMax.y));

// Scale-factor computation (identical to cvt_warp_fp16_to_fp4).
float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f));
__nv_fp8_e4m3 fp8SF = __nv_fp8_e4m3(SFValue);
uint8_t fp8SFVal = fp8SF.__x;
SFValue = static_cast<float>(fp8SF);

float outputScale
= vecMax != 0.0f ? reciprocal_approximate_ftz(SFValue * reciprocal_approximate_ftz(SFScaleVal)) : 0.0f;

if (colIdx % kNumThreadsPerSf == 0)
{
auto sfOutPtr = cvt_quant_get_sf_out_offset<uint32_t, kNumThreadsPerSf>(std::nullopt, rowIdx, colIdx,
std::optional<int>(m), numColVecs, outputSf, QuantizationSFLayout::SWIZZLED);
if (sfOutPtr != nullptr)
{
*sfOutPtr = fp8SFVal;
}
}

if (isValidCol)
{
float2 fp2Vals[kPackedPerThread];
#pragma unroll
for (int i = 0; i < kPackedPerThread; i++)
{
if constexpr (std::is_same_v<T, half>)
{
fp2Vals[i] = __half22float2(packedVals[i]);
}
else
{
fp2Vals[i] = __bfloat1622float2(packedVals[i]);
}
fp2Vals[i].x *= outputScale;
fp2Vals[i].y *= outputScale;
}

outputFp4[rowIdx * numColThreads + colIdx] = fp32_vec_to_e2m1(fp2Vals);
}
}
#else
if (threadIdx.x == 0 && blockIdx.x == 0)
{
printf("FP4 quantization requires SM100 (Blackwell) or later!\n");
}
#endif
}

template <typename T>
void invokeFusedRelu2Quantize(T const* input, float const* sfScale, std::uint8_t* outputFp4, std::uint8_t* outputSf,
int m, int n, int sfVecSize, cudaStream_t stream)
{
constexpr int kSfVecSize = 16;
int const numColThreadsPadded = ((n + 4 * kSfVecSize - 1) / (4 * kSfVecSize)) * (4 * kSfVecSize) / kEltsPerThread;
int threadsPerBlock = min(512, numColThreadsPadded);
threadsPerBlock = max(32, ((threadsPerBlock + 31) / 32) * 32);

fusedRelu2QuantizeKernel<T><<<m, threadsPerBlock, 0, stream>>>(
input, sfScale, reinterpret_cast<uint32_t*>(outputFp4), reinterpret_cast<uint32_t*>(outputSf), m, n);
}

template void invokeFusedRelu2Quantize<half>(
half const*, float const*, std::uint8_t*, std::uint8_t*, int, int, int, cudaStream_t);

#ifdef ENABLE_BF16
template void invokeFusedRelu2Quantize<__nv_bfloat16>(
__nv_bfloat16 const*, float const*, std::uint8_t*, std::uint8_t*, int, int, int, cudaStream_t);
#endif

} // namespace kernels

TRTLLM_NAMESPACE_END
33 changes: 33 additions & 0 deletions cpp/tensorrt_llm/kernels/fusedActivationQuant.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "tensorrt_llm/common/config.h"
#include <cstdint>
#include <cuda_runtime.h>

TRTLLM_NAMESPACE_BEGIN

namespace kernels
{

template <typename T>
void invokeFusedRelu2Quantize(T const* input, float const* sfScale, std::uint8_t* outputFp4, std::uint8_t* outputSf,
int m, int n, int sfVecSize, cudaStream_t stream);

} // namespace kernels

TRTLLM_NAMESPACE_END
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ struct GeneralFP4AddBiasResidualPreLayerNormParam
T const* bias = nullptr;
T const* gamma = nullptr;
T const* beta = nullptr;
T* high_precision_normed_output = nullptr;

int m;
int n;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ struct LowLatencyLayerNorm
}

typename PackType<typename Traits::OutputType, Traits::PACKED_ELEMS_PER_COMPUTE>::type normed_output;
typename PackType<typename Traits::AccumulatorType, Traits::PACKED_ELEMS_PER_COMPUTE>::type
typename PackType<typename Traits::InputType, Traits::PACKED_ELEMS_PER_COMPUTE>::type
high_precision_normed_output;
for (int j = 0; j < Traits::PACKED_ELEMS_PER_COMPUTE; j++)
{
Expand All @@ -300,7 +300,7 @@ struct LowLatencyLayerNorm
}
if constexpr (Traits::HIGH_PRECISION_NORMED_OUTPUT)
{
high_precision_normed_output.array[j] = normed_out;
high_precision_normed_output.array[j] = (typename Traits::InputType) normed_out;
}
if constexpr (Traits::OUTPUT_SCALE == SCALE_TYPE::SCALAR)
{
Expand Down
12 changes: 6 additions & 6 deletions cpp/tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ struct WarpSpecializedLayerNorm
typename PackType<typename Traits::OutputType, Traits::PACKED_ELEMS_PER_COMPUTE>::type
normed_output;
typename PackType<typename Traits::InputType, Traits::PACKED_ELEMS_PER_COMPUTE>::type output;
typename PackType<typename Traits::AccumulatorType, Traits::PACKED_ELEMS_PER_COMPUTE>::type
typename PackType<typename Traits::InputType, Traits::PACKED_ELEMS_PER_COMPUTE>::type
high_precision_normed_output;

#pragma unroll Traits::PACKED_ELEMS_PER_COMPUTE
Expand Down Expand Up @@ -719,6 +719,11 @@ struct WarpSpecializedLayerNorm
normed_out += beta[j];
}

if constexpr (Traits::HIGH_PRECISION_NORMED_OUTPUT)
{
high_precision_normed_output.array[j] = (typename Traits::InputType) normed_out;
}

if constexpr (Traits::OUTPUT_SCALE != SCALE_TYPE::NONE)
{
static_assert(Traits::OUTPUT_SCALE == SCALE_TYPE::SCALAR);
Expand All @@ -730,11 +735,6 @@ struct WarpSpecializedLayerNorm
output.array[j] = (typename Traits::InputType) data[m_offset][i][j];
}

if constexpr (Traits::HIGH_PRECISION_NORMED_OUTPUT)
{
high_precision_normed_output.array[j] = normed_out;
}

normed_output.array[j] = (typename Traits::OutputType) normed_out;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ enum class SCALE_TYPE
};

template <typename T>
void invokeWSLayerNorm(WarpSpecializedParam<T> param, bool use_rms_norm, int ctas);
void invokeWSLayerNorm(WarpSpecializedParam<T> param, bool use_rms_norm, int ctas, bool output_hp_norm = false);

} // namespace kernels

Expand Down
Loading