77#include " unary-ops.h"
88#include " vec.h"
99
10- #include < cfloat>
1110#include < algorithm>
11+ #include < cfloat>
1212#include < cmath>
13- #include < functional>
1413
1514// ggml_compute_forward_dup
1615
@@ -7110,12 +7109,13 @@ void ggml_compute_forward_conv_2d_dw(
71107109 }
71117110}
71127111
7113- // ggml_compute_forward_pool_1d_sk_p0
7114-
7115- static void ggml_compute_forward_pool_1d_sk_p0 (
7112+ // ggml_compute_forward_pool_1d_ksp
7113+ static void ggml_compute_forward_pool_1d_ksp (
71167114 const ggml_compute_params * params,
71177115 const ggml_op_pool op,
71187116 const int k,
7117+ const int s,
7118+ const int p,
71197119 ggml_tensor * dst) {
71207120
71217121 const ggml_tensor * src = dst->src [0 ];
@@ -7126,39 +7126,56 @@ static void ggml_compute_forward_pool_1d_sk_p0(
71267126 return ;
71277127 }
71287128
7129- const char * cdata = (const char *)src->data ;
7130- const char * const data_end = cdata + ggml_nbytes (src);
7131- float * drow = (float *)dst->data ;
7129+ const int64_t IW = src->ne [0 ];
7130+ const int64_t OW = dst->ne [0 ];
71327131
7133- const int64_t rs = dst-> ne [ 0 ] ;
7132+ const int64_t nr = ggml_nrows (src) ;
71347133
7135- while (cdata < data_end) {
7136- const void * srow = (const void *)cdata;
7137- int j = 0 ;
7138- for (int64_t i = 0 ; i < rs; ++i) {
7134+ for (int64_t ir = 0 ; ir < nr; ++ir) {
7135+ const char * srow_bytes = (const char *) src->data + ir * src->nb [1 ];
7136+ float * drow = (float *) (( char *) dst->data + ir * dst->nb [1 ]);
7137+
7138+ for (int64_t ow = 0 ; ow < OW; ++ow) {
7139+ float res = 0 ;
71397140 switch (op) {
7140- case GGML_OP_POOL_AVG: drow[i] = 0 ; break ;
7141- case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break ;
7141+ case GGML_OP_POOL_AVG: res = 0 . 0f ; break ;
7142+ case GGML_OP_POOL_MAX: res = -FLT_MAX; break ;
71427143 case GGML_OP_POOL_COUNT: GGML_ABORT (" fatal error" );
71437144 }
7145+
7146+ int count = 0 ;
7147+ const int base = (int ) ow * s - p;
7148+
71447149 for (int ki = 0 ; ki < k; ++ki) {
7145- const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float *)srow)[j] : GGML_CPU_FP16_TO_FP32 (((const ggml_fp16_t *)srow)[j]);
7150+ const int j = base + ki;
7151+ if (j < 0 || j >= (int ) IW) {
7152+ continue ;
7153+ }
7154+
7155+ float v;
7156+ if (src->type == GGML_TYPE_F32) {
7157+ v = ((const float *) srow_bytes)[j];
7158+ } else {
7159+ v = GGML_CPU_FP16_TO_FP32 (((const ggml_fp16_t *) srow_bytes)[j]);
7160+ }
7161+
71467162 switch (op) {
7147- case GGML_OP_POOL_AVG: drow[i] += srow_j; break ;
7148- case GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j ; break ;
7149- case GGML_OP_POOL_COUNT: GGML_ABORT (" fatal error" );
7163+ case GGML_OP_POOL_AVG: res += v; break ;
7164+ case GGML_OP_POOL_MAX: res = std::max (v, res) ; break ;
7165+ case GGML_OP_POOL_COUNT: GGML_ABORT (" fatal error" );
71507166 }
7151- ++j;
7167+
7168+ ++count;
71527169 }
7170+
71537171 switch (op) {
7154- case GGML_OP_POOL_AVG: drow[i] /= k ; break ;
7155- case GGML_OP_POOL_MAX: break ;
7172+ case GGML_OP_POOL_AVG: res = (count > 0 ) ? (res / count) : 0 . 0f ; break ;
7173+ case GGML_OP_POOL_MAX: break ;
71567174 case GGML_OP_POOL_COUNT: GGML_ABORT (" fatal error" );
71577175 }
7158- }
71597176
7160- cdata += src-> nb [ 1 ] ;
7161- drow += rs;
7177+ drow[ow] = res ;
7178+ }
71627179 }
71637180}
71647181
@@ -7173,10 +7190,8 @@ void ggml_compute_forward_pool_1d(
71737190 const int k0 = opts[1 ];
71747191 const int s0 = opts[2 ];
71757192 const int p0 = opts[3 ];
7176- GGML_ASSERT (p0 == 0 ); // padding not supported
7177- GGML_ASSERT (k0 == s0); // only s = k supported
71787193
7179- ggml_compute_forward_pool_1d_sk_p0 (params, op, k0, dst);
7194+ ggml_compute_forward_pool_1d_ksp (params, op, k0, s0, p0 , dst);
71807195}
71817196
71827197// ggml_compute_forward_pool_2d
@@ -7194,6 +7209,7 @@ void ggml_compute_forward_pool_2d(
71947209 }
71957210
71967211 const int32_t * opts = (const int32_t *)dst->op_params ;
7212+
71977213 ggml_op_pool op = static_cast <ggml_op_pool>(opts[0 ]);
71987214 const int k0 = opts[1 ];
71997215 const int k1 = opts[2 ];
@@ -7217,36 +7233,46 @@ void ggml_compute_forward_pool_2d(
72177233 while (cdata < data_end) {
72187234 for (int oy = 0 ; oy < py; ++oy) {
72197235 float * const drow = dplane + oy * px;
7236+ float * const out = drow;
7237+
72207238 for (int ox = 0 ; ox < px; ++ox) {
7221- float * const out = drow + ox ;
7239+ float res = 0 ;
72227240 switch (op) {
7223- case GGML_OP_POOL_AVG: *out = 0 ; break ;
7224- case GGML_OP_POOL_MAX: *out = -FLT_MAX; break ;
7241+ case GGML_OP_POOL_AVG: res = 0 ; break ;
7242+ case GGML_OP_POOL_MAX: res = -FLT_MAX; break ;
72257243 case GGML_OP_POOL_COUNT: GGML_ABORT (" fatal error" );
72267244 }
72277245
72287246 const int ix = offset0 + ox * s0;
72297247 const int iy = offset1 + oy * s1;
72307248
72317249 for (int ky = 0 ; ky < k1; ++ky) {
7232- if (iy + ky < 0 || iy + ky >= src->ne [1 ]) continue ;
7250+ if (iy + ky < 0 || iy + ky >= src->ne [1 ]) {
7251+ continue ;
7252+ }
7253+
72337254 const void * srow = (const void *)(cdata + src->nb [1 ] * (iy + ky));
72347255 for (int kx = 0 ; kx < k0; ++kx) {
72357256 int j = ix + kx;
7236- if (j < 0 || j >= src->ne [0 ]) continue ;
7257+ if (j < 0 || j >= src->ne [0 ]) {
7258+ continue ;
7259+ }
7260+
72377261 const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float *)srow)[j] : GGML_CPU_FP16_TO_FP32 (((const ggml_fp16_t *)srow)[j]);
72387262 switch (op) {
7239- case GGML_OP_POOL_AVG: *out += srow_j; break ;
7240- case GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j ; break ;
7263+ case GGML_OP_POOL_AVG: res += srow_j; break ;
7264+ case GGML_OP_POOL_MAX: res = std::max (srow_j, res) ; break ;
72417265 case GGML_OP_POOL_COUNT: GGML_ABORT (" fatal error" );
72427266 }
72437267 }
72447268 }
72457269 switch (op) {
7246- case GGML_OP_POOL_AVG: *out /= ka; break ;
7247- case GGML_OP_POOL_MAX: break ;
7270+ case GGML_OP_POOL_AVG: res /= ka; break ;
7271+ case GGML_OP_POOL_MAX: break ;
72487272 case GGML_OP_POOL_COUNT: GGML_ABORT (" fatal error" );
72497273 }
7274+
7275+ out[ox] = res;
72507276 }
72517277 }
72527278
0 commit comments