Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ venv.bak/
.spyderproject
.spyproject

# Rope project settings
# RotaryEmbedding project settings
.ropeproject

# mkdocs documentation
Expand Down
127 changes: 67 additions & 60 deletions src/base/rotary_embedding.h
Original file line number Diff line number Diff line change
@@ -1,87 +1,94 @@
#ifndef INFINI_OPS_BASE_ROTARY_EMBEDDING_H_
#define INFINI_OPS_BASE_ROTARY_EMBEDDING_H_

#include <cassert>
#include <cstddef>
#include <vector>

#include "data_type.h"
#include "operator.h"
#include "tensor.h"

namespace infini::ops {

// Rotary position embedding (RoPE) applied in-place to Q and K.
//
// Interface follows vLLM's `RotaryEmbedding.forward_oot()`:
// `vllm.model_executor.layers.rotary_embedding.RotaryEmbedding`
//
// `positions`: `[T]` token position indices.
// `cos_sin_cache`: precomputed `[max_seq_len, rotary_dim]` table.
// `query` / `key`: `[T, N, D]` (TND layout), mutated in-place into
// `query_out` / `key_out`.
class RotaryEmbedding : public Operator<RotaryEmbedding> {
public:
RotaryEmbedding(const Tensor positions, const Tensor query, const Tensor key,
const Tensor cos_sin_cache, int64_t head_size,
int64_t rotary_dim, bool is_neox_style, Tensor query_out,
Tensor key_out)
: num_tokens_{query.size(0)},
num_heads_{static_cast<int64_t>(query.size(1))},
num_kv_heads_{static_cast<int64_t>(key.size(1))},
head_size_{head_size},
rotary_dim_{rotary_dim},
is_neox_style_{is_neox_style},
query_shape_{query.shape()},
key_shape_{key.shape()},
cos_sin_cache_shape_{cos_sin_cache.shape()},
query_out_shape_{query_out.shape()},
key_out_shape_{key_out.shape()},
query_strides_{query.strides()},
key_strides_{key.strides()},
query_out_strides_{query_out.strides()},
key_out_strides_{key_out.strides()} {
assert(query.ndim() == 3 &&
"`RotaryEmbedding` requires query to be 3D [T, N, D]");
assert(key.ndim() == 3 &&
"`RotaryEmbedding` requires key to be 3D [T, N_kv, D]");
assert(rotary_dim <= head_size &&
"`RotaryEmbedding` requires rotary_dim <= head_size");
RotaryEmbedding(const Tensor input, const Tensor pos_ids,
const Tensor sin_table, const Tensor cos_table, Tensor out,
bool is_neox = false)
: ndim_{out.ndim()},
batch_size_{ndim_ == 4 ? out.size(-4) : 1},
seq_len_{out.size(-3)},
nhead_{out.size(-2)},
table_dim_{sin_table.size(1)},
has_batch_dim_{ndim_ == 4},
pos_has_batch_dim_{pos_ids.ndim() == 2},
input_strides_{input.strides()},
out_strides_{out.strides()},
pos_strides_{pos_ids.strides()} {
const auto head_dim = out.size(-1);
const auto table_len = sin_table.size(0);
const auto angle_dtype = sin_table.dtype();
const auto pos_dtype = pos_ids.dtype();

assert(input.shape() == out.shape() &&
"`RotaryEmbedding` requires `input` and `out` same shape");
assert(input.dtype() == out.dtype() &&
"`RotaryEmbedding` requires `input` and `out` same dtype");
assert((ndim_ == 3 || ndim_ == 4) &&
"`RotaryEmbedding` requires 3D or 4D tensor");
assert(head_dim % 2 == 0 &&
"`RotaryEmbedding` requires head dimension to be even");
assert(head_dim == table_dim_ * 2 &&
"`RotaryEmbedding` requires table dim to be half of head dim");
assert(pos_ids.ndim() == 1 || pos_ids.ndim() == 2);
assert((pos_dtype == DataType::kInt32 || pos_dtype == DataType::kInt64) &&
"`RotaryEmbedding` requires int32 or int64 position ids");
assert(sin_table.shape() == cos_table.shape() &&
"`RotaryEmbedding` requires sin_table and cos_table same shape");
assert(sin_table.dtype() == cos_table.dtype() &&
"`RotaryEmbedding` requires sin_table and cos_table same dtype");
assert((angle_dtype == DataType::kFloat16 ||
angle_dtype == DataType::kBFloat16 ||
angle_dtype == DataType::kFloat32) &&
"`RotaryEmbedding` requires float sin/cos tables");
assert(sin_table.ndim() == 2 && cos_table.ndim() == 2 &&
"`RotaryEmbedding` requires 2D sin/cos tables");
assert(table_len >= seq_len_ &&
"`RotaryEmbedding` requires table length >= sequence length");
assert((pos_has_batch_dim_ ? (pos_ids.size(0) == batch_size_ &&
pos_ids.size(1) == seq_len_)
: (pos_ids.size(0) == seq_len_)) &&
"`RotaryEmbedding` requires pos_ids shape [seq] or [batch, seq]");
assert(out_strides_[ndim_ - 1] == 1 && input_strides_[ndim_ - 1] == 1 &&
"`RotaryEmbedding` requires contiguous head dimension");
assert(sin_table.strides()[1] == 1 && cos_table.strides()[1] == 1 &&
"`RotaryEmbedding` requires contiguous table dimension");
}

virtual void operator()(const Tensor positions, const Tensor query,
const Tensor key, const Tensor cos_sin_cache,
int64_t head_size, int64_t rotary_dim,
bool is_neox_style, Tensor query_out,
Tensor key_out) const = 0;
virtual void operator()(const Tensor input, const Tensor pos_ids,
const Tensor sin_table, const Tensor cos_table,
Tensor out, bool is_neox = false) const = 0;

protected:
Tensor::Size num_tokens_{0};
Tensor::Size ndim_{0};

int64_t num_heads_{0};
Tensor::Size batch_size_{0};

int64_t num_kv_heads_{0};
Tensor::Size seq_len_{0};

int64_t head_size_{0};
Tensor::Size nhead_{0};

int64_t rotary_dim_{0};
Tensor::Size table_dim_{0};

bool is_neox_style_{true};
bool has_batch_dim_{false};

Tensor::Shape query_shape_;
bool pos_has_batch_dim_{false};

Tensor::Shape key_shape_;
Tensor::Strides input_strides_;

Tensor::Shape cos_sin_cache_shape_;
Tensor::Strides out_strides_;

Tensor::Shape query_out_shape_;

Tensor::Shape key_out_shape_;

Tensor::Strides query_strides_;

Tensor::Strides key_strides_;

Tensor::Strides query_out_strides_;

Tensor::Strides key_out_strides_;
Tensor::Strides pos_strides_;
};

} // namespace infini::ops
Expand Down
22 changes: 22 additions & 0 deletions src/native/cuda/iluvatar/ops/rotary_embedding/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef INFINI_OPS_ILUVATAR_ROTARY_EMBEDDING_KERNEL_H_
#define INFINI_OPS_ILUVATAR_ROTARY_EMBEDDING_KERNEL_H_

#include <utility>

#include "native/cuda/iluvatar/caster.cuh"
#include "native/cuda/iluvatar/runtime_.h"
#include "native/cuda/ops/rotary_embedding/kernel.h"

namespace infini::ops {

template <>
class Operator<RotaryEmbedding, Device::Type::kIluvatar>
: public CudaRotaryEmbedding<Runtime<Device::Type::kIluvatar>> {
public:
using CudaRotaryEmbedding<
Runtime<Device::Type::kIluvatar>>::CudaRotaryEmbedding;
};

} // namespace infini::ops

#endif
21 changes: 21 additions & 0 deletions src/native/cuda/metax/ops/rotary_embedding/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#ifndef INFINI_OPS_METAX_ROTARY_EMBEDDING_KERNEL_H_
#define INFINI_OPS_METAX_ROTARY_EMBEDDING_KERNEL_H_

#include <utility>

#include "native/cuda/metax/caster.cuh"
#include "native/cuda/metax/runtime_.h"
#include "native/cuda/ops/rotary_embedding/kernel.h"

namespace infini::ops {

template <>
class Operator<RotaryEmbedding, Device::Type::kMetax>
: public CudaRotaryEmbedding<Runtime<Device::Type::kMetax>> {
public:
using CudaRotaryEmbedding<Runtime<Device::Type::kMetax>>::CudaRotaryEmbedding;
};

} // namespace infini::ops

#endif
25 changes: 25 additions & 0 deletions src/native/cuda/moore/ops/rotary_embedding/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef INFINI_OPS_MOORE_ROTARY_EMBEDDING_KERNEL_H_
#define INFINI_OPS_MOORE_ROTARY_EMBEDDING_KERNEL_H_

#include <utility>

// clang-format off
#include <musa_runtime.h>
// clang-format on

#include "native/cuda/moore/caster.cuh"
#include "native/cuda/moore/runtime_.h"
#include "native/cuda/ops/rotary_embedding/kernel.h"

namespace infini::ops {

template <>
class Operator<RotaryEmbedding, Device::Type::kMoore>
: public CudaRotaryEmbedding<Runtime<Device::Type::kMoore>> {
public:
using CudaRotaryEmbedding<Runtime<Device::Type::kMoore>>::CudaRotaryEmbedding;
};

} // namespace infini::ops

#endif
22 changes: 22 additions & 0 deletions src/native/cuda/nvidia/ops/rotary_embedding/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef INFINI_OPS_NVIDIA_ROTARY_EMBEDDING_KERNEL_H_
#define INFINI_OPS_NVIDIA_ROTARY_EMBEDDING_KERNEL_H_

#include <utility>

#include "native/cuda/nvidia/caster.cuh"
#include "native/cuda/nvidia/runtime_.h"
#include "native/cuda/ops/rotary_embedding/kernel.h"

namespace infini::ops {

template <>
class Operator<RotaryEmbedding, Device::Type::kNvidia>
: public CudaRotaryEmbedding<Runtime<Device::Type::kNvidia>> {
public:
using CudaRotaryEmbedding<
Runtime<Device::Type::kNvidia>>::CudaRotaryEmbedding;
};

} // namespace infini::ops

#endif
127 changes: 127 additions & 0 deletions src/native/cuda/ops/rotary_embedding/kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#ifndef INFINI_OPS_CUDA_ROPE_KERNEL_CUH_
#define INFINI_OPS_CUDA_ROPE_KERNEL_CUH_

#include <cstddef>
#include <cstdint>
#include <type_traits>

#include "native/cuda/caster.cuh"

namespace infini::ops {

template <typename T>
struct VecTypeHelper {};

template <>
struct VecTypeHelper<half> {
using type2 = half2;
static __device__ __forceinline__ half2 make_half2(float x, float y) {
return __floats2half2_rn(x, y);
}
static __device__ __forceinline__ float low(half2 v) {
return __low2float(v);
}
static __device__ __forceinline__ float high(half2 v) {
return __high2float(v);
}
};

template <>
struct VecTypeHelper<cuda_bfloat16> {
using type2 = cuda_bfloat162;
static __device__ __forceinline__ cuda_bfloat162 make_half2(float x,
float y) {
return __floats2bfloat162_rn(x, y);
}
static __device__ __forceinline__ float low(cuda_bfloat162 v) {
return __low2float(v);
}
static __device__ __forceinline__ float high(cuda_bfloat162 v) {
return __high2float(v);
}
};

template <bool IsNeox, Device::Type kDev, typename TData, typename TIndex,
typename TAngle>
__global__ void RoPEKernel(
TData* __restrict__ out_ptr, const TData* __restrict__ input_ptr,
const TIndex* __restrict__ pos_ids_ptr, const TAngle* __restrict__ sin_ptr,
const TAngle* __restrict__ cos_ptr, size_t table_dim,
ptrdiff_t out_stride_batch, ptrdiff_t out_stride_seqlen,
ptrdiff_t out_stride_nhead, ptrdiff_t input_stride_batch,
ptrdiff_t input_stride_seqlen, ptrdiff_t input_stride_nhead,
ptrdiff_t pos_stride_batch, bool pos_has_batch_dim, bool has_batch_dim) {
const size_t batch_idx = has_batch_dim ? blockIdx.z : 0;
const size_t seq_idx = blockIdx.x;
const size_t head_idx = blockIdx.y;

auto out_offset = (has_batch_dim ? batch_idx * out_stride_batch : 0) +
seq_idx * out_stride_seqlen + head_idx * out_stride_nhead;
auto input_offset = (has_batch_dim ? batch_idx * input_stride_batch : 0) +
seq_idx * input_stride_seqlen +
head_idx * input_stride_nhead;

size_t pos_offset;
if (pos_has_batch_dim) {
pos_offset = batch_idx * pos_stride_batch + seq_idx;
} else {
pos_offset = seq_idx;
}

size_t pos_id = static_cast<size_t>(pos_ids_ptr[pos_offset]);
size_t table_offset = pos_id * table_dim;

using VecHelper = VecTypeHelper<TData>;

for (size_t i = threadIdx.x; i < table_dim; i += blockDim.x) {
float sin_val =
Caster<kDev>::template Cast<float>(sin_ptr[table_offset + i]);
float cos_val =
Caster<kDev>::template Cast<float>(cos_ptr[table_offset + i]);

if constexpr (IsNeox) {
if constexpr (std::is_same<TData, half>::value ||
std::is_same<TData, cuda_bfloat16>::value) {
auto& y = reinterpret_cast<typename VecHelper::type2&>(
out_ptr[out_offset + 2 * i]);
auto& x = reinterpret_cast<const typename VecHelper::type2&>(
input_ptr[input_offset + 2 * i]);

float x0 = VecHelper::low(x);
float x1 = VecHelper::high(x);

float y0 = x0 * cos_val - x1 * sin_val;
float y1 = x0 * sin_val + x1 * cos_val;

y = VecHelper::make_half2(y0, y1);
} else {
float x0 =
Caster<kDev>::template Cast<float>(input_ptr[input_offset + 2 * i]);
float x1 = Caster<kDev>::template Cast<float>(
input_ptr[input_offset + 2 * i + 1]);
out_ptr[out_offset + 2 * i] =
Caster<kDev>::template Cast<TData>(x0 * cos_val - x1 * sin_val);
out_ptr[out_offset + 2 * i + 1] =
Caster<kDev>::template Cast<TData>(x0 * sin_val + x1 * cos_val);
}
} else {
size_t pos0 = i;
size_t pos1 = i + table_dim;

float x0 =
Caster<kDev>::template Cast<float>(input_ptr[input_offset + pos0]);
float x1 =
Caster<kDev>::template Cast<float>(input_ptr[input_offset + pos1]);

float y0 = x0 * cos_val - x1 * sin_val;
float y1 = x0 * sin_val + x1 * cos_val;

out_ptr[out_offset + pos0] = Caster<kDev>::template Cast<TData>(y0);
out_ptr[out_offset + pos1] = Caster<kDev>::template Cast<TData>(y1);
}
}
}

} // namespace infini::ops

#endif
Loading
Loading