Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit fc4ceb7

Browse files
committed
quantize/dequantize, mul_mat_vec kernels
1 parent 8ddc0e9 commit fc4ceb7

File tree

3 files changed

+217
-1
lines changed

3 files changed

+217
-1
lines changed
Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,101 @@
11
#define QK8_0 32
22
#define QR8_0 1
3+
#define WARP_SIZE 32
4+
5+
typedef struct {
6+
half d;
7+
int8_t qs[QK8_0];
8+
} block_q8_0;
39

410
static void dequantize_mul_mat_vec_q8_0_cuda(const void* vx, const dfloat * y,
511
float *dst, const int ncols, const int nrows, cudaStream_t stream)
612
{
713
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
14+
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
15+
const dim3 block_nums(1, block_num_y, 1);
16+
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
17+
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
18+
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
19+
}
20+
21+
static __device__ __forceinline__ void dequantize_q8_0(const void *vx, const int ib, const int iqs, dfloat2 & v) {
22+
// 均匀对称量化
23+
// dequantize is int8 * scale
24+
const block_q8_0* x = (const block_q8_0*) vx;
25+
const dfloat2 d = x[ib].d; // scale
26+
27+
v.x = x[ib].qs[iqs+0];
28+
v.y = x[ib].qs[iqs+1];
29+
30+
#ifdef GGML_CUDA_FP16
31+
// FP16
32+
v = __hmul2(v, {d, d});
33+
#else
34+
// FP32
35+
v.x *= d;
36+
v.y *= d;
37+
#endif
38+
}
39+
40+
template<int qk, int qr, dequantize_kernel_t dequantize_kernel>
41+
static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat *__restrict__ y,
42+
float* __restrict__ dst, const int ncols, const int nrows)
43+
{
44+
// qk = quantized weights per x block
45+
// qr = number of quantized weights per data value in x block
846

9-
47+
const int row = blockIdx.y * blockDim.y + threadIdx.x;
48+
49+
if (row >= nrows) return;
50+
51+
const int tid = threadIdx.x;
52+
const int iter_stride = 2 * GGML_CUDA_DMMV_X; // 2*32
53+
const int vals_per_iter = iter_stride / WARP_SIZE;
54+
const int y_offset = qr == 1 ? 1 : qk/2;
55+
56+
#ifdef GGML_CUDA_FP16
57+
half2 tmp = {0.0f, 0.0f};
58+
#else
59+
float tmp = 0.0f;
60+
#endif
61+
// 32 threads process 4096-set data
62+
63+
for (int i =0; i < ncols; i += iter_stride) {
64+
const int col = i + vals_per_iter * tid;
65+
const int ib = (row * ncols + col) / qk;
66+
const int iqs = (col % qk) / qr;
67+
const int iybs = col - col % qk;
68+
69+
for (int j = 0; j < vals_per_iter; j+= 2) {
70+
// 2 vals per j iter
71+
72+
// dequantize
73+
// for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
74+
dfloat2 v;
75+
dequantize_kernel(vx, ib, iqs + j / qr, v);
76+
77+
#ifdef GGML_CUDA_FP16
78+
tmp += __hmul2(v, {
79+
y[iybs + iqs + j/qr + 0],
80+
y[iybs + iqs + j/qr + y_offset]
81+
});
82+
#else
83+
tmp += v.x * y[iybs + iqs + j / qr + 0];
84+
tmp += v.y * y[iybs + iqs + j / qr + y_offset];
85+
#endif
86+
}
87+
}
88+
89+
#pragma unroll
90+
for (int mask = 16; mask > 0; mask >>= 1) {
91+
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
92+
}
93+
94+
if (tid == 0) {
95+
#ifdef GGML_CUDA_FP16
96+
dst[row] = tmp.x + tmp.y;
97+
#else
98+
dst[row] = tmp;
99+
#endif
100+
}
10101
}

