Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 388ce82

Browse files
ggml : extend ggml_pool_1d + metal (ggml-org#16429)
* chore: resolve conflicts * feat: ggml metal impl * fix: ggml_metal_kargs_pool_1d struct * fix: require contiguous input * chore: test pool_1d * chore: limit pool1d test cases to p0=0 and s0=k0 to conform with asserts * chore: add p0 and s0 to testing * fix: allow padding for cpu and metal * Update ggml/src/ggml-metal/ggml-metal.metal * fix: correct single-threaded loop * ggml : cleanup * tests : add ne[1] != 1 tests * fix: ne[1] handling in np * cont : fixes --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 6ba6a3c commit 388ce82

File tree

10 files changed

+271
-39
lines changed

10 files changed

+271
-39
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 63 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77
#include "unary-ops.h"
88
#include "vec.h"
99

10-
#include <cfloat>
1110
#include <algorithm>
11+
#include <cfloat>
1212
#include <cmath>
13-
#include <functional>
1413

1514
// ggml_compute_forward_dup
1615

@@ -7110,12 +7109,13 @@ void ggml_compute_forward_conv_2d_dw(
71107109
}
71117110
}
71127111

7113-
// ggml_compute_forward_pool_1d_sk_p0
7114-
7115-
static void ggml_compute_forward_pool_1d_sk_p0(
7112+
// ggml_compute_forward_pool_1d_ksp
7113+
static void ggml_compute_forward_pool_1d_ksp(
71167114
const ggml_compute_params * params,
71177115
const ggml_op_pool op,
71187116
const int k,
7117+
const int s,
7118+
const int p,
71197119
ggml_tensor * dst) {
71207120

71217121
const ggml_tensor * src = dst->src[0];
@@ -7126,39 +7126,56 @@ static void ggml_compute_forward_pool_1d_sk_p0(
71267126
return;
71277127
}
71287128

7129-
const char * cdata = (const char *)src->data;
7130-
const char * const data_end = cdata + ggml_nbytes(src);
7131-
float * drow = (float *)dst->data;
7129+
const int64_t IW = src->ne[0];
7130+
const int64_t OW = dst->ne[0];
71327131

7133-
const int64_t rs = dst->ne[0];
7132+
const int64_t nr = ggml_nrows(src);
71347133

7135-
while (cdata < data_end) {
7136-
const void * srow = (const void *)cdata;
7137-
int j = 0;
7138-
for (int64_t i = 0; i < rs; ++i) {
7134+
for (int64_t ir = 0; ir < nr; ++ir) {
7135+
const char * srow_bytes = (const char *) src->data + ir * src->nb[1];
7136+
float * drow = (float *) (( char *) dst->data + ir * dst->nb[1]);
7137+
7138+
for (int64_t ow = 0; ow < OW; ++ow) {
7139+
float res = 0;
71397140
switch (op) {
7140-
case GGML_OP_POOL_AVG: drow[i] = 0; break;
7141-
case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break;
7141+
case GGML_OP_POOL_AVG: res = 0.0f; break;
7142+
case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
71427143
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
71437144
}
7145+
7146+
int count = 0;
7147+
const int base = (int) ow * s - p;
7148+
71447149
for (int ki = 0; ki < k; ++ki) {
7145-
const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
7150+
const int j = base + ki;
7151+
if (j < 0 || j >= (int) IW) {
7152+
continue;
7153+
}
7154+
7155+
float v;
7156+
if (src->type == GGML_TYPE_F32) {
7157+
v = ((const float *) srow_bytes)[j];
7158+
} else {
7159+
v = GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) srow_bytes)[j]);
7160+
}
7161+
71467162
switch (op) {
7147-
case GGML_OP_POOL_AVG: drow[i] += srow_j; break;
7148-
case GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j; break;
7149-
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7163+
case GGML_OP_POOL_AVG: res += v; break;
7164+
case GGML_OP_POOL_MAX: res = std::max(v, res); break;
7165+
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
71507166
}
7151-
++j;
7167+
7168+
++count;
71527169
}
7170+
71537171
switch (op) {
7154-
case GGML_OP_POOL_AVG: drow[i] /= k; break;
7155-
case GGML_OP_POOL_MAX: break;
7172+
case GGML_OP_POOL_AVG: res = (count > 0) ? (res / count) : 0.0f; break;
7173+
case GGML_OP_POOL_MAX: break;
71567174
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
71577175
}
7158-
}
71597176

7160-
cdata += src->nb[1];
7161-
drow += rs;
7177+
drow[ow] = res;
7178+
}
71627179
}
71637180
}
71647181

@@ -7173,10 +7190,8 @@ void ggml_compute_forward_pool_1d(
71737190
const int k0 = opts[1];
71747191
const int s0 = opts[2];
71757192
const int p0 = opts[3];
7176-
GGML_ASSERT(p0 == 0); // padding not supported
7177-
GGML_ASSERT(k0 == s0); // only s = k supported
71787193

7179-
ggml_compute_forward_pool_1d_sk_p0(params, op, k0, dst);
7194+
ggml_compute_forward_pool_1d_ksp(params, op, k0, s0, p0, dst);
71807195
}
71817196

