From 5bebc0a6e21c7ffc05006ff69b47c5c91cd0db31 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 26 Apr 2023 10:33:57 +0300 Subject: [PATCH 01/10] ggml : add Q5_0 quantization (cuBLAS only) --- .gitignore | 1 + ggml-cuda.cu | 43 ++++++++ ggml-cuda.h | 1 + ggml.c | 277 ++++++++++++++++++++++++++++++++++++++++++++++++++- ggml.h | 6 +- llama.cpp | 4 + llama.h | 1 + 7 files changed, 327 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index e52d479eeafa8..c7573bb3b93c4 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ build-em/ build-debug/ build-release/ build-static/ +build-cublas/ build-no-accel/ build-sanitize-addr/ build-sanitize-thread/ diff --git a/ggml-cuda.cu b/ggml-cuda.cu index f104ed5ac42dd..ac34fdd364102 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -37,6 +37,15 @@ typedef struct { } block_q4_3; static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding"); +#define QK5_0 32 +typedef struct { + __half d; // delta + __half m; // min + int32_t qh; // 5-th bit of quants + uint8_t qs[QK5_0 / 2]; // nibbles / quants +} block_q5_0; +static_assert(sizeof(block_q5_0) == 2 * sizeof(ggml_fp16_t) + sizeof(int32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); + #define QK8_0 32 typedef struct { float d; // delta @@ -138,6 +147,35 @@ static __global__ void dequantize_block_q4_3(const void * vx, float * y) { } } +static __global__ void dequantize_block_q5_0(const void * vx, float * y) { + const block_q5_0 * x = (const block_q5_0 *) vx; + + const int i = blockIdx.x; + + const float d = x[i].d; + const float m = x[i].m; + + const uint8_t * pp = x[i].qs; + + const uint32_t qh = x[i].qh; + + for (int l = 0; l < QK5_0; l += 2) { + const uint8_t vi = pp[l/2]; + + const int8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; + const int8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; + + const int8_t vi0 = (vi & 0xf) | vh0; + const int8_t vi1 = (vi >> 4) | vh1; + + const float v0 = vi0*d + m; + const float v1 = vi1*d + m; + + y[i*QK5_0 + l + 0] = v0; + y[i*QK5_0 + l + 1] = v1; + } +} + static __global__ void dequantize_block_q8_0(const void * vx, float * y) { const block_q8_0 * x = (const block_q8_0 *) vx; @@ -174,6 +212,11 @@ void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t st dequantize_block_q4_3<<>>(vx, y); } +void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { + const int nb = k / QK5_0; + dequantize_block_q5_0<<>>(vx, y); +} + void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK8_0; dequantize_block_q8_0<<>>(vx, y); diff --git a/ggml-cuda.h b/ggml-cuda.h index 4048ea4919321..eb5f23a9dc558 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -35,6 +35,7 @@ void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t st void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream); void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream); void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream); +void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream); void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream); #ifdef __cplusplus diff --git a/ggml.c b/ggml.c index 064510edaa798..897cd5d588889 100644 --- a/ggml.c +++ b/ggml.c @@ -673,6 +673,15 @@ typedef struct { } block_q4_3; static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding"); +#define QK5_0 32 +typedef struct { + ggml_fp16_t d; // delta + ggml_fp16_t m; // min + int32_t qh; // 5-th bit of quants + uint8_t qs[QK5_0 / 2]; // nibbles / quants +} block_q5_0; +static_assert(sizeof(block_q5_0) == 2 * sizeof(ggml_fp16_t) + sizeof(int32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); + #define QK8_0 32 typedef struct { float d; // delta @@ -1288,6 +1297,50 @@ static void quantize_row_q4_3(const float * restrict x, void * restrict vy, int quantize_row_q4_3_reference(x, y, k); } +static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) { + assert(k % QK5_0 == 0); + const int nb = k / QK5_0; + + for (int i = 0; i < nb; i++) { + float min = FLT_MAX; + float max = -FLT_MAX; + + for (int l = 0; l < QK5_0; l++) { + const float v = x[i*QK5_0 + l]; + if (v < min) min = v; + if (v > max) max = v; + } + + const float d = (max - min) / ((1 << 5) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + y[i].m = GGML_FP32_TO_FP16(min); + y[i].qh = 0; + + for (int l = 0; l < QK5_0; l += 2) { + const float v0 = (x[i*QK5_0 + l + 0] - min)*id; + const float v1 = (x[i*QK5_0 + l + 1] - min)*id; + + const uint32_t vi0 = (int) (v0 + 0.5f); + const uint32_t vi1 = (int) (v1 + 0.5f); + + y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4); + + y[i].qh |= ((vi0 & 0x10) >> 4) << (l + 0); + y[i].qh |= ((vi1 & 0x10) >> 4) << (l + 1); + } + } +} + +static void quantize_row_q5_0(const float * restrict x, void * restrict vy, int k) { + assert(k % QK5_0 == 0); + + block_q5_0 * restrict y = vy; + + quantize_row_q5_0_reference(x, y, k); +} + // reference implementation for deterministic creation of model files static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) { assert(k % QK8_0 == 0); @@ -1804,6 +1857,41 @@ static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, in } } +static void dequantize_row_q5_0(const void * restrict vx, float * restrict y, int k) { + assert(k % QK5_0 == 0); + const int nb = k / QK5_0; + + const block_q5_0 * restrict x = vx; + + for (int i = 0; i < nb; i++) { + const float d = GGML_FP16_TO_FP32(x[i].d); + const float m = GGML_FP16_TO_FP32(x[i].m); + + const uint8_t * restrict pp = x[i].qs; + + const uint32_t qh = x[i].qh; + + for (int l = 0; l < QK5_0; l += 2) { + const uint8_t vi = pp[l/2]; + + const int8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; + const int8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; + + const int8_t vi0 = (vi & 0xf) | vh0; + const int8_t vi1 = (vi >> 4) | vh1; + + const float v0 = vi0*d + m; + const float v1 = vi1*d + m; + + y[i*QK5_0 + l + 0] = v0; + y[i*QK5_0 + l + 1] = v1; + + assert(!isnan(y[i*QK5_0 + l + 0])); + assert(!isnan(y[i*QK5_0 + l + 1])); + } + } +} + static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, int k) { assert(k % QK8_0 == 0); const int nb = k / QK8_0; @@ -1825,6 +1913,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); +static void ggml_vec_dot_q5_0_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { @@ -1860,6 +1949,14 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { .vec_dot_q = ggml_vec_dot_q4_3_q8_1, .vec_dot_type = GGML_TYPE_Q8_1, }, + [GGML_TYPE_Q5_0] = { + .dequantize_row_q = dequantize_row_q5_0, + .quantize_row_q = quantize_row_q5_0, + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_0_reference, + .quantize_row_q_dot = quantize_row_q8_1, + .vec_dot_q = ggml_vec_dot_q5_0_q8_1, + .vec_dot_type = GGML_TYPE_Q8_1, + }, [GGML_TYPE_Q8_0] = { .dequantize_row_q = dequantize_row_q8_0, .quantize_row_q = quantize_row_q8_0, @@ -3067,6 +3164,138 @@ static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * #endif } +static void ggml_vec_dot_q5_0_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + GGML_ASSERT(false); // TODO xxxxxxxxx + + const int nb = n / QK8_1; + + assert(n % QK8_1 == 0); + assert(nb % 2 == 0); + assert(QK8_1 == 2*QK5_0); + + const block_q5_0 * restrict x = vx; + const block_q8_1 * restrict y = vy; + +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + float summs0 = 0.0f; + float summs1 = 0.0f; + + for (int i = 0; i < nb; ++i) { + const block_q5_0 * restrict x0_0 = &x[2*(i + 0) + 0]; + const block_q5_0 * restrict x0_1 = &x[2*(i + 0) + 1]; + + const block_q8_1 * restrict y0 = &y[i + 0]; + + summs0 += GGML_FP16_TO_FP32(x0_0->m) * y0->s0; + summs1 += GGML_FP16_TO_FP32(x0_1->m) * y0->s1; + + const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs)); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, vdupq_n_u8(0xf))); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + + // interleave + const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h); + const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + + const float x0_0d = GGML_FP16_TO_FP32(x0_0->d); + const float x0_1d = GGML_FP16_TO_FP32(x0_1->d); + +#if defined(__ARM_FEATURE_DOTPROD) + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d); +#else + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h)); + + const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); + const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph0), x0_1d*y0->d); +#endif + } + + *s = vaddvq_f32(vaddq_f32(sumv0, sumv1)) + summs0 + summs1; +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + float summs = 0.0f; + + // Main loop + for (int i = 0; i < nb; i++) { + const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d)); + const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d)); + const __m256 dx = _mm256_set_m128(d1, d0); + + summs += GGML_FP16_TO_FP32(x[2*i + 0].m) * y[i].s0 + + GGML_FP16_TO_FP32(x[2*i + 1].m) * y[i].s1; + + const __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs); + const __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs); + const __m256i bx = _mm256_set_m128i(bx1, bx0); + + const __m256 dy = _mm256_broadcast_ss(&y[i].d); + const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q = mul_sum_i8_pairs_float(bx, by); + + acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc); + } + + *s = hsum_float_8(acc) + summs; +#else + // scalar + float sumf = 0.0; + for (int i = 0; i < nb; i++) { + const uint8_t * restrict x0 = x[2*i + 0].qs; + const uint8_t * restrict x1 = x[2*i + 1].qs; + const int8_t * restrict y0 = y[i].qs; + + const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d); + const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m); + const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d); + const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m); + + int sxy_0 = 0; + int sxy_1 = 0; + + for (int j = 0; j < QK8_1/4; j++) { + const uint8_t v0 = x0[j]; + const uint8_t v1 = x1[j]; + + const int x0_0 = v0 & 0xf; + const int x1_0 = v0 >> 4; + + const int x0_1 = v1 & 0xf; + const int x1_1 = v1 >> 4; + + const int y0_0 = y0[2*j + 0]; + const int y1_0 = y0[2*j + 1]; + + const int y0_1 = y0[2*(j + QK8_1/4) + 0]; + const int y1_1 = y0[2*(j + QK8_1/4) + 1]; + + sxy_0 += x0_0*y0_0 + x1_0*y1_0; + sxy_1 += x0_1*y0_1 + x1_1*y1_1; + } + + sumf += (d0*sxy_0 + d1*sxy_1)*y[i].d + m0*y[i].s0 + m1*y[i].s1; + } + *s = sumf; +#endif +} + static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { const int nb = n / QK8_0; @@ -3409,13 +3638,14 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_1] = QK4_1, [GGML_TYPE_Q4_2] = QK4_2, [GGML_TYPE_Q4_3] = QK4_3, + [GGML_TYPE_Q5_0] = QK5_0, [GGML_TYPE_Q8_0] = QK8_0, [GGML_TYPE_Q8_1] = QK8_1, [GGML_TYPE_I8] = 1, [GGML_TYPE_I16] = 1, [GGML_TYPE_I32] = 1, }; -static_assert(GGML_TYPE_COUNT == 11, "GGML_BLCK_SIZE is outdated"); +static_assert(GGML_TYPE_COUNT == 12, "GGML_BLCK_SIZE is outdated"); static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_F32] = sizeof(float), @@ -3424,13 +3654,14 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_1] = sizeof(block_q4_1), [GGML_TYPE_Q4_2] = sizeof(block_q4_2), [GGML_TYPE_Q4_3] = sizeof(block_q4_3), + [GGML_TYPE_Q5_0] = sizeof(block_q5_0), [GGML_TYPE_Q8_0] = sizeof(block_q8_0), [GGML_TYPE_Q8_1] = sizeof(block_q8_1), [GGML_TYPE_I8] = sizeof(int8_t), [GGML_TYPE_I16] = sizeof(int16_t), [GGML_TYPE_I32] = sizeof(int32_t), }; -static_assert(GGML_TYPE_COUNT == 11, "GGML_TYPE_SIZE is outdated"); +static_assert(GGML_TYPE_COUNT == 12, "GGML_TYPE_SIZE is outdated"); static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = { @@ -3440,13 +3671,14 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_1] = "q4_1", [GGML_TYPE_Q4_2] = "q4_2", [GGML_TYPE_Q4_3] = "q4_3", + [GGML_TYPE_Q5_0] = "q5_0", [GGML_TYPE_Q8_0] = "q8_0", [GGML_TYPE_Q8_1] = "q8_1", [GGML_TYPE_I8] = "i8", [GGML_TYPE_I16] = "i16", [GGML_TYPE_I32] = "i32", }; -static_assert(GGML_TYPE_COUNT == 11, "GGML_TYPE_NAME is outdated"); +static_assert(GGML_TYPE_COUNT == 12, "GGML_TYPE_NAME is outdated"); static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = { [GGML_TYPE_F32] = false, @@ -3455,13 +3687,14 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_1] = true, [GGML_TYPE_Q4_2] = true, [GGML_TYPE_Q4_3] = true, + [GGML_TYPE_Q5_0] = true, [GGML_TYPE_Q8_0] = true, [GGML_TYPE_Q8_1] = true, [GGML_TYPE_I8] = false, [GGML_TYPE_I16] = false, [GGML_TYPE_I32] = false, }; -static_assert(GGML_TYPE_COUNT == 11, "GGML_IS_QUANTIZED is outdated"); +static_assert(GGML_TYPE_COUNT == 12, "GGML_IS_QUANTIZED is outdated"); static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "NONE", @@ -6673,6 +6906,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: case GGML_TYPE_Q4_3: + case GGML_TYPE_Q5_0: case GGML_TYPE_Q8_0: { ggml_compute_forward_add_q_f32(params, src0, src1, dst); @@ -8161,6 +8395,9 @@ static void ggml_compute_forward_mul_mat_q_f32( else if (type == GGML_TYPE_Q4_3) { dequantize_row_q_cuda = dequantize_row_q4_3_cuda; } + else if (type == GGML_TYPE_Q5_0) { + dequantize_row_q_cuda = dequantize_row_q5_0_cuda; + } else if (type == GGML_TYPE_Q8_0) { dequantize_row_q_cuda = dequantize_row_q8_0_cuda; } @@ -8319,6 +8556,7 @@ static void ggml_compute_forward_mul_mat( case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: case GGML_TYPE_Q4_3: + case GGML_TYPE_Q5_0: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: { @@ -8549,6 +8787,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: case GGML_TYPE_Q4_3: + case GGML_TYPE_Q5_0: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: { @@ -12342,6 +12581,30 @@ size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * return (n/QK4_3*sizeof(block_q4_3)); } +size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK5_0 == 0); + const int nb = k / QK5_0; + + for (int j = 0; j < n; j += k) { + block_q5_0 * restrict y = (block_q5_0 *)dst + j/QK5_0; + + quantize_row_q5_0_reference(src + j, y, k); + + // TODO: this is wrong + for (int i = 0; i < nb; i++) { + for (int l = 0; l < QK5_0; l += 2) { + const uint8_t vi0 = y[i].qs[l/2] & 0xF; + const uint8_t vi1 = y[i].qs[l/2] >> 4; + + hist[vi0]++; + hist[vi1]++; + } + } + } + + return (n/QK5_0*sizeof(block_q5_0)); +} + size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist) { assert(k % QK8_0 == 0); const int nb = k / QK8_0; @@ -12390,6 +12653,12 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i block_q4_3 * block = (block_q4_3*)dst + start / QK4_3; result = ggml_quantize_q4_3(src + start, block, n, n, hist); } break; + case GGML_TYPE_Q5_0: + { + GGML_ASSERT(start % QK5_0 == 0); + block_q5_0 * block = (block_q5_0*)dst + start / QK5_0; + result = ggml_quantize_q5_0(src + start, block, n, n, hist); + } break; case GGML_TYPE_Q8_0: { GGML_ASSERT(start % QK8_0 == 0); diff --git a/ggml.h b/ggml.h index 8300a0c62db9b..154c0eb4b2941 100644 --- a/ggml.h +++ b/ggml.h @@ -222,8 +222,9 @@ extern "C" { GGML_TYPE_Q4_1 = 3, GGML_TYPE_Q4_2 = 4, GGML_TYPE_Q4_3 = 5, - GGML_TYPE_Q8_0 = 6, - GGML_TYPE_Q8_1 = 7, + GGML_TYPE_Q5_0 = 6, + GGML_TYPE_Q8_0 = 7, + GGML_TYPE_Q8_1 = 8, GGML_TYPE_I8, GGML_TYPE_I16, GGML_TYPE_I32, @@ -833,6 +834,7 @@ extern "C" { GGML_API size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist); + GGML_API size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist); diff --git a/llama.cpp b/llama.cpp index 25203c9e90b28..a82ce30978cbc 100644 --- a/llama.cpp +++ b/llama.cpp @@ -484,6 +484,7 @@ struct llama_file_loader { case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: case GGML_TYPE_Q4_3: + case GGML_TYPE_Q5_0: case GGML_TYPE_Q8_0: break; default: { @@ -559,6 +560,7 @@ struct llama_file_saver { case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: case GGML_TYPE_Q4_3: + case GGML_TYPE_Q5_0: case GGML_TYPE_Q8_0: break; default: LLAMA_ASSERT(false); @@ -850,6 +852,7 @@ static const char *llama_ftype_name(enum llama_ftype ftype) { return "mostly Q4_1, some F16"; case LLAMA_FTYPE_MOSTLY_Q4_2: return "mostly Q4_2"; case LLAMA_FTYPE_MOSTLY_Q4_3: return "mostly Q4_3"; + case LLAMA_FTYPE_MOSTLY_Q5_0: return "mostly Q5_0"; case LLAMA_FTYPE_MOSTLY_Q8_0: return "mostly Q8_0"; default: return "unknown, may not work"; } @@ -1588,6 +1591,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_Q4_1: quantized_type = GGML_TYPE_Q4_1; break; case LLAMA_FTYPE_MOSTLY_Q4_2: quantized_type = GGML_TYPE_Q4_2; break; case LLAMA_FTYPE_MOSTLY_Q4_3: quantized_type = GGML_TYPE_Q4_3; break; + case LLAMA_FTYPE_MOSTLY_Q5_0: quantized_type = GGML_TYPE_Q5_0; break; case LLAMA_FTYPE_MOSTLY_Q8_0: quantized_type = GGML_TYPE_Q8_0; break; default: throw format("invalid output file type %d\n", ftype); }; diff --git a/llama.h b/llama.h index ab41798d8b712..ea3fe18d54cb6 100644 --- a/llama.h +++ b/llama.h @@ -75,6 +75,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors }; LLAMA_API struct llama_context_params llama_context_default_params(); From 2576c16f00b5a1d22405aa6b35e86e9bbe078f7f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 26 Apr 2023 10:43:26 +0300 Subject: [PATCH 02/10] ggml : fix Q5_0 qh -> uint32_t --- ggml-cuda.cu | 4 ++-- ggml.c | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index ac34fdd364102..e066cc9d07dc8 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -41,10 +41,10 @@ static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong typedef struct { __half d; // delta __half m; // min - int32_t qh; // 5-th bit of quants + uint32_t qh; // 5-th bit of quants uint8_t qs[QK5_0 / 2]; // nibbles / quants } block_q5_0; -static_assert(sizeof(block_q5_0) == 2 * sizeof(ggml_fp16_t) + sizeof(int32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); +static_assert(sizeof(block_q5_0) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); #define QK8_0 32 typedef struct { diff --git a/ggml.c b/ggml.c index 897cd5d588889..1b1fa717a2d25 100644 --- a/ggml.c +++ b/ggml.c @@ -677,10 +677,10 @@ static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong typedef struct { ggml_fp16_t d; // delta ggml_fp16_t m; // min - int32_t qh; // 5-th bit of quants + uint32_t qh; // 5-th bit of quants uint8_t qs[QK5_0 / 2]; // nibbles / quants } block_q5_0; -static_assert(sizeof(block_q5_0) == 2 * sizeof(ggml_fp16_t) + sizeof(int32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); +static_assert(sizeof(block_q5_0) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); #define QK8_0 32 typedef struct { From 99238e4c28addaa7dfe18f004129037424313cf1 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 26 Apr 2023 13:37:57 +0300 Subject: [PATCH 03/10] ggml : fix q5_0 histogram stats --- ggml.c | 73 +++++++++++++++++++++++++++++++--------------------------- 1 file changed, 39 insertions(+), 34 deletions(-) diff --git a/ggml.c b/ggml.c index 1b1fa717a2d25..423b95952009b 100644 --- a/ggml.c +++ b/ggml.c @@ -1327,6 +1327,7 @@ static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * r y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4); + // get the 5-th bit and store it in qh at the right position y[i].qh |= ((vi0 & 0x10) >> 4) << (l + 0); y[i].qh |= ((vi1 & 0x10) >> 4) << (l + 1); } @@ -1624,7 +1625,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in const uint8x8_t v8 = vld1_u8(pp + l/2); // Expand 4-bit qs to 8-bit bytes - const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f)); + const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0F)); const uint8x8_t v1 = vshr_n_u8(v8, 4); // Convert to signed 8-bit integers @@ -1674,7 +1675,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in for (int l = 0; l < QK4_0; l += 2) { const uint8_t vi = pp[l/2]; - const int8_t vi0 = vi & 0xf; + const int8_t vi0 = vi & 0x0F; const int8_t vi1 = vi >> 4; const float v0 = (vi0 - 8)*d; @@ -1740,7 +1741,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in const uint8x8_t v8 = vld1_u8(pp + l/2); // Expand 4-bit qs to 8-bit bytes - const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f)); + const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0F)); const uint8x8_t v1 = vshr_n_u8(v8, 4); // Interleave and combine @@ -1782,7 +1783,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in for (int l = 0; l < QK4_1; l += 2) { const uint8_t vi = pp[l/2]; - const int8_t vi0 = vi & 0xf; + const int8_t vi0 = vi & 0x0F; const int8_t vi1 = vi >> 4; const float v0 = vi0*d + m; @@ -1812,7 +1813,7 @@ static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, in for (int l = 0; l < QK4_2; l += 2) { const uint8_t vi = pp[l/2]; - const int8_t vi0 = vi & 0xf; + const int8_t vi0 = vi & 0x0F; const int8_t vi1 = vi >> 4; const float v0 = (vi0 - 8)*d; @@ -1842,7 +1843,7 @@ static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, in for (int l = 0; l < QK4_3; l += 2) { const uint8_t vi = pp[l/2]; - const int8_t vi0 = vi & 0xf; + const int8_t vi0 = vi & 0x0F; const int8_t vi1 = vi >> 4; const float v0 = vi0*d + m; @@ -1874,11 +1875,12 @@ static void dequantize_row_q5_0(const void * restrict vx, float * restrict y, in for (int l = 0; l < QK5_0; l += 2) { const uint8_t vi = pp[l/2]; - const int8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; - const int8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; + // extract the 5-th bit from qh + const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; + const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; - const int8_t vi0 = (vi & 0xf) | vh0; - const int8_t vi1 = (vi >> 4) | vh1; + const uint8_t vi0 = (vi & 0x0F) | vh0; + const uint8_t vi1 = (vi >> 4) | vh1; const float v0 = vi0*d + m; const float v1 = vi1*d + m; @@ -2593,7 +2595,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * const block_q8_0 * restrict y0 = &y[i + 0]; const block_q8_0 * restrict y1 = &y[i + 1]; - const uint8x16_t m4b = vdupq_n_u8(0xf); + const uint8x16_t m4b = vdupq_n_u8(0x0F); const int8x16_t s8b = vdupq_n_s8(0x8); const uint8x16_t v0_0 = vld1q_u8(x0->qs); @@ -2729,8 +2731,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * for (int j = 0; j < QK8_0/2; j++) { const uint8_t v0 = p0[j]; - const int i0 = (int8_t) (v0 & 0xf) - 8; - const int i1 = (int8_t) (v0 >> 4) - 8; + const int i0 = (int8_t) (v0 & 0x0F) - 8; + const int i1 = (int8_t) (v0 >> 4) - 8; const int i2 = p1[2*j + 0]; const int i3 = p1[2*j + 1]; @@ -2767,7 +2769,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * summs += x0->m * (y0->s0 + y0->s1) + x1->m * (y1->s0 + y1->s1); - const uint8x16_t m4b = vdupq_n_u8(0xf); + const uint8x16_t m4b = vdupq_n_u8(0x0F); const uint8x16_t v0_0 = vld1q_u8(x0->qs); const uint8x16_t v0_1 = vld1q_u8(x1->qs); @@ -2864,8 +2866,8 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * for (int j = 0; j < QK8_1/2; j++) { const uint8_t v0 = p0[j]; - const float f0 = d0*(v0 & 0xf) + m0; - const float f1 = d0*(v0 >> 4) + m0; + const float f0 = d0*(v0 & 0x0F) + m0; + const float f1 = d0*(v0 >> 4) + m0; const float f2 = d1*p1[2*j + 0]; const float f3 = d1*p1[2*j + 1]; @@ -2900,7 +2902,7 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * const block_q8_0 * restrict y0 = &y[i + 0]; const block_q8_0 * restrict y1 = &y[i + 1]; - const uint8x16_t m4b = vdupq_n_u8(0xf); + const uint8x16_t m4b = vdupq_n_u8(0x0F); const int8x16_t s8b = vdupq_n_s8(0x8); const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs)); @@ -3011,11 +3013,11 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * const uint8_t v0 = x0[j]; const uint8_t v1 = x1[j]; - const int i0_0 = (int8_t) (v0 & 0xf) - 8; - const int i1_0 = (int8_t) (v0 >> 4) - 8; + const int i0_0 = (int8_t) (v0 & 0x0F) - 8; + const int i1_0 = (int8_t) (v0 >> 4) - 8; - const int i0_1 = (int8_t) (v1 & 0xf) - 8; - const int i1_1 = (int8_t) (v1 >> 4) - 8; + const int i0_1 = (int8_t) (v1 & 0x0F) - 8; + const int i1_1 = (int8_t) (v1 >> 4) - 8; const int i2_0 = y0[2*j + 0]; const int i3_0 = y0[2*j + 1]; @@ -3063,7 +3065,7 @@ static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs)); // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, vdupq_n_u8(0xf))); + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, vdupq_n_u8(0x0F))); const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); // interleave @@ -3142,10 +3144,10 @@ static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * const uint8_t v0 = x0[j]; const uint8_t v1 = x1[j]; - const int x0_0 = v0 & 0xf; + const int x0_0 = v0 & 0x0F; const int x1_0 = v0 >> 4; - const int x0_1 = v1 & 0xf; + const int x0_1 = v1 & 0x0F; const int x1_1 = v1 >> 4; const int y0_0 = y0[2*j + 0]; @@ -3195,7 +3197,7 @@ static void ggml_vec_dot_q5_0_q8_1(const int n, float * restrict s, const void * const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs)); // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, vdupq_n_u8(0xf))); + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, vdupq_n_u8(0x0F))); const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); // interleave @@ -3274,10 +3276,10 @@ static void ggml_vec_dot_q5_0_q8_1(const int n, float * restrict s, const void * const uint8_t v0 = x0[j]; const uint8_t v1 = x1[j]; - const int x0_0 = v0 & 0xf; + const int x0_0 = v0 & 0x0F; const int x1_0 = v0 >> 4; - const int x0_1 = v1 & 0xf; + const int x0_1 = v1 & 0x0F; const int x1_1 = v1 >> 4; const int y0_0 = y0[2*j + 0]; @@ -12500,7 +12502,7 @@ size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * for (int i = 0; i < nb; i++) { for (int l = 0; l < QK4_0; l += 2) { - const uint8_t vi0 = y[i].qs[l/2] & 0xF; + const uint8_t vi0 = y[i].qs[l/2] & 0x0F; const uint8_t vi1 = y[i].qs[l/2] >> 4; hist[vi0]++; @@ -12523,7 +12525,7 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * for (int i = 0; i < nb; i++) { for (int l = 0; l < QK4_1; l += 2) { - const uint8_t vi0 = y[i].qs[l/2] & 0xF; + const uint8_t vi0 = y[i].qs[l/2] & 0x0F; const uint8_t vi1 = y[i].qs[l/2] >> 4; hist[vi0]++; @@ -12546,7 +12548,7 @@ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * for (int i = 0; i < nb; i++) { for (int l = 0; l < QK4_2; l += 2) { - const uint8_t vi0 = y[i].qs[l/2] & 0xF; + const uint8_t vi0 = y[i].qs[l/2] & 0x0F; const uint8_t vi1 = y[i].qs[l/2] >> 4; hist[vi0]++; @@ -12569,7 +12571,7 @@ size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * for (int i = 0; i < nb; i++) { for (int l = 0; l < QK4_3; l += 2) { - const uint8_t vi0 = y[i].qs[l/2] & 0xF; + const uint8_t vi0 = y[i].qs[l/2] & 0x0F; const uint8_t vi1 = y[i].qs[l/2] >> 4; hist[vi0]++; @@ -12590,11 +12592,14 @@ size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * quantize_row_q5_0_reference(src + j, y, k); - // TODO: this is wrong for (int i = 0; i < nb; i++) { for (int l = 0; l < QK5_0; l += 2) { - const uint8_t vi0 = y[i].qs[l/2] & 0xF; - const uint8_t vi1 = y[i].qs[l/2] >> 4; + const uint8_t vh0 = ((y[i].qh & (1 << (l + 0))) >> (l + 0)) << 4; + const uint8_t vh1 = ((y[i].qh & (1 << (l + 1))) >> (l + 1)) << 4; + + // cast to 16 bins + const uint8_t vi0 = ((y[i].qs[l/2] & 0x0F) | vh0) / 2; + const uint8_t vi1 = ((y[i].qs[l/2] >> 4) | vh1) / 2; hist[vi0]++; hist[vi1]++; From ef8e3ee6f5efa8067486a0c2b5ffa35d5900a6b7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 26 Apr 2023 13:58:47 +0300 Subject: [PATCH 04/10] ggml : q5_0 scalar dot product --- ggml.c | 42 +++++++++++++++++------------------------- 1 file changed, 17 insertions(+), 25 deletions(-) diff --git a/ggml.c b/ggml.c index 423b95952009b..f4bc9db571704 100644 --- a/ggml.c +++ b/ggml.c @@ -3167,18 +3167,16 @@ static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * } static void ggml_vec_dot_q5_0_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - GGML_ASSERT(false); // TODO xxxxxxxxx - const int nb = n / QK8_1; assert(n % QK8_1 == 0); assert(nb % 2 == 0); - assert(QK8_1 == 2*QK5_0); + assert(QK8_1 == QK5_0); const block_q5_0 * restrict x = vx; const block_q8_1 * restrict y = vy; -#if defined(__ARM_NEON) +#if defined(__ARM_NEON_XXX) float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); @@ -3257,43 +3255,37 @@ static void ggml_vec_dot_q5_0_q8_1(const int n, float * restrict s, const void * *s = hsum_float_8(acc) + summs; #else - // scalar float sumf = 0.0; + for (int i = 0; i < nb; i++) { - const uint8_t * restrict x0 = x[2*i + 0].qs; - const uint8_t * restrict x1 = x[2*i + 1].qs; + const uint8_t * restrict x0 = x[i].qs; const int8_t * restrict y0 = y[i].qs; - const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d); - const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m); - const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d); - const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m); + const uint32_t qh = x[i].qh; - int sxy_0 = 0; - int sxy_1 = 0; + const float d = GGML_FP16_TO_FP32(x[i].d); + const float m = GGML_FP16_TO_FP32(x[i].m); - for (int j = 0; j < QK8_1/4; j++) { + int sxy = 0; + + for (int j = 0; j < QK8_1/2; j++) { const uint8_t v0 = x0[j]; - const uint8_t v1 = x1[j]; - const int x0_0 = v0 & 0x0F; - const int x1_0 = v0 >> 4; + const int x0_0h = ((qh & (1 << (2*j + 0))) >> (2*j + 0)) << 4; + const int x1_0h = ((qh & (1 << (2*j + 1))) >> (2*j + 1)) << 4; - const int x0_1 = v1 & 0x0F; - const int x1_1 = v1 >> 4; + const int x0_0 = (v0 & 0x0F) | x0_0h; + const int x1_0 = (v0 >> 4) | x1_0h; const int y0_0 = y0[2*j + 0]; const int y1_0 = y0[2*j + 1]; - const int y0_1 = y0[2*(j + QK8_1/4) + 0]; - const int y1_1 = y0[2*(j + QK8_1/4) + 1]; - - sxy_0 += x0_0*y0_0 + x1_0*y1_0; - sxy_1 += x0_1*y0_1 + x1_1*y1_1; + sxy += x0_0*y0_0 + x1_0*y1_0; } - sumf += (d0*sxy_0 + d1*sxy_1)*y[i].d + m0*y[i].s0 + m1*y[i].s1; + sumf += (d*sxy)*y[i].d + m*(y[i].s0 + y[i].s1); } + *s = sumf; #endif } From b294b7fdc03c28d08bd28eb35a710422719c830b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 26 Apr 2023 16:24:27 +0300 Subject: [PATCH 05/10] ggml : q5_0 ARM NEON dot --- ggml.c | 78 +++++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 50 insertions(+), 28 deletions(-) diff --git a/ggml.c b/ggml.c index f4bc9db571704..6abd1cf90e433 100644 --- a/ggml.c +++ b/ggml.c @@ -3176,57 +3176,79 @@ static void ggml_vec_dot_q5_0_q8_1(const int n, float * restrict s, const void * const block_q5_0 * restrict x = vx; const block_q8_1 * restrict y = vy; -#if defined(__ARM_NEON_XXX) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); +#if defined(__ARM_NEON) + float32x4_t sumv = vdupq_n_f32(0.0f); - float summs0 = 0.0f; - float summs1 = 0.0f; + float summs = 0.0f; + + uint32_t tmp[8]; + + static const uint32_t k_mask[16] = { + 0x00000000, 0x00000010, 0x00001000, 0x00001010, + 0x00100000, 0x00100010, 0x00101000, 0x00101010, + 0x10000000, 0x10000010, 0x10001000, 0x10001010, + 0x10100000, 0x10100010, 0x10101000, 0x10101010, + }; for (int i = 0; i < nb; ++i) { - const block_q5_0 * restrict x0_0 = &x[2*(i + 0) + 0]; - const block_q5_0 * restrict x0_1 = &x[2*(i + 0) + 1]; + const block_q5_0 * restrict x0 = &x[i]; + const block_q8_1 * restrict y0 = &y[i]; - const block_q8_1 * restrict y0 = &y[i + 0]; + summs += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1); - summs0 += GGML_FP16_TO_FP32(x0_0->m) * y0->s0; - summs1 += GGML_FP16_TO_FP32(x0_1->m) * y0->s1; + // extract the 5th bit + const uint32_t qh = x0->qh; - const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs)); + tmp[0] = k_mask[(qh >> 0) & 0x0F]; + tmp[1] = k_mask[(qh >> 4) & 0x0F]; + tmp[2] = k_mask[(qh >> 8) & 0x0F]; + tmp[3] = k_mask[(qh >> 12) & 0x0F]; + tmp[4] = k_mask[(qh >> 16) & 0x0F]; + tmp[5] = k_mask[(qh >> 20) & 0x0F]; + tmp[6] = k_mask[(qh >> 24) & 0x0F]; + tmp[7] = k_mask[(qh >> 28)]; + + const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0)); + const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 4)); + + const uint8x16_t v0 = vld1q_u8(x0->qs); // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, vdupq_n_u8(0x0F))); - const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, vdupq_n_u8(0x0F))); + const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4)); // interleave - const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h); - const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h); + const int8x16_t v0lz = vzip1q_s8(v0l, v0h); + const int8x16_t v0hz = vzip2q_s8(v0l, v0h); + + // add + const int8x16_t v0lf = vorrq_s8(v0lz, qhl); + const int8x16_t v0hf = vorrq_s8(v0hz, qhh); // load y - const int8x16_t v1_0l = vld1q_s8(y0->qs); - const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1l = vld1q_s8(y0->qs); + const int8x16_t v1h = vld1q_s8(y0->qs + 16); - const float x0_0d = GGML_FP16_TO_FP32(x0_0->d); - const float x0_1d = GGML_FP16_TO_FP32(x0_1->d); + const float x0d = GGML_FP16_TO_FP32(x0->d); #if defined(__ARM_FEATURE_DOTPROD) - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d); + sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), v0lf, v1l), + vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d); #else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h)); + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h)); const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph0), x0_1d*y0->d); + sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d); #endif } - *s = vaddvq_f32(vaddq_f32(sumv0, sumv1)) + summs0 + summs1; + *s = vaddvq_f32(sumv) + summs; #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); From d390f4f7dd93c58b7e77a4e34aa447626a273d4a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 26 Apr 2023 16:32:33 +0300 Subject: [PATCH 06/10] ggml : q5_0 more efficient ARM NEON using uint64_t masks --- ggml.c | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/ggml.c b/ggml.c index 6abd1cf90e433..09f63deac4a03 100644 --- a/ggml.c +++ b/ggml.c @@ -328,6 +328,9 @@ static ggml_fp16_t table_exp_f16[1 << 16]; // precomputed f32 table for f16 (256 KB) static float table_f32_f16[1 << 16]; +// precomputed table for expanding 8bits to 8 bytes (shl 4) +static uint64_t table_b2b[1 << 8]; + // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32, // so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON. // This is also true for POWER9. @@ -3181,14 +3184,7 @@ static void ggml_vec_dot_q5_0_q8_1(const int n, float * restrict s, const void * float summs = 0.0f; - uint32_t tmp[8]; - - static const uint32_t k_mask[16] = { - 0x00000000, 0x00000010, 0x00001000, 0x00001010, - 0x00100000, 0x00100010, 0x00101000, 0x00101010, - 0x10000000, 0x10000010, 0x10001000, 0x10001010, - 0x10100000, 0x10100010, 0x10101000, 0x10101010, - }; + uint64_t tmp[4]; for (int i = 0; i < nb; ++i) { const block_q5_0 * restrict x0 = &x[i]; @@ -3199,17 +3195,13 @@ static void ggml_vec_dot_q5_0_q8_1(const int n, float * restrict s, const void * // extract the 5th bit const uint32_t qh = x0->qh; - tmp[0] = k_mask[(qh >> 0) & 0x0F]; - tmp[1] = k_mask[(qh >> 4) & 0x0F]; - tmp[2] = k_mask[(qh >> 8) & 0x0F]; - tmp[3] = k_mask[(qh >> 12) & 0x0F]; - tmp[4] = k_mask[(qh >> 16) & 0x0F]; - tmp[5] = k_mask[(qh >> 20) & 0x0F]; - tmp[6] = k_mask[(qh >> 24) & 0x0F]; - tmp[7] = k_mask[(qh >> 28)]; + tmp[0] = table_b2b[(qh >> 0) & 0xFF]; + tmp[1] = table_b2b[(qh >> 8) & 0xFF]; + tmp[2] = table_b2b[(qh >> 16) & 0xFF]; + tmp[3] = table_b2b[(qh >> 24) ]; const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0)); - const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 4)); + const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2)); const uint8x16_t v0 = vld1q_u8(x0->qs); @@ -4064,6 +4056,15 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f)); } + for (int i = 0; i < 256; ++i) { + table_b2b[i] = 0; + for (int b = 0; b < 8; ++b) { + table_b2b[i] |= ((uint64_t)(((i >> b) & 0x01) << 4)) << (8*b); + } + + //printf("%3d %016llx\n", i, table_b2b[i]); + } + const uint64_t t_end = ggml_time_us(); UNUSED(t_end); GGML_PRINT_DEBUG("%s: GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); From b9c43584f6b9366211c1179d8115499ca22b7bdb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 26 Apr 2023 17:57:26 +0300 Subject: [PATCH 07/10] ggml : rename Q5_0 -> Q5_1 --- ggml-cuda.cu | 24 +++++------ ggml-cuda.h | 2 +- ggml.c | 110 +++++++++++++++++++++++++-------------------------- ggml.h | 4 +- llama.cpp | 8 ++-- llama.h | 2 +- 6 files changed, 75 insertions(+), 75 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index e066cc9d07dc8..6d1cc70083910 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -37,14 +37,14 @@ typedef struct { } block_q4_3; static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding"); -#define QK5_0 32 +#define QK5_1 32 typedef struct { __half d; // delta __half m; // min uint32_t qh; // 5-th bit of quants - uint8_t qs[QK5_0 / 2]; // nibbles / quants -} block_q5_0; -static_assert(sizeof(block_q5_0) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); + uint8_t qs[QK5_1 / 2]; // nibbles / quants +} block_q5_1; +static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); #define QK8_0 32 typedef struct { @@ -147,8 +147,8 @@ static __global__ void dequantize_block_q4_3(const void * vx, float * y) { } } -static __global__ void dequantize_block_q5_0(const void * vx, float * y) { - const block_q5_0 * x = (const block_q5_0 *) vx; +static __global__ void dequantize_block_q5_1(const void * vx, float * y) { + const block_q5_1 * x = (const block_q5_1 *) vx; const int i = blockIdx.x; @@ -159,7 +159,7 @@ static __global__ void dequantize_block_q5_0(const void * vx, float * y) { const uint32_t qh = x[i].qh; - for (int l = 0; l < QK5_0; l += 2) { + for (int l = 0; l < QK5_1; l += 2) { const uint8_t vi = pp[l/2]; const int8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; @@ -171,8 +171,8 @@ static __global__ void dequantize_block_q5_0(const void * vx, float * y) { const float v0 = vi0*d + m; const float v1 = vi1*d + m; - y[i*QK5_0 + l + 0] = v0; - y[i*QK5_0 + l + 1] = v1; + y[i*QK5_1 + l + 0] = v0; + y[i*QK5_1 + l + 1] = v1; } } @@ -212,9 +212,9 @@ void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t st dequantize_block_q4_3<<>>(vx, y); } -void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { - const int nb = k / QK5_0; - dequantize_block_q5_0<<>>(vx, y); +void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) { + const int nb = k / QK5_1; + dequantize_block_q5_1<<>>(vx, y); } void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { diff --git a/ggml-cuda.h b/ggml-cuda.h index eb5f23a9dc558..348d9e907c837 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -35,7 +35,7 @@ void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t st void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream); void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream); void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream); -void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream); +void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream); void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream); #ifdef __cplusplus diff --git a/ggml.c b/ggml.c index 09f63deac4a03..91afe62edaf89 100644 --- a/ggml.c +++ b/ggml.c @@ -676,14 +676,14 @@ typedef struct { } block_q4_3; static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding"); -#define QK5_0 32 +#define QK5_1 32 typedef struct { ggml_fp16_t d; // delta ggml_fp16_t m; // min uint32_t qh; // 5-th bit of quants - uint8_t qs[QK5_0 / 2]; // nibbles / quants -} block_q5_0; -static_assert(sizeof(block_q5_0) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); + uint8_t qs[QK5_1 / 2]; // nibbles / quants +} block_q5_1; +static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); #define QK8_0 32 typedef struct { @@ -1300,16 +1300,16 @@ static void quantize_row_q4_3(const float * restrict x, void * restrict vy, int quantize_row_q4_3_reference(x, y, k); } -static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) { - assert(k % QK5_0 == 0); - const int nb = k / QK5_0; +static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) { + assert(k % QK5_1 == 0); + const int nb = k / QK5_1; for (int i = 0; i < nb; i++) { float min = FLT_MAX; float max = -FLT_MAX; - for (int l = 0; l < QK5_0; l++) { - const float v = x[i*QK5_0 + l]; + for (int l = 0; l < QK5_1; l++) { + const float v = x[i*QK5_1 + l]; if (v < min) min = v; if (v > max) max = v; } @@ -1321,9 +1321,9 @@ static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * r y[i].m = GGML_FP32_TO_FP16(min); y[i].qh = 0; - for (int l = 0; l < QK5_0; l += 2) { - const float v0 = (x[i*QK5_0 + l + 0] - min)*id; - const float v1 = (x[i*QK5_0 + l + 1] - min)*id; + for (int l = 0; l < QK5_1; l += 2) { + const float v0 = (x[i*QK5_1 + l + 0] - min)*id; + const float v1 = (x[i*QK5_1 + l + 1] - min)*id; const uint32_t vi0 = (int) (v0 + 0.5f); const uint32_t vi1 = (int) (v1 + 0.5f); @@ -1337,12 +1337,12 @@ static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * r } } -static void quantize_row_q5_0(const float * restrict x, void * restrict vy, int k) { - assert(k % QK5_0 == 0); +static void quantize_row_q5_1(const float * restrict x, void * restrict vy, int k) { + assert(k % QK5_1 == 0); - block_q5_0 * restrict y = vy; + block_q5_1 * restrict y = vy; - quantize_row_q5_0_reference(x, y, k); + quantize_row_q5_1_reference(x, y, k); } // reference implementation for deterministic creation of model files @@ -1861,11 +1861,11 @@ static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, in } } -static void dequantize_row_q5_0(const void * restrict vx, float * restrict y, int k) { - assert(k % QK5_0 == 0); - const int nb = k / QK5_0; +static void dequantize_row_q5_1(const void * restrict vx, float * restrict y, int k) { + assert(k % QK5_1 == 0); + const int nb = k / QK5_1; - const block_q5_0 * restrict x = vx; + const block_q5_1 * restrict x = vx; for (int i = 0; i < nb; i++) { const float d = GGML_FP16_TO_FP32(x[i].d); @@ -1875,7 +1875,7 @@ static void dequantize_row_q5_0(const void * restrict vx, float * restrict y, in const uint32_t qh = x[i].qh; - for (int l = 0; l < QK5_0; l += 2) { + for (int l = 0; l < QK5_1; l += 2) { const uint8_t vi = pp[l/2]; // extract the 5-th bit from qh @@ -1888,11 +1888,11 @@ static void dequantize_row_q5_0(const void * restrict vx, float * restrict y, in const float v0 = vi0*d + m; const float v1 = vi1*d + m; - y[i*QK5_0 + l + 0] = v0; - y[i*QK5_0 + l + 1] = v1; + y[i*QK5_1 + l + 0] = v0; + y[i*QK5_1 + l + 1] = v1; - assert(!isnan(y[i*QK5_0 + l + 0])); - assert(!isnan(y[i*QK5_0 + l + 1])); + assert(!isnan(y[i*QK5_1 + l + 0])); + assert(!isnan(y[i*QK5_1 + l + 1])); } } } @@ -1918,7 +1918,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); -static void ggml_vec_dot_q5_0_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); +static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { @@ -1954,12 +1954,12 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { .vec_dot_q = ggml_vec_dot_q4_3_q8_1, .vec_dot_type = GGML_TYPE_Q8_1, }, - [GGML_TYPE_Q5_0] = { - .dequantize_row_q = dequantize_row_q5_0, - .quantize_row_q = quantize_row_q5_0, - .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_0_reference, + [GGML_TYPE_Q5_1] = { + .dequantize_row_q = dequantize_row_q5_1, + .quantize_row_q = quantize_row_q5_1, + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_1_reference, .quantize_row_q_dot = quantize_row_q8_1, - .vec_dot_q = ggml_vec_dot_q5_0_q8_1, + .vec_dot_q = ggml_vec_dot_q5_1_q8_1, .vec_dot_type = GGML_TYPE_Q8_1, }, [GGML_TYPE_Q8_0] = { @@ -3169,14 +3169,14 @@ static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * #endif } -static void ggml_vec_dot_q5_0_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { const int nb = n / QK8_1; assert(n % QK8_1 == 0); assert(nb % 2 == 0); - assert(QK8_1 == QK5_0); + assert(QK8_1 == QK5_1); - const block_q5_0 * restrict x = vx; + const block_q5_1 * restrict x = vx; const block_q8_1 * restrict y = vy; #if defined(__ARM_NEON) @@ -3187,7 +3187,7 @@ static void ggml_vec_dot_q5_0_q8_1(const int n, float * restrict s, const void * uint64_t tmp[4]; for (int i = 0; i < nb; ++i) { - const block_q5_0 * restrict x0 = &x[i]; + const block_q5_1 * restrict x0 = &x[i]; const block_q8_1 * restrict y0 = &y[i]; summs += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1); @@ -3646,7 +3646,7 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_1] = QK4_1, [GGML_TYPE_Q4_2] = QK4_2, [GGML_TYPE_Q4_3] = QK4_3, - [GGML_TYPE_Q5_0] = QK5_0, + [GGML_TYPE_Q5_1] = QK5_1, [GGML_TYPE_Q8_0] = QK8_0, [GGML_TYPE_Q8_1] = QK8_1, [GGML_TYPE_I8] = 1, @@ -3662,7 +3662,7 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_1] = sizeof(block_q4_1), [GGML_TYPE_Q4_2] = sizeof(block_q4_2), [GGML_TYPE_Q4_3] = sizeof(block_q4_3), - [GGML_TYPE_Q5_0] = sizeof(block_q5_0), + [GGML_TYPE_Q5_1] = sizeof(block_q5_1), [GGML_TYPE_Q8_0] = sizeof(block_q8_0), [GGML_TYPE_Q8_1] = sizeof(block_q8_1), [GGML_TYPE_I8] = sizeof(int8_t), @@ -3679,7 +3679,7 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_1] = "q4_1", [GGML_TYPE_Q4_2] = "q4_2", [GGML_TYPE_Q4_3] = "q4_3", - [GGML_TYPE_Q5_0] = "q5_0", + [GGML_TYPE_Q5_1] = "q5_1", [GGML_TYPE_Q8_0] = "q8_0", [GGML_TYPE_Q8_1] = "q8_1", [GGML_TYPE_I8] = "i8", @@ -3695,7 +3695,7 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_1] = true, [GGML_TYPE_Q4_2] = true, [GGML_TYPE_Q4_3] = true, - [GGML_TYPE_Q5_0] = true, + [GGML_TYPE_Q5_1] = true, [GGML_TYPE_Q8_0] = true, [GGML_TYPE_Q8_1] = true, [GGML_TYPE_I8] = false, @@ -6923,7 +6923,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: case GGML_TYPE_Q4_3: - case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: { ggml_compute_forward_add_q_f32(params, src0, src1, dst); @@ -8412,8 +8412,8 @@ static void ggml_compute_forward_mul_mat_q_f32( else if (type == GGML_TYPE_Q4_3) { dequantize_row_q_cuda = dequantize_row_q4_3_cuda; } - else if (type == GGML_TYPE_Q5_0) { - dequantize_row_q_cuda = dequantize_row_q5_0_cuda; + else if (type == GGML_TYPE_Q5_1) { + dequantize_row_q_cuda = dequantize_row_q5_1_cuda; } else if (type == GGML_TYPE_Q8_0) { dequantize_row_q_cuda = dequantize_row_q8_0_cuda; @@ -8573,7 +8573,7 @@ static void ggml_compute_forward_mul_mat( case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: case GGML_TYPE_Q4_3: - case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: { @@ -8804,7 +8804,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: case GGML_TYPE_Q4_3: - case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: { @@ -12598,17 +12598,17 @@ size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * return (n/QK4_3*sizeof(block_q4_3)); } -size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist) { - assert(k % QK5_0 == 0); - const int nb = k / QK5_0; +size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK5_1 == 0); + const int nb = k / QK5_1; for (int j = 0; j < n; j += k) { - block_q5_0 * restrict y = (block_q5_0 *)dst + j/QK5_0; + block_q5_1 * restrict y = (block_q5_1 *)dst + j/QK5_1; - quantize_row_q5_0_reference(src + j, y, k); + quantize_row_q5_1_reference(src + j, y, k); for (int i = 0; i < nb; i++) { - for (int l = 0; l < QK5_0; l += 2) { + for (int l = 0; l < QK5_1; l += 2) { const uint8_t vh0 = ((y[i].qh & (1 << (l + 0))) >> (l + 0)) << 4; const uint8_t vh1 = ((y[i].qh & (1 << (l + 1))) >> (l + 1)) << 4; @@ -12622,7 +12622,7 @@ size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * } } - return (n/QK5_0*sizeof(block_q5_0)); + return (n/QK5_1*sizeof(block_q5_1)); } size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist) { @@ -12673,11 +12673,11 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i block_q4_3 * block = (block_q4_3*)dst + start / QK4_3; result = ggml_quantize_q4_3(src + start, block, n, n, hist); } break; - case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: { - GGML_ASSERT(start % QK5_0 == 0); - block_q5_0 * block = (block_q5_0*)dst + start / QK5_0; - result = ggml_quantize_q5_0(src + start, block, n, n, hist); + GGML_ASSERT(start % QK5_1 == 0); + block_q5_1 * block = (block_q5_1*)dst + start / QK5_1; + result = ggml_quantize_q5_1(src + start, block, n, n, hist); } break; case GGML_TYPE_Q8_0: { diff --git a/ggml.h b/ggml.h index 154c0eb4b2941..2784afc3d70b6 100644 --- a/ggml.h +++ b/ggml.h @@ -222,7 +222,7 @@ extern "C" { GGML_TYPE_Q4_1 = 3, GGML_TYPE_Q4_2 = 4, GGML_TYPE_Q4_3 = 5, - GGML_TYPE_Q5_0 = 6, + GGML_TYPE_Q5_1 = 6, GGML_TYPE_Q8_0 = 7, GGML_TYPE_Q8_1 = 8, GGML_TYPE_I8, @@ -834,7 +834,7 @@ extern "C" { GGML_API size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist); - GGML_API size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist); + GGML_API size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist); diff --git a/llama.cpp b/llama.cpp index a82ce30978cbc..9b167f9712e18 100644 --- a/llama.cpp +++ b/llama.cpp @@ -484,7 +484,7 @@ struct llama_file_loader { case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: case GGML_TYPE_Q4_3: - case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: break; default: { @@ -560,7 +560,7 @@ struct llama_file_saver { case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: case GGML_TYPE_Q4_3: - case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: break; default: LLAMA_ASSERT(false); @@ -852,7 +852,7 @@ static const char *llama_ftype_name(enum llama_ftype ftype) { return "mostly Q4_1, some F16"; case LLAMA_FTYPE_MOSTLY_Q4_2: return "mostly Q4_2"; case LLAMA_FTYPE_MOSTLY_Q4_3: return "mostly Q4_3"; - case LLAMA_FTYPE_MOSTLY_Q5_0: return "mostly Q5_0"; + case LLAMA_FTYPE_MOSTLY_Q5_1: return "mostly Q5_1"; case LLAMA_FTYPE_MOSTLY_Q8_0: return "mostly Q8_0"; default: return "unknown, may not work"; } @@ -1591,7 +1591,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_Q4_1: quantized_type = GGML_TYPE_Q4_1; break; case LLAMA_FTYPE_MOSTLY_Q4_2: quantized_type = GGML_TYPE_Q4_2; break; case LLAMA_FTYPE_MOSTLY_Q4_3: quantized_type = GGML_TYPE_Q4_3; break; - case LLAMA_FTYPE_MOSTLY_Q5_0: quantized_type = GGML_TYPE_Q5_0; break; + case LLAMA_FTYPE_MOSTLY_Q5_1: quantized_type = GGML_TYPE_Q5_1; break; case LLAMA_FTYPE_MOSTLY_Q8_0: quantized_type = GGML_TYPE_Q8_0; break; default: throw format("invalid output file type %d\n", ftype); }; diff --git a/llama.h b/llama.h index ea3fe18d54cb6..ef5e7a7f54b29 100644 --- a/llama.h +++ b/llama.h @@ -75,7 +75,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors }; LLAMA_API struct llama_context_params llama_context_default_params(); From 8e936ad0cd08b6783a65cfafaf033ca2a1195a08 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 26 Apr 2023 18:30:56 +0300 Subject: [PATCH 08/10] ggml : adding Q5_0 mode --- ggml-cuda.cu | 42 ++++++++ ggml-cuda.h | 1 + ggml.c | 291 ++++++++++++++++++++++++++++++++++++++++++++++++++- ggml.h | 8 +- llama.cpp | 4 + llama.h | 1 + 6 files changed, 340 insertions(+), 7 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 6d1cc70083910..b1bd29b100d81 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -37,6 +37,14 @@ typedef struct { } block_q4_3; static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding"); +#define QK5_0 32 +typedef struct { + __half d; // delta + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_0 / 2]; // nibbles / quants +} block_q5_0; +static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); + #define QK5_1 32 typedef struct { __half d; // delta @@ -147,6 +155,35 @@ static __global__ void dequantize_block_q4_3(const void * vx, float * y) { } } +static __global__ void dequantize_block_q5_0(const void * vx, float * y) { + const block_q5_0 * x = (const block_q5_0 *) vx; + + const int i = blockIdx.x; + + const float d = x[i].d; + + const uint8_t * pp = x[i].qs; + + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); + + for (int l = 0; l < QK5_0; l += 2) { + const uint8_t vi = pp[l/2]; + + const int8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; + const int8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; + + const int8_t vi0 = ((vi & 0xf) | vh0); + const int8_t vi1 = ((vi >> 4) | vh1); + + const float v0 = (vi0 - 16)*d; + const float v1 = (vi1 - 16)*d; + + y[i*QK5_0 + l + 0] = v0; + y[i*QK5_0 + l + 1] = v1; + } +} + static __global__ void dequantize_block_q5_1(const void * vx, float * y) { const block_q5_1 * x = (const block_q5_1 *) vx; @@ -212,6 +249,11 @@ void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t st dequantize_block_q4_3<<>>(vx, y); } +void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { + const int nb = k / QK5_0; + dequantize_block_q5_0<<>>(vx, y); +} + void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK5_1; dequantize_block_q5_1<<>>(vx, y); diff --git a/ggml-cuda.h b/ggml-cuda.h index 348d9e907c837..ed9b44184bf56 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -35,6 +35,7 @@ void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t st void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream); void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream); void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream); +void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream); void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream); void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream); diff --git a/ggml.c b/ggml.c index 91afe62edaf89..90eb48fd70324 100644 --- a/ggml.c +++ b/ggml.c @@ -676,6 +676,14 @@ typedef struct { } block_q4_3; static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding"); +#define QK5_0 32 +typedef struct { + ggml_fp16_t d; // delta + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_0 / 2]; // nibbles / quants +} block_q5_0; +static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); + #define QK5_1 32 typedef struct { ggml_fp16_t d; // delta @@ -1300,6 +1308,55 @@ static void quantize_row_q4_3(const float * restrict x, void * restrict vy, int quantize_row_q4_3_reference(x, y, k); } +static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) { + assert(k % QK5_0 == 0); + const int nb = k / QK5_0; + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int l = 0; l < QK5_0; l++) { + const float v = x[i*QK5_0 + l]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } + } + + const float d = max / -16; + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + uint32_t qh = 0; + + for (int l = 0; l < QK5_0; l += 2) { + const float v0 = x[i*QK5_0 + l + 0]*id; + const float v1 = x[i*QK5_0 + l + 1]*id; + + const uint32_t vi0 = MIN(31, (int) (v0 + 16.5f)); + const uint32_t vi1 = MIN(31, (int) (v1 + 16.5f)); + + y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4); + + // get the 5-th bit and store it in qh at the right position + qh |= ((vi0 & 0x10) >> 4) << (l + 0); + qh |= ((vi1 & 0x10) >> 4) << (l + 1); + } + + memcpy(&y[i].qh, &qh, sizeof(y[i].qh)); + } +} + +static void quantize_row_q5_0(const float * restrict x, void * restrict vy, int k) { + assert(k % QK5_0 == 0); + + block_q5_0 * restrict y = vy; + + quantize_row_q5_0_reference(x, y, k); +} + static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) { assert(k % QK5_1 == 0); const int nb = k / QK5_1; @@ -1861,6 +1918,42 @@ static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, in } } +static void dequantize_row_q5_0(const void * restrict vx, float * restrict y, int k) { + assert(k % QK5_0 == 0); + const int nb = k / QK5_0; + + const block_q5_0 * restrict x = vx; + + for (int i = 0; i < nb; i++) { + const float d = GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict pp = x[i].qs; + + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); + + for (int l = 0; l < QK5_0; l += 2) { + const uint8_t vi = pp[l/2]; + + // extract the 5-th bit from qh + const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; + const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; + + const int8_t vi0 = (vi & 0x0F) | vh0; + const int8_t vi1 = (vi >> 4) | vh1; + + const float v0 = (vi0 - 16)*d; + const float v1 = (vi1 - 16)*d; + + y[i*QK5_0 + l + 0] = v0; + y[i*QK5_0 + l + 1] = v1; + + assert(!isnan(y[i*QK5_0 + l + 0])); + assert(!isnan(y[i*QK5_0 + l + 1])); + } + } +} + static void dequantize_row_q5_1(const void * restrict vx, float * restrict y, int k) { assert(k % QK5_1 == 0); const int nb = k / QK5_1; @@ -1918,6 +2011,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); +static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); @@ -1954,6 +2048,14 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { .vec_dot_q = ggml_vec_dot_q4_3_q8_1, .vec_dot_type = GGML_TYPE_Q8_1, }, + [GGML_TYPE_Q5_0] = { + .dequantize_row_q = dequantize_row_q5_0, + .quantize_row_q = quantize_row_q5_0, + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_0_reference, + .quantize_row_q_dot = quantize_row_q8_0, + .vec_dot_q = ggml_vec_dot_q5_0_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, + }, [GGML_TYPE_Q5_1] = { .dequantize_row_q = dequantize_row_q5_1, .quantize_row_q = quantize_row_q5_1, @@ -3169,6 +3271,141 @@ static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * #endif } +static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int nb = n / QK8_0; + + assert(n % QK8_0 == 0); + assert(nb % 2 == 0); + assert(QK8_0 == QK5_0); + + const block_q5_0 * restrict x = vx; + const block_q8_0 * restrict y = vy; + +#if defined(__ARM_NEON) + float32x4_t sumv = vdupq_n_f32(0.0f); + + uint64_t tmp[4]; + + for (int i = 0; i < nb; ++i) { + const block_q5_0 * restrict x0 = &x[i]; + const block_q8_0 * restrict y0 = &y[i]; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + const int8x16_t s16b = vdupq_n_s8(0x10); + + // extract the 5th bit + uint32_t qh; + memcpy(&qh, x0->qh, sizeof(qh)); + + tmp[0] = table_b2b[(qh >> 0) & 0xFF]; + tmp[1] = table_b2b[(qh >> 8) & 0xFF]; + tmp[2] = table_b2b[(qh >> 16) & 0xFF]; + tmp[3] = table_b2b[(qh >> 24) ]; + + const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0)); + const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2)); + + const uint8x16_t v0 = vld1q_u8(x0->qs); + + // 4-bit -> 8-bit + const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, m4b)); + const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4)); + + // interleave + const int8x16_t v0lz = vzip1q_s8(v0l, v0h); + const int8x16_t v0hz = vzip2q_s8(v0l, v0h); + + // add high bit and sub 16 + const int8x16_t v0lf = vsubq_s8(vorrq_s8(v0lz, qhl), s16b); + const int8x16_t v0hf = vsubq_s8(vorrq_s8(v0hz, qhh), s16b); + + // load y + const int8x16_t v1l = vld1q_s8(y0->qs); + const int8x16_t v1h = vld1q_s8(y0->qs + 16); + + const float x0d = GGML_FP16_TO_FP32(x0->d); + +#if defined(__ARM_FEATURE_DOTPROD) + sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), v0lf, v1l), + vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d); +#else + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h)); + + const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); + const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); + + sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d); +#endif + } + + *s = vaddvq_f32(sumv); +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (int i = 0; i < nb; i++) { + /* Compute combined scale for the block */ + const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d)); + const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d)); + const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d)); + + __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs); + __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs); + __m256i bx = _mm256_set_m128i(bx1, bx0); + + // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. + const __m256i off = _mm256_set1_epi8(8); + bx = _mm256_sub_epi8(bx, off); + + __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q = mul_sum_i8_pairs_float(bx, by); + + /* Multiply q with scale and accumulate */ + acc = _mm256_fmadd_ps(d, q, acc); + } + + *s = hsum_float_8(acc); +#else + // scalar + float sumf = 0.0; + for (int i = 0; i < nb; i++) { + const uint8_t * restrict x0 = x[i].qs; + const int8_t * restrict y0 = y[i].qs; + + uint32_t qh; + memcpy(&qh, x0->qh, sizeof(qh)); + + const float d = GGML_FP16_TO_FP32(x[i].d); + + int sxy = 0; + + for (int j = 0; j < QK8_0/2; j++) { + const uint8_t v0 = x0[j]; + + const int x0_0h = ((qh & (1 << (2*j + 0))) >> (2*j + 0)) << 4; + const int x1_0h = ((qh & (1 << (2*j + 1))) >> (2*j + 1)) << 4; + + const int x0_0 = ((v0 & 0x0F) | x0_0h) - 16; + const int x1_0 = ((v0 >> 4) | x1_0h) - 16; + + const int y0_0 = y0[2*j + 0]; + const int y1_0 = y0[2*j + 1]; + + sxy += x0_0*y0_0 + x1_0*y1_0; + } + + sumf += (d*sxy)*y[i].d; + } + *s = sumf; +#endif +} + static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { const int nb = n / QK8_1; @@ -3646,6 +3883,7 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_1] = QK4_1, [GGML_TYPE_Q4_2] = QK4_2, [GGML_TYPE_Q4_3] = QK4_3, + [GGML_TYPE_Q5_0] = QK5_0, [GGML_TYPE_Q5_1] = QK5_1, [GGML_TYPE_Q8_0] = QK8_0, [GGML_TYPE_Q8_1] = QK8_1, @@ -3653,7 +3891,7 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_I16] = 1, [GGML_TYPE_I32] = 1, }; -static_assert(GGML_TYPE_COUNT == 12, "GGML_BLCK_SIZE is outdated"); +static_assert(GGML_TYPE_COUNT == 13, "GGML_BLCK_SIZE is outdated"); static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_F32] = sizeof(float), @@ -3662,6 +3900,7 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_1] = sizeof(block_q4_1), [GGML_TYPE_Q4_2] = sizeof(block_q4_2), [GGML_TYPE_Q4_3] = sizeof(block_q4_3), + [GGML_TYPE_Q5_0] = sizeof(block_q5_0), [GGML_TYPE_Q5_1] = sizeof(block_q5_1), [GGML_TYPE_Q8_0] = sizeof(block_q8_0), [GGML_TYPE_Q8_1] = sizeof(block_q8_1), @@ -3669,7 +3908,7 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_I16] = sizeof(int16_t), [GGML_TYPE_I32] = sizeof(int32_t), }; -static_assert(GGML_TYPE_COUNT == 12, "GGML_TYPE_SIZE is outdated"); +static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_SIZE is outdated"); static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = { @@ -3679,6 +3918,7 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_1] = "q4_1", [GGML_TYPE_Q4_2] = "q4_2", [GGML_TYPE_Q4_3] = "q4_3", + [GGML_TYPE_Q5_0] = "q5_0", [GGML_TYPE_Q5_1] = "q5_1", [GGML_TYPE_Q8_0] = "q8_0", [GGML_TYPE_Q8_1] = "q8_1", @@ -3686,7 +3926,7 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = { [GGML_TYPE_I16] = "i16", [GGML_TYPE_I32] = "i32", }; -static_assert(GGML_TYPE_COUNT == 12, "GGML_TYPE_NAME is outdated"); +static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_NAME is outdated"); static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = { [GGML_TYPE_F32] = false, @@ -3695,6 +3935,7 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_1] = true, [GGML_TYPE_Q4_2] = true, [GGML_TYPE_Q4_3] = true, + [GGML_TYPE_Q5_0] = true, [GGML_TYPE_Q5_1] = true, [GGML_TYPE_Q8_0] = true, [GGML_TYPE_Q8_1] = true, @@ -3702,7 +3943,7 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = { [GGML_TYPE_I16] = false, [GGML_TYPE_I32] = false, }; -static_assert(GGML_TYPE_COUNT == 12, "GGML_IS_QUANTIZED is outdated"); +static_assert(GGML_TYPE_COUNT == 13, "GGML_IS_QUANTIZED is outdated"); static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "NONE", @@ -6923,6 +7164,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: case GGML_TYPE_Q4_3: + case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: { @@ -8412,6 +8654,9 @@ static void ggml_compute_forward_mul_mat_q_f32( else if (type == GGML_TYPE_Q4_3) { dequantize_row_q_cuda = dequantize_row_q4_3_cuda; } + else if (type == GGML_TYPE_Q5_0) { + dequantize_row_q_cuda = dequantize_row_q5_0_cuda; + } else if (type == GGML_TYPE_Q5_1) { dequantize_row_q_cuda = dequantize_row_q5_1_cuda; } @@ -8573,6 +8818,7 @@ static void ggml_compute_forward_mul_mat( case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: case GGML_TYPE_Q4_3: + case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: @@ -8804,6 +9050,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: case GGML_TYPE_Q4_3: + case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: @@ -12598,6 +12845,36 @@ size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * return (n/QK4_3*sizeof(block_q4_3)); } +size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK5_0 == 0); + const int nb = k / QK5_0; + + for (int j = 0; j < n; j += k) { + block_q5_0 * restrict y = (block_q5_0 *)dst + j/QK5_0; + + quantize_row_q5_0_reference(src + j, y, k); + + for (int i = 0; i < nb; i++) { + for (int l = 0; l < QK5_0; l += 2) { + uint32_t qh; + memcpy(&qh, &y[i].qh, sizeof(qh)); + + const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; + const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; + + // cast to 16 bins + const uint8_t vi0 = ((y[i].qs[l/2] & 0x0F) | vh0) / 2; + const uint8_t vi1 = ((y[i].qs[l/2] >> 4) | vh1) / 2; + + hist[vi0]++; + hist[vi1]++; + } + } + } + + return (n/QK5_0*sizeof(block_q5_0)); +} + size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist) { assert(k % QK5_1 == 0); const int nb = k / QK5_1; @@ -12673,6 +12950,12 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i block_q4_3 * block = (block_q4_3*)dst + start / QK4_3; result = ggml_quantize_q4_3(src + start, block, n, n, hist); } break; + case GGML_TYPE_Q5_0: + { + GGML_ASSERT(start % QK5_0 == 0); + block_q5_0 * block = (block_q5_0*)dst + start / QK5_0; + result = ggml_quantize_q5_0(src + start, block, n, n, hist); + } break; case GGML_TYPE_Q5_1: { GGML_ASSERT(start % QK5_1 == 0); diff --git a/ggml.h b/ggml.h index 2784afc3d70b6..d9d3d214e84e7 100644 --- a/ggml.h +++ b/ggml.h @@ -222,9 +222,10 @@ extern "C" { GGML_TYPE_Q4_1 = 3, GGML_TYPE_Q4_2 = 4, GGML_TYPE_Q4_3 = 5, - GGML_TYPE_Q5_1 = 6, - GGML_TYPE_Q8_0 = 7, - GGML_TYPE_Q8_1 = 8, + GGML_TYPE_Q5_0 = 6, + GGML_TYPE_Q5_1 = 7, + GGML_TYPE_Q8_0 = 8, + GGML_TYPE_Q8_1 = 9, GGML_TYPE_I8, GGML_TYPE_I16, GGML_TYPE_I32, @@ -834,6 +835,7 @@ extern "C" { GGML_API size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist); + GGML_API size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist); diff --git a/llama.cpp b/llama.cpp index 9b167f9712e18..2ae6cedb25085 100644 --- a/llama.cpp +++ b/llama.cpp @@ -484,6 +484,7 @@ struct llama_file_loader { case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: case GGML_TYPE_Q4_3: + case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: break; @@ -560,6 +561,7 @@ struct llama_file_saver { case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: case GGML_TYPE_Q4_3: + case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: break; @@ -852,6 +854,7 @@ static const char *llama_ftype_name(enum llama_ftype ftype) { return "mostly Q4_1, some F16"; case LLAMA_FTYPE_MOSTLY_Q4_2: return "mostly Q4_2"; case LLAMA_FTYPE_MOSTLY_Q4_3: return "mostly Q4_3"; + case LLAMA_FTYPE_MOSTLY_Q5_0: return "mostly Q5_0"; case LLAMA_FTYPE_MOSTLY_Q5_1: return "mostly Q5_1"; case LLAMA_FTYPE_MOSTLY_Q8_0: return "mostly Q8_0"; default: return "unknown, may not work"; @@ -1591,6 +1594,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_Q4_1: quantized_type = GGML_TYPE_Q4_1; break; case LLAMA_FTYPE_MOSTLY_Q4_2: quantized_type = GGML_TYPE_Q4_2; break; case LLAMA_FTYPE_MOSTLY_Q4_3: quantized_type = GGML_TYPE_Q4_3; break; + case LLAMA_FTYPE_MOSTLY_Q5_0: quantized_type = GGML_TYPE_Q5_0; break; case LLAMA_FTYPE_MOSTLY_Q5_1: quantized_type = GGML_TYPE_Q5_1; break; case LLAMA_FTYPE_MOSTLY_Q8_0: quantized_type = GGML_TYPE_Q8_0; break; default: throw format("invalid output file type %d\n", ftype); diff --git a/llama.h b/llama.h index ef5e7a7f54b29..3b6c6cd62dcf8 100644 --- a/llama.h +++ b/llama.h @@ -75,6 +75,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors }; From 982bfce67895162b15f8ab48c824787bac123eec Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 26 Apr 2023 20:45:03 +0300 Subject: [PATCH 09/10] quantize : add Q5_0 and Q5_1 to map --- examples/quantize/quantize.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index ec7f91aae6bf0..60966595e9561 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -10,6 +10,8 @@ static const std::map LLAMA_FTYPE_MAP = { {"q4_1", LLAMA_FTYPE_MOSTLY_Q4_1}, {"q4_2", LLAMA_FTYPE_MOSTLY_Q4_2}, {"q4_3", LLAMA_FTYPE_MOSTLY_Q4_3}, + {"q5_0", LLAMA_FTYPE_MOSTLY_Q5_0}, + {"q5_1", LLAMA_FTYPE_MOSTLY_Q5_1}, {"q8_0", LLAMA_FTYPE_MOSTLY_Q8_0}, }; From 2bfa1fe8e76ccd214b47a0ed6d77c8ad2b276605 Mon Sep 17 00:00:00 2001 From: Stephan Walter Date: Wed, 26 Apr 2023 19:38:15 +0000 Subject: [PATCH 10/10] ggml : AVX2 optimizations for Q5_0, Q5_1 (#1195) --- ggml.c | 105 ++++++++++++++++++++++++++++++--------------------------- 1 file changed, 55 insertions(+), 50 deletions(-) diff --git a/ggml.c b/ggml.c index 90eb48fd70324..03b4bd439f299 100644 --- a/ggml.c +++ b/ggml.c @@ -328,8 +328,18 @@ static ggml_fp16_t table_exp_f16[1 << 16]; // precomputed f32 table for f16 (256 KB) static float table_f32_f16[1 << 16]; -// precomputed table for expanding 8bits to 8 bytes (shl 4) -static uint64_t table_b2b[1 << 8]; +#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s +#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s) +#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s) +#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s) +#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s) +#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s) +#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s) +#define B8(c,s ) B7(c,s, c), B7(c,s, s) + +// precomputed tables for expanding 8bits to 8 bytes (shl 4) +static const uint64_t table_b2b_u[1 << 8] = { B8(00, 10) }; +static const uint64_t table_b2b_i[1 << 8] = { B8(F0, 00) }; // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32, // so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON. @@ -688,7 +698,7 @@ static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5 typedef struct { ggml_fp16_t d; // delta ggml_fp16_t m; // min - uint32_t qh; // 5-th bit of quants + uint8_t qh[4]; // 5-th bit of quants uint8_t qs[QK5_1 / 2]; // nibbles / quants } block_q5_1; static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); @@ -1376,7 +1386,8 @@ static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * r y[i].d = GGML_FP32_TO_FP16(d); y[i].m = GGML_FP32_TO_FP16(min); - y[i].qh = 0; + + uint32_t qh = 0; for (int l = 0; l < QK5_1; l += 2) { const float v0 = (x[i*QK5_1 + l + 0] - min)*id; @@ -1388,9 +1399,11 @@ static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * r y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4); // get the 5-th bit and store it in qh at the right position - y[i].qh |= ((vi0 & 0x10) >> 4) << (l + 0); - y[i].qh |= ((vi1 & 0x10) >> 4) << (l + 1); + qh |= ((vi0 & 0x10) >> 4) << (l + 0); + qh |= ((vi1 & 0x10) >> 4) << (l + 1); } + + memcpy(&y[i].qh, &qh, sizeof(y[i].qh)); } } @@ -1966,7 +1979,8 @@ static void dequantize_row_q5_1(const void * restrict vx, float * restrict y, in const uint8_t * restrict pp = x[i].qs; - const uint32_t qh = x[i].qh; + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); for (int l = 0; l < QK5_1; l += 2) { const uint8_t vi = pp[l/2]; @@ -3297,10 +3311,10 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * uint32_t qh; memcpy(&qh, x0->qh, sizeof(qh)); - tmp[0] = table_b2b[(qh >> 0) & 0xFF]; - tmp[1] = table_b2b[(qh >> 8) & 0xFF]; - tmp[2] = table_b2b[(qh >> 16) & 0xFF]; - tmp[3] = table_b2b[(qh >> 24) ]; + tmp[0] = table_b2b_u[(qh >> 0) & 0xFF]; + tmp[1] = table_b2b_u[(qh >> 8) & 0xFF]; + tmp[2] = table_b2b_u[(qh >> 16) & 0xFF]; + tmp[3] = table_b2b_u[(qh >> 24) ]; const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0)); const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2)); @@ -3350,17 +3364,13 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * // Main loop for (int i = 0; i < nb; i++) { /* Compute combined scale for the block */ - const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d)); - const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d)); - const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d)); - - __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs); - __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs); - __m256i bx = _mm256_set_m128i(bx1, bx0); + const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d)); - // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. - const __m256i off = _mm256_set1_epi8(8); - bx = _mm256_sub_epi8(bx, off); + __m256i bx = bytes_from_nibbles_32(x[i].qs); + const __m256i bxhi = _mm256_set_epi64x( + table_b2b_i[x[i].qh[3]], table_b2b_i[x[i].qh[2]], + table_b2b_i[x[i].qh[1]], table_b2b_i[x[i].qh[0]]); + bx = _mm256_or_si256(bx, bxhi); __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); @@ -3379,7 +3389,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * const int8_t * restrict y0 = y[i].qs; uint32_t qh; - memcpy(&qh, x0->qh, sizeof(qh)); + memcpy(&qh, x[i].qh, sizeof(qh)); const float d = GGML_FP16_TO_FP32(x[i].d); @@ -3430,12 +3440,13 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * summs += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1); // extract the 5th bit - const uint32_t qh = x0->qh; + uint32_t qh; + memcpy(&qh, x0->qh, sizeof(qh)); - tmp[0] = table_b2b[(qh >> 0) & 0xFF]; - tmp[1] = table_b2b[(qh >> 8) & 0xFF]; - tmp[2] = table_b2b[(qh >> 16) & 0xFF]; - tmp[3] = table_b2b[(qh >> 24) ]; + tmp[0] = table_b2b_u[(qh >> 0) & 0xFF]; + tmp[1] = table_b2b_u[(qh >> 8) & 0xFF]; + tmp[2] = table_b2b_u[(qh >> 16) & 0xFF]; + tmp[3] = table_b2b_u[(qh >> 24) ]; const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0)); const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2)); @@ -3485,16 +3496,15 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * // Main loop for (int i = 0; i < nb; i++) { - const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d)); - const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d)); - const __m256 dx = _mm256_set_m128(d1, d0); + const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)); - summs += GGML_FP16_TO_FP32(x[2*i + 0].m) * y[i].s0 - + GGML_FP16_TO_FP32(x[2*i + 1].m) * y[i].s1; + summs += GGML_FP16_TO_FP32(x[i].m) * (y[i].s0 + y[i].s1); - const __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs); - const __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs); - const __m256i bx = _mm256_set_m128i(bx1, bx0); + __m256i bx = bytes_from_nibbles_32(x[i].qs); + const __m256i bxhi = _mm256_set_epi64x( + table_b2b_u[x[i].qh[3]], table_b2b_u[x[i].qh[2]], + table_b2b_u[x[i].qh[1]], table_b2b_u[x[i].qh[0]]); + bx = _mm256_or_si256(bx, bxhi); const __m256 dy = _mm256_broadcast_ss(&y[i].d); const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); @@ -3512,7 +3522,8 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * const uint8_t * restrict x0 = x[i].qs; const int8_t * restrict y0 = y[i].qs; - const uint32_t qh = x[i].qh; + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); const float d = GGML_FP16_TO_FP32(x[i].d); const float m = GGML_FP16_TO_FP32(x[i].m); @@ -4297,15 +4308,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f)); } - for (int i = 0; i < 256; ++i) { - table_b2b[i] = 0; - for (int b = 0; b < 8; ++b) { - table_b2b[i] |= ((uint64_t)(((i >> b) & 0x01) << 4)) << (8*b); - } - - //printf("%3d %016llx\n", i, table_b2b[i]); - } - const uint64_t t_end = ggml_time_us(); UNUSED(t_end); GGML_PRINT_DEBUG("%s: GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); @@ -12855,10 +12857,10 @@ size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * quantize_row_q5_0_reference(src + j, y, k); for (int i = 0; i < nb; i++) { - for (int l = 0; l < QK5_0; l += 2) { - uint32_t qh; - memcpy(&qh, &y[i].qh, sizeof(qh)); + uint32_t qh; + memcpy(&qh, &y[i].qh, sizeof(qh)); + for (int l = 0; l < QK5_0; l += 2) { const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; @@ -12885,9 +12887,12 @@ size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * quantize_row_q5_1_reference(src + j, y, k); for (int i = 0; i < nb; i++) { + uint32_t qh; + memcpy(&qh, &y[i].qh, sizeof(qh)); + for (int l = 0; l < QK5_1; l += 2) { - const uint8_t vh0 = ((y[i].qh & (1 << (l + 0))) >> (l + 0)) << 4; - const uint8_t vh1 = ((y[i].qh & (1 << (l + 1))) >> (l + 1)) << 4; + const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; + const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; // cast to 16 bins const uint8_t vi0 = ((y[i].qs[l/2] & 0x0F) | vh0) / 2;