-
Notifications
You must be signed in to change notification settings - Fork 14
Expand file tree
/
Copy pathkernels.h
More file actions
261 lines (208 loc) · 22.1 KB
/
kernels.h
File metadata and controls
261 lines (208 loc) · 22.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
// Copyright (c) 2025, IST Austria, developed by Erik Schultheis
// SPDX-License-Identifier: Apache-2.0
//
#ifndef LLMQ_SRC_KERNELS_KERNELS_H
#define LLMQ_SRC_KERNELS_KERNELS_H
#include <cstdint>
#include <optional>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
struct cudaDeviceProp;
typedef struct cudnnContext* cudnnHandle_t;
typedef struct cublasLtContext* cublasLtHandle_t;
struct Tensor;
enum class ETensorDType: int;
enum class EMMTranspose { TT, TN, NT, NN };
enum class EMatmulBackend {CuBLAS, Custom};
void encoder_forward(float* out, const int* inp, const float* wte, const float* wpe, int B, int T, int C, int V, cudaStream_t stream);
void encoder_forward(nv_bfloat16* out, const int* inp, const nv_bfloat16* wte, const nv_bfloat16* wpe, int B, int T, int C, int V, cudaStream_t stream);
void encoder_forward(Tensor& out, const Tensor& inp, const Tensor& wte, const Tensor& wpe, int B, int T, int C, int V, cudaStream_t stream);
void encoder_backward(float* dwte, int* scratch,
int* workload_indices, int4* bucket_info,
const float* dout, const int* inp, const int* inputs_cpu,
int B, int T, int C, unsigned int seed, cudaStream_t stream, cudaEvent_t sync_event, cudaStream_t copy_stream);
void encoder_backward(nv_bfloat16* dwte, int* scratch,
int* workload_indices, int4* bucket_info,
const nv_bfloat16* dout, const int* inp, const int* inputs_cpu,
int B, int T, int C, unsigned int seed, cudaStream_t stream, cudaEvent_t sync_event, cudaStream_t copy_stream);
// The kernel runs on `stream`, but the bucket info that gets generated on CPU to enable efficient determinism
// can be copied using `copy_stream`, so the kernel launch does not have to wait.
void encoder_backward(Tensor& dwte, Tensor& scratch,
Tensor& workload_indices, Tensor& bucket_info,
const Tensor& dout, const Tensor& inp, const Tensor& inputs_cpu,
int B, int T, int C, unsigned int seed, cudaStream_t stream, cudaEvent_t sync_event, cudaStream_t copy_stream);
void rmsnorm_forward(float* out, float* rms, const float* inp, const float* weight, float* abs_max_ptr, float epsilon, int B, int T, int C, cudaStream_t stream);
void rmsnorm_forward(nv_bfloat16* out, float* rms, const nv_bfloat16* inp, const nv_bfloat16* weight, float* abs_max_ptr, float epsilon, int B, int T, int C, cudaStream_t stream);
void rmsnorm_forward(Tensor& out, Tensor& rms, const Tensor& inp, const Tensor& weight, float* abs_max_ptr, float epsilon, int B, int T, int C, cudaStream_t stream);
int get_rmsnorm_backward_scratch_size(int C, const cudaDeviceProp& dp);
void rmsnorm_backward(float* dinp, float* dweight, std::byte* scratch, const float* dresidual, const float* dout, const float* inp, const float* weight, const float* rstd, float* abs_max_ptr,
int B, int T, int C, const cudaDeviceProp& dp, cudaStream_t stream);
void rmsnorm_backward(nv_bfloat16* dinp, nv_bfloat16* dweight, std::byte* scratch, const nv_bfloat16* dresidual, const nv_bfloat16* dout, const nv_bfloat16* inp, const nv_bfloat16* weight, const float* rstd, float* abs_max_ptr,
int B, int T, int C, const cudaDeviceProp& dp, cudaStream_t stream);
void rmsnorm_backward(Tensor& dinp, Tensor& dweight, Tensor& scratch, const Tensor& dresidual, const Tensor& dout, const Tensor& inp, const Tensor& weight, const Tensor& rstd, float* abs_max_ptr,
int B, int T, int C, const cudaDeviceProp& dp, cudaStream_t stream);
void fused_residual_rmsnorm_forward(float* residual, float* normed, float* rrms, const float* inp1, const float* inp2, const float* weight, float* abs_max_ptr,
float epsilon, int N, int C, cudaStream_t stream);
void fused_residual_rmsnorm_forward(nv_bfloat16* residual, nv_bfloat16* normed, float* rrms, const nv_bfloat16* inp1, const nv_bfloat16* inp2, const nv_bfloat16* weight, float* abs_max_ptr,
float epsilon, int N, int C, cudaStream_t stream);
void fused_residual_rmsnorm_forward(Tensor& residual, Tensor& normed, Tensor& rrms, const Tensor& inp1, const Tensor& inp2, const Tensor& weight, float* abs_max_ptr,
float epsilon, int N, int C, cudaStream_t stream);
void matmul(float* c, const float* a, const float* b, const float* bias, const float* scale_a, const float* scale_b,
cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size,
int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend);
void matmul(float* c, const nv_bfloat16* a, const nv_bfloat16* b, const float* bias, const float* scale_a, const float* scale_b,
cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size,
int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend);
void matmul(float* c, const __nv_fp8_e4m3* a, const __nv_fp8_e4m3* b, const float* bias, const float* scale_a, const float* scale_b,
cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size,
int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend);
void matmul(float* c, const __nv_fp8_e4m3* a, const __nv_fp8_e4m3* b, const nv_bfloat16* bias, const float* scale_a, const float* scale_b,
cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size,
int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend);
void matmul(float* c, const __nv_fp8_e4m3* a, const __nv_fp8_e5m2* b, const nv_bfloat16* bias, const float* scale_a, const float* scale_b,
cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size,
int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend);
void matmul(nv_bfloat16* c, const nv_bfloat16* a, const nv_bfloat16* b, const nv_bfloat16* bias, const float* scale_a, const float* scale_b,
cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size,
int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend);
void matmul(nv_bfloat16* c, const __nv_fp8_e4m3* a, const __nv_fp8_e4m3* b, const nv_bfloat16* bias, const float* scale_a, const float* scale_b,
cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size,
int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend);
void matmul(nv_bfloat16* c, const __nv_fp8_e4m3* a, const __nv_fp8_e5m2* b, const nv_bfloat16* bias, const float* scale_a, const float* scale_b,
cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size,
int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend);
void matmul(Tensor& c, const Tensor& a, const Tensor& b, const Tensor& bias, const float* scale_a, const float* scale_b,
cublasLtHandle_t handle, Tensor& workspace,
int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream, EMatmulBackend backend);
void add_bias(float* out, const float* bias, int B, int T, int OC, cudaStream_t stream);
void add_bias(nv_bfloat16* out, const nv_bfloat16* bias, int B, int T, int OC, cudaStream_t stream);
int get_bias_backward_scratch_size(ETensorDType dtype, int OC, const cudaDeviceProp& dp);
void backward_bias(float* dbias, const float* dout, const float* scale_a, const float* scale_b, float* dbias_buffer, int B, int T, int OC, const cudaDeviceProp& dp, cudaStream_t stream);
void backward_bias(nv_bfloat16* dbias, const nv_bfloat16* dout, const float* scale_a, const float* scale_b, float* dbias_buffer, int B, int T, int OC, const cudaDeviceProp& dp, cudaStream_t stream);
void backward_bias(nv_bfloat16* dbias, const __nv_fp8_e4m3* dout, const float* scale_a, const float* scale_b, float* dbias_buffer, int B, int T, int OC, const cudaDeviceProp& dp, cudaStream_t stream);
void backward_bias(nv_bfloat16* dbias, const __nv_fp8_e5m2* dout, const float* scale_a, const float* scale_b, float* dbias_buffer, int B, int T, int OC, const cudaDeviceProp& dp, cudaStream_t stream);
void backward_bias(Tensor& dbias, const Tensor& dout, const float* scale_a, const float* scale_b, Tensor& dbias_buffer, int B, int T, int OC, const cudaDeviceProp& dp, cudaStream_t stream);
void precompute_freqs_cis(float *freqs_cis, int dim, int end, float theta);
void precompute_freqs_cis(half *freqs_cis, int dim, int end, float theta);
void rope_forward(float* out, const float* in, const float *freqs_cis, float* abs_max_ptr, int B, int T, int Nq, int Nkv, int head_dim, cudaStream_t stream);
void rope_forward(nv_bfloat16* out, const nv_bfloat16* in, const half *freqs_cis, float* abs_max_ptr, int B, int T, int Nq, int Nkv, int head_dim, cudaStream_t stream);
void rope_forward(Tensor& out, const Tensor& in, const Tensor& freqs_cis, float* abs_max_ptr, int B, int T, int Nq, int Nkv, int head_dim, cudaStream_t stream);
void rope_backward(float* dinp, const float* dout, const float *freqs_cis, float* abs_max_ptr, int B, int T, int Nq, int Nkv, int head_dim, cudaStream_t stream);
void rope_backward(nv_bfloat16* dinp, const nv_bfloat16* dout, const half *freqs_cis, float* abs_max_ptr, int B, int T, int Nq, int Nkv, int head_dim, cudaStream_t stream);
void rope_backward(Tensor& dinp, const Tensor& dout, const Tensor& freqs_cis, float* abs_max_ptr, int B, int T, int Nq, int Nkv, int head_dim, cudaStream_t stream);
// swiglu assumes that input is the concatenation of gate and up projection.
void swiglu_forward(nv_bfloat16* out, const nv_bfloat16* inp, float* abs_max_ptr, int B, int T, int C, cudaStream_t stream);
void swiglu_forward(float* out, const float* inp, float* abs_max_ptr, int B, int T, int C, cudaStream_t stream);
void swiglu_forward(Tensor& out, const Tensor& inp, float* abs_max_ptr, int B, int T, int C, cudaStream_t stream);
void swiglu_forward_quant(__nv_fp8_e4m3* out, float* scale_ptr,const nv_bfloat16* inp, const float* abs_max_ptr, int B, int T, int C, cudaStream_t stream);
void swiglu_forward_quant(Tensor& out, float* scale_ptr, const Tensor& inp, const float* abs_max_ptr, int B, int T, int C, cudaStream_t stream);
void swiglu_backward(nv_bfloat16* dinp, const nv_bfloat16* dout, const nv_bfloat16* inp, float* abs_max, int B, int T, int C, cudaStream_t stream);
void swiglu_backward(float* dinp, const float* dout, const float* inp, float* abs_max, int B, int T, int C, cudaStream_t stream);
void swiglu_backward(Tensor& dinp, const Tensor& dout, const Tensor& inp, float* abs_max, int B, int T, int C, cudaStream_t stream);
void attention_forward_cudnn(nv_bfloat16* out, // output: (B, T, Nq, HS)
float* stats, // output for backward pass: (B, Hq, T)
const nv_bfloat16* inp, // input: (B, T, Hq + 2Hkv, HS) QKV
std::byte* workspace, cudnnHandle_t handle,
int B, int T, int Hq, int Hkv, int HS, cudaStream_t stream);
void attention_forward_cudnn(float* out, // output: (B, T, Nq, HS)
float* stats, // output for backward pass: (B, Hq, T)
const float* inp, // input: (B, T, Hq + 2Hkv, HS) QKV
std::byte* workspace, cudnnHandle_t handle,
int B, int T, int Hq, int Hkv, int HS, cudaStream_t stream);
void attention_forward_cudnn(Tensor& out, // output: (B, T, Nq, HS)
Tensor& stats, // output for backward pass: (B, Hq, T)
const Tensor& inp, // input: (B, T, Hq + 2Hkv, HS) QKV
Tensor& workspace, cudnnHandle_t handle,
int B, int T, int Hq, int Hkv, int HS, cudaStream_t stream);
std::size_t cudnn_get_workspace_size(int B, int T, int Hq, int Hkv, int HS, cudnnHandle_t handle);
void attention_backward_cudnn(nv_bfloat16* dqkv, const float* stats,
const nv_bfloat16* out, const nv_bfloat16* dout, const nv_bfloat16* qkv,
std::byte* workspace, cudnnHandle_t handle,
int B, int T, int Hq, int Hkv, int HS, cudaStream_t stream);
void attention_backward_cudnn(Tensor& dqkv, const Tensor& stats,
const Tensor& out, const Tensor& dout, const Tensor& qkv,
Tensor& workspace, cudnnHandle_t handle,
int B, int T, int Hq, int Hkv, int HS, cudaStream_t stream);
void fused_classifier(float* logits, float* losses, float* lse,
float dloss, const int* targets, float z_reg,
int BT, int V, int P, bool write_dlogits, cudaStream_t stream);
void fused_classifier(nv_bfloat16* logits, float* losses, float* lse,
float dloss, const int* targets, float z_reg,
int BT, int V, int P, bool write_dlogits, cudaStream_t stream);
void fused_classifier(Tensor& logits, Tensor& losses, Tensor& lse,
float dloss, const Tensor& targets, float z_reg,
int BT, int V, int P, bool write_dlogits, cudaStream_t stream);
void grouped_loss_sum(float* out, const float* per_token_loss, int B, int T, cudaStream_t stream);
void grouped_loss_sum(Tensor& out, const Tensor& per_token_loss, int B, int T, cudaStream_t stream);
void reduce_lse_stats(float* result, const float* in, long N, bool first_step, cudaStream_t stream);
int get_max_num_block_sums(const cudaDeviceProp& dp);
void global_norm_squared(float* out, const float* values, size_t count, const cudaDeviceProp& dp, cudaStream_t stream);
void global_norm_squared(float* out, const nv_bfloat16* values, size_t count, const cudaDeviceProp& dp, cudaStream_t stream);
void global_norm_squared(Tensor& out, const Tensor& values, size_t count, const cudaDeviceProp& dp, cudaStream_t stream);
/// puts norm squared in out[0], norm in out_cpu, and grad scale factor in out[1]
void global_norm_sqrt(float* out, float* out_cpu, float grad_clip, const cudaDeviceProp& dp, cudaStream_t stream);
void deterministic_sum(float* out, const float* values, std::size_t count, cudaStream_t stream);
void deterministic_sum(float* out, const nv_bfloat16* values, std::size_t count, cudaStream_t stream);
void adamw_update(float* params_memory, const float* grads_memory, float* m_memory, float* v_memory, size_t num_parameters,
float learning_rate, float beta1, float beta2, int t, float eps, float weight_decay,
const float* grad_scale, float* abs_max, unsigned int seed, cudaStream_t stream);
void adamw_update(nv_bfloat16* params_memory, const nv_bfloat16* grads_memory, float* m_memory, float* v_memory, size_t num_parameters,
float learning_rate, float beta1, float beta2, int t, float eps, float weight_decay,
const float* grad_scale, float* abs_max, unsigned int seed, cudaStream_t stream);
void adamw_update(nv_bfloat16* params_memory, const nv_bfloat16* grads_memory, nv_bfloat16* m_memory, float* v_memory, size_t num_parameters,
float learning_rate, float beta1, float beta2, int t, float eps, float weight_decay,
const float* grad_scale, float* abs_max, unsigned int seed, cudaStream_t stream);
void adamw_update(nv_bfloat16* params_memory, const nv_bfloat16* grads_memory, nv_bfloat16* m_memory, nv_bfloat16* v_memory, size_t num_parameters,
float learning_rate, float beta1, float beta2, int t, float eps, float weight_decay,
const float* grad_scale, float* abs_max, unsigned int seed, cudaStream_t stream);
void adamw_update(nv_bfloat16* params_memory, const nv_bfloat16* grads_memory, __nv_fp8_e4m3* m_memory, nv_bfloat16* v_memory, size_t num_parameters,
float learning_rate, float beta1, float beta2, int t, float eps, float weight_decay,
const float* grad_scale, float* m_scales, float* abs_max, unsigned int seed, cudaStream_t stream);
void adamw_update(Tensor& params_memory, const Tensor& grads_memory, Tensor& m_memory, Tensor& v_memory, size_t num_parameters,
float learning_rate, float beta1, float beta2, int t, float eps, float weight_decay,
const float* grad_scale, Tensor& m_scales, float* abs_max, unsigned int seed, cudaStream_t stream);
// quantization
void abs_max(float* scale, const float* in, long N, const cudaDeviceProp& dp, cudaStream_t stream);
void abs_max(float* scale, const nv_bfloat16* in, long N, const cudaDeviceProp& dp, cudaStream_t stream);
void abs_max(float* scale, const Tensor& in, long N, const cudaDeviceProp& dp, cudaStream_t stream);
void quantize_with_abs_max(nv_bfloat16* out, float* scale_ptr, const float* in, const float* abs_max, long N, const cudaDeviceProp& dp, cudaStream_t stream);
void quantize_with_abs_max(std::int8_t* out, float* scale_ptr, const float* in, const float* abs_max, long N, const cudaDeviceProp& dp, cudaStream_t stream);
void quantize_with_abs_max(__nv_fp8_e4m3* out, float* scale_ptr, const float* in, const float* abs_max, long N, const cudaDeviceProp& dp, cudaStream_t stream);
void quantize_with_abs_max(__nv_fp8_e5m2* out, float* scale_ptr, const float* in, const float* abs_max, long N, const cudaDeviceProp& dp, cudaStream_t stream);
void quantize_with_abs_max(std::int8_t* out, float* scale_ptr, const nv_bfloat16* in, const float* abs_max, long N, const cudaDeviceProp& dp, cudaStream_t stream);
void quantize_with_abs_max(__nv_fp8_e4m3* out, float* scale_ptr, const nv_bfloat16* in, const float* abs_max, long N, const cudaDeviceProp& dp, cudaStream_t stream);
void quantize_with_abs_max(__nv_fp8_e5m2* out, float* scale_ptr, const nv_bfloat16* in, const float* abs_max, long N, const cudaDeviceProp& dp, cudaStream_t stream);
void quantize_with_abs_max(Tensor& out, float* scale_ptr, const Tensor& in, const float* abs_max, long N, const cudaDeviceProp& dp, cudaStream_t stream);
void quantize_and_transpose_with_abs_max(nv_bfloat16* out, float* scale_ptr, const float* in, const float* abs_max, int rows, int cols, const cudaDeviceProp& dp, cudaStream_t stream);
void quantize_and_transpose_with_abs_max(std::int8_t* out, float* scale_ptr, const float* in, const float* abs_max, int rows, int cols, const cudaDeviceProp& dp, cudaStream_t stream);
void quantize_and_transpose_with_abs_max(__nv_fp8_e4m3* out, float* scale_ptr, const float* in, const float* abs_max, int rows, int cols, const cudaDeviceProp& dp, cudaStream_t stream);
void quantize_and_transpose_with_abs_max(std::int8_t* out, float* scale_ptr, const nv_bfloat16* in, const float* abs_max, int rows, int cols, const cudaDeviceProp& dp, cudaStream_t stream);
void quantize_and_transpose_with_abs_max(__nv_fp8_e4m3* out, float* scale_ptr, const nv_bfloat16* in, const float* abs_max, int rows, int cols, const cudaDeviceProp& dp, cudaStream_t stream);
void quantize_and_transpose_with_abs_max(Tensor& out, float* scale_ptr, const Tensor& in, const float* abs_max, int rows, int cols, const cudaDeviceProp& dp, cudaStream_t stream);
void transpose(float* dst, const float* src, int rows, int cols, cudaStream_t stream);
void transpose(__nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src, int rows, int cols, cudaStream_t stream);
void transpose(__nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src, int rows, int cols, cudaStream_t stream);
void transpose(nv_bfloat16* dst, const nv_bfloat16* src, int rows, int cols, cudaStream_t stream);
void transpose(Tensor& dst, const Tensor& src, int rows, int cols, cudaStream_t stream);
void vector_add_sr(float* dest, const float* left, const float* right, float scale, long nelem, unsigned seed, cudaStream_t stream);
void vector_add_sr(nv_bfloat16* dest, const nv_bfloat16* left, const nv_bfloat16* right, float scale, long nelem, unsigned seed, cudaStream_t stream);
void vector_add_sr(Tensor& dest, const Tensor& left, const Tensor& right, float scale, long nelem, unsigned seed, cudaStream_t stream);
//! \fn void vector_reduce_sr(Tensor& dest, const Tensor& src, float scale, int n_shards, int skip, long nelem, bool accumulate, unsigned seed, cudaStream_t stream);
//! \brief Reduce shards of tensor using stochastic rounding
//! \details Interprets `src` as a tensor of `n_shard` shards of size `nelem` each. The shards are summed together, and the result is either written to (`accumulate = false`)
//! or added into (`accumulate = true`) `dest`, after being scaled by `scale`. All intermediate calculations are done in float precision, and stochastic rounding using the
//! provided `seed` is applied before writing to `dest`. The `skip` parameter allows to skip one of the shards. Set to -1 to disable skipping.
void vector_reduce_sr(float* dest, const float* src, float scale, int n_shards, int skip, long nelem, bool accumulate, unsigned seed, cudaStream_t stream);
void vector_reduce_sr(nv_bfloat16* dest, const nv_bfloat16* src, float scale, int n_shards, int skip, long nelem, bool accumulate, unsigned seed, cudaStream_t stream);
void vector_reduce_sr(Tensor& dest, const Tensor& src, float scale, int n_shards, int skip, long nelem, bool accumulate, unsigned seed, cudaStream_t stream);
void fill_normal(float* dst, std::size_t count, float mean, float std, unsigned long long seed, unsigned long long subsequence, cudaStream_t stream);
void fill_normal(nv_bfloat16* dst, std::size_t count, float mean, float std, unsigned long long seed, unsigned long long subsequence, cudaStream_t stream);
void fill_normal(Tensor& dest, std::size_t count, float mean, float std, unsigned long long seed, unsigned long long subsequence, cudaStream_t stream);
void fill_constant(float* dst, float value, std::size_t count, cudaStream_t stream);
void fill_constant(nv_bfloat16* dst, nv_bfloat16 value, std::size_t count, cudaStream_t stream);
void fill_constant(Tensor& dest, float value, std::size_t count, cudaStream_t stream);
void convert_dtype(float* target, const nv_bfloat16* source, std::size_t size);
void convert_dtype(nv_bfloat16* target, const float* source, std::size_t size);
void convert_dtype(nv_bfloat16* target, const half* source, std::size_t size);
#endif //LLMQ_SRC_KERNELS_KERNELS_H