From 50148408b5b26db6cd3754c46ae80db7f36311eb Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Sat, 6 May 2023 12:17:45 +0200 Subject: [PATCH 1/4] More GPU threads for CUDA kernels --- ggml-cuda.cu | 207 ++++++++++++++++++++++++++++++--------------------- 1 file changed, 121 insertions(+), 86 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index e8a1e77cb06fc..7c856e9eeef46 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -80,192 +80,227 @@ typedef struct { } block_q8_0; static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding"); -static __global__ void dequantize_block_q4_0(const void * vx, float * y) { +static __global__ void dequantize_block_q4_0(const void * vx, float * y, int k) { const block_q4_0 * x = (const block_q4_0 *) vx; - const int i = blockIdx.x; + const int i = blockIdx.x*blockDim.x + threadIdx.x; - const float d = x[i].d; + if (i < k) { + const float d = x[i].d; - const uint8_t * pp = x[i].qs; + const uint8_t * pp = x[i].qs; - for (int l = 0; l < QK4_0; l += 2) { - const uint8_t vi = pp[l/2]; + for (int l = 0; l < QK4_0; l += 2) { + const uint8_t vi = pp[l/2]; - const int8_t vi0 = vi & 0xf; - const int8_t vi1 = vi >> 4; + const int8_t vi0 = vi & 0xf; + const int8_t vi1 = vi >> 4; - const float v0 = (vi0 - 8)*d; - const float v1 = (vi1 - 8)*d; + const float v0 = (vi0 - 8)*d; + const float v1 = (vi1 - 8)*d; - y[i*QK4_0 + l + 0] = v0; - y[i*QK4_0 + l + 1] = v1; + y[i*QK4_0 + l + 0] = v0; + y[i*QK4_0 + l + 1] = v1; + } } } -static __global__ void dequantize_block_q4_1(const void * vx, float * y) { +static __global__ void dequantize_block_q4_1(const void * vx, float * y, int k) { const block_q4_1 * x = (const block_q4_1 *) vx; - const int i = blockIdx.x; + const int i = blockIdx.x*blockDim.x + threadIdx.x; - const float d = x[i].d; - const float m = x[i].m; + if (i < k) { + const float d = x[i].d; + const float m = x[i].m; - const uint8_t * pp = x[i].qs; + const uint8_t * pp = x[i].qs; - for (int l = 0; l < QK4_1; l += 2) { - const uint8_t vi = pp[l/2]; + for (int l = 0; l < QK4_1; l += 2) { + const uint8_t vi = pp[l/2]; - const int8_t vi0 = vi & 0xf; - const int8_t vi1 = vi >> 4; + const int8_t vi0 = vi & 0xf; + const int8_t vi1 = vi >> 4; - const float v0 = vi0*d + m; - const float v1 = vi1*d + m; + const float v0 = vi0*d + m; + const float v1 = vi1*d + m; - y[i*QK4_1 + l + 0] = v0; - y[i*QK4_1 + l + 1] = v1; + y[i*QK4_1 + l + 0] = v0; + y[i*QK4_1 + l + 1] = v1; + } } } -static __global__ void dequantize_block_q4_2(const void * vx, float * y) { +static __global__ void dequantize_block_q4_2(const void * vx, float * y, int k) { const block_q4_2 * x = (const block_q4_2 *) vx; - const int i = blockIdx.x; + const int i = blockIdx.x*blockDim.x + threadIdx.x; - const float d = x[i].d; + if (i < k) { + const float d = x[i].d; - const uint8_t * pp = x[i].qs; + const uint8_t * pp = x[i].qs; - for (int l = 0; l < QK4_2; l += 2) { - const uint8_t vi = pp[l/2]; + for (int l = 0; l < QK4_2; l += 2) { + const uint8_t vi = pp[l/2]; - const int8_t vi0 = vi & 0xf; - const int8_t vi1 = vi >> 4; + const int8_t vi0 = vi & 0xf; + const int8_t vi1 = vi >> 4; - const float v0 = (vi0 - 8)*d; - const float v1 = (vi1 - 8)*d; + const float v0 = (vi0 - 8)*d; + const float v1 = (vi1 - 8)*d; - y[i*QK4_2 + l + 0] = v0; - y[i*QK4_2 + l + 1] = v1; + y[i*QK4_2 + l + 0] = v0; + y[i*QK4_2 + l + 1] = v1; + } } } -static __global__ void dequantize_block_q5_0(const void * vx, float * y) { +static __global__ void dequantize_block_q5_0(const void * vx, float * y, int k) { const block_q5_0 * x = (const block_q5_0 *) vx; - const int i = blockIdx.x; + const int i = blockIdx.x*blockDim.x + threadIdx.x; - const float d = x[i].d; + if (i < k) { + const float d = x[i].d; - const uint8_t * pp = x[i].qs; + const uint8_t * pp = x[i].qs; - uint32_t qh; - memcpy(&qh, x[i].qh, sizeof(qh)); + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); - for (int l = 0; l < QK5_0; l += 2) { - const uint8_t vi = pp[l/2]; + for (int l = 0; l < QK5_0; l += 2) { + const uint8_t vi = pp[l/2]; - const int8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; - const int8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; + const int8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; + const int8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; - const int8_t vi0 = ((vi & 0xf) | vh0); - const int8_t vi1 = ((vi >> 4) | vh1); + const int8_t vi0 = ((vi & 0xf) | vh0); + const int8_t vi1 = ((vi >> 4) | vh1); - const float v0 = (vi0 - 16)*d; - const float v1 = (vi1 - 16)*d; + const float v0 = (vi0 - 16)*d; + const float v1 = (vi1 - 16)*d; - y[i*QK5_0 + l + 0] = v0; - y[i*QK5_0 + l + 1] = v1; + y[i*QK5_0 + l + 0] = v0; + y[i*QK5_0 + l + 1] = v1; + } } } -static __global__ void dequantize_block_q5_1(const void * vx, float * y) { +static __global__ void dequantize_block_q5_1(const void * vx, float * y, int k) { const block_q5_1 * x = (const block_q5_1 *) vx; - const int i = blockIdx.x; + const int i = blockIdx.x*blockDim.x + threadIdx.x; - const float d = x[i].d; - const float m = x[i].m; + if (i < k) { + const float d = x[i].d; + const float m = x[i].m; - const uint8_t * pp = x[i].qs; + const uint8_t * pp = x[i].qs; - uint32_t qh; - memcpy(&qh, x[i].qh, sizeof(qh)); + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); - for (int l = 0; l < QK5_1; l += 2) { - const uint8_t vi = pp[l/2]; + for (int l = 0; l < QK5_1; l += 2) { + const uint8_t vi = pp[l/2]; - const int8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; - const int8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; + const int8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; + const int8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; - const int8_t vi0 = (vi & 0xf) | vh0; - const int8_t vi1 = (vi >> 4) | vh1; + const int8_t vi0 = (vi & 0xf) | vh0; + const int8_t vi1 = (vi >> 4) | vh1; - const float v0 = vi0*d + m; - const float v1 = vi1*d + m; + const float v0 = vi0*d + m; + const float v1 = vi1*d + m; - y[i*QK5_1 + l + 0] = v0; - y[i*QK5_1 + l + 1] = v1; + y[i*QK5_1 + l + 0] = v0; + y[i*QK5_1 + l + 1] = v1; + } } } -static __global__ void dequantize_block_q8_0(const void * vx, float * y) { +static __global__ void dequantize_block_q8_0(const void * vx, float * y, int k) { const block_q8_0 * x = (const block_q8_0 *) vx; - const int i = blockIdx.x; + const int i = blockIdx.x*blockDim.x + threadIdx.x; - const float d = x[i].d; + if (i < k) { + const float d = x[i].d; - const int8_t * pp = x[i].qs; + const int8_t * pp = x[i].qs; - for (int l = 0; l < QK8_0; l++) { - const int8_t vi = pp[l]; + for (int l = 0; l < QK8_0; l++) { + const int8_t vi = pp[l]; - y[i*QK8_0 + l] = vi*d; + y[i*QK8_0 + l] = vi*d; + } } } static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK4_0; - dequantize_block_q4_0<<>>(vx, y); + int min_grid_size, block_size = 1; // Initialize to suppress compiler warning. + CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q4_0, 0, 0)); + int grid_size = (nb + block_size - 1) / block_size; // Round up. + dequantize_block_q4_0<<>>(vx, y, k); } static void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK4_1; - dequantize_block_q4_1<<>>(vx, y); + int min_grid_size, block_size = 1; // Initialize to suppress compiler warning. + CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q4_1, 0, 0)); + int grid_size = (nb + block_size - 1) / block_size; // Round up. + dequantize_block_q4_1<<>>(vx, y, k); } static void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK4_2; - dequantize_block_q4_2<<>>(vx, y); + int min_grid_size, block_size = 1; // Initialize to suppress compiler warning. + CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q4_2, 0, 0)); + int grid_size = (nb + block_size - 1) / block_size; // Round up. + dequantize_block_q4_2<<>>(vx, y, k); } static void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK5_0; - dequantize_block_q5_0<<>>(vx, y); + int min_grid_size, block_size = 1; // Initialize to suppress compiler warning. + CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q5_0, 0, 0)); + int grid_size = (nb + block_size - 1) / block_size; // Round up. + dequantize_block_q5_0<<>>(vx, y, k); } static void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK5_1; - dequantize_block_q5_1<<>>(vx, y); + int min_grid_size, block_size = 1; // Initialize to suppress compiler warning. + CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q5_1, 0, 0)); + int grid_size = (nb + block_size - 1) / block_size; // Round up. + dequantize_block_q5_1<<>>(vx, y, k); } static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK8_0; - dequantize_block_q8_0<<>>(vx, y); + int min_grid_size, block_size = 1; // Initialize to suppress compiler warning. + CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q8_0, 0, 0)); + int grid_size = (nb + block_size - 1) / block_size; // Round up. + dequantize_block_q8_0<<>>(vx, y, k); } // TODO: optimize -static __global__ void convert_fp16_to_fp32(const void * vx, float * y) { +static __global__ void convert_fp16_to_fp32(const void * vx, float * y, int k) { const half * x = (const half *) vx; const int i = blockIdx.x; - y[i] = __half2float(x[i]); + if (i < k) { + y[i] = __half2float(x[i]); + } } static void convert_fp16_to_fp32_cuda(const void * x, float * y, int k, cudaStream_t stream) { - convert_fp16_to_fp32<<>>(x, y); + int min_grid_size, block_size = 1; // Initialize to suppress compiler warning. + CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, convert_fp16_to_fp32, 0, 0)); + int grid_size = (k + block_size - 1) / block_size; // Round up. + convert_fp16_to_fp32<<>>(x, y, k); } static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { From 8d8de07a4ebce80620da080eb152e8c671b631b3 Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Sun, 7 May 2023 18:34:04 +0200 Subject: [PATCH 2/4] fixup! More GPU threads for CUDA kernels --- ggml-cuda.cu | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 7c856e9eeef46..2f7f02162b7a9 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -242,7 +242,7 @@ static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStre int min_grid_size, block_size = 1; // Initialize to suppress compiler warning. CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q4_0, 0, 0)); int grid_size = (nb + block_size - 1) / block_size; // Round up. - dequantize_block_q4_0<<>>(vx, y, k); + dequantize_block_q4_0<<>>(vx, y, nb); } static void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) { @@ -250,7 +250,7 @@ static void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStre int min_grid_size, block_size = 1; // Initialize to suppress compiler warning. CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q4_1, 0, 0)); int grid_size = (nb + block_size - 1) / block_size; // Round up. - dequantize_block_q4_1<<>>(vx, y, k); + dequantize_block_q4_1<<>>(vx, y, nb); } static void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) { @@ -258,7 +258,7 @@ static void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStre int min_grid_size, block_size = 1; // Initialize to suppress compiler warning. CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q4_2, 0, 0)); int grid_size = (nb + block_size - 1) / block_size; // Round up. - dequantize_block_q4_2<<>>(vx, y, k); + dequantize_block_q4_2<<>>(vx, y, nb); } static void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { @@ -266,7 +266,7 @@ static void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStre int min_grid_size, block_size = 1; // Initialize to suppress compiler warning. CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q5_0, 0, 0)); int grid_size = (nb + block_size - 1) / block_size; // Round up. - dequantize_block_q5_0<<>>(vx, y, k); + dequantize_block_q5_0<<>>(vx, y, nb); } static void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) { @@ -274,7 +274,7 @@ static void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStre int min_grid_size, block_size = 1; // Initialize to suppress compiler warning. CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q5_1, 0, 0)); int grid_size = (nb + block_size - 1) / block_size; // Round up. - dequantize_block_q5_1<<>>(vx, y, k); + dequantize_block_q5_1<<>>(vx, y, nb); } static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { @@ -282,7 +282,7 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStre int min_grid_size, block_size = 1; // Initialize to suppress compiler warning. CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q8_0, 0, 0)); int grid_size = (nb + block_size - 1) / block_size; // Round up. - dequantize_block_q8_0<<>>(vx, y, k); + dequantize_block_q8_0<<>>(vx, y, nb); } // TODO: optimize From d0199b3bc3acdd2ea2caae75992e64bd669ebec2 Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Mon, 8 May 2023 12:56:32 +0200 Subject: [PATCH 3/4] fixup! More GPU threads for CUDA kernels --- ggml-cuda.cu | 246 ++++++++++++++++++++++++++++++--------------------- 1 file changed, 145 insertions(+), 101 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 2f7f02162b7a9..1ce3c01eb9fa0 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -34,6 +34,8 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream); +#define GGML_CUDA_MAX_BLOCK_SIZE 256 + #define QK4_0 32 typedef struct { float d; // delta @@ -85,23 +87,25 @@ static __global__ void dequantize_block_q4_0(const void * vx, float * y, int k) const int i = blockIdx.x*blockDim.x + threadIdx.x; - if (i < k) { - const float d = x[i].d; + if (i >= k) { + return; + } + + const float d = x[i].d; - const uint8_t * pp = x[i].qs; + const uint8_t * pp = x[i].qs; - for (int l = 0; l < QK4_0; l += 2) { - const uint8_t vi = pp[l/2]; + for (int l = 0; l < QK4_0; l += 2) { + const uint8_t vi = pp[l/2]; - const int8_t vi0 = vi & 0xf; - const int8_t vi1 = vi >> 4; + const int8_t vi0 = vi & 0xf; + const int8_t vi1 = vi >> 4; - const float v0 = (vi0 - 8)*d; - const float v1 = (vi1 - 8)*d; + const float v0 = (vi0 - 8)*d; + const float v1 = (vi1 - 8)*d; - y[i*QK4_0 + l + 0] = v0; - y[i*QK4_0 + l + 1] = v1; - } + y[i*QK4_0 + l + 0] = v0; + y[i*QK4_0 + l + 1] = v1; } } @@ -110,24 +114,26 @@ static __global__ void dequantize_block_q4_1(const void * vx, float * y, int k) const int i = blockIdx.x*blockDim.x + threadIdx.x; - if (i < k) { - const float d = x[i].d; - const float m = x[i].m; + if (i >= k) { + return; + } - const uint8_t * pp = x[i].qs; + const float d = x[i].d; + const float m = x[i].m; - for (int l = 0; l < QK4_1; l += 2) { - const uint8_t vi = pp[l/2]; + const uint8_t * pp = x[i].qs; - const int8_t vi0 = vi & 0xf; - const int8_t vi1 = vi >> 4; + for (int l = 0; l < QK4_1; l += 2) { + const uint8_t vi = pp[l/2]; - const float v0 = vi0*d + m; - const float v1 = vi1*d + m; + const int8_t vi0 = vi & 0xf; + const int8_t vi1 = vi >> 4; - y[i*QK4_1 + l + 0] = v0; - y[i*QK4_1 + l + 1] = v1; - } + const float v0 = vi0*d + m; + const float v1 = vi1*d + m; + + y[i*QK4_1 + l + 0] = v0; + y[i*QK4_1 + l + 1] = v1; } } @@ -136,23 +142,25 @@ static __global__ void dequantize_block_q4_2(const void * vx, float * y, int k) const int i = blockIdx.x*blockDim.x + threadIdx.x; - if (i < k) { - const float d = x[i].d; + if (i >= k) { + return; + } - const uint8_t * pp = x[i].qs; + const float d = x[i].d; - for (int l = 0; l < QK4_2; l += 2) { - const uint8_t vi = pp[l/2]; + const uint8_t * pp = x[i].qs; - const int8_t vi0 = vi & 0xf; - const int8_t vi1 = vi >> 4; + for (int l = 0; l < QK4_2; l += 2) { + const uint8_t vi = pp[l/2]; - const float v0 = (vi0 - 8)*d; - const float v1 = (vi1 - 8)*d; + const int8_t vi0 = vi & 0xf; + const int8_t vi1 = vi >> 4; - y[i*QK4_2 + l + 0] = v0; - y[i*QK4_2 + l + 1] = v1; - } + const float v0 = (vi0 - 8)*d; + const float v1 = (vi1 - 8)*d; + + y[i*QK4_2 + l + 0] = v0; + y[i*QK4_2 + l + 1] = v1; } } @@ -161,29 +169,31 @@ static __global__ void dequantize_block_q5_0(const void * vx, float * y, int k) const int i = blockIdx.x*blockDim.x + threadIdx.x; - if (i < k) { - const float d = x[i].d; + if (i >= k) { + return; + } - const uint8_t * pp = x[i].qs; + const float d = x[i].d; - uint32_t qh; - memcpy(&qh, x[i].qh, sizeof(qh)); + const uint8_t * pp = x[i].qs; - for (int l = 0; l < QK5_0; l += 2) { - const uint8_t vi = pp[l/2]; + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); - const int8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; - const int8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; + for (int l = 0; l < QK5_0; l += 2) { + const uint8_t vi = pp[l/2]; - const int8_t vi0 = ((vi & 0xf) | vh0); - const int8_t vi1 = ((vi >> 4) | vh1); + const int8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; + const int8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; - const float v0 = (vi0 - 16)*d; - const float v1 = (vi1 - 16)*d; + const int8_t vi0 = ((vi & 0xf) | vh0); + const int8_t vi1 = ((vi >> 4) | vh1); - y[i*QK5_0 + l + 0] = v0; - y[i*QK5_0 + l + 1] = v1; - } + const float v0 = (vi0 - 16)*d; + const float v1 = (vi1 - 16)*d; + + y[i*QK5_0 + l + 0] = v0; + y[i*QK5_0 + l + 1] = v1; } } @@ -192,30 +202,32 @@ static __global__ void dequantize_block_q5_1(const void * vx, float * y, int k) const int i = blockIdx.x*blockDim.x + threadIdx.x; - if (i < k) { - const float d = x[i].d; - const float m = x[i].m; + if (i >= k) { + return; + } + + const float d = x[i].d; + const float m = x[i].m; - const uint8_t * pp = x[i].qs; + const uint8_t * pp = x[i].qs; - uint32_t qh; - memcpy(&qh, x[i].qh, sizeof(qh)); + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); - for (int l = 0; l < QK5_1; l += 2) { - const uint8_t vi = pp[l/2]; + for (int l = 0; l < QK5_1; l += 2) { + const uint8_t vi = pp[l/2]; - const int8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; - const int8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; + const int8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; + const int8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; - const int8_t vi0 = (vi & 0xf) | vh0; - const int8_t vi1 = (vi >> 4) | vh1; + const int8_t vi0 = (vi & 0xf) | vh0; + const int8_t vi1 = (vi >> 4) | vh1; - const float v0 = vi0*d + m; - const float v1 = vi1*d + m; + const float v0 = vi0*d + m; + const float v1 = vi1*d + m; - y[i*QK5_1 + l + 0] = v0; - y[i*QK5_1 + l + 1] = v1; - } + y[i*QK5_1 + l + 0] = v0; + y[i*QK5_1 + l + 1] = v1; } } @@ -224,64 +236,90 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y, int k) const int i = blockIdx.x*blockDim.x + threadIdx.x; - if (i < k) { - const float d = x[i].d; + if (i >= k) { + return; + } - const int8_t * pp = x[i].qs; + const float d = x[i].d; - for (int l = 0; l < QK8_0; l++) { - const int8_t vi = pp[l]; + const int8_t * pp = x[i].qs; - y[i*QK8_0 + l] = vi*d; - } + for (int l = 0; l < QK8_0; l++) { + const int8_t vi = pp[l]; + + y[i*QK8_0 + l] = vi*d; } } static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK4_0; - int min_grid_size, block_size = 1; // Initialize to suppress compiler warning. - CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q4_0, 0, 0)); - int grid_size = (nb + block_size - 1) / block_size; // Round up. + static int grid_size, block_size = -1; + if (block_size == -1) { + int min_grid_size; + CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q4_0, 0, 0)); + block_size = min(block_size, GGML_CUDA_MAX_BLOCK_SIZE); + grid_size = (nb + block_size - 1) / block_size; // Round up. + } dequantize_block_q4_0<<>>(vx, y, nb); } static void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK4_1; - int min_grid_size, block_size = 1; // Initialize to suppress compiler warning. - CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q4_1, 0, 0)); - int grid_size = (nb + block_size - 1) / block_size; // Round up. + static int grid_size, block_size = -1; + if (block_size == -1) { + int min_grid_size; + CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q4_1, 0, 0)); + block_size = min(block_size, GGML_CUDA_MAX_BLOCK_SIZE); + grid_size = (nb + block_size - 1) / block_size; // Round up. + } dequantize_block_q4_1<<>>(vx, y, nb); } static void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK4_2; - int min_grid_size, block_size = 1; // Initialize to suppress compiler warning. - CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q4_2, 0, 0)); - int grid_size = (nb + block_size - 1) / block_size; // Round up. + static int grid_size, block_size = -1; + if (block_size == -1) { + int min_grid_size; + CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q4_2, 0, 0)); + block_size = min(block_size, GGML_CUDA_MAX_BLOCK_SIZE); + grid_size = (nb + block_size - 1) / block_size; // Round up. + } dequantize_block_q4_2<<>>(vx, y, nb); } static void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK5_0; - int min_grid_size, block_size = 1; // Initialize to suppress compiler warning. - CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q5_0, 0, 0)); - int grid_size = (nb + block_size - 1) / block_size; // Round up. + static int grid_size, block_size = -1; + if (block_size == -1) { + int min_grid_size; + CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q5_0, 0, 0)); + block_size = min(block_size, GGML_CUDA_MAX_BLOCK_SIZE); + grid_size = (nb + block_size - 1) / block_size; // Round up. + } dequantize_block_q5_0<<>>(vx, y, nb); } static void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK5_1; - int min_grid_size, block_size = 1; // Initialize to suppress compiler warning. - CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q5_1, 0, 0)); - int grid_size = (nb + block_size - 1) / block_size; // Round up. + static int grid_size, block_size = -1; + if (block_size == -1) { + int min_grid_size; + CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q5_1, 0, 0)); + block_size = min(block_size, GGML_CUDA_MAX_BLOCK_SIZE); + grid_size = (nb + block_size - 1) / block_size; // Round up. + } dequantize_block_q5_1<<>>(vx, y, nb); } static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK8_0; - int min_grid_size, block_size = 1; // Initialize to suppress compiler warning. - CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q8_0, 0, 0)); - int grid_size = (nb + block_size - 1) / block_size; // Round up. + static int grid_size, block_size = -1; + if (block_size == -1) { + int min_grid_size; + CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q8_0, 0, 0)); + block_size = min(block_size, GGML_CUDA_MAX_BLOCK_SIZE); + grid_size = (nb + block_size - 1) / block_size; // Round up. + } dequantize_block_q8_0<<>>(vx, y, nb); } @@ -289,17 +327,23 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStre static __global__ void convert_fp16_to_fp32(const void * vx, float * y, int k) { const half * x = (const half *) vx; - const int i = blockIdx.x; + const int i = blockIdx.x*blockDim.x + threadIdx.x; - if (i < k) { - y[i] = __half2float(x[i]); + if (i >= k) { + return; } + + y[i] = __half2float(x[i]); } static void convert_fp16_to_fp32_cuda(const void * x, float * y, int k, cudaStream_t stream) { - int min_grid_size, block_size = 1; // Initialize to suppress compiler warning. - CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, convert_fp16_to_fp32, 0, 0)); - int grid_size = (k + block_size - 1) / block_size; // Round up. + static int grid_size, block_size = -1; + if (block_size == -1) { + int min_grid_size; + CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, convert_fp16_to_fp32, 0, 0)); + block_size = min(block_size, GGML_CUDA_MAX_BLOCK_SIZE); + grid_size = (k + block_size - 1) / block_size; // Round up. + } convert_fp16_to_fp32<<>>(x, y, k); } From 006db8e0bbb826e7096234503dbac0a0f8c5f930 Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Mon, 8 May 2023 19:45:02 +0200 Subject: [PATCH 4/4] fixup! More GPU threads for CUDA kernels --- ggml-cuda.cu | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 1ce3c01eb9fa0..0469852770e35 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -253,73 +253,73 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y, int k) static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK4_0; - static int grid_size, block_size = -1; + static int block_size = -1; if (block_size == -1) { int min_grid_size; CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q4_0, 0, 0)); block_size = min(block_size, GGML_CUDA_MAX_BLOCK_SIZE); - grid_size = (nb + block_size - 1) / block_size; // Round up. } + const int grid_size = (nb + block_size - 1) / block_size; // Round up. dequantize_block_q4_0<<>>(vx, y, nb); } static void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK4_1; - static int grid_size, block_size = -1; + static int block_size = -1; if (block_size == -1) { int min_grid_size; CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q4_1, 0, 0)); block_size = min(block_size, GGML_CUDA_MAX_BLOCK_SIZE); - grid_size = (nb + block_size - 1) / block_size; // Round up. } + const int grid_size = (nb + block_size - 1) / block_size; // Round up. dequantize_block_q4_1<<>>(vx, y, nb); } static void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK4_2; - static int grid_size, block_size = -1; + static int block_size = -1; if (block_size == -1) { int min_grid_size; CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q4_2, 0, 0)); block_size = min(block_size, GGML_CUDA_MAX_BLOCK_SIZE); - grid_size = (nb + block_size - 1) / block_size; // Round up. } + const int grid_size = (nb + block_size - 1) / block_size; // Round up. dequantize_block_q4_2<<>>(vx, y, nb); } static void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK5_0; - static int grid_size, block_size = -1; + static int block_size = -1; if (block_size == -1) { int min_grid_size; CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q5_0, 0, 0)); block_size = min(block_size, GGML_CUDA_MAX_BLOCK_SIZE); - grid_size = (nb + block_size - 1) / block_size; // Round up. } + const int grid_size = (nb + block_size - 1) / block_size; // Round up. dequantize_block_q5_0<<>>(vx, y, nb); } static void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK5_1; - static int grid_size, block_size = -1; + static int block_size = -1; if (block_size == -1) { int min_grid_size; CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q5_1, 0, 0)); block_size = min(block_size, GGML_CUDA_MAX_BLOCK_SIZE); - grid_size = (nb + block_size - 1) / block_size; // Round up. } + const int grid_size = (nb + block_size - 1) / block_size; // Round up. dequantize_block_q5_1<<>>(vx, y, nb); } static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK8_0; - static int grid_size, block_size = -1; + static int block_size = -1; if (block_size == -1) { int min_grid_size; CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q8_0, 0, 0)); block_size = min(block_size, GGML_CUDA_MAX_BLOCK_SIZE); - grid_size = (nb + block_size - 1) / block_size; // Round up. } + const int grid_size = (nb + block_size - 1) / block_size; // Round up. dequantize_block_q8_0<<>>(vx, y, nb); } @@ -337,13 +337,13 @@ static __global__ void convert_fp16_to_fp32(const void * vx, float * y, int k) { } static void convert_fp16_to_fp32_cuda(const void * x, float * y, int k, cudaStream_t stream) { - static int grid_size, block_size = -1; + static int block_size = -1; if (block_size == -1) { int min_grid_size; CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, convert_fp16_to_fp32, 0, 0)); block_size = min(block_size, GGML_CUDA_MAX_BLOCK_SIZE); - grid_size = (k + block_size - 1) / block_size; // Round up. } + const int grid_size = (k + block_size - 1) / block_size; // Round up. convert_fp16_to_fp32<<>>(x, y, k); }