Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/infiniop/ops/swiglu/moore/swiglu_moore_musa.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __SWIGLU_CUDA_MOORE_H__
#define __SWIGLU_CUDA_MOORE_H__

#include "../swiglu_cuda.h"

DESCRIPTOR(moore)

#endif
123 changes: 123 additions & 0 deletions src/infiniop/ops/swiglu/moore/swiglu_moore_musa.mu
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
#include "../../../devices/moore/moore_common.h"
#include "../../../devices/moore/moore_kernel_common.h"

#include "swiglu_moore_musa.h"
#include "swiglu_moore_musa_kernel.h"


template <typename T, unsigned int BLOCK_SIZE>
INFINIOP_MOORE_KERNEL SwiGLUCuda(
T *c,
const T *a,
const T *b,
int length,
size_t batch, size_t seq_len, size_t hidden_dim,
ptrdiff_t c_strides_0, ptrdiff_t c_strides_1, ptrdiff_t c_strides_2,
ptrdiff_t a_strides_0, ptrdiff_t a_strides_1, ptrdiff_t a_strides_2,
ptrdiff_t b_strides_0, ptrdiff_t b_strides_1, ptrdiff_t b_strides_2) {
SwiGLUCudaKernel<T, BLOCK_SIZE>(c, a, b, length, batch, seq_len, hidden_dim,
c_strides_0, c_strides_1, c_strides_2,
a_strides_0, a_strides_1, a_strides_2,
b_strides_0, b_strides_1, b_strides_2);
}

namespace op::swiglu_cuda::moore {

struct Descriptor::Opaque {
std::shared_ptr<device::moore::Handle::Internal> internal;
};

Descriptor::~Descriptor() {
delete _opaque;
}

infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) {

auto info = SwiGLUCudaInfo::createSwiGLUCudaInfo(c_desc, a_desc, b_desc);
CHECK_RESULT(info);

*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::moore::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}

template <unsigned int BLOCK_SIZE, typename T>
infiniStatus_t calculate_swiglu_cuda(
const SwiGLUCudaInfo &info,
T *c,
const T *a,
const T *b,
musaStream_t stream,
void *workspace) {

int length = (int)info.length;
size_t batch = info.batch;
size_t seq_len = info.seq_len;
size_t hidden_dim = info.hidden_dim;
ptrdiff_t c_strides_0 = info.c_strides_0;
ptrdiff_t c_strides_1 = info.c_strides_1;
ptrdiff_t c_strides_2 = info.c_strides_2;
ptrdiff_t a_strides_0 = info.a_strides_0;
ptrdiff_t a_strides_1 = info.a_strides_1;
ptrdiff_t a_strides_2 = info.a_strides_2;
ptrdiff_t b_strides_0 = info.b_strides_0;
ptrdiff_t b_strides_1 = info.b_strides_1;
ptrdiff_t b_strides_2 = info.b_strides_2;

int num_blocks = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
SwiGLUCuda<T, BLOCK_SIZE>
<<<num_blocks, BLOCK_SIZE, 0, stream>>>(c, a, b, length, batch, seq_len, hidden_dim,
c_strides_0, c_strides_1, c_strides_2,
a_strides_0, a_strides_1, a_strides_2,
b_strides_0, b_strides_1, b_strides_2);

return INFINI_STATUS_SUCCESS;
}

infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *c,
const void *a,
const void *b,
void *stream_) const {

if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}

musaStream_t stream = (musaStream_t)stream_;

#define CALCULATE_SWIGLU_CUDA(BLOCK_SIZE, TDATA) \
calculate_swiglu_cuda<BLOCK_SIZE, TDATA>(_info, (TDATA *)c, (const TDATA *)a, (const TDATA *)b, stream, workspace)
#define CALCULATE_SWIGLU_CUDA_WITH_BLOCK_SIZE(BLOCK_SIZE) \
{ \
if (_info.dtype == INFINI_DTYPE_F16) \
return CALCULATE_SWIGLU_CUDA(BLOCK_SIZE, half); \
else if (_info.dtype == INFINI_DTYPE_F32) \
return CALCULATE_SWIGLU_CUDA(BLOCK_SIZE, float); \
else if (_info.dtype == INFINI_DTYPE_BF16) \
return CALCULATE_SWIGLU_CUDA(BLOCK_SIZE, __mt_bfloat16); \
else \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}

if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_512) {
CALCULATE_SWIGLU_CUDA_WITH_BLOCK_SIZE(MOORE_BLOCK_SIZE_512)
} else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_1024) {
CALCULATE_SWIGLU_CUDA_WITH_BLOCK_SIZE(MOORE_BLOCK_SIZE_1024)
} else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_2048) {
CALCULATE_SWIGLU_CUDA_WITH_BLOCK_SIZE(MOORE_BLOCK_SIZE_2048)
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}

return INFINI_STATUS_SUCCESS;
}
} // namespace op::swiglu_cuda::moore
95 changes: 95 additions & 0 deletions src/infiniop/ops/swiglu/moore/swiglu_moore_musa_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#ifndef __SWIGLU_CUDA_KERNEL_CUH__
#define __SWIGLU_CUDA_KERNEL_CUH__

