Thanks to visit codestin.com
Credit goes to github.com

Skip to content

ggml : PoC for normalizing weights for better quantization packing #2434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 66 additions & 31 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -163,21 +163,31 @@ typedef float2 dfloat2;
#endif //GGML_CUDA_F16

static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) {
const uint16_t * x16 = (uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
//const uint16_t * x16 = (uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
x8 += sizeof(int) * i32;

int x32 = 0;
x32 |= x16[0] << 0;
x32 |= x16[1] << 16;
//x32 |= x16[0] << 0;
//x32 |= x16[1] << 16;
x32 |= ((uint32_t)(x8[0])) << 0;
x32 |= ((uint32_t)(x8[1])) << 8;
x32 |= ((uint32_t)(x8[2])) << 16;
x32 |= ((uint32_t)(x8[3])) << 24;

return x32;
}

static __device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, const int & i32) {
const uint16_t * x16 = (uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
//const uint16_t * x16 = (uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
x8 += sizeof(int) * i32;

int x32 = 0;
x32 |= x16[0] << 0;
x32 |= x16[1] << 16;
//x32 |= x16[0] << 0;
//x32 |= x16[1] << 16;
x32 |= ((uint32_t)(x8[0])) << 0;
x32 |= ((uint32_t)(x8[1])) << 8;
x32 |= ((uint32_t)(x8[2])) << 16;
x32 |= ((uint32_t)(x8[3])) << 24;

return x32;
}
Expand All @@ -204,23 +214,31 @@ typedef void (*ggml_cuda_op_t)(
// QR = QK / number of values before dequantization
// QI = number of 32 bit integers before dequantization

#define Q4_0DM (1.0f/8.0f)
#define Q4_0D(x) (((x)*Q4_0DM) / 127.0f)

#define QK4_0 32
#define QR4_0 2
#define QI4_0 (QK4_0 / (4 * QR4_0))
typedef struct {
half d; // delta
int8_t d; // delta
uint8_t qs[QK4_0 / 2]; // nibbles / quants
} block_q4_0;
static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
static_assert(sizeof(block_q4_0) == sizeof(int8_t) + QK4_0 / 2, "wrong q4_0 block size/padding");

#define Q4_1DM (2.0f/15.0f)
#define Q4_1MM (2.0f )
#define Q4_1D(x) ( (((x) & 0xFF)*Q4_1DM) / 255.0f)
#define Q4_1M(x) (-1.0f + (((x) >> 8)*Q4_1MM) / 255.0f)

#define QK4_1 32
#define QR4_1 2
#define QI4_1 (QK4_1 / (4 * QR4_1))
typedef struct {
half2 dm; // dm.x = delta, dm.y = min
uint8_t qs[QK4_1 / 2]; // nibbles / quants
uint16_t dm; // 8-bit delta + 8-bit min (can be adjusted easily)
uint8_t qs[QK4_1 / 2]; // nibbles / quants
} block_q4_1;
static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
static_assert(sizeof(block_q4_1) == sizeof(uint16_t) + QK4_1 / 2, "wrong q4_1 block size/padding");

#define QK5_0 32
#define QR5_0 2
Expand All @@ -232,15 +250,20 @@ typedef struct {
} block_q5_0;
static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");

#define Q5_1DM (2.0f/31.0f)
#define Q5_1MM (2.0f )
#define Q5_1D(x) ( (((x) & 0x0F)*Q5_1DM) / 15.0f)
#define Q5_1M(x) (-1.0f + (((x) >> 4)*Q5_1MM) / 15.0f)

#define QK5_1 32
#define QR5_1 2
#define QI5_1 (QK5_1 / (4 * QR5_1))
typedef struct {
half2 dm; // dm.x = delta, dm.y = min
uint8_t dm; // 4-bit delta + 4-bit min
uint8_t qh[4]; // 5-th bit of quants
uint8_t qs[QK5_1 / 2]; // nibbles / quants
} block_q5_1;
static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
static_assert(sizeof(block_q5_1) == sizeof(uint8_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");

#define QK8_0 32
#define QR8_0 1
Expand Down Expand Up @@ -506,7 +529,7 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
const block_q4_0 * x = (const block_q4_0 *) vx;

const dfloat d = x[ib].d;
const dfloat d = Q4_0D(x[ib].d);

const int vui = x[ib].qs[iqs];

Expand All @@ -525,8 +548,8 @@ static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const in
static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
const block_q4_1 * x = (const block_q4_1 *) vx;

const dfloat d = __low2half(x[ib].dm);
const dfloat m = __high2half(x[ib].dm);
const dfloat d = Q4_1D(x[ib].dm);
const dfloat m = Q4_1M(x[ib].dm);

const int vui = x[ib].qs[iqs];

Expand Down Expand Up @@ -568,8 +591,8 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in
static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
const block_q5_1 * x = (const block_q5_1 *) vx;

const dfloat d = __low2half(x[ib].dm);
const dfloat m = __high2half(x[ib].dm);
const dfloat d = Q5_1D(x[ib].dm);
const dfloat m = Q5_1M(x[ib].dm);

uint32_t qh;
memcpy(&qh, x[ib].qh, sizeof(qh));
Expand Down Expand Up @@ -2041,7 +2064,7 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0);
}

return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMVQ>(v, u, bq4_0->d, bq8_1->ds);
return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMVQ>(v, u, Q4_0D(bq4_0->d), bq8_1->ds);
}

template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
Expand Down Expand Up @@ -2080,7 +2103,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx;

x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
// x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d;
//x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = Q4_0D(bxi->d);
}