kernels/mul_mat_vec_q8_0_q8_1.cu

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#define VDR_Q8_0_Q8_1_MMVQ 2
2+
3+
static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst,
4+
const int ncols, const int nrows, cudaStream_t stream)
5+
{
6+
GGML_ASSERT(ncols % QK8_0 == 0);
7+
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
8+
const dim3 block_nums(1, block_num_y, 1);
9+
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
10+
// QK8_0 = 32, QI8_0 = 8
11+
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
12+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
13+
}
14+
15+
template<int vdr>
16+
static __device__ __forceinline__ float vec_dot_q8_0_q8_1_impl(
17+
const int *v, const int *u, const float & d8_0, const float & d8_1)
18+
{
19+
#if __CUDA_ARCH__ >= MIN_CC_DP4A
20+
int sumi = 0;
21+
#pragma unroll
22+
for (int i = 0; i < vdr; ++i) {
23+
sumi = __dp4a(v[i], u[i], sumi);
24+
}
25+
return d8_0 * d8_1 * sumi;
26+
#else
27+
assert(false);
28+
return 0.0f;
29+
#endif
30+
}
31+
32+
static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
33+
const void* __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs)
34+
{
35+
const block_q8_0* bq8_0 = (const block_q8_0 *) vbq;
36+
37+
int v[VDR_Q8_0_Q8_1_MMVQ];
38+
int u[VDR_Q8_0_Q8_1_MMVQ];
39+
40+
#pragma unroll
41+
for (int i = 0; i < VDR_Q8_0_Q8_1_MMVQ; ++i) {
42+
v[i] = get_int_from_int8(bq8_0->qs, iqs + i);
43+
u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs+i);
44+
}
45+
46+
return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d, bq8_1->ds.x);
47+
}
48+
49+
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
50+
static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy,
51+
float* __restrict__ dst, const int ncols, const int nrows)
52+
{
53+
const int row = blockIdx.y * blockDim.y + threadIdx.y;
54+
55+
if (row >= nrows) return;
56+
57+
const int blocks_per_row = ncols / qk;
58+
const int blocks_per_warp = vdr * WARP_SIZE / qi;
59+
60+
float tmp = 0.0f;
61+
62+
const block_q_t * x = (const block_q_t * ) vx;
63+
const block_q8_1 * y = (const block_q8_1 *) vy;
64+
65+
for (int i = 0; i < blocks_per_row; i += blocks_per_warp) {
66+
const int ibx = row * blocks_per_row + i + threadIdx.x / (qi / vdr);
67+
68+
const int iby = (i + threadIdx.x / (qi / vdr)) * (qk / QK8_1);
69+
const int iqs = vdr * (threadIdx.x % (qi / vdr));
70+
71+
tmp += vec_dot_q_cuda(&x[ibx], &y[iby], iqs);
72+
}
73+
74+
#pragma unroll
75+
for (int mask = 16; mask > 0; mask >>= 1) {
76+
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
77+
}
78+
79+
if (threadIdx.x == 0)
80+
dst[row] = tmp;
81+
}
82+

kernels/quantize_q8_1.cu

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#define QK8_1 32
2+
3+
typedef struct {
4+
half2 ds;
5+
int8_t qs[QK8_0];
6+
} block_q8_1;
7+
8+
static __global__ void quantize_q8_1(const float* __restrict__ x, void* __restrict__ vy, const int kx,
9+
const int kx_padded)
10+
{
11+
const int ix = blockDim.x * blockIdx.x + threadIdx.x; // 0-4096
12+
13+
if (ix >= kx_padded) return;
14+
15+
const int iy = blockDimx.y * blockIdx.y + threadIdx.y; // 0
16+
const int i_padded = iy * ky_padded + ix;
17+
block_q8_1* y = (block_q8_1*) vy;
18+
19+
const int ib = i_padded / QK8_1; // block index
20+
const int iqs = i_padded % QK8_1; // quant index
21+
22+
const float xi = ix < kx ? x[iy * kx + ix] : 0.0f;
23+
float amax = fabsf(xi);
24+
float sum = xi;
25+
26+
#pragma unroll
27+
for (int mask = 16; mask > 0; mask >>= 1) {
28+
amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32));
29+
sum += __shfl_xor_sync(0xffffffff, sum, mask, 32);
30+
}
31+
32+
// q = round(clip(r_i / scale, Q_{min}, Q_{max}))
33+
// scale = fmax - fmin / qmax - qmin
34+
const float d = amax / 127;
35+
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
36+
37+
y[ib].qs[iqs] = q;
38+
39+
if (iqs > 0) return;
40+
41+
y[ib].ds.x = d;
42+
y[ib].ds.y = sum;
43+
}

0 commit comments

Comments
 (0)