71827197
// ggml_compute_forward_pool_2d
@@ -7194,6 +7209,7 @@ void ggml_compute_forward_pool_2d(
71947209
}
71957210

71967211
const int32_t * opts = (const int32_t *)dst->op_params;
7212+
71977213
ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
71987214
const int k0 = opts[1];
71997215
const int k1 = opts[2];
@@ -7217,36 +7233,46 @@ void ggml_compute_forward_pool_2d(
72177233
while (cdata < data_end) {
72187234
for (int oy = 0; oy < py; ++oy) {
72197235
float * const drow = dplane + oy * px;
7236+
float * const out = drow;
7237+
72207238
for (int ox = 0; ox < px; ++ox) {
7221-
float * const out = drow + ox;
7239+
float res = 0;
72227240
switch (op) {
7223-
case GGML_OP_POOL_AVG: *out = 0; break;
7224-
case GGML_OP_POOL_MAX: *out = -FLT_MAX; break;
7241+
case GGML_OP_POOL_AVG: res = 0; break;
7242+
case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
72257243
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
72267244
}
72277245

72287246
const int ix = offset0 + ox * s0;
72297247
const int iy = offset1 + oy * s1;
72307248

72317249
for (int ky = 0; ky < k1; ++ky) {
7232-
if (iy + ky < 0 || iy + ky >= src->ne[1]) continue;
7250+
if (iy + ky < 0 || iy + ky >= src->ne[1]) {
7251+
continue;
7252+
}
7253+
72337254
const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky));
72347255
for (int kx = 0; kx < k0; ++kx) {
72357256
int j = ix + kx;
7236-
if (j < 0 || j >= src->ne[0]) continue;
7257+
if (j < 0 || j >= src->ne[0]) {
7258+
continue;
7259+
}
7260+
72377261
const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
72387262
switch (op) {
7239-
case GGML_OP_POOL_AVG: *out += srow_j; break;
7240-
case GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j; break;
7263+
case GGML_OP_POOL_AVG: res += srow_j; break;
7264+
case GGML_OP_POOL_MAX: res = std::max(srow_j, res); break;
72417265
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
72427266
}
72437267
}
72447268
}
72457269
switch (op) {
7246-
case GGML_OP_POOL_AVG: *out /= ka; break;
7247-
case GGML_OP_POOL_MAX: break;
7270+
case GGML_OP_POOL_AVG: res /= ka; break;
7271+
case GGML_OP_POOL_MAX: break;
72487272
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
72497273
}
7274+
7275+
out[ox] = res;
72507276
}
72517277
}
72527278

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,31 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy(ggml_metal_l
9494
return res;
9595
}
9696

97+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) {
98+
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
99+
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type);
100+
101+
const char * pool_str = "undefined";
102+
switch (op_pool) {
103+
case GGML_OP_POOL_AVG: pool_str = "avg"; break;
104+
case GGML_OP_POOL_MAX: pool_str = "max"; break;
105+
default: GGML_ASSERT(false && "not implemented");
106+
};
107+
108+
char base[256];
109+
char name[256];
110+
111+
snprintf(base, sizeof(base), "kernel_pool_1d_%s_%s", pool_str, ggml_type_name(op->src[0]->type));
112+
snprintf(name, sizeof(name), "%s", base);
113+
114+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
115+
if (!res.pipeline) {
116+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
117+
}
118+
119+
return res;
120+
}
121+
97122
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) {
98123
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
99124
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type);

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_
104104