const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
Expand All @@ -2096,7 +2119,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin

const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd;

x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d;
x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = Q4_0D(bxi->d);
}
}

Expand Down Expand Up @@ -2130,12 +2153,17 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(

#pragma unroll
for (int i = 0; i < VDR_Q4_1_Q8_1_MMVQ; ++i) {
v[i] = get_int_from_uint8_aligned(bq4_1->qs, iqs + i);
v[i] = get_int_from_uint8(bq4_1->qs, iqs + i);
u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_1);
}

return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMVQ>(v, u, bq4_1->dm, bq8_1->ds);
const float d = Q4_1D(bq4_1->dm);
const float m = Q4_1M(bq4_1->dm);

const half2 dm = {d, m};

return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMVQ>(v, u, dm, bq8_1->ds);
}

template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
Expand Down Expand Up @@ -2171,7 +2199,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin

const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx;

x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
}

const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
Expand All @@ -2187,7 +2215,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin

const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd;

x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm;
x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd].x = Q4_1D(bxi->dm);
x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd].y = Q4_1M(bxi->dm);
}
}

Expand Down Expand Up @@ -2335,13 +2364,18 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(

#pragma unroll
for (int i = 0; i < VDR_Q5_1_Q8_1_MMVQ; ++i) {
vl[i] = get_int_from_uint8_aligned(bq5_1->qs, iqs + i);
vh[i] = get_int_from_uint8_aligned(bq5_1->qh, 0) >> (4 * (iqs + i));
vl[i] = get_int_from_uint8(bq5_1->qs, iqs + i);
vh[i] = get_int_from_uint8(bq5_1->qh, 0) >> (4 * (iqs + i));
u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_1);
}

return vec_dot_q5_1_q8_1_impl<VDR_Q5_1_Q8_1_MMVQ>(vl, vh, u, bq5_1->dm, bq8_1->ds);
const half d = Q5_1D(bq5_1->dm);
const half m = Q5_1M(bq5_1->dm);

const half2 dm = {d, m};

return vec_dot_q5_1_q8_1_impl<VDR_Q5_1_Q8_1_MMVQ>(vl, vh, u, dm, bq8_1->ds);
}

template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
Expand Down Expand Up @@ -2377,8 +2411,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin

const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbx;

const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (k % QI5_1));
const int ql = get_int_from_uint8(bxi->qs, kqsx);
const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (k % QI5_1));

int qs0 = (ql >> 0) & 0x0F0F0F0F;
qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
Expand Down Expand Up @@ -2410,7 +2444,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin

const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbxd;

x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm;
x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd].x = Q5_1D(bxi->dm);
x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd].y = Q5_1M(bxi->dm);
}
}

