From a432f38e5899d7d7623890170983aef5a0b264a6 Mon Sep 17 00:00:00 2001 From: sentseven Date: Wed, 1 Apr 2026 13:27:03 -0500 Subject: [PATCH] feat: port TQ3_0 KV cache quantization from llama-turboquant TurboQuant 3-bit (3.5 bpw) KV cache compression: - Per-block WHT rotation with 4-centroid MSE codebook - QJL residual signs for error correction - GPU kernels: vec_dot, MMVQ, convert, set-rows, cpy - CPU: quantize/dequantize with WHT butterfly transform - Flash attention auto-disabled for TQ3_0 K cache Combined with PrismML's Q1_0 GPU inference, this enables 1-bit weights + 3-bit KV cache on a single build. --- common/arg.cpp | 1 + ggml/include/ggml.h | 3 +- ggml/src/ggml-common.h | 15 +++ ggml/src/ggml-cpu/ggml-cpu.c | 4 + ggml/src/ggml-cpu/ggml-cpu.cpp | 9 +- ggml/src/ggml-cpu/ops.cpp | 7 ++ ggml/src/ggml-cpu/quants.c | 6 ++ ggml/src/ggml-cpu/quants.h | 1 + ggml/src/ggml-cuda/common.cuh | 7 ++ ggml/src/ggml-cuda/convert.cu | 54 ++++++++++ ggml/src/ggml-cuda/cpy-utils.cuh | 73 +++++++++++++ ggml/src/ggml-cuda/ggml-cuda.cu | 4 +- ggml/src/ggml-cuda/mmvq.cu | 8 ++ ggml/src/ggml-cuda/set-rows.cu | 10 ++ ggml/src/ggml-cuda/vecdotq.cuh | 56 ++++++++++ ggml/src/ggml-quants.c | 169 ++++++++++++++++++++++++++++++ ggml/src/ggml-quants.h | 3 + ggml/src/ggml.c | 9 ++ src/llama-context.cpp | 6 ++ tools/llama-bench/llama-bench.cpp | 3 + 20 files changed, 445 insertions(+), 3 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 05f4a5244e7..93f0584f81b 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -398,6 +398,7 @@ const std::vector kv_cache_types = { GGML_TYPE_IQ4_NL, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, + GGML_TYPE_TQ3_0, }; static ggml_type kv_cache_type_from_str(const std::string & s) { diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index a64ca675439..1949a15c442 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -429,7 +429,8 @@ extern "C" { GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block) GGML_TYPE_Q1_0 = 40, GGML_TYPE_Q1_0_g128 = 41, - GGML_TYPE_COUNT = 42, + GGML_TYPE_TQ3_0 = 42, // TurboQuant 3-bit polar + QJL (no per-block scale) + GGML_TYPE_COUNT = 43, }; // precision diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index cf8216e39d4..8e123b9cb71 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -276,6 +276,21 @@ typedef struct { } block_tq2_0; static_assert(sizeof(block_tq2_0) == sizeof(ggml_half) + QK_K / 4, "wrong tq2_0 block size/padding"); +// TurboQuant 3-bit quantization (3.5 bpw) +// Per TurboQuant paper (Algorithm 2: TurboQuant_prod), ICLR 2026 +// Each block of 32 values is quantized as: +// - 2-bit MSE codebook indices (after random rotation Π·x) +// - 1-bit QJL residual signs (sign(S·r) where r = x - dequant_mse(quant_mse(x))) +// - FP16 residual norm ||r||₂ for QJL scaling +// Requires per-model rotation matrices Π and S (stored externally) +#define QK_TQ3_0 32 +typedef struct { + uint8_t qs[QK_TQ3_0 / 4]; // 2-bit codebook indices, 32 × 2 bits = 8 bytes + uint8_t qr[QK_TQ3_0 / 8]; // QJL residual signs, 32 × 1 bit = 4 bytes + ggml_half gamma; // ||residual||₂ for QJL correction scaling +} block_tq3_0; +static_assert(sizeof(block_tq3_0) == QK_TQ3_0/4 + QK_TQ3_0/8 + sizeof(ggml_half), "wrong tq3_0 block size/padding"); + // // Super-block quantization structures // diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 48fbddf74f5..dcebc2bdbf4 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -396,6 +396,10 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, }, + [GGML_TYPE_TQ3_0] = { + .from_float = quantize_row_tq3_0, + .nrows = 1, + }, [GGML_TYPE_I32] = { .from_float = (ggml_from_float_t) ggml_cpu_fp32_to_i32, }, diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index ddf1737a317..ac3ccd9c1c2 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -448,7 +448,11 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st op->type != GGML_TYPE_IQ1_S && op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float case GGML_OP_MUL_MAT: - return src1->type == GGML_TYPE_F32 || src1->type == ggml_get_type_traits_cpu(src0->type)->vec_dot_type; + { + const auto * traits = ggml_get_type_traits_cpu(src0->type); + return traits->vec_dot != NULL && + (src1->type == GGML_TYPE_F32 || src1->type == traits->vec_dot_type); + } case GGML_OP_SOFT_MAX_BACK: { if (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type != GGML_TYPE_F32) { return false; @@ -466,6 +470,9 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st case GGML_OP_OUT_PROD: return (src0->type == GGML_TYPE_F32 || (ggml_is_quantized(src0->type) && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3])) && src1->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; + case GGML_OP_FLASH_ATTN_EXT: + // K type must have vec_dot for CPU flash attention + return ggml_get_type_traits_cpu(src1->type)->vec_dot != NULL; default: return true; } diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index efe3a8ef757..3b26f09e93e 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -677,6 +677,7 @@ void ggml_compute_forward_add( case GGML_TYPE_Q6_K: case GGML_TYPE_TQ1_0: case GGML_TYPE_TQ2_0: + case GGML_TYPE_TQ3_0: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: @@ -1126,6 +1127,7 @@ void ggml_compute_forward_add1( case GGML_TYPE_Q6_K: case GGML_TYPE_TQ1_0: case GGML_TYPE_TQ2_0: + case GGML_TYPE_TQ3_0: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: @@ -1254,6 +1256,7 @@ void ggml_compute_forward_acc( case GGML_TYPE_Q6_K: case GGML_TYPE_TQ1_0: case GGML_TYPE_TQ2_0: + case GGML_TYPE_TQ3_0: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: @@ -4341,6 +4344,7 @@ void ggml_compute_forward_out_prod( case GGML_TYPE_Q6_K: case GGML_TYPE_TQ1_0: case GGML_TYPE_TQ2_0: + case GGML_TYPE_TQ3_0: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: @@ -4616,6 +4620,7 @@ void ggml_compute_forward_set( case GGML_TYPE_Q6_K: case GGML_TYPE_TQ1_0: case GGML_TYPE_TQ2_0: + case GGML_TYPE_TQ3_0: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: @@ -4840,6 +4845,7 @@ void ggml_compute_forward_get_rows( case GGML_TYPE_Q6_K: case GGML_TYPE_TQ1_0: case GGML_TYPE_TQ2_0: + case GGML_TYPE_TQ3_0: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: @@ -5566,6 +5572,7 @@ void ggml_compute_forward_clamp( case GGML_TYPE_Q6_K: case GGML_TYPE_TQ1_0: case GGML_TYPE_TQ2_0: + case GGML_TYPE_TQ3_0: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: diff --git a/ggml/src/ggml-cpu/quants.c b/ggml/src/ggml-cpu/quants.c index 7f8456a5db8..e705aacb0f6 100644 --- a/ggml/src/ggml-cpu/quants.c +++ b/ggml/src/ggml-cpu/quants.c @@ -112,6 +112,12 @@ void quantize_row_tq2_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, quantize_row_tq2_0_ref(x, y, k); } +void quantize_row_tq3_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(k % QK_TQ3_0 == 0); + block_tq3_0 * GGML_RESTRICT y = vy; + quantize_row_tq3_0_ref(x, y, k); +} + //===================================== Q8_K ============================================== void quantize_row_q8_K_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { diff --git a/ggml/src/ggml-cpu/quants.h b/ggml/src/ggml-cpu/quants.h index dcf3ebe911e..98884619cdf 100644 --- a/ggml/src/ggml-cpu/quants.h +++ b/ggml/src/ggml-cpu/quants.h @@ -32,6 +32,7 @@ void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in void quantize_row_tq1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_tq2_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_tq3_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 23af1219780..984affaeaaa 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1043,6 +1043,13 @@ struct ggml_cuda_type_traits { static constexpr int qi = QI3_S; }; +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_TQ3_0; // 32 + static constexpr int qr = 1; + static constexpr int qi = QK_TQ3_0 / 4; // 8 +}; + ////////////////////// struct ggml_cuda_device_info { diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 960a3d94ebc..553f6e6be01 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -486,6 +486,50 @@ static __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_ } } +// TurboQuant TQ3_0: 2-bit codebook dequantization + inverse WHT +// Dequantize to rotated space, then apply inverse WHT32 cooperatively +template +static __global__ void dequantize_block_tq3_0(const void * __restrict__ vx, dst_t * __restrict__ yy) { + const float centroids[4] = { -1.510f, -0.4528f, 0.4528f, 1.510f }; + const int8_t signs[32] = { + +1, -1, +1, +1, -1, -1, +1, -1, +1, +1, -1, +1, -1, +1, -1, -1, + +1, -1, -1, +1, +1, -1, +1, -1, -1, +1, +1, +1, -1, -1, +1, -1 + }; + + const int64_t i = blockIdx.x; + const block_tq3_0 * x = (const block_tq3_0 *)vx; + const int tid = threadIdx.x; + if (tid >= 32) return; + + const float d = __half2float(x[i].gamma); + + // Step 1: Each thread dequantizes its value (in rotated space) + const int byte_idx = tid / 4; + const int bit_shift = 2 * (tid % 4); + const int idx = (x[i].qs[byte_idx] >> bit_shift) & 3; + + __shared__ float shmem[32]; + shmem[tid] = d * centroids[idx]; + __syncthreads(); + + // Step 2: Cooperative inverse WHT (5 butterfly stages) + for (int step = 1; step < 32; step <<= 1) { + int partner = tid ^ step; // butterfly partner + float a = shmem[tid]; + float b = shmem[partner]; + __syncthreads(); + if (tid < partner) { + shmem[tid] = a + b; + shmem[partner] = a - b; + } + __syncthreads(); + } + + // Step 3: Normalize and undo sign flips + const float inv_sqrt32 = 0.17677669529663688f; + yy[i * QK_TQ3_0 + tid] = shmem[tid] * inv_sqrt32 * signs[tid]; +} + template static void dequantize_block_cuda(const void * vx, dst_t * y, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, @@ -617,6 +661,12 @@ static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t dequantize_block_mxfp4<<>>(vx, y); } +template +static void dequantize_row_tq3_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb = k / QK_TQ3_0; + dequantize_block_tq3_0<<>>(vx, y); +} + template static __global__ void convert_unary( const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, @@ -719,6 +769,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_iq3_s_cuda; case GGML_TYPE_MXFP4: return dequantize_row_mxfp4_cuda; + case GGML_TYPE_TQ3_0: + return dequantize_row_tq3_0_cuda; case GGML_TYPE_F32: return convert_unary_cont_cuda; case GGML_TYPE_BF16: @@ -774,6 +826,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq3_s_cuda; case GGML_TYPE_MXFP4: return dequantize_row_mxfp4_cuda; + case GGML_TYPE_TQ3_0: + return dequantize_row_tq3_0_cuda; case GGML_TYPE_F16: return convert_unary_cont_cuda; case GGML_TYPE_BF16: diff --git a/ggml/src/ggml-cuda/cpy-utils.cuh b/ggml/src/ggml-cuda/cpy-utils.cuh index 7697c292dd6..02c6aaa6f3d 100644 --- a/ggml/src/ggml-cuda/cpy-utils.cuh +++ b/ggml/src/ggml-cuda/cpy-utils.cuh @@ -211,6 +211,79 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) { quantize_f32_iq4_nl_block((const float *)cxi, (block_iq4_nl *)cdsti); } +// TQ3_0: Device-side Walsh-Hadamard Transform (WHT32) for rotation +// Same sign pattern as CPU (must match for consistency) +static __device__ __forceinline__ void tq3_wht32_forward_device(float * x) { + const int8_t signs[32] = { + +1, -1, +1, +1, -1, -1, +1, -1, +1, +1, -1, +1, -1, +1, -1, -1, + +1, -1, -1, +1, +1, -1, +1, -1, -1, +1, +1, +1, -1, -1, +1, -1 + }; + for (int j = 0; j < 32; j++) x[j] *= signs[j]; + for (int step = 1; step < 32; step <<= 1) { + for (int i = 0; i < 32; i += step * 2) { + for (int j = i; j < i + step; j++) { + float a = x[j], b = x[j + step]; + x[j] = a + b; x[j + step] = a - b; + } + } + } + const float s = 0.17677669529663688f; // 1/sqrt(32) + for (int j = 0; j < 32; j++) x[j] *= s; +} + +static __device__ __forceinline__ void tq3_wht32_inverse_device(float * x) { + for (int step = 1; step < 32; step <<= 1) { + for (int i = 0; i < 32; i += step * 2) { + for (int j = i; j < i + step; j++) { + float a = x[j], b = x[j + step]; + x[j] = a + b; x[j + step] = a - b; + } + } + } + const int8_t signs[32] = { + +1, -1, +1, +1, -1, -1, +1, -1, +1, +1, -1, +1, -1, +1, -1, -1, + +1, -1, -1, +1, +1, -1, +1, -1, -1, +1, +1, +1, -1, -1, +1, -1 + }; + const float s = 0.17677669529663688f; + for (int j = 0; j < 32; j++) x[j] *= s * signs[j]; +} + +// TQ3_0: GPU-side 2-bit scalar codebook quantization with WHT rotation +static __device__ void quantize_f32_tq3_0_block(const float * __restrict__ x, block_tq3_0 * __restrict__ y) { + const float centroids[4] = { -1.510f, -0.4528f, 0.4528f, 1.510f }; + + // Copy and apply WHT rotation + float rotated[QK_TQ3_0]; + for (int j = 0; j < QK_TQ3_0; j++) rotated[j] = x[j]; + tq3_wht32_forward_device(rotated); + + memset(y, 0, sizeof(block_tq3_0)); + + float amax = 0.0f; + for (int j = 0; j < QK_TQ3_0; j++) { + float av = fabsf(rotated[j]); + if (av > amax) amax = av; + } + + const float d = amax / 1.510f; + const float id = d > 0.0f ? 1.0f / d : 0.0f; + y->gamma = __float2half(d); + + for (int j = 0; j < QK_TQ3_0; j++) { + float xn = rotated[j] * id; + int idx; + if (xn < 0.0f) { idx = (xn < -0.9814f) ? 0 : 1; } + else { idx = (xn < 0.9814f) ? 2 : 3; } + y->qs[j / 4] |= (idx << (2 * (j % 4))); + float residual = rotated[j] - d * centroids[idx]; + if (residual >= 0.0f) { y->qr[j / 8] |= (1 << (j % 8)); } + } +} + +static __device__ void cpy_blck_f32_tq3_0(const char * cxi, char * cdsti) { + quantize_f32_tq3_0_block((const float *)cxi, (block_tq3_0 *)cdsti); +} + template static __device__ void cpy_1_scalar(const char * cxi, char * cdsti) { *(dst_t *) cdsti = ggml_cuda_cast(*(const src_t *) cxi); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 2f49305431f..44b6061f105 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4622,6 +4622,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_BF16: + case GGML_TYPE_TQ3_0: return true; default: return false; @@ -4656,7 +4657,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g { return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 || op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q5_0 || - op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_IQ4_NL) && + op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_IQ4_NL || + op->type == GGML_TYPE_TQ3_0) && op->src[0]->type == GGML_TYPE_F32 && (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32); } break; diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 8ad36573e1d..ca54764cffd 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -31,6 +31,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) case GGML_TYPE_IQ4_NL: return vec_dot_iq4_nl_q8_1; case GGML_TYPE_IQ4_XS: return vec_dot_iq4_xs_q8_1; case GGML_TYPE_IQ3_S: return vec_dot_iq3_s_q8_1; + case GGML_TYPE_TQ3_0: return vec_dot_tq3_0_q8_1; default: return nullptr; } } @@ -57,6 +58,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) { case GGML_TYPE_IQ3_S: return VDR_IQ3_S_Q8_1_MMVQ; case GGML_TYPE_IQ4_NL: return VDR_IQ4_NL_Q8_1_MMVQ; case GGML_TYPE_IQ4_XS: return VDR_IQ4_XS_Q8_1_MMVQ; + case GGML_TYPE_TQ3_0: return VDR_TQ3_0_Q8_1_MMVQ; default: return 1; } } @@ -645,6 +647,12 @@ static void mul_mat_vec_q_switch_type( nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; + case GGML_TYPE_TQ3_0: + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; default: GGML_ABORT("fatal error"); break; diff --git a/ggml/src/ggml-cuda/set-rows.cu b/ggml/src/ggml-cuda/set-rows.cu index 631de7e8fa5..78e14ebb0b7 100644 --- a/ggml/src/ggml-cuda/set-rows.cu +++ b/ggml/src/ggml-cuda/set-rows.cu @@ -309,6 +309,16 @@ static void set_rows_cuda(ggml_backend_cuda_context & ctx, const ggml_tensor * s nb1, nb2, nb3, stream ); + } else if (dst->type == GGML_TYPE_TQ3_0) { + set_rows_cuda_quant( + src0_d, src1_d, (block_tq3_0*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); } else { GGML_ABORT("unsupported type %s", ggml_type_name(dst->type)); } diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index b17d5409ac7..98846eb1a33 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -1352,3 +1352,59 @@ static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1( const float d = __half2float(bq4->d) * __low2float(bq8_1[iqs/4].ds); return d * sumi; } + +// TurboQuant TQ3_0: Fused MMVQ with per-block WHT on query +// K is stored in WHT-rotated space. We apply WHT to Q inside the kernel. +// Since WHT is orthogonal: dot(q, k) = dot(WHT(q), WHT(k)) +// Both 1/sqrt(32) normalizations combine to 1/32. +#define VDR_TQ3_0_Q8_1_MMVQ 8 +#define VDR_TQ3_0_Q8_1_MMQ 8 + +static __device__ __forceinline__ float vec_dot_tq3_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const float centroids[4] = { -1.510f, -0.4528f, 0.4528f, 1.510f }; + const int8_t signs[32] = { + +1, -1, +1, +1, -1, -1, +1, -1, +1, +1, -1, +1, -1, +1, -1, -1, + +1, -1, -1, +1, +1, -1, +1, -1, -1, +1, +1, +1, -1, -1, +1, -1 + }; + + if (iqs != 0) { + return 0.0f; + } + + const block_tq3_0 * btq = (const block_tq3_0 *) vbq + kbx; + const float d = __half2float(btq->gamma); + + // Step 1: Apply WHT to Q8_1 int8 values (sign flip + butterfly in int32) + int32_t sq[32]; + #pragma unroll + for (int j = 0; j < 32; j++) { + sq[j] = (int32_t)bq8_1[0].qs[j] * signs[j]; + } + + // 5-stage butterfly transform + #pragma unroll + for (int step = 1; step < 32; step <<= 1) { + #pragma unroll + for (int i = 0; i < 32; i += step * 2) { + #pragma unroll + for (int j = i; j < i + step; j++) { + int32_t a = sq[j], b = sq[j + step]; + sq[j] = a + b; sq[j + step] = a - b; + } + } + } + + // Step 2: Dot product in rotated space + float sumf = 0.0f; + #pragma unroll + for (int j = 0; j < 32; j++) { + const int idx = (btq->qs[j / 4] >> (2 * (j % 4))) & 3; + sumf += (float)sq[j] * centroids[idx]; + } + + // Scale: d_tq3 * d_q8 / 32 (two 1/sqrt(32) normalizations combined) + const float d_q8 = __low2float(bq8_1[0].ds); + return sumf * d * d_q8 * 0.03125f; // 0.03125 = 1/32 +} diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 75508333523..e51b7a8faae 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -2412,6 +2412,166 @@ void dequantize_row_tq2_0(const block_tq2_0 * GGML_RESTRICT x, float * GGML_REST } } +// ====================== TurboQuant TQ3_0: Scalar Codebook + QJL (de)-quantization ====================== +// +// Per TurboQuant paper (ICLR 2026, Algorithm 2: TurboQuant_prod): +// Quantize: random rotation → per-block 2-bit scalar codebook → QJL residual signs +// Dequant: scale * centroid[idx] → inverse rotation +// +// Rotation: Per-block Walsh-Hadamard Transform (WHT32) with fixed sign flips. +// WHT makes any distribution approximately Gaussian (by CLT), making the +// fixed Max-Lloyd codebook optimal. WHT is self-inverse: WHT(WHT(x)) = 32*x. +// +// Optimal 2-bit codebook centroids for Gaussian N(0,1) via Max-Lloyd algorithm: +// {-1.510, -0.4528, +0.4528, +1.510} +// + +// Codebook centroids (normalized — will be scaled by per-block 'd') +static const float tq3_centroids[4] = { -1.510f, -0.4528f, 0.4528f, 1.510f }; + +// QJL correction constant: sqrt(pi/2) / block_size +static const float TQ3_QJL_SCALE = 0.03921875f; // sqrt(pi/2) / 32 ≈ 1.2533 / 32 + +// Fixed random sign pattern for WHT preconditioning (generated from seed 42) +// Multiplying by random ±1 before WHT ensures the transform is a random rotation, +// not just a fixed permutation. This breaks any structure in the input. +static const int8_t tq3_signs[32] = { + +1, -1, +1, +1, -1, -1, +1, -1, +1, +1, -1, +1, -1, +1, -1, -1, + +1, -1, -1, +1, +1, -1, +1, -1, -1, +1, +1, +1, -1, -1, +1, -1 +}; + +// In-place normalized Walsh-Hadamard Transform for 32 values +// After transform, output has same L2 norm as input (due to 1/sqrt(32) normalization) +static void tq3_wht32_forward(float * x) { + // Apply sign flips (preconditioning) + for (int j = 0; j < 32; j++) { + x[j] *= tq3_signs[j]; + } + + // Butterfly stages (log2(32) = 5 stages) + for (int step = 1; step < 32; step <<= 1) { + for (int i = 0; i < 32; i += step * 2) { + for (int j = i; j < i + step; j++) { + float a = x[j]; + float b = x[j + step]; + x[j] = a + b; + x[j + step] = a - b; + } + } + } + + // Normalize to preserve L2 norm: divide by sqrt(32) + const float inv_sqrt32 = 0.17677669529663688f; // 1/sqrt(32) + for (int j = 0; j < 32; j++) { + x[j] *= inv_sqrt32; + } +} + +// In-place inverse normalized Walsh-Hadamard Transform for 32 values +// Inverse of normalized WHT: apply WHT, then undo sign flips, then normalize +static void tq3_wht32_inverse(float * x) { + // Butterfly stages (same as forward — WHT is self-adjoint) + for (int step = 1; step < 32; step <<= 1) { + for (int i = 0; i < 32; i += step * 2) { + for (int j = i; j < i + step; j++) { + float a = x[j]; + float b = x[j + step]; + x[j] = a + b; + x[j + step] = a - b; + } + } + } + + // Normalize by 1/sqrt(32) and undo sign flips + const float inv_sqrt32 = 0.17677669529663688f; + for (int j = 0; j < 32; j++) { + x[j] *= inv_sqrt32 * tq3_signs[j]; + } +} + +void quantize_row_tq3_0_ref(const float * GGML_RESTRICT x, block_tq3_0 * GGML_RESTRICT y, int64_t k) { + assert(k % QK_TQ3_0 == 0); + const int64_t nb = k / QK_TQ3_0; + + for (int64_t i = 0; i < nb; i++) { + const float * xb = x + i * QK_TQ3_0; + + // --- Step 0: Apply WHT rotation (makes distribution ~Gaussian) --- + float rotated[QK_TQ3_0]; + for (int j = 0; j < QK_TQ3_0; j++) rotated[j] = xb[j]; + tq3_wht32_forward(rotated); + + // --- Step 1: Find per-block scale (amax / outermost centroid) --- + float amax = 0.0f; + for (int j = 0; j < QK_TQ3_0; j++) { + float av = fabsf(rotated[j]); + if (av > amax) amax = av; + } + + const float d = amax / 1.510f; + const float id = d > 0.0f ? 1.0f / d : 0.0f; + + y[i].gamma = GGML_FP32_TO_FP16(d); + + // --- Step 2: 2-bit scalar quantize each rotated value --- + memset(y[i].qs, 0, sizeof(y[i].qs)); + memset(y[i].qr, 0, sizeof(y[i].qr)); + + float residuals[QK_TQ3_0]; + + for (int j = 0; j < QK_TQ3_0; j++) { + float xn = rotated[j] * id; + + int idx; + if (xn < 0.0f) { + idx = (xn < -0.9814f) ? 0 : 1; + } else { + idx = (xn < 0.9814f) ? 2 : 3; + } + + y[i].qs[j / 4] |= (idx << (2 * (j % 4))); + residuals[j] = rotated[j] - d * tq3_centroids[idx]; + } + + // --- Step 3: QJL signs = sign(residual) --- + for (int j = 0; j < QK_TQ3_0; j++) { + if (residuals[j] >= 0.0f) { + y[i].qr[j / 8] |= (1 << (j % 8)); + } + } + } +} + +void dequantize_row_tq3_0(const block_tq3_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_TQ3_0 == 0); + const int64_t nb = k / QK_TQ3_0; + + for (int64_t i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].gamma); + + // Dequantize to rotated space + float rotated[QK_TQ3_0]; + for (int j = 0; j < QK_TQ3_0; j++) { + const int idx = (x[i].qs[j / 4] >> (2 * (j % 4))) & 3; + rotated[j] = d * tq3_centroids[idx]; + } + + // Apply inverse WHT to get back to original space + tq3_wht32_inverse(rotated); + + for (int j = 0; j < QK_TQ3_0; j++) { + y[i * QK_TQ3_0 + j] = rotated[j]; + } + } +} + +size_t quantize_tq3_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + (void)quant_weights; // not used + const size_t row_size = ggml_row_size(GGML_TYPE_TQ3_0, n_per_row); + quantize_row_tq3_0_ref(src, dst, (int64_t)nrow*n_per_row); + return nrow * row_size; +} + // ====================== "True" 2-bit (de)-quantization void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { @@ -5412,6 +5572,15 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte { VALIDATE_ROW_DATA_D_F16_IMPL(block_tq2_0, data, nb); } break; + case GGML_TYPE_TQ3_0: + { + const block_tq3_0 * q = (const block_tq3_0 *) data; + for (size_t i = 0; i < nb; ++i) { + if (!validate_fp16(q[i].gamma, i)) { + return false; + } + } + } break; case GGML_TYPE_IQ1_S: { VALIDATE_ROW_DATA_D_F16_IMPL(block_iq1_s, data, nb); diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index ca637667468..ab37134bd0b 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -34,6 +34,7 @@ GGML_API void quantize_row_q8_K_ref(const float * GGML_RESTRICT x, block_q8_K * GGML_API void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_tq2_0_ref(const float * GGML_RESTRICT x, block_tq2_0 * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_tq3_0_ref(const float * GGML_RESTRICT x, block_tq3_0 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_iq3_xxs_ref(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_iq4_nl_ref (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k); @@ -62,6 +63,7 @@ GGML_API void dequantize_row_q8_K(const block_q8_K * GGML_RESTRICT x, float * GG GGML_API void dequantize_row_tq1_0(const block_tq1_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_tq2_0(const block_tq2_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_tq3_0(const block_tq3_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); @@ -86,6 +88,7 @@ GGML_API size_t quantize_iq3_s (const float * GGML_RESTRICT src, void * GGML_RE GGML_API size_t quantize_tq1_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_tq2_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_tq3_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q2_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 0b7edba32f9..e4cd9f143a8 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -912,6 +912,14 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { .type_size = 0, .is_quantized = false, }, + [GGML_TYPE_TQ3_0] = { + .type_name = "tq3_0", + .blck_size = QK_TQ3_0, + .type_size = sizeof(block_tq3_0), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_tq3_0, + .from_float_ref = (ggml_from_float_t) quantize_row_tq3_0_ref, + }, }; const struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type) { @@ -7617,6 +7625,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_Q6_K: result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_TQ1_0: result = quantize_tq1_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_TQ2_0: result = quantize_tq2_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_TQ3_0: result = quantize_tq3_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_XS: result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ3_XXS: result = quantize_iq3_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 98d055d34ef..16fa9373fd6 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2838,6 +2838,12 @@ llama_context * llama_init_from_model( } } + // TQ3_0 K cache has no flash attention kernel support - force off + if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && params.type_k == GGML_TYPE_TQ3_0) { + LLAMA_LOG_WARN("%s: flash_attn is not supported with TQ3_0 K cache - forcing off\n", __func__); + params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; + } + if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) { const uint32_t blck_size = ggml_blck_size(params.type_v); if (model->hparams.n_embd_head_v % blck_size != 0) { diff --git a/tools/llama-bench/llama-bench.cpp b/tools/llama-bench/llama-bench.cpp index 7da6c3957c7..bb385692693 100644 --- a/tools/llama-bench/llama-bench.cpp +++ b/tools/llama-bench/llama-bench.cpp @@ -493,6 +493,9 @@ static ggml_type ggml_type_from_name(const std::string & s) { if (s == "iq4_nl") { return GGML_TYPE_IQ4_NL; } + if (s == "tq3_0") { + return GGML_TYPE_TQ3_0; + } return GGML_TYPE_COUNT; }