105105
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op);
106106
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst);
107+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
107108
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
108109
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc);
109110
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,10 +1044,10 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
10441044
op->src[1]->type == GGML_TYPE_F32 &&
10451045
op->type == GGML_TYPE_F32 &&
10461046
(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
1047-
case GGML_OP_POOL_1D:
1048-
return false;
10491047
case GGML_OP_UPSCALE:
10501048
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
1049+
case GGML_OP_POOL_1D:
1050+
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
10511051
case GGML_OP_POOL_2D:
10521052
return op->src[0]->type == GGML_TYPE_F32;
10531053
case GGML_OP_PAD:

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -928,6 +928,15 @@ typedef struct {
928928
int64_t np;
929929
} ggml_metal_kargs_pool_2d;
930930

931+
typedef struct {
932+
int32_t k0;
933+
int32_t s0;
934+
int32_t p0;
935+
int64_t IW;
936+
int64_t OW;
937+
int64_t np;
938+
} ggml_metal_kargs_pool_1d;
939+
931940
typedef struct {
932941
int64_t ne00;
933942
uint64_t nb01;

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
432432
{
433433
n_fuse = ggml_metal_op_cpy(ctx, idx);
434434
} break;
435+
case GGML_OP_POOL_1D:
436+
{
437+
n_fuse = ggml_metal_op_pool_1d(ctx, idx);
438+
} break;
435439
case GGML_OP_POOL_2D:
436440
{
437441
n_fuse = ggml_metal_op_pool_2d(ctx, idx);
@@ -1622,6 +1626,54 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
16221626
return 1;
16231627
}
16241628

1629+
int ggml_metal_op_pool_1d(ggml_metal_op_t ctx, int idx) {
1630+
ggml_tensor * op = ctx->node(idx);
1631+
1632+
ggml_metal_library_t lib = ctx->lib;
1633+
ggml_metal_encoder_t enc = ctx->enc;
1634+
1635+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1636+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1637+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1638+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1639+
1640+
const int32_t * opts = op->op_params;
1641+
ggml_op_pool op_pool = (ggml_op_pool) opts[0];
1642+
1643+
const int32_t k0 = opts[1];
1644+
const int32_t s0 = opts[2];
1645+
const int32_t p0 = opts[3];
1646+
1647+
const int64_t IW = op->src[0]->ne[0];
1648+
const int64_t OW = op->ne[0];
1649+
1650+
const int64_t np = ggml_nelements(op);
1651+
1652+
ggml_metal_kargs_pool_1d args_pool_1d = {
1653+
/* .k0 = */ k0,
1654+
/* .s0 = */ s0,
1655+
/* .p0 = */ p0,
1656+
/* .IW = */ IW,
1657+
/* .OW = */ OW,
1658+
/* .np = */ np
1659+
};
1660+
1661+
auto pipeline = ggml_metal_library_get_pipeline_pool_1d(lib, op, op_pool);
1662+
1663+
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);
1664+
const int ntg = (np + nth - 1) / nth;
1665+
1666+
ggml_metal_encoder_set_pipeline(enc, pipeline);
1667+
ggml_metal_encoder_set_bytes (enc, &args_pool_1d, sizeof(args_pool_1d), 0);
1668+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1669+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
1670+
1671+
ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1);
1672+
1673+
return 1;
1674+
}
1675+
1676+
16251677
int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
16261678
ggml_tensor * op = ctx->node(idx);
16271679

ggml/src/ggml-metal/ggml-metal-ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx);
6161
int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx);
6262
int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx);
6363
int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx);
64+
int ggml_metal_op_pool_1d (ggml_metal_op_t ctx, int idx);
6465
int ggml_metal_op_pool_2d (ggml_metal_op_t ctx, int idx);
6566
int ggml_metal_op_mul_mat (ggml_metal_op_t ctx, int idx);
6667
int ggml_metal_op_mul_mat_id (ggml_metal_op_t ctx, int idx);

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9869,6 +9869,74 @@ kernel void kernel_pool_2d_avg_f32(
98699869
o_ptr[cur_oh * args.OW + cur_ow] = res;
98709870
}
98719871

9872+
9873+
kernel void kernel_pool_1d_max_f32(
9874+
constant ggml_metal_kargs_pool_1d & args,
9875+
device const float * src,
9876+
device float * dst,
9877+
uint gid [[thread_position_in_grid]]
9878+
) {
9879+
9880+
if (gid >= args.np) {
9881+
return;
9882+
}
9883+
9884+
const int ow = (int)gid % args.OW;
9885+
const int row = (int)gid / args.OW;
9886+
9887+
const int base = ow * args.s0 - args.p0;
9888+
9889+
float acc = -INFINITY;
9890+
9891+
const int src_off = row * args.IW;
9892+
const int dst_off = row * args.OW;
9893+
9894+
for (int ki = 0; ki < args.k0; ++ki) {
9895+
int j = base + ki;
9896+
if (j < 0 || j >= args.IW){
9897+
continue;
9898+
}
9899+
float v = src[src_off + j];
9900+
acc = max(acc, v);
9901+
}
9902+
9903+
dst[dst_off + ow] = acc;
9904+
}
9905+
9906+
kernel void kernel_pool_1d_avg_f32(
9907+
constant ggml_metal_kargs_pool_1d & args,
9908+
device const float * src,
9909+
device float * dst,
9910+
uint gid [[thread_position_in_grid]]
9911+
) {
9912+
9913+
if (gid >= args.np) {
9914+
return;
9915+
}
9916+
9917+
const int ow = (int)gid % args.OW;
9918+
const int row = (int)gid / args.OW;
9919+
9920+
const int base = ow * args.s0 - args.p0;
9921+
9922+
float acc = 0.0f;
9923+
int cnt = 0;
9924+
9925+
const int src_off = row * args.IW;
9926+
const int dst_off = row * args.OW;
9927+
9928+
for (int ki = 0; ki < args.k0; ++ki) {
9929+
const int j = base + ki;
9930+
if (j < 0 || j >= args.IW) {
9931+
continue;
9932+
}
9933+
acc += src[src_off + j];
9934+
cnt += 1;
9935+
}
9936+
9937+
dst[dst_off + ow] = (cnt > 0) ? (acc / (float)cnt) : 0.0f;
9938+
}
9939+
98729940
kernel void kernel_opt_step_adamw_f32(
98739941
constant ggml_metal_kargs_opt_step_adamw & args,
98749942
device float * x,

0 commit comments

Comments
 (0)