Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions custom_ops/gpu_ops/append_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ void AppendAttentionKernel(
const paddle::optional<paddle::Tensor>& out_linear_shifts,
const paddle::optional<paddle::Tensor>& out_linear_smooths,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const paddle::optional<paddle::Tensor>& rope_3d_delta,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const paddle::optional<paddle::Tensor>& sinks,
Expand Down Expand Up @@ -213,6 +214,7 @@ void AppendAttentionKernel(
max_input_length,
use_neox_rotary_style,
rope_3d,
rope_3d_delta,
main_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
Expand Down Expand Up @@ -310,6 +312,7 @@ void AppendAttentionKernel(
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
rope_3d_delta,
max_input_length,
exec_stream,
&qkv_out,
Expand Down Expand Up @@ -337,6 +340,7 @@ void AppendAttentionKernel(
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
rope_3d_delta,
max_input_length,
exec_stream,
&qkv_out,
Expand Down Expand Up @@ -365,6 +369,7 @@ void AppendAttentionKernel(
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
rope_3d_delta,
max_input_length,
exec_stream,
&qkv_out,
Expand All @@ -391,6 +396,7 @@ void AppendAttentionKernel(
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
rope_3d_delta,
max_input_length,
exec_stream,
&qkv_out,
Expand Down Expand Up @@ -485,6 +491,7 @@ std::vector<paddle::Tensor> AppendAttention(
const paddle::optional<paddle::Tensor>& out_linear_smooths,
const paddle::optional<paddle::Tensor>& mask_offset,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const paddle::optional<paddle::Tensor>& rope_3d_delta,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const paddle::optional<paddle::Tensor>& sinks,
Expand Down Expand Up @@ -619,6 +626,7 @@ std::vector<paddle::Tensor> AppendAttention(
out_linear_shifts,
out_linear_smooths,
kv_signal_data,
rope_3d_delta,
q_norm_weight,
k_norm_weight,
sinks,
Expand Down Expand Up @@ -697,6 +705,7 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
const paddle::optional<paddle::Tensor>& out_linear_smooths,
const paddle::optional<paddle::Tensor>& mask_offset,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const paddle::optional<paddle::Tensor>& rope_3d_delta,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const paddle::optional<paddle::Tensor>& sinks,
Expand Down Expand Up @@ -777,6 +786,7 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
out_linear_shifts,
out_linear_smooths,
kv_signal_data,
rope_3d_delta,
q_norm_weight,
k_norm_weight,
sinks,
Expand Down Expand Up @@ -868,6 +878,7 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape,
const paddle::optional<std::vector<int64_t>>& mask_offset_shape,
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
const paddle::optional<std::vector<int64_t>>& rope_3d_delta_shape,
const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape,
const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape,
const paddle::optional<std::vector<int64_t>>& sinks_shape,
Expand Down Expand Up @@ -934,6 +945,7 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
const paddle::optional<paddle::DataType>& out_linear_smooths_dtype,
const paddle::optional<paddle::DataType>& mask_offset_dtype,
const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
const paddle::optional<paddle::DataType>& rope_3d_delta_dtype,
const paddle::optional<paddle::DataType>& q_norm_weight_dtype,
const paddle::optional<paddle::DataType>& k_norm_weight_dtype,
const paddle::optional<paddle::DataType>& sinks_dtype,
Expand Down Expand Up @@ -1021,6 +1033,7 @@ std::vector<std::vector<int64_t>> AppendAttentionWithOutputInferShape(
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape,
const paddle::optional<std::vector<int64_t>>& mask_offset_shape,
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
const paddle::optional<std::vector<int64_t>>& rope_3d_delta_shape,
const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape,
const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape,
const paddle::optional<std::vector<int64_t>>& sinks_shape,
Expand Down Expand Up @@ -1080,6 +1093,7 @@ std::vector<paddle::DataType> AppendAttentionWithOutputInferDtype(
const paddle::optional<paddle::DataType>& out_linear_smooths_dtype,
const paddle::optional<paddle::DataType>& mask_offset_dtype,
const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
const paddle::optional<paddle::DataType>& rope_3d_delta_dtype,
const paddle::optional<paddle::DataType>& q_norm_weight_dtype,
const paddle::optional<paddle::DataType>& k_norm_weight_dtype,
const paddle::optional<paddle::DataType>& sinks_dtype,
Expand Down Expand Up @@ -1138,6 +1152,7 @@ PD_BUILD_STATIC_OP(append_attention)
paddle::Optional("out_linear_smooths"),
paddle::Optional("mask_offset"),
paddle::Optional("kv_signal_data"),
paddle::Optional("rope_3d_delta"),
paddle::Optional("q_norm_weight"),
paddle::Optional("k_norm_weight"),
paddle::Optional("sinks")})
Expand Down Expand Up @@ -1201,6 +1216,7 @@ PD_BUILD_STATIC_OP(append_attention_with_output)
paddle::Optional("out_linear_smooths"),
paddle::Optional("mask_offset"),
paddle::Optional("kv_signal_data"),
paddle::Optional("rope_3d_delta"),
paddle::Optional("q_norm_weight"),
paddle::Optional("k_norm_weight"),
paddle::Optional("sinks")})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
const uint32_t elem_cnt,
const int kv_num_heads,
const bool rope_3d,
const int* rope_3d_delta,
const float* q_norm_weight,
const float* k_norm_weight,
const float rms_norm_eps) {
Expand Down Expand Up @@ -143,8 +144,15 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
if (hi < num_heads + kv_num_heads) {
// q k rope
const uint32_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
uint32_t new_emb_idx =
rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
uint32_t new_emb_idx = emb_idx;
if (rope_3d) {
if (rope_3d_delta) {
const int rope_pos = write_seq_id + rope_3d_delta[ori_bi];
new_emb_idx = rope_pos * half_head_size + h_bias / 2;
} else {
new_emb_idx = emb_idx + ori_bi * max_seq_len * head_size;
}
}
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
}
Expand Down Expand Up @@ -237,7 +245,8 @@ __global__ void append_decode_cache_T_rope_kernel(
const int block_size,
const uint32_t elem_cnt,
const int kv_num_heads,
const bool rope_3d) {
const bool rope_3d,
const int* rope_3d_delta) {
using LoadT = AlignedVector<T, VecSize>;
using LoadBiasT = AlignedVector<T, VecSize>;
using LoadKVT = AlignedVector<T, VecSize>;
Expand Down Expand Up @@ -282,8 +291,15 @@ __global__ void append_decode_cache_T_rope_kernel(
if (hi < num_heads + kv_num_heads) {
// q k rope
const uint32_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
uint32_t new_emb_idx =
rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
uint32_t new_emb_idx = emb_idx;
if (rope_3d) {
if (rope_3d_delta) {
const int rope_pos = write_seq_id + rope_3d_delta[ori_bi];
new_emb_idx = rope_pos * half_head_size + h_bias / 2;
} else {
new_emb_idx = emb_idx + ori_bi * max_seq_len * head_size;
}
}
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
}
Expand Down Expand Up @@ -1221,6 +1237,7 @@ __global__ void append_decode_cache_int8_rope_qk_norm_kernel(
const float min_bound,
const int kv_num_heads,
const bool rope_3d,
const int* rope_3d_delta,
const float rms_norm_eps) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
Expand Down Expand Up @@ -1268,8 +1285,15 @@ __global__ void append_decode_cache_int8_rope_qk_norm_kernel(
Load<T, VecSize>(&qkv_now[bias_idx], &src_vec);
// q rope
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
const uint32_t new_emb_idx =
rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
uint32_t new_emb_idx = emb_idx;
if (rope_3d) {
if (rope_3d_delta) {
const int rope_pos = write_seq_id + rope_3d_delta[bid];
new_emb_idx = rope_pos * half_head_size + head_bias / 2;
} else {
new_emb_idx = emb_idx + bid * max_seq_len * HeadDim;
}
}
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
#pragma unroll
Expand Down Expand Up @@ -1363,8 +1387,15 @@ __global__ void append_decode_cache_int8_rope_qk_norm_kernel(
const int v_head_idx = head_idx - num_heads - kv_num_heads;
if (head_idx < num_heads + kv_num_heads) {
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
const uint32_t new_emb_idx =
rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
uint32_t new_emb_idx = emb_idx;
if (rope_3d) {
if (rope_3d_delta) {
const int rope_pos = write_seq_id + rope_3d_delta[bid];
new_emb_idx = rope_pos * half_head_size + head_bias / 2;
} else {
new_emb_idx = emb_idx + bid * max_seq_len * HeadDim;
}
}
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Expand Down Expand Up @@ -1533,7 +1564,8 @@ __global__ void append_decode_cache_int8_rope_kernel(
const float max_bound,
const float min_bound,
const int kv_num_heads,
const bool rope_3d) {
const bool rope_3d,
const int* rope_3d_delta) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
Expand Down Expand Up @@ -1564,7 +1596,13 @@ __global__ void append_decode_cache_int8_rope_kernel(
qkv_out + start_token_idx * hidden_size + head_idx * HeadDim;

uint32_t emb_offset = write_seq_id * half_head_size;
emb_offset += rope_3d ? bid * max_seq_len * HeadDim : 0;
if (rope_3d) {
if (rope_3d_delta) {
emb_offset = (write_seq_id + rope_3d_delta[bid]) * half_head_size;
} else {
emb_offset += bid * max_seq_len * HeadDim;
}
}
apply_rope<T, VecSize, HeadDim, 32, EnforceFmulRN>(qkv_now,
cos_emb + emb_offset,
sin_emb + emb_offset,
Expand Down Expand Up @@ -1634,8 +1672,15 @@ __global__ void append_decode_cache_int8_rope_kernel(
cache_v_scale + v_head_idx * HeadDim + head_bias;
if (head_idx < num_heads + kv_num_heads) {
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
uint32_t new_emb_idx =
rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
uint32_t new_emb_idx = emb_idx;
if (rope_3d) {
if (rope_3d_delta) {
const int rope_pos = write_seq_id + rope_3d_delta[bid];
new_emb_idx = rope_pos * half_head_size + head_bias / 2;
} else {
new_emb_idx = emb_idx + bid * max_seq_len * HeadDim;
}
}
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Expand Down
Loading
Loading