template <typename T>
__device__ __forceinline__ T sigmoid(const T &x) {
if constexpr (std::is_same_v<T, half2>) {
return h2rcp(__hadd2(make_half2(1, 1), h2exp(__hneg2(x))));
} else if constexpr (std::is_same_v<T, half>) {
// The original CUDA implementation's reliance on platform-specific intrinsics like hrcp for half-precision,
// which was not supported on the MUSA platform.
// To resolve this, the half-precision input is first converted to a higher-precision float,
// the calculation is performed, and the result is cast back to half.
float xf = __half2float(x);
float sigf = 1.0f / (1.0f + std::exp(-xf));
return __float2half(sigf);
} else if constexpr (std::is_same_v<T, cuda_bfloat162>) {
float x0 = __bfloat162float(__low2bfloat16(x));
float x1 = __bfloat162float(__high2bfloat16(x));
float sig0 = __frcp_rn(__fadd_rn(1.0f, __expf(-x0)));
float sig1 = __frcp_rn(__fadd_rn(1.0f, __expf(-x1)));
return __floats2bfloat162_rn(sig0, sig1);
} else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
float xf = __bfloat162float(x);
return __float2bfloat16_rn(__frcp_rn(__fadd_rn(1.0f, __expf(-xf))));
} else if constexpr (std::is_same_v<T, float>) {
return __frcp_rn(__fadd_rn(1, __expf(-x)));
} else {
return 1 / (1 + std::exp(-x));
}
}

template <typename T, unsigned int BLOCK_SIZE>
__device__ void SwiGLUCudaKernel(
T *c,
const T *a,
const T *b,
int length,
size_t batch, size_t seq_len, size_t hidden_dim,
ptrdiff_t c_strides_0, ptrdiff_t c_strides_1, ptrdiff_t c_strides_2,
ptrdiff_t a_strides_0, ptrdiff_t a_strides_1, ptrdiff_t a_strides_2,
ptrdiff_t b_strides_0, ptrdiff_t b_strides_1, ptrdiff_t b_strides_2) {
int ind_c = 0;
int ind_a = 0;
int ind_b = 0;
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid < length) {
ind_c += tid % (int)hidden_dim * (int)c_strides_2;
ind_a += tid % (int)hidden_dim * (int)a_strides_2;
ind_b += tid % (int)hidden_dim * (int)b_strides_2;
tid = tid / (int)hidden_dim;
ind_c += (tid % (int)seq_len) * (int)c_strides_1;
ind_a += (tid % (int)seq_len) * (int)a_strides_1;
ind_b += (tid % (int)seq_len) * (int)b_strides_1;
tid = tid / (int)seq_len;
ind_c += (tid % (int)batch) * (int)c_strides_0;
ind_a += (tid % (int)batch) * (int)a_strides_0;
ind_b += (tid % (int)batch) * (int)b_strides_0;

T gate = b[ind_b];
T up = a[ind_a];

if constexpr (std::is_same_v<T, half2>) {
c[ind_c] = __hmul2(__hmul2(gate, sigmoid(gate)), up);
} else if constexpr (std::is_same_v<T, half>) {
c[ind_c] = __hmul(__hmul(gate, sigmoid(gate)), up);
} else if constexpr (std::is_same_v<T, cuda_bfloat162>) {
cuda_bfloat162 sig = sigmoid(gate);
// On the MUSA platform, `__low2float()` and `__high2float()` are used to directly
// extract and convert bfloat16 values to float. These functions replace the
// two-step process used in CUDA (e.g., `__low2bfloat16` followed by `__bfloat162float`).
// Since MUSA may not support '__low2bfloat16'
float gate0 = __low2float(gate);
float gate1 = __high2float(gate);
float sig0 = __low2float(sig);
float sig1 = __high2float(sig);
float up0 = __low2float(up);
float up1 = __high2float(up);
float res0 = __fmul_rn(__fmul_rn(gate0, sig0), up0);
float res1 = __fmul_rn(__fmul_rn(gate1, sig1), up1);
c[ind_c] = __floats2bfloat162_rn(res0, res1);
} else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
cuda_bfloat16 sig = sigmoid(gate);
float gatef = __bfloat162float(gate);
float sigf = __bfloat162float(sig);
float upf = __bfloat162float(up);
c[ind_c] = __float2bfloat16_rn(__fmul_rn(__fmul_rn(gatef, sigf), upf));
} else if constexpr (std::is_same_v<T, float>) {
c[ind_c] = __fmul_rn(__fmul_rn(gate, sigmoid(gate)), up);
} else {
c[ind_c] = gate * sigmoid(gate) * up;
}
}
}

#endif // __SWIGLU_CUDA_KERNEL_CUH__
10 changes: 5 additions & 5 deletions src/infiniop/ops/swiglu/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
#include "ascend/swiglu_ascend.h"
#endif
#ifdef ENABLE_MOORE_API
#include "moore/swiglu_moore.h"
#include "moore/swiglu_moore_musa.h"
#endif

__INFINI_C infiniStatus_t infiniopCreateSwiGLUDescriptor(
Expand Down Expand Up @@ -94,7 +94,7 @@ __INFINI_C infiniStatus_t infiniopCreateSwiGLUDescriptor(
CREATE(INFINI_DEVICE_ASCEND, ascend);
#endif
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, moore);
CREATE_CUDA(INFINI_DEVICE_MOORE, moore);
#endif

default:
Expand Down Expand Up @@ -158,7 +158,7 @@ __INFINI_C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescripto
GET(INFINI_DEVICE_ASCEND, ascend);
#endif
#ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, moore);
GET_CUDA(INFINI_DEVICE_MOORE, moore);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
Expand Down Expand Up @@ -228,7 +228,7 @@ __INFINI_C infiniStatus_t infiniopSwiGLU(
CALCULATE(INFINI_DEVICE_ASCEND, ascend);
#endif
#ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, moore);
CALCULATE_CUDA(INFINI_DEVICE_MOORE, moore);
#endif

default:
Expand Down Expand Up @@ -293,7 +293,7 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) {
DELETE(INFINI_DEVICE_ASCEND, ascend)
#endif
#ifdef ENABLE_MOORE_API
DELETE(INFINI_DEVICE_MOORE, moore);
DELETE_CUDA(INFINI_DEVICE_MOORE, moore);
#endif

default:
Expand Down
Loading