From d0183986dfbd58d3119712394814dcd2b02f0529 Mon Sep 17 00:00:00 2001 From: wooway777 Date: Wed, 10 Jun 2026 16:04:45 +0800 Subject: [PATCH] issue/1230 - cambricon add rms norm --- src/infinicore/nn/rmsnorm.cc | 3 +- .../ops/add_rms_norm/bang/add_rms_norm_bang.h | 8 + .../add_rms_norm/bang/add_rms_norm_bang.mlu | 480 ++++++++++++++++++ src/infiniop/ops/add_rms_norm/operator.cc | 15 +- 4 files changed, 503 insertions(+), 3 deletions(-) create mode 100644 src/infiniop/ops/add_rms_norm/bang/add_rms_norm_bang.h create mode 100644 src/infiniop/ops/add_rms_norm/bang/add_rms_norm_bang.mlu diff --git a/src/infinicore/nn/rmsnorm.cc b/src/infinicore/nn/rmsnorm.cc index bc703300f..24d090049 100644 --- a/src/infinicore/nn/rmsnorm.cc +++ b/src/infinicore/nn/rmsnorm.cc @@ -31,7 +31,8 @@ void RMSNorm::forward_inplace(Tensor &x, Tensor &residual) const { || device_.getType() == Device::Type::ILUVATAR || device_.getType() == Device::Type::METAX || device_.getType() == Device::Type::MOORE - || device_.getType() == Device::Type::ALI) { + || device_.getType() == Device::Type::ALI + || device_.getType() == Device::Type::CAMBRICON) { op::add_rms_norm_inplace(x, residual, weight_, static_cast(eps_)); } else { op::add_(residual, x, residual); diff --git a/src/infiniop/ops/add_rms_norm/bang/add_rms_norm_bang.h b/src/infiniop/ops/add_rms_norm/bang/add_rms_norm_bang.h new file mode 100644 index 000000000..0fdc89718 --- /dev/null +++ b/src/infiniop/ops/add_rms_norm/bang/add_rms_norm_bang.h @@ -0,0 +1,8 @@ +#ifndef __ADD_RMS_NORM_BANG_H__ +#define __ADD_RMS_NORM_BANG_H__ + +#include "../add_rms_norm.h" + +DESCRIPTOR(bang) + +#endif // __ADD_RMS_NORM_BANG_H__ diff --git a/src/infiniop/ops/add_rms_norm/bang/add_rms_norm_bang.mlu b/src/infiniop/ops/add_rms_norm/bang/add_rms_norm_bang.mlu new file mode 100644 index 000000000..aa36a29bf --- /dev/null +++ b/src/infiniop/ops/add_rms_norm/bang/add_rms_norm_bang.mlu @@ -0,0 +1,480 @@ +#include "../../../devices/bang/common_bang.h" +#include "../../../reduce/bang/reduce_bang.h" +#include "add_rms_norm_bang.h" + +#include +#include + +__nram__ char nram_buffer[NRAM_MAX_SIZE]; + +static size_t getPartsPerRow(size_t rows, size_t dim, int task_dim) { + if (rows == 0 || task_dim <= 0) { + return 1; + } + size_t max_parts_by_tasks = std::max(1, static_cast(task_dim) / rows); + size_t parts_by_dim = std::max(1, (dim + 1023) / 1024); + return std::max(1, std::min(max_parts_by_tasks, parts_by_dim)); +} + +template +__mlu_device__ void loadToFloat(float *dst, T *cache, const T *src, size_t n) { + __memcpy(cache, src, n * sizeof(T), GDRAM2NRAM); + if constexpr (std::is_same::value) { + __bang_half2float(dst, cache, n); + } else if constexpr (std::is_same::value) { + __bang_bfloat162float(dst, cache, n); + } else { + __memcpy(dst, cache, n * sizeof(float), NRAM2NRAM); + } +} + +template +__mlu_device__ void storeFromFloat(T *dst, T *cache, float *src, size_t n) { + if constexpr (std::is_same::value) { + __bang_float2half(cache, src, n); + } else if constexpr (std::is_same::value) { + __bang_float2bfloat16(cache, src, n); + } else { + __memcpy(cache, src, n * sizeof(float), NRAM2NRAM); + } + __memcpy(dst, cache, n * sizeof(T), NRAM2GDRAM); +} + +__mlu_device__ int getMaxBatchSize(size_t dim, size_t data_size, size_t weight_size) { + constexpr int reduce_buffer_size = 128 / sizeof(float); + int max_batch_size = (NRAM_MAX_SIZE - 256 - reduce_buffer_size * sizeof(float)) / (3 * data_size + weight_size + 3 * sizeof(float)); + max_batch_size = std::min(max_batch_size, static_cast(dim)); + if (max_batch_size > 64) { + max_batch_size = (max_batch_size / 64) * 64; + } + return std::max(max_batch_size, 1); +} + +template +__mlu_device__ void bindNramBuffers( + int max_batch_size, + float **reduction_buffer, + Tdata **a_cache, + Tdata **b_cache, + Tdata **out_cache, + Tweight **weight_cache, + float **a_float, + float **b_float, + float **weight_float) { + constexpr int reduce_buffer_size = 128 / sizeof(float); + *reduction_buffer = reinterpret_cast(nram_buffer); + *a_cache = reinterpret_cast(*reduction_buffer + reduce_buffer_size); + *b_cache = *a_cache + max_batch_size; + *out_cache = *b_cache + max_batch_size; + *weight_cache = reinterpret_cast(*out_cache + max_batch_size); + *a_float = reinterpret_cast(*weight_cache + max_batch_size); + *b_float = *a_float + max_batch_size; + *weight_float = *b_float + max_batch_size; +} + +template +__mlu_global__ void addRMSNormKernel( + Tdata *__restrict__ y, + Tdata *__restrict__ residual_out, + const Tdata *__restrict__ a, + const Tdata *__restrict__ b, + const Tweight *__restrict__ weight, + size_t batch_size, + size_t nhead, + size_t dim, + ptrdiff_t stride_y_batch, + ptrdiff_t stride_y_nhead, + ptrdiff_t stride_residual_batch, + ptrdiff_t stride_residual_nhead, + ptrdiff_t stride_a_batch, + ptrdiff_t stride_a_nhead, + ptrdiff_t stride_b_batch, + ptrdiff_t stride_b_nhead, + float epsilon) { + + int max_batch_size = getMaxBatchSize(dim, sizeof(Tdata), sizeof(Tweight)); + float *reduction_buffer; + Tdata *a_cache; + Tdata *b_cache; + Tdata *out_cache; + Tweight *weight_cache; + float *a_float; + float *b_float; + float *weight_float; + bindNramBuffers(max_batch_size, &reduction_buffer, &a_cache, &b_cache, &out_cache, &weight_cache, &a_float, &b_float, &weight_float); + + size_t rows = batch_size * nhead; + for (size_t row = taskId; row < rows; row += taskDim) { + size_t batch_idx = row / nhead; + size_t head_idx = row - batch_idx * nhead; + + Tdata *y_ptr = y + batch_idx * stride_y_batch + head_idx * stride_y_nhead; + Tdata *residual_ptr = residual_out + batch_idx * stride_residual_batch + head_idx * stride_residual_nhead; + const Tdata *a_ptr = a + batch_idx * stride_a_batch + head_idx * stride_a_nhead; + const Tdata *b_ptr = b + batch_idx * stride_b_batch + head_idx * stride_b_nhead; + + float sum_squared = 0.0f; + size_t processed = 0; + while (processed < dim) { + size_t current = std::min(static_cast(max_batch_size), dim - processed); + + loadToFloat(a_float, a_cache, a_ptr + processed, current); + loadToFloat(b_float, b_cache, b_ptr + processed, current); + __bang_add(a_float, a_float, b_float, current); + storeFromFloat(residual_ptr + processed, out_cache, a_float, current); + + __bang_mul(a_float, a_float, a_float, current); + if (current >= 128) { + op::common_bang::reduce_op::sumInternal(reduction_buffer, a_float, current); + sum_squared += reduction_buffer[0]; + } else { + for (size_t i = 0; i < current; ++i) { + sum_squared += a_float[i]; + } + } + processed += current; + } + + float inv_rms = 1.0f / sqrtf(sum_squared / static_cast(dim) + epsilon); + + processed = 0; + while (processed < dim) { + size_t current = std::min(static_cast(max_batch_size), dim - processed); + + loadToFloat(a_float, a_cache, residual_ptr + processed, current); + loadToFloat(weight_float, weight_cache, weight + processed, current); + __bang_mul(a_float, a_float, weight_float, current); + __bang_mul_scalar(a_float, a_float, inv_rms, current); + storeFromFloat(y_ptr + processed, out_cache, a_float, current); + + processed += current; + } + } +} + +template +__mlu_global__ void addRMSNormPartialSumKernel( + Tdata *__restrict__ residual_out, + float *__restrict__ partial_sums, + const Tdata *__restrict__ a, + const Tdata *__restrict__ b, + size_t batch_size, + size_t nhead, + size_t dim, + size_t parts_per_row, + ptrdiff_t stride_residual_batch, + ptrdiff_t stride_residual_nhead, + ptrdiff_t stride_a_batch, + ptrdiff_t stride_a_nhead, + ptrdiff_t stride_b_batch, + ptrdiff_t stride_b_nhead) { + + int max_batch_size = getMaxBatchSize(dim, sizeof(Tdata), sizeof(Tweight)); + float *reduction_buffer; + Tdata *a_cache; + Tdata *b_cache; + Tdata *out_cache; + Tweight *weight_cache; + float *a_float; + float *b_float; + float *weight_float; + bindNramBuffers(max_batch_size, &reduction_buffer, &a_cache, &b_cache, &out_cache, &weight_cache, &a_float, &b_float, &weight_float); + + size_t rows = batch_size * nhead; + size_t logical_tasks = rows * parts_per_row; + size_t part_size = (dim + parts_per_row - 1) / parts_per_row; + + for (size_t logical = taskId; logical < logical_tasks; logical += taskDim) { + size_t row = logical / parts_per_row; + size_t part = logical - row * parts_per_row; + size_t start = part * part_size; + size_t end = std::min(dim, start + part_size); + if (start >= end) { + partial_sums[logical] = 0.0f; + continue; + } + + size_t batch_idx = row / nhead; + size_t head_idx = row - batch_idx * nhead; + Tdata *residual_ptr = residual_out + batch_idx * stride_residual_batch + head_idx * stride_residual_nhead; + const Tdata *a_ptr = a + batch_idx * stride_a_batch + head_idx * stride_a_nhead; + const Tdata *b_ptr = b + batch_idx * stride_b_batch + head_idx * stride_b_nhead; + + float sum_squared = 0.0f; + size_t processed = start; + while (processed < end) { + size_t current = std::min(static_cast(max_batch_size), end - processed); + + loadToFloat(a_float, a_cache, a_ptr + processed, current); + loadToFloat(b_float, b_cache, b_ptr + processed, current); + __bang_add(a_float, a_float, b_float, current); + storeFromFloat(residual_ptr + processed, out_cache, a_float, current); + + __bang_mul(a_float, a_float, a_float, current); + if (current >= 128) { + op::common_bang::reduce_op::sumInternal(reduction_buffer, a_float, current); + sum_squared += reduction_buffer[0]; + } else { + for (size_t i = 0; i < current; ++i) { + sum_squared += a_float[i]; + } + } + processed += current; + } + partial_sums[logical] = sum_squared; + } +} + +template +__mlu_global__ void addRMSNormApplyKernel( + Tdata *__restrict__ y, + const Tdata *__restrict__ residual_out, + const Tweight *__restrict__ weight, + const float *__restrict__ partial_sums, + size_t batch_size, + size_t nhead, + size_t dim, + size_t parts_per_row, + ptrdiff_t stride_y_batch, + ptrdiff_t stride_y_nhead, + ptrdiff_t stride_residual_batch, + ptrdiff_t stride_residual_nhead, + float epsilon) { + + int max_batch_size = getMaxBatchSize(dim, sizeof(Tdata), sizeof(Tweight)); + float *reduction_buffer; + Tdata *a_cache; + Tdata *b_cache; + Tdata *out_cache; + Tweight *weight_cache; + float *a_float; + float *b_float; + float *weight_float; + bindNramBuffers(max_batch_size, &reduction_buffer, &a_cache, &b_cache, &out_cache, &weight_cache, &a_float, &b_float, &weight_float); + + size_t rows = batch_size * nhead; + size_t logical_tasks = rows * parts_per_row; + size_t part_size = (dim + parts_per_row - 1) / parts_per_row; + + for (size_t logical = taskId; logical < logical_tasks; logical += taskDim) { + size_t row = logical / parts_per_row; + size_t part = logical - row * parts_per_row; + size_t start = part * part_size; + size_t end = std::min(dim, start + part_size); + if (start >= end) { + continue; + } + + float sum_squared = 0.0f; + const float *row_sums = partial_sums + row * parts_per_row; + for (size_t i = 0; i < parts_per_row; ++i) { + sum_squared += row_sums[i]; + } + float inv_rms = 1.0f / sqrtf(sum_squared / static_cast(dim) + epsilon); + + size_t batch_idx = row / nhead; + size_t head_idx = row - batch_idx * nhead; + Tdata *y_ptr = y + batch_idx * stride_y_batch + head_idx * stride_y_nhead; + const Tdata *residual_ptr = residual_out + batch_idx * stride_residual_batch + head_idx * stride_residual_nhead; + + size_t processed = start; + while (processed < end) { + size_t current = std::min(static_cast(max_batch_size), end - processed); + + loadToFloat(a_float, a_cache, residual_ptr + processed, current); + loadToFloat(weight_float, weight_cache, weight + processed, current); + __bang_mul(a_float, a_float, weight_float, current); + __bang_mul_scalar(a_float, a_float, inv_rms, current); + storeFromFloat(y_ptr + processed, out_cache, a_float, current); + + processed += current; + } + } +} + +template +static infiniStatus_t launchAddRMSNorm( + void *workspace, + int core_per_cluster, + int cluster_count, + cnrtQueue_t queue, + size_t batch_size, + size_t nhead, + size_t dim, + size_t parts_per_row, + void *y, + void *residual_out, + const void *a, + const void *b, + const void *weight, + ptrdiff_t stride_y_batch, + ptrdiff_t stride_y_nhead, + ptrdiff_t stride_residual_batch, + ptrdiff_t stride_residual_nhead, + ptrdiff_t stride_a_batch, + ptrdiff_t stride_a_nhead, + ptrdiff_t stride_b_batch, + ptrdiff_t stride_b_nhead, + float epsilon) { + + cnrtDim3_t kernel_dim; + kernel_dim.x = core_per_cluster; + kernel_dim.y = cluster_count; + kernel_dim.z = 1; + + if (parts_per_row > 1) { + auto partial_sums = reinterpret_cast(workspace); + addRMSNormPartialSumKernel<<>>( + reinterpret_cast(residual_out), + partial_sums, + reinterpret_cast(a), + reinterpret_cast(b), + batch_size, + nhead, + dim, + parts_per_row, + stride_residual_batch, + stride_residual_nhead, + stride_a_batch, + stride_a_nhead, + stride_b_batch, + stride_b_nhead); + addRMSNormApplyKernel<<>>( + reinterpret_cast(y), + reinterpret_cast(residual_out), + reinterpret_cast(weight), + partial_sums, + batch_size, + nhead, + dim, + parts_per_row, + stride_y_batch, + stride_y_nhead, + stride_residual_batch, + stride_residual_nhead, + epsilon); + } else { + addRMSNormKernel<<>>( + reinterpret_cast(y), + reinterpret_cast(residual_out), + reinterpret_cast(a), + reinterpret_cast(b), + reinterpret_cast(weight), + batch_size, + nhead, + dim, + stride_y_batch, + stride_y_nhead, + stride_residual_batch, + stride_residual_nhead, + stride_a_batch, + stride_a_nhead, + stride_b_batch, + stride_b_nhead, + epsilon); + } + cnrtQueueSync(queue); + return INFINI_STATUS_SUCCESS; +} + +namespace op::add_rms_norm::bang { + +struct Descriptor::Opaque { + std::shared_ptr internal; + size_t parts_per_row; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t residual_out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t weight_desc, + float epsilon) { + auto result = AddRMSNormInfo::create(y_desc, residual_out_desc, a_desc, b_desc, weight_desc, epsilon); + CHECK_RESULT(result); + auto info = result.take(); + + auto internal = reinterpret_cast(handle)->internal(); + size_t rows = info.shape[0] * (info.shape.size() > 2 ? info.shape[1] : 1); + int task_dim = internal->getCorePerCluster() * internal->getClusterCount(); + size_t parts_per_row = getPartsPerRow(rows, info.dim(), task_dim); + size_t workspace_size = parts_per_row > 1 ? rows * parts_per_row * sizeof(float) : 0; + + *desc_ptr = new Descriptor( + new Opaque{internal, parts_per_row}, + std::move(info), + workspace_size, + handle->device, + handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *y, void *residual_out, const void *a, const void *b, const void *weight, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + auto queue = reinterpret_cast(stream); + int core_per_cluster = _opaque->internal->getCorePerCluster(); + int cluster_count = _opaque->internal->getClusterCount(); + + size_t batch_size = _info.shape[0]; + size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1; + size_t dim = _info.dim(); + + ptrdiff_t stride_y_batch = _info.y_strides[0]; + ptrdiff_t stride_y_nhead = _info.shape.size() > 2 ? _info.y_strides[1] : 0; + ptrdiff_t stride_residual_batch = _info.residual_out_strides[0]; + ptrdiff_t stride_residual_nhead = _info.shape.size() > 2 ? _info.residual_out_strides[1] : 0; + ptrdiff_t stride_a_batch = _info.a_strides[0]; + ptrdiff_t stride_a_nhead = _info.shape.size() > 2 ? _info.a_strides[1] : 0; + ptrdiff_t stride_b_batch = _info.b_strides[0]; + ptrdiff_t stride_b_nhead = _info.shape.size() > 2 ? _info.b_strides[1] : 0; + +#define DISPATCH(Tdata, Tweight) \ + return launchAddRMSNorm( \ + workspace, core_per_cluster, cluster_count, queue, batch_size, nhead, dim, \ + _opaque->parts_per_row, y, residual_out, a, b, weight, \ + stride_y_batch, stride_y_nhead, \ + stride_residual_batch, stride_residual_nhead, \ + stride_a_batch, stride_a_nhead, \ + stride_b_batch, stride_b_nhead, \ + _info.epsilon) + + if (_info.atype == INFINI_DTYPE_F16) { + if (_info.wtype == INFINI_DTYPE_F16) { + DISPATCH(half, half); + } else if (_info.wtype == INFINI_DTYPE_BF16) { + DISPATCH(half, bfloat16_t); + } else if (_info.wtype == INFINI_DTYPE_F32) { + DISPATCH(half, float); + } + } else if (_info.atype == INFINI_DTYPE_BF16) { + if (_info.wtype == INFINI_DTYPE_BF16) { + DISPATCH(bfloat16_t, bfloat16_t); + } else if (_info.wtype == INFINI_DTYPE_F16) { + DISPATCH(bfloat16_t, half); + } else if (_info.wtype == INFINI_DTYPE_F32) { + DISPATCH(bfloat16_t, float); + } + } else if (_info.atype == INFINI_DTYPE_F32 && _info.wtype == INFINI_DTYPE_F32) { + DISPATCH(float, float); + } + +#undef DISPATCH + + return INFINI_STATUS_BAD_TENSOR_DTYPE; +} + +} // namespace op::add_rms_norm::bang diff --git a/src/infiniop/ops/add_rms_norm/operator.cc b/src/infiniop/ops/add_rms_norm/operator.cc index 8434aa9ef..7d502770d 100644 --- a/src/infiniop/ops/add_rms_norm/operator.cc +++ b/src/infiniop/ops/add_rms_norm/operator.cc @@ -13,8 +13,7 @@ // #include "ascend/add_rms_norm_aclnn.h" #endif #ifdef ENABLE_CAMBRICON_API -// TODO: Add Cambricon implementation -// #include "bang/add_rms_norm_bang.h" +#include "bang/add_rms_norm_bang.h" #endif #ifdef ENABLE_METAX_API #include "metax/add_rms_norm_metax.cuh" @@ -68,6 +67,9 @@ __INFINI_C infiniStatus_t infiniopCreateAddRMSNormDescriptor( #ifdef ENABLE_METAX_API CREATE(INFINI_DEVICE_METAX, metax); #endif +#ifdef ENABLE_CAMBRICON_API + CREATE(INFINI_DEVICE_CAMBRICON, bang); +#endif #ifdef ENABLE_QY_API CREATE(INFINI_DEVICE_QY, nvidia); #endif @@ -110,6 +112,9 @@ __INFINI_C infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormD #ifdef ENABLE_METAX_API GET(INFINI_DEVICE_METAX, metax); #endif +#ifdef ENABLE_CAMBRICON_API + GET(INFINI_DEVICE_CAMBRICON, bang); +#endif #ifdef ENABLE_QY_API GET(INFINI_DEVICE_QY, nvidia); #endif @@ -163,6 +168,9 @@ __INFINI_C infiniStatus_t infiniopAddRMSNorm( #ifdef ENABLE_METAX_API CALCULATE(INFINI_DEVICE_METAX, metax); #endif +#ifdef ENABLE_CAMBRICON_API + CALCULATE(INFINI_DEVICE_CAMBRICON, bang); +#endif #ifdef ENABLE_QY_API CALCULATE(INFINI_DEVICE_QY, nvidia); #endif @@ -207,6 +215,9 @@ __INFINI_C infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNorm #ifdef ENABLE_METAX_API DESTROY(INFINI_DEVICE_METAX, metax); #endif +#ifdef ENABLE_CAMBRICON_API + DESTROY(INFINI_DEVICE_CAMBRICON, bang); +#endif #ifdef ENABLE_QY_API DESTROY(INFINI_DEVICE_QY, nvidia); #endif