Expand Down
7 changes: 5 additions & 2 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,9 @@ void ggml_metal_graph_compute(
} break;
case GGML_OP_MUL:
{
GGML_ASSERT(ne00 % 4 == 0);
const int64_t nb = ne00/4;

if (ggml_nelements(src1) == ne10) {
// src1 is a row
[encoder setComputePipelineState:ctx->pipeline_mul_row];
Expand All @@ -706,9 +709,9 @@ void ggml_metal_graph_compute(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&nb length:sizeof(nb) atIndex:3];

const int64_t n = ggml_nelements(dst);
const int64_t n = ggml_nelements(dst)/4;

[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
Expand Down
67 changes: 40 additions & 27 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,22 @@ using namespace metal;

#define MAX(x, y) ((x) > (y) ? (x) : (y))

#define Q4_0DM (1.0f/8.0f)
#define Q4_0D(x) (((x)*Q4_0DM) / 127.0f)
#define QK4_0 32
#define QR4_0 2
typedef struct {
half d; // delta
int8_t d; // delta
uint8_t qs[QK4_0 / 2]; // nibbles / quants
} block_q4_0;

#define Q4_1DM (2.0f/15.0f)
#define Q4_1MM (2.0f )
#define Q4_1D(x) ( (((x) & 0xFF)*Q4_1DM) / 255.0f)
#define Q4_1M(x) (-1.0f + (((x) >> 8)*Q4_1MM) / 255.0f)
#define QK4_1 32
typedef struct {
half d; // delta
half m; // min
uint16_t dm;
uint8_t qs[QK4_1 / 2]; // nibbles / quants
} block_q4_1;

Expand Down Expand Up @@ -44,22 +49,22 @@ kernel void kernel_add_row(
}

kernel void kernel_mul(
device const float * src0,
device const float * src1,
device float * dst,
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * src1[tpig];
}

// assumption: src1 is a row
// broadcast src1 into src0
kernel void kernel_mul_row(
device const float * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
constant int64_t & nb,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * src1[tpig % ne00];
dst[tpig] = src0[tpig] * src1[tpig % nb];
}

kernel void kernel_scale(
Expand Down Expand Up @@ -314,14 +319,18 @@ kernel void kernel_rms_norm(
// we assume that the yl's have been multiplied with the appropriate scale factor
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
float d = qb_curr->d;
float d = Q4_0D(qb_curr->d);
float2 acc = 0.f;
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
device const uint8_t * qs = ((device const uint8_t *)qb_curr->qs + il);
uint16_t qs16;
for (int i = 0; i < 8; i+=2) {
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
+ yl[i + 9] * (qs[i / 2] & 0xF000);
qs16 = qs[i+1];
qs16 <<= 8;
qs16 |= qs[i];
acc[0] += yl[i + 0] * (qs16 & 0x000F)
+ yl[i + 1] * (qs16 & 0x0F00);
acc[1] += yl[i + 8] * (qs16 & 0x00F0)
+ yl[i + 9] * (qs16 & 0xF000);
}
return d * (sumy * -8.f + acc[0] + acc[1]);
}
Expand All @@ -331,9 +340,9 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre
// we assume that the yl's have been multiplied with the appropriate scale factor
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
float d = qb_curr->d;
float m = qb_curr->m;
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
float d = Q4_1D(qb_curr->dm);
float m = Q4_1M(qb_curr->dm);
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
float2 acc = 0.f;
for (int i = 0; i < 8; i+=2) {
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
Expand Down Expand Up @@ -1686,23 +1695,27 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)

template <typename type4x4>
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
const half d = il ? (xb->d / 16.h) : xb->d;
device const uint8_t * qs = ((device const uint8_t *)xb->qs);
const half d = il ? (Q4_0D(xb->d) / 16.h) : Q4_0D(xb->d);
const half m = il ? ( -8.h * 16.h) : -8.h;
const ushort mask0 = il ? 0x00F0 : 0x000F;
const ushort mask1 = il ? 0xF000 : 0x0F00;

uint16_t qs16;
for (int i=0;i<8;i++) {
reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d;
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
qs16 = qs[2*i+1];
qs16 <<= 8;
qs16 |= qs[2*i];
reg[i/2][2*(i%2)] = (((qs16 & mask0) ) + m) * d;
reg[i/2][2*(i%2)+1] = (((qs16 & mask1) >> 8) + m) * d;
}
}

template <typename type4x4>
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
const half d = il ? (xb->d / 16.h) : xb->d;
const half m = xb->m;
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
const half d = il ? (Q4_1D(xb->dm) / 16.h) : Q4_1D(xb->dm);
const half m = Q4_1M(xb->dm);
const ushort mask0 = il ? 0x00F0 : 0x000F;
const ushort mask1 = il ? 0xF000 : 0x0F00;

Expand Down
Loading