diff --git a/.gitignore b/.gitignore index 2effaff2f..45fec281a 100644 --- a/.gitignore +++ b/.gitignore @@ -207,7 +207,7 @@ venv.bak/ .spyderproject .spyproject -# Rope project settings +# RotaryEmbedding project settings .ropeproject # mkdocs documentation diff --git a/src/base/rotary_embedding.h b/src/base/rotary_embedding.h index 10426ee86..3a64905e4 100644 --- a/src/base/rotary_embedding.h +++ b/src/base/rotary_embedding.h @@ -1,87 +1,94 @@ #ifndef INFINI_OPS_BASE_ROTARY_EMBEDDING_H_ #define INFINI_OPS_BASE_ROTARY_EMBEDDING_H_ +#include #include -#include +#include "data_type.h" #include "operator.h" +#include "tensor.h" namespace infini::ops { -// Rotary position embedding (RoPE) applied in-place to Q and K. -// -// Interface follows vLLM's `RotaryEmbedding.forward_oot()`: -// `vllm.model_executor.layers.rotary_embedding.RotaryEmbedding` -// -// `positions`: `[T]` token position indices. -// `cos_sin_cache`: precomputed `[max_seq_len, rotary_dim]` table. -// `query` / `key`: `[T, N, D]` (TND layout), mutated in-place into -// `query_out` / `key_out`. class RotaryEmbedding : public Operator { public: - RotaryEmbedding(const Tensor positions, const Tensor query, const Tensor key, - const Tensor cos_sin_cache, int64_t head_size, - int64_t rotary_dim, bool is_neox_style, Tensor query_out, - Tensor key_out) - : num_tokens_{query.size(0)}, - num_heads_{static_cast(query.size(1))}, - num_kv_heads_{static_cast(key.size(1))}, - head_size_{head_size}, - rotary_dim_{rotary_dim}, - is_neox_style_{is_neox_style}, - query_shape_{query.shape()}, - key_shape_{key.shape()}, - cos_sin_cache_shape_{cos_sin_cache.shape()}, - query_out_shape_{query_out.shape()}, - key_out_shape_{key_out.shape()}, - query_strides_{query.strides()}, - key_strides_{key.strides()}, - query_out_strides_{query_out.strides()}, - key_out_strides_{key_out.strides()} { - assert(query.ndim() == 3 && - "`RotaryEmbedding` requires query to be 3D [T, N, D]"); - assert(key.ndim() == 3 && - "`RotaryEmbedding` requires key to be 3D [T, N_kv, D]"); - assert(rotary_dim <= head_size && - "`RotaryEmbedding` requires rotary_dim <= head_size"); + RotaryEmbedding(const Tensor input, const Tensor pos_ids, + const Tensor sin_table, const Tensor cos_table, Tensor out, + bool is_neox = false) + : ndim_{out.ndim()}, + batch_size_{ndim_ == 4 ? out.size(-4) : 1}, + seq_len_{out.size(-3)}, + nhead_{out.size(-2)}, + table_dim_{sin_table.size(1)}, + has_batch_dim_{ndim_ == 4}, + pos_has_batch_dim_{pos_ids.ndim() == 2}, + input_strides_{input.strides()}, + out_strides_{out.strides()}, + pos_strides_{pos_ids.strides()} { + const auto head_dim = out.size(-1); + const auto table_len = sin_table.size(0); + const auto angle_dtype = sin_table.dtype(); + const auto pos_dtype = pos_ids.dtype(); + + assert(input.shape() == out.shape() && + "`RotaryEmbedding` requires `input` and `out` same shape"); + assert(input.dtype() == out.dtype() && + "`RotaryEmbedding` requires `input` and `out` same dtype"); + assert((ndim_ == 3 || ndim_ == 4) && + "`RotaryEmbedding` requires 3D or 4D tensor"); + assert(head_dim % 2 == 0 && + "`RotaryEmbedding` requires head dimension to be even"); + assert(head_dim == table_dim_ * 2 && + "`RotaryEmbedding` requires table dim to be half of head dim"); + assert(pos_ids.ndim() == 1 || pos_ids.ndim() == 2); + assert((pos_dtype == DataType::kInt32 || pos_dtype == DataType::kInt64) && + "`RotaryEmbedding` requires int32 or int64 position ids"); + assert(sin_table.shape() == cos_table.shape() && + "`RotaryEmbedding` requires sin_table and cos_table same shape"); + assert(sin_table.dtype() == cos_table.dtype() && + "`RotaryEmbedding` requires sin_table and cos_table same dtype"); + assert((angle_dtype == DataType::kFloat16 || + angle_dtype == DataType::kBFloat16 || + angle_dtype == DataType::kFloat32) && + "`RotaryEmbedding` requires float sin/cos tables"); + assert(sin_table.ndim() == 2 && cos_table.ndim() == 2 && + "`RotaryEmbedding` requires 2D sin/cos tables"); + assert(table_len >= seq_len_ && + "`RotaryEmbedding` requires table length >= sequence length"); + assert((pos_has_batch_dim_ ? (pos_ids.size(0) == batch_size_ && + pos_ids.size(1) == seq_len_) + : (pos_ids.size(0) == seq_len_)) && + "`RotaryEmbedding` requires pos_ids shape [seq] or [batch, seq]"); + assert(out_strides_[ndim_ - 1] == 1 && input_strides_[ndim_ - 1] == 1 && + "`RotaryEmbedding` requires contiguous head dimension"); + assert(sin_table.strides()[1] == 1 && cos_table.strides()[1] == 1 && + "`RotaryEmbedding` requires contiguous table dimension"); } - virtual void operator()(const Tensor positions, const Tensor query, - const Tensor key, const Tensor cos_sin_cache, - int64_t head_size, int64_t rotary_dim, - bool is_neox_style, Tensor query_out, - Tensor key_out) const = 0; + virtual void operator()(const Tensor input, const Tensor pos_ids, + const Tensor sin_table, const Tensor cos_table, + Tensor out, bool is_neox = false) const = 0; protected: - Tensor::Size num_tokens_{0}; + Tensor::Size ndim_{0}; - int64_t num_heads_{0}; + Tensor::Size batch_size_{0}; - int64_t num_kv_heads_{0}; + Tensor::Size seq_len_{0}; - int64_t head_size_{0}; + Tensor::Size nhead_{0}; - int64_t rotary_dim_{0}; + Tensor::Size table_dim_{0}; - bool is_neox_style_{true}; + bool has_batch_dim_{false}; - Tensor::Shape query_shape_; + bool pos_has_batch_dim_{false}; - Tensor::Shape key_shape_; + Tensor::Strides input_strides_; - Tensor::Shape cos_sin_cache_shape_; + Tensor::Strides out_strides_; - Tensor::Shape query_out_shape_; - - Tensor::Shape key_out_shape_; - - Tensor::Strides query_strides_; - - Tensor::Strides key_strides_; - - Tensor::Strides query_out_strides_; - - Tensor::Strides key_out_strides_; + Tensor::Strides pos_strides_; }; } // namespace infini::ops diff --git a/src/native/cuda/iluvatar/ops/rotary_embedding/kernel.h b/src/native/cuda/iluvatar/ops/rotary_embedding/kernel.h new file mode 100644 index 000000000..23ba4c3ba --- /dev/null +++ b/src/native/cuda/iluvatar/ops/rotary_embedding/kernel.h @@ -0,0 +1,22 @@ +#ifndef INFINI_OPS_ILUVATAR_ROTARY_EMBEDDING_KERNEL_H_ +#define INFINI_OPS_ILUVATAR_ROTARY_EMBEDDING_KERNEL_H_ + +#include + +#include "native/cuda/iluvatar/caster.cuh" +#include "native/cuda/iluvatar/runtime_.h" +#include "native/cuda/ops/rotary_embedding/kernel.h" + +namespace infini::ops { + +template <> +class Operator + : public CudaRotaryEmbedding> { + public: + using CudaRotaryEmbedding< + Runtime>::CudaRotaryEmbedding; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/metax/ops/rotary_embedding/kernel.h b/src/native/cuda/metax/ops/rotary_embedding/kernel.h new file mode 100644 index 000000000..2b643a50a --- /dev/null +++ b/src/native/cuda/metax/ops/rotary_embedding/kernel.h @@ -0,0 +1,21 @@ +#ifndef INFINI_OPS_METAX_ROTARY_EMBEDDING_KERNEL_H_ +#define INFINI_OPS_METAX_ROTARY_EMBEDDING_KERNEL_H_ + +#include + +#include "native/cuda/metax/caster.cuh" +#include "native/cuda/metax/runtime_.h" +#include "native/cuda/ops/rotary_embedding/kernel.h" + +namespace infini::ops { + +template <> +class Operator + : public CudaRotaryEmbedding> { + public: + using CudaRotaryEmbedding>::CudaRotaryEmbedding; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/moore/ops/rotary_embedding/kernel.h b/src/native/cuda/moore/ops/rotary_embedding/kernel.h new file mode 100644 index 000000000..ffee912a1 --- /dev/null +++ b/src/native/cuda/moore/ops/rotary_embedding/kernel.h @@ -0,0 +1,25 @@ +#ifndef INFINI_OPS_MOORE_ROTARY_EMBEDDING_KERNEL_H_ +#define INFINI_OPS_MOORE_ROTARY_EMBEDDING_KERNEL_H_ + +#include + +// clang-format off +#include +// clang-format on + +#include "native/cuda/moore/caster.cuh" +#include "native/cuda/moore/runtime_.h" +#include "native/cuda/ops/rotary_embedding/kernel.h" + +namespace infini::ops { + +template <> +class Operator + : public CudaRotaryEmbedding> { + public: + using CudaRotaryEmbedding>::CudaRotaryEmbedding; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/nvidia/ops/rotary_embedding/kernel.h b/src/native/cuda/nvidia/ops/rotary_embedding/kernel.h new file mode 100644 index 000000000..ce7583419 --- /dev/null +++ b/src/native/cuda/nvidia/ops/rotary_embedding/kernel.h @@ -0,0 +1,22 @@ +#ifndef INFINI_OPS_NVIDIA_ROTARY_EMBEDDING_KERNEL_H_ +#define INFINI_OPS_NVIDIA_ROTARY_EMBEDDING_KERNEL_H_ + +#include + +#include "native/cuda/nvidia/caster.cuh" +#include "native/cuda/nvidia/runtime_.h" +#include "native/cuda/ops/rotary_embedding/kernel.h" + +namespace infini::ops { + +template <> +class Operator + : public CudaRotaryEmbedding> { + public: + using CudaRotaryEmbedding< + Runtime>::CudaRotaryEmbedding; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/ops/rotary_embedding/kernel.cuh b/src/native/cuda/ops/rotary_embedding/kernel.cuh new file mode 100644 index 000000000..803533dfe --- /dev/null +++ b/src/native/cuda/ops/rotary_embedding/kernel.cuh @@ -0,0 +1,127 @@ +#ifndef INFINI_OPS_CUDA_ROPE_KERNEL_CUH_ +#define INFINI_OPS_CUDA_ROPE_KERNEL_CUH_ + +#include +#include +#include + +#include "native/cuda/caster.cuh" + +namespace infini::ops { + +template +struct VecTypeHelper {}; + +template <> +struct VecTypeHelper { + using type2 = half2; + static __device__ __forceinline__ half2 make_half2(float x, float y) { + return __floats2half2_rn(x, y); + } + static __device__ __forceinline__ float low(half2 v) { + return __low2float(v); + } + static __device__ __forceinline__ float high(half2 v) { + return __high2float(v); + } +}; + +template <> +struct VecTypeHelper { + using type2 = cuda_bfloat162; + static __device__ __forceinline__ cuda_bfloat162 make_half2(float x, + float y) { + return __floats2bfloat162_rn(x, y); + } + static __device__ __forceinline__ float low(cuda_bfloat162 v) { + return __low2float(v); + } + static __device__ __forceinline__ float high(cuda_bfloat162 v) { + return __high2float(v); + } +}; + +template +__global__ void RoPEKernel( + TData* __restrict__ out_ptr, const TData* __restrict__ input_ptr, + const TIndex* __restrict__ pos_ids_ptr, const TAngle* __restrict__ sin_ptr, + const TAngle* __restrict__ cos_ptr, size_t table_dim, + ptrdiff_t out_stride_batch, ptrdiff_t out_stride_seqlen, + ptrdiff_t out_stride_nhead, ptrdiff_t input_stride_batch, + ptrdiff_t input_stride_seqlen, ptrdiff_t input_stride_nhead, + ptrdiff_t pos_stride_batch, bool pos_has_batch_dim, bool has_batch_dim) { + const size_t batch_idx = has_batch_dim ? blockIdx.z : 0; + const size_t seq_idx = blockIdx.x; + const size_t head_idx = blockIdx.y; + + auto out_offset = (has_batch_dim ? batch_idx * out_stride_batch : 0) + + seq_idx * out_stride_seqlen + head_idx * out_stride_nhead; + auto input_offset = (has_batch_dim ? batch_idx * input_stride_batch : 0) + + seq_idx * input_stride_seqlen + + head_idx * input_stride_nhead; + + size_t pos_offset; + if (pos_has_batch_dim) { + pos_offset = batch_idx * pos_stride_batch + seq_idx; + } else { + pos_offset = seq_idx; + } + + size_t pos_id = static_cast(pos_ids_ptr[pos_offset]); + size_t table_offset = pos_id * table_dim; + + using VecHelper = VecTypeHelper; + + for (size_t i = threadIdx.x; i < table_dim; i += blockDim.x) { + float sin_val = + Caster::template Cast(sin_ptr[table_offset + i]); + float cos_val = + Caster::template Cast(cos_ptr[table_offset + i]); + + if constexpr (IsNeox) { + if constexpr (std::is_same::value || + std::is_same::value) { + auto& y = reinterpret_cast( + out_ptr[out_offset + 2 * i]); + auto& x = reinterpret_cast( + input_ptr[input_offset + 2 * i]); + + float x0 = VecHelper::low(x); + float x1 = VecHelper::high(x); + + float y0 = x0 * cos_val - x1 * sin_val; + float y1 = x0 * sin_val + x1 * cos_val; + + y = VecHelper::make_half2(y0, y1); + } else { + float x0 = + Caster::template Cast(input_ptr[input_offset + 2 * i]); + float x1 = Caster::template Cast( + input_ptr[input_offset + 2 * i + 1]); + out_ptr[out_offset + 2 * i] = + Caster::template Cast(x0 * cos_val - x1 * sin_val); + out_ptr[out_offset + 2 * i + 1] = + Caster::template Cast(x0 * sin_val + x1 * cos_val); + } + } else { + size_t pos0 = i; + size_t pos1 = i + table_dim; + + float x0 = + Caster::template Cast(input_ptr[input_offset + pos0]); + float x1 = + Caster::template Cast(input_ptr[input_offset + pos1]); + + float y0 = x0 * cos_val - x1 * sin_val; + float y1 = x0 * sin_val + x1 * cos_val; + + out_ptr[out_offset + pos0] = Caster::template Cast(y0); + out_ptr[out_offset + pos1] = Caster::template Cast(y1); + } + } +} + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/ops/rotary_embedding/kernel.h b/src/native/cuda/ops/rotary_embedding/kernel.h new file mode 100644 index 000000000..c93697d62 --- /dev/null +++ b/src/native/cuda/ops/rotary_embedding/kernel.h @@ -0,0 +1,132 @@ +#ifndef INFINI_OPS_CUDA_ROPE_KERNEL_H_ +#define INFINI_OPS_CUDA_ROPE_KERNEL_H_ + +#include +#include +#include + +#include "base/rotary_embedding.h" +#include "data_type.h" +#include "dispatcher.h" +#include "native/cuda/kernel_commons.cuh" +#include "native/cuda/ops/rotary_embedding/kernel.cuh" +#include "native/cuda/runtime_utils.h" + +namespace infini::ops { + +namespace { + +template +void LaunchRoPEKernel(dim3 grid, int block_size, + typename Backend::Stream cuda_stream, T* out_ptr, + const T* input_ptr, const void* pos_ids_ptr, + DataType pos_dtype, const TAngle* sin_ptr, + const TAngle* cos_ptr, size_t table_dim, + ptrdiff_t out_stride_batch, ptrdiff_t out_stride_seqlen, + ptrdiff_t out_stride_nhead, ptrdiff_t input_stride_batch, + ptrdiff_t input_stride_seqlen, + ptrdiff_t input_stride_nhead, ptrdiff_t pos_stride_batch, + bool pos_has_batch_dim, bool has_batch_dim) { + if (pos_dtype == DataType::kInt64) { + RoPEKernel + <<>>( + out_ptr, input_ptr, reinterpret_cast(pos_ids_ptr), + sin_ptr, cos_ptr, table_dim, out_stride_batch, out_stride_seqlen, + out_stride_nhead, input_stride_batch, input_stride_seqlen, + input_stride_nhead, pos_stride_batch, pos_has_batch_dim, + has_batch_dim); + } else { + RoPEKernel + <<>>( + out_ptr, input_ptr, reinterpret_cast(pos_ids_ptr), + sin_ptr, cos_ptr, table_dim, out_stride_batch, out_stride_seqlen, + out_stride_nhead, input_stride_batch, input_stride_seqlen, + input_stride_nhead, pos_stride_batch, pos_has_batch_dim, + has_batch_dim); + } +} + +} // namespace + +template +class CudaRotaryEmbedding : public RotaryEmbedding { + public: + using RotaryEmbedding::RotaryEmbedding; + + void operator()(const Tensor input, const Tensor pos_ids, + const Tensor sin_table, const Tensor cos_table, Tensor out, + bool is_neox = false) const override { + auto cuda_stream = + static_cast(stream_ ? stream_ : 0); + + dim3 grid(static_cast(seq_len_), static_cast(nhead_), + static_cast(batch_size_)); + + assert(out.dtype() == input.dtype()); + assert(pos_ids.dtype() == DataType::kInt64 || + pos_ids.dtype() == DataType::kInt32); + + ptrdiff_t out_stride_batch = ndim_ == 4 ? out_strides_[0] : 0; + ptrdiff_t out_stride_seqlen = out_strides_[ndim_ - 3]; + ptrdiff_t out_stride_nhead = out_strides_[ndim_ - 2]; + + ptrdiff_t input_stride_batch = ndim_ == 4 ? input_strides_[0] : 0; + ptrdiff_t input_stride_seqlen = input_strides_[ndim_ - 3]; + ptrdiff_t input_stride_nhead = input_strides_[ndim_ - 2]; + + ptrdiff_t pos_stride_batch = pos_has_batch_dim_ ? pos_strides_[0] : 0; + + int block_size = RuntimeUtils::GetOptimalBlockSize(); + + using DataTypes = ConcatType, ReducedFloatTypes>; + using AngleTypes = ConcatType, ReducedFloatTypes>; + + if (is_neox) { + DispatchFunc( + {static_cast(out.dtype()), + static_cast(sin_table.dtype()), block_size}, + [&](auto list_tag) { + using T = TypeMapType(list_tag)>; + using TAngle = + TypeMapType(list_tag)>; + constexpr int kBlockSize = ListGet<2>(list_tag); + + LaunchRoPEKernel( + grid, kBlockSize, cuda_stream, reinterpret_cast(out.data()), + reinterpret_cast(input.data()), pos_ids.data(), + pos_ids.dtype(), + reinterpret_cast(sin_table.data()), + reinterpret_cast(cos_table.data()), table_dim_, + out_stride_batch, out_stride_seqlen, out_stride_nhead, + input_stride_batch, input_stride_seqlen, input_stride_nhead, + pos_stride_batch, pos_has_batch_dim_, has_batch_dim_); + }, + "CudaRotaryEmbedding::operator() (Neox)"); + } else { + DispatchFunc( + {static_cast(out.dtype()), + static_cast(sin_table.dtype()), block_size}, + [&](auto list_tag) { + using T = TypeMapType(list_tag)>; + using TAngle = + TypeMapType(list_tag)>; + constexpr int kBlockSize = ListGet<2>(list_tag); + + LaunchRoPEKernel( + grid, kBlockSize, cuda_stream, reinterpret_cast(out.data()), + reinterpret_cast(input.data()), pos_ids.data(), + pos_ids.dtype(), + reinterpret_cast(sin_table.data()), + reinterpret_cast(cos_table.data()), table_dim_, + out_stride_batch, out_stride_seqlen, out_stride_nhead, + input_stride_batch, input_stride_seqlen, input_stride_nhead, + pos_stride_batch, pos_has_batch_dim_, has_batch_dim_); + }, + "CudaRotaryEmbedding::operator() (Standard)"); + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/tests/test_rotary_embedding.py b/tests/test_rotary_embedding.py new file mode 100644 index 000000000..51574c2ce --- /dev/null +++ b/tests/test_rotary_embedding.py @@ -0,0 +1,297 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, get_stream, randn_strided + + +def _compute_sin_cos_table(pos_ids, head_dim, theta=10000.0, dtype=torch.float32): + """Compute sin and cos tables for RoPE.""" + assert head_dim % 2 == 0, "Head dimension must be even" + freqs = 1.0 / ( + theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim) + ) + angles = torch.outer(pos_ids.float(), freqs) + sin_table = torch.sin(angles).to(dtype) + cos_table = torch.cos(angles).to(dtype) + return sin_table, cos_table + + +def _torch_rope_standard(input_tensor, sin, cos): + """Standard RoPE style: first half and second half""" + dh = input_tensor.shape[-1] + half_dim = dh // 2 + + t_first = input_tensor[..., :half_dim] + t_second = input_tensor[..., half_dim:] + + t_out_first = t_first * cos - t_second * sin + t_out_second = t_first * sin + t_second * cos + + return torch.cat([t_out_first, t_out_second], dim=-1) + + +def _torch_rope_neox_interleaved(input_tensor, sin, cos): + """GPT-NeoX style RoPE: interleaved [even, odd] pairs""" + dh = input_tensor.shape[-1] + assert dh % 2 == 0, "Embedding dimension must be even." + + t_even = input_tensor[..., 0::2] + t_odd = input_tensor[..., 1::2] + + t_out_even = t_even * cos - t_odd * sin + t_out_odd = t_even * sin + t_odd * cos + + output = torch.empty_like(input_tensor) + output[..., 0::2] = t_out_even + output[..., 1::2] = t_out_odd + + return output + + +def _torch_rotary_embedding(input_tensor, pos_ids, sin_table, cos_table, is_neox=False): + """PyTorch reference implementation of RoPE.""" + target_dtype = input_tensor.dtype + + has_batch_dim = input_tensor.dim() == 4 + + # Gather sin/cos values + sin_gathered = sin_table[pos_ids] + cos_gathered = cos_table[pos_ids] + + # Expand dimensions for broadcasting + if has_batch_dim and pos_ids.dim() == 2: + # pos_ids: [batch, seq_len] + # sin_gathered: [batch, seq_len, half_dim] + # Need: [batch, seq_len, 1, half_dim] + sin_expanded = sin_gathered.unsqueeze(2) + cos_expanded = cos_gathered.unsqueeze(2) + elif has_batch_dim: + # pos_ids: [seq_len] + # sin_gathered: [seq_len, half_dim] + # Need: [1, seq_len, 1, half_dim] + sin_expanded = sin_gathered.unsqueeze(0).unsqueeze(2) + cos_expanded = cos_gathered.unsqueeze(0).unsqueeze(2) + else: + # pos_ids: [seq_len] + # sin_gathered: [seq_len, half_dim] + # Need: [seq_len, 1, half_dim] + sin_expanded = sin_gathered.unsqueeze(1) + cos_expanded = cos_gathered.unsqueeze(1) + + # Apply RoPE + if is_neox: + output = _torch_rope_neox_interleaved(input_tensor, sin_expanded, cos_expanded) + else: + output = _torch_rope_standard(input_tensor, sin_expanded, cos_expanded) + + return output.to(dtype=target_dtype) + + +def _rotary_embedding( + input_tensor, pos_ids, sin_table, cos_table, *, out=None, stream=None, is_neox=False +): + """Wrapper for calling infini.ops.rotary_embedding.""" + infini.ops.rotary_embedding( + input_tensor, + pos_ids, + sin_table, + cos_table, + out, + is_neox, + stream=get_stream(input_tensor.device) if stream is None else stream, + ) + return out + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "shape, input_strides, out_strides, is_neox", + ( + # 3D tensors (no batch dimension) - NeoX style (interleaved) + ((32, 8, 64), None, None, True), + ((128, 12, 128), None, None, True), + # 3D tensors - Standard style (first/second half) + ((32, 8, 64), None, None, False), + ((128, 12, 128), None, None, False), + # 4D tensors (with batch dimension) - NeoX style + ((4, 32, 8, 64), None, None, True), + ((8, 128, 12, 128), None, None, True), + # 4D tensors - Standard style + ((4, 32, 8, 64), None, None, False), + ((8, 128, 12, 128), None, None, False), + # With custom strides + ((4, 32, 8, 64), (16384, 512, 64, 1), (16384, 512, 64, 1), True), + ((32, 8, 64), (2048, 64, 1), (2048, 64, 1), False), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-5, 1e-5), + (torch.float16, 1e-2, 1e-2), + (torch.bfloat16, 5e-2, 2e-2), + ), +) +def test_rotary_embedding( + shape, input_strides, out_strides, is_neox, dtype, device, rtol, atol +): + """Test Rotary Positional Embedding operator.""" + input_tensor = randn_strided(shape, input_strides, dtype=dtype, device=device) + out = empty_strided(shape, out_strides, dtype=dtype, device=device) + + seq_len = shape[-3] + pos_ids = torch.arange(seq_len, dtype=torch.int32, device=device) + + head_dim = shape[-1] + sin_table, cos_table = _compute_sin_cos_table( + pos_ids.cpu(), head_dim, theta=10000.0, dtype=dtype + ) + sin_table = sin_table.to(device) + cos_table = cos_table.to(device) + + ref_output = _torch_rotary_embedding( + input_tensor, pos_ids, sin_table, cos_table, is_neox + ) + + return Payload( + lambda *args, **kwargs: _rotary_embedding(*args, **kwargs, is_neox=is_neox), + lambda *args, **kwargs: ref_output, + (input_tensor, pos_ids, sin_table, cos_table), + {"out": out}, + rtol=rtol, + atol=atol, + ) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "shape, is_neox, theta", + ( + ((1, 32, 12, 64), True, 10000.0), + ((1, 32, 12, 64), False, 50000.0), + ((32, 12, 64), True, 100000.0), + ((32, 12, 64), False, 10000.0), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-5, 1e-5), + (torch.float16, 1e-2, 1e-2), + (torch.bfloat16, 5e-2, 2e-2), + ), +) +def test_rotary_embedding_different_theta( + shape, is_neox, theta, dtype, device, rtol, atol +): + """Test RoPE with different theta values.""" + input_tensor = randn_strided(shape, None, dtype=dtype, device=device) + out = empty_strided(shape, None, dtype=dtype, device=device) + + seq_len = shape[-3] + pos_ids = torch.arange(seq_len, dtype=torch.int32, device=device) + + head_dim = shape[-1] + sin_table, cos_table = _compute_sin_cos_table( + pos_ids.cpu(), head_dim, theta=theta, dtype=dtype + ) + sin_table = sin_table.to(device) + cos_table = cos_table.to(device) + + ref_output = _torch_rotary_embedding( + input_tensor, pos_ids, sin_table, cos_table, is_neox + ) + + return Payload( + lambda *args, **kwargs: _rotary_embedding(*args, **kwargs, is_neox=is_neox), + lambda *args, **kwargs: ref_output, + (input_tensor, pos_ids, sin_table, cos_table), + {"out": out}, + rtol=rtol, + atol=atol, + ) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-5, 1e-5), + (torch.float16, 1e-2, 1e-2), + (torch.bfloat16, 5e-2, 2e-2), + ), +) +def test_rotary_embedding_inplace(device, dtype, rtol, atol): + """Test in-place RoPE operation.""" + shape = (4, 32, 8, 64) + + input_tensor = randn_strided(shape, None, dtype=dtype, device=device) + out = input_tensor.clone() + + pos_ids = torch.arange(shape[-3], dtype=torch.int32, device=device) + + head_dim = shape[-1] + sin_table, cos_table = _compute_sin_cos_table(pos_ids.cpu(), head_dim, dtype=dtype) + sin_table = sin_table.to(device) + cos_table = cos_table.to(device) + + input_copy = input_tensor.clone() + ref_output = _torch_rotary_embedding( + input_copy, pos_ids, sin_table, cos_table, is_neox=False + ) + + return Payload( + lambda *args, **kwargs: _rotary_embedding(*args, **kwargs, is_neox=False), + lambda *args, **kwargs: ref_output, + (input_tensor, pos_ids, sin_table, cos_table), + {"out": out}, + rtol=rtol, + atol=atol, + ) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "seq_len, nhead, head_dim", + ( + (1, 1, 2), + (1, 4, 64), + (2, 1, 128), + (1, 1, 128), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-5, 1e-5), + (torch.float16, 1e-2, 1e-2), + (torch.bfloat16, 5e-2, 2e-2), + ), +) +def test_rotary_embedding_edge_cases( + seq_len, nhead, head_dim, dtype, device, rtol, atol +): + """Test edge cases: minimum dimensions.""" + shape = (seq_len, nhead, head_dim) + + input_tensor = randn_strided(shape, None, dtype=dtype, device=device) + out = empty_strided(shape, None, dtype=dtype, device=device) + + pos_ids = torch.arange(seq_len, dtype=torch.int32, device=device) + sin_table, cos_table = _compute_sin_cos_table(pos_ids.cpu(), head_dim, dtype=dtype) + sin_table = sin_table.to(device) + cos_table = cos_table.to(device) + + ref_output = _torch_rotary_embedding( + input_tensor, pos_ids, sin_table, cos_table, is_neox=False + ) + + return Payload( + lambda *args, **kwargs: _rotary_embedding(*args, **kwargs, is_neox=False), + lambda *args, **kwargs: ref_output, + (input_tensor, pos_ids, sin_table, cos_table), + {"out": out}, + rtol=rtol, + atol=atol, + )