Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 9307d6f

Browse files
committed
Fixes to mean, on all platforms
1 parent 82ca3b3 commit 9307d6f

File tree

8 files changed

+766
-222
lines changed

8 files changed

+766
-222
lines changed

src/backend/cpu/kernel/mean.hpp

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,45 @@ struct MeanOp {
3838
}
3939
};
4040

41+
template<typename Ti, typename To, typename Tw>
42+
struct MeanOpWithCorrection {
43+
common::Transform<Ti, To, af_add_t> transform;
44+
To runningMean;
45+
Tw runningCount;
46+
To correction;
47+
MeanOpWithCorrection(Ti mean, Tw count)
48+
: transform()
49+
, runningMean(transform(mean))
50+
, runningCount(count)
51+
, correction(scalar<To>(0)) {}
52+
53+
void operator()(Ti newMean, Tw newCount) {
54+
runningCount += newCount;
55+
if ((newCount != 0) || (runningCount != 0)) { //
56+
// Since only 1 pass is used, the rounding errors will become
57+
// important because the longer the serie the larger the
58+
// difference between the 2 numbers to sum See:
59+
// https://en.wikipedia.org/wiki/Kahan_summation_algorithm to
60+
// reduce this error
61+
// correction is zero for the first time around
62+
To y =
63+
(transform(newMean) - runningMean) * newCount / runningCount -
64+
correction;
65+
// Alas, runningMean is big, y small, so low-order digits of y
66+
// are lost
67+
To t = runningMean + y;
68+
// (t - runningMean) cancels the high order part of y
69+
// subtracting y recovers negative (low part of y)
70+
correction = (t - runningMean) - y;
71+
// Algebraically, correction should always be zero. Beware
72+
// overly-agressive optimizing compilers!
73+
runningMean = t;
74+
// Next time around, the lost low part will be added to y in a
75+
// fresh attempt
76+
}
77+
}
78+
};
79+
4180
template<typename T, typename Tw, int D>
4281
struct mean_weighted_dim {
4382
void operator()(Param<T> output, const dim_t outOffset,
@@ -74,7 +113,8 @@ struct mean_weighted_dim<T, Tw, 0> {
74113

75114
dim_t istride = istrides[dim];
76115
dim_t wstride = wstrides[dim];
77-
MeanOp<compute_t<T>, compute_t<T>, compute_t<Tw>> Op(0, 0);
116+
MeanOpWithCorrection<compute_t<T>, compute_t<T>, compute_t<Tw>> Op(0,
117+
0);
78118
for (dim_t i = 0; i < idims[dim]; i++) {
79119
Op(compute_t<T>(in[inOffset + i * istride]),
80120
compute_t<Tw>(wt[wtOffset + i * wstride]));
@@ -113,7 +153,8 @@ struct mean_dim<Ti, Tw, To, 0> {
113153

114154
dim_t istride = istrides[dim];
115155
dim_t end = inOffset + idims[dim] * istride;
116-
MeanOp<compute_t<Ti>, compute_t<To>, compute_t<Tw>> Op(0, 0);
156+
MeanOpWithCorrection<compute_t<Ti>, compute_t<To>, compute_t<Tw>> Op(0,
157+
0);
117158
for (dim_t i = inOffset; i < end; i += istride) {
118159
Op(compute_t<Ti>(in[i]), 1);
119160
}

src/backend/cpu/mean.cpp

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -63,32 +63,35 @@ Array<T> mean(const Array<T> &in, const Array<Tw> &wt, const int dim) {
6363

6464
template<typename T, typename Tw>
6565
T mean(const Array<T> &in, const Array<Tw> &wt) {
66-
using MeanOpT = kernel::MeanOp<compute_t<T>, compute_t<T>, compute_t<Tw>>;
66+
using MeanOpT =
67+
kernel::MeanOpWithCorrection<compute_t<T>, compute_t<T>, compute_t<Tw>>;
6768
in.eval();
6869
wt.eval();
6970
getQueue().sync();
7071

71-
af::dim4 dims = in.dims();
72-
af::dim4 strides = in.strides();
73-
const T *inPtr = in.get();
74-
const Tw *wtPtr = wt.get();
75-
76-
auto input = compute_t<T>(inPtr[0]);
77-
auto weight = compute_t<Tw>(wtPtr[0]);
78-
MeanOpT Op(input, weight);
72+
const af::dim4 &dims = in.dims();
73+
const af::dim4 &istrides = in.strides();
74+
const af::dim4 &wstrides = wt.strides();
75+
const T *inPtr = in.get();
76+
const Tw *wtPtr = wt.get();
7977

78+
// Split workload in equal parts, to improve accuracy or larger arrays
79+
MeanOpT Op(0, 0);
8080
for (dim_t l = 0; l < dims[3]; l++) {
81-
dim_t off3 = l * strides[3];
81+
dim_t ioff3 = l * istrides[3];
82+
dim_t woff3 = l * wstrides[3];
8283

8384
for (dim_t k = 0; k < dims[2]; k++) {
84-
dim_t off2 = k * strides[2];
85+
dim_t ioff2 = k * istrides[2];
86+
dim_t woff2 = k * wstrides[2];
8587

8688
for (dim_t j = 0; j < dims[1]; j++) {
87-
dim_t off1 = j * strides[1];
89+
dim_t ioff1 = j * istrides[1];
90+
dim_t woff1 = j * wstrides[1];
8891

8992
for (dim_t i = 0; i < dims[0]; i++) {
90-
dim_t idx = i + off1 + off2 + off3;
91-
Op(compute_t<T>(inPtr[idx]), compute_t<Tw>(wtPtr[idx]));
93+
Op(compute_t<T>(inPtr[i + ioff1 + ioff2 + ioff3]),
94+
compute_t<Tw>(wtPtr[i + woff1 + woff2 + woff3]));
9295
}
9396
}
9497
}
@@ -99,16 +102,16 @@ T mean(const Array<T> &in, const Array<Tw> &wt) {
99102

100103
template<typename Ti, typename Tw, typename To>
101104
To mean(const Array<Ti> &in) {
102-
using MeanOpT = kernel::MeanOp<compute_t<Ti>, compute_t<To>, compute_t<Tw>>;
105+
using MeanOpT = kernel::MeanOpWithCorrection<compute_t<Ti>, compute_t<To>,
106+
compute_t<Tw>>;
103107
in.eval();
104108
getQueue().sync();
105109

106-
af::dim4 dims = in.dims();
107-
af::dim4 strides = in.strides();
108-
const Ti *inPtr = in.get();
110+
const af::dim4 &dims = in.dims();
111+
const af::dim4 &strides = in.strides();
112+
const Ti *inPtr = in.get();
109113

110114
MeanOpT Op(0, 0);
111-
112115
for (dim_t l = 0; l < dims[3]; l++) {
113116
dim_t off3 = l * strides[3];
114117

@@ -120,7 +123,7 @@ To mean(const Array<Ti> &in) {
120123

121124
for (dim_t i = 0; i < dims[0]; i++) {
122125
dim_t idx = i + off1 + off2 + off3;
123-
Op(compute_t<Ti>(inPtr[idx]), 1);
126+
Op(compute_t<Ti>(inPtr[idx]), 1.);
124127
}
125128
}
126129
}

src/backend/cuda/kernel/mean.hpp

Lines changed: 79 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,43 @@ namespace kernel {
3939

4040
template<typename To, typename Tw>
4141
__device__ __host__ void stable_mean(To *lhs, Tw *l_wt, To rhs, Tw r_wt) {
42+
Tw l_scale = (*l_wt);
43+
(*l_wt) += r_wt;
4244
if (((*l_wt) != (Tw)0) || (r_wt != (Tw)0)) {
43-
Tw l_scale = (*l_wt);
44-
(*l_wt) += r_wt;
4545
l_scale = l_scale / (*l_wt);
4646

4747
Tw r_scale = r_wt / (*l_wt);
4848
(*lhs) = (l_scale * *lhs) + (r_scale * rhs);
4949
}
5050
}
5151

52+
template<typename To, typename Tw>
53+
__device__ __host__ void stable_mean(To *c, To *lhs, Tw *l_wt, To rhs,
54+
Tw r_wt) {
55+
(*l_wt) += r_wt;
56+
if (((*l_wt) != (Tw)0) || (r_wt != (Tw)0)) {
57+
// Since only 1 pass is used, the rounding errors will become important
58+
// because the longer the serie the larger the difference between the 2
59+
// numbers to sum See:
60+
// https://en.wikipedia.org/wiki/Kahan_summation_algorithm to reduce
61+
// this error
62+
// (*lhs) = (*lhs) + (rhs - (*lhs)) * r_wt / (*l_wt);
63+
64+
// *c is zero for the first time around
65+
To y = (rhs - (*lhs)) * r_wt / (*l_wt) - (*c);
66+
// Alas, (*lhs) is big, y small, so low-order digits of y are lost
67+
To t = (*lhs) + y;
68+
// (t - (*lhs)) cancels the high-order part of y
69+
// subtracting y recovers negative (low part of y)
70+
(*c) = (t - (*lhs)) - y;
71+
// Algebraically, *c should always be zero. Beware overly-agressive
72+
// optimizing compilers!
73+
(*lhs) = t;
74+
// Next time around, the lost low part will be added to y in a fresh
75+
// attempt
76+
}
77+
}
78+
5279
template<typename Ti, typename Tw, typename To, uint dim, uint DIMY>
5380
__global__ static void mean_dim_kernel(Param<To> out, Param<Tw> owt,
5481
CParam<Ti> in, CParam<Tw> iwt,
@@ -74,24 +101,32 @@ __global__ static void mean_dim_kernel(Param<To> out, Param<Tw> owt,
74101

75102
int ooffset = ids[3] * out.strides[3] + ids[2] * out.strides[2] +
76103
ids[1] * out.strides[1] + ids[0];
104+
optr += ooffset;
105+
if (owptr != NULL) {
106+
int owoffset = ids[3] * owt.strides[3] + ids[2] * owt.strides[2] +
107+
ids[1] * owt.strides[1] + ids[0];
108+
owptr += owoffset;
109+
}
110+
77111
// There is only one element per block for out
78112
// There are blockDim.y elements per block for in
79113
// Hence increment ids[dim] just after offseting out and before offsetting
80114
// in
81-
optr += ooffset;
82-
if (owptr != NULL) owptr += ooffset;
83-
84115
const uint blockIdx_dim = ids[dim];
85-
86-
ids[dim] = ids[dim] * blockDim.y + tidy;
116+
ids[dim] = ids[dim] * blockDim.y + tidy;
87117

88118
int ioffset = ids[3] * in.strides[3] + ids[2] * in.strides[2] +
89119
ids[1] * in.strides[1] + ids[0];
90120
iptr += ioffset;
91-
if (iwptr != NULL) iwptr += ioffset;
121+
if (iwptr != NULL) {
122+
int iwoffset = ids[3] * iwt.strides[3] + ids[2] * iwt.strides[2] +
123+
ids[1] * iwt.strides[1] + ids[0];
124+
iwptr += iwoffset;
125+
}
92126

93-
const uint id_dim_in = ids[dim];
94-
const uint istride_dim = in.strides[dim];
127+
const uint id_dim_in = ids[dim];
128+
const uint istride_dim = in.strides[dim];
129+
const uint iwstride_dim = iwt.strides[dim];
95130

96131
bool is_valid = (ids[0] < in.dims[0]) && (ids[1] < in.dims[1]) &&
97132
(ids[2] < in.dims[2]) && (ids[3] < in.dims[3]);
@@ -106,7 +141,7 @@ __global__ static void mean_dim_kernel(Param<To> out, Param<Tw> owt,
106141
if (iwptr != NULL) {
107142
weight = *iwptr;
108143
} else {
109-
weight = (Tw)1;
144+
weight = compute_t<Tw>(1);
110145
}
111146
}
112147

@@ -115,16 +150,17 @@ __global__ static void mean_dim_kernel(Param<To> out, Param<Tw> owt,
115150
__shared__ compute_t<To> s_val[THREADS_X * DIMY];
116151
__shared__ compute_t<Tw> s_idx[THREADS_X * DIMY];
117152

153+
compute_t<To> correction = scalar<compute_t<To>>(0);
118154
for (int id = id_dim_in_start; is_valid && (id < in.dims[dim]);
119155
id += offset_dim * blockDim.y) {
120156
iptr = iptr + offset_dim * blockDim.y * istride_dim;
121157
if (iwptr != NULL) {
122-
iwptr = iwptr + offset_dim * blockDim.y * istride_dim;
123-
stable_mean(&val, &weight, transform(*iptr), compute_t<Tw>(*iwptr));
158+
iwptr = iwptr + offset_dim * blockDim.y * iwstride_dim;
159+
stable_mean(&correction, &val, &weight, transform(*iptr),
160+
compute_t<Tw>(*iwptr));
124161
} else {
125-
// Faster version of stable_mean when iwptr is NULL
126-
val = val + (transform(*iptr) - val) / (weight + (Tw)1);
127-
weight = weight + (Tw)1;
162+
stable_mean(&correction, &val, &weight, transform(*iptr),
163+
compute_t<Tw>(1));
128164
}
129165
}
130166

@@ -299,16 +335,16 @@ __global__ static void mean_first_kernel(Param<To> out, Param<Tw> owt,
299335
__shared__ compute_t<To> s_val[THREADS_PER_BLOCK];
300336
__shared__ compute_t<Tw> s_idx[THREADS_PER_BLOCK];
301337

338+
compute_t<To> correction = scalar<compute_t<To>>(0);
302339
if (iwptr != NULL) {
303340
for (int id = xid + DIMX; id < lim; id += DIMX) {
304-
stable_mean(&val, &weight, transform(iptr[id]),
341+
stable_mean(&correction, &val, &weight, transform(iptr[id]),
305342
compute_t<Tw>(iwptr[id]));
306343
}
307344
} else {
308345
for (int id = xid + DIMX; id < lim; id += DIMX) {
309-
// Faster version of stable_mean when iwptr is NULL
310-
val = val + (transform(iptr[id]) - val) / (weight + (Tw)1);
311-
weight = weight + (Tw)1;
346+
stable_mean(&correction, &val, &weight, transform(iptr[id]),
347+
compute_t<Tw>(1));
312348
}
313349
}
314350

@@ -406,8 +442,7 @@ void mean_first(Param<To> out, CParam<Ti> in, CParam<Tw> iwt) {
406442
threads_x);
407443

408444
if (blocks_x > 1) {
409-
Param<Tw> owt;
410-
owt.ptr = NULL;
445+
Array<Tw> owt = createEmptyArray<Tw>(dim4());
411446
mean_first_launcher<To, Tw, To>(out, owt, tmpOut, tmpWt, 1, blocks_y,
412447
threads_x);
413448
}
@@ -425,25 +460,24 @@ void mean_weighted(Param<To> out, CParam<Ti> in, CParam<Tw> iwt, int dim) {
425460

426461
template<typename Ti, typename Tw, typename To>
427462
void mean(Param<To> out, CParam<Ti> in, int dim) {
428-
Param<Tw> dummy_weight;
463+
Array<Tw> dummy_weight = createEmptyArray<Tw>(dim4());
429464
mean_weighted<Ti, Tw, To>(out, in, dummy_weight, dim);
430465
}
431466

432467
template<typename T, typename Tw>
433468
T mean_all_weighted(CParam<T> in, CParam<Tw> iwt) {
434469
int in_elements = in.dims[0] * in.dims[1] * in.dims[2] * in.dims[3];
435470

436-
// FIXME: Use better heuristics to get to the optimum number
437-
if (in_elements > 4096) {
438-
bool in_is_linear = (in.strides[0] == 1);
439-
bool wt_is_linear = (iwt.strides[0] == 1);
440-
for (int k = 1; k < 4; k++) {
441-
in_is_linear &=
442-
(in.strides[k] == (in.strides[k - 1] * in.dims[k - 1]));
443-
wt_is_linear &=
444-
(iwt.strides[k] == (iwt.strides[k - 1] * iwt.dims[k - 1]));
445-
}
471+
bool in_is_linear = (in.strides[0] == 1);
472+
bool wt_is_linear = (iwt.strides[0] == 1);
473+
for (int k = 1; k < 4; k++) {
474+
in_is_linear &= (in.strides[k] == (in.strides[k - 1] * in.dims[k - 1]));
475+
wt_is_linear &=
476+
(iwt.strides[k] == (iwt.strides[k - 1] * iwt.dims[k - 1]));
477+
}
446478

479+
// FIXME: Use better heuristics to get to the optimum number
480+
if (in_elements > 4096 || !in_is_linear || !wt_is_linear) {
447481
if (in_is_linear && wt_is_linear) {
448482
in.dims[0] = in_elements;
449483
for (int k = 1; k < 4; k++) {
@@ -487,9 +521,11 @@ T mean_all_weighted(CParam<T> in, CParam<Tw> iwt) {
487521

488522
compute_t<T> val = static_cast<compute_t<T>>(h_ptr[0]);
489523
compute_t<Tw> weight = static_cast<compute_t<Tw>>(h_wptr[0]);
524+
compute_t<T> correction =
525+
common::Binary<compute_t<T>, af_add_t>::init();
490526

491527
for (int i = 1; i < tmp_elements; i++) {
492-
stable_mean(&val, &weight, compute_t<T>(h_ptr[i]),
528+
stable_mean(&correction, &val, &weight, compute_t<T>(h_ptr[i]),
493529
compute_t<Tw>(h_wptr[i]));
494530
}
495531

@@ -508,8 +544,10 @@ T mean_all_weighted(CParam<T> in, CParam<Tw> iwt) {
508544

509545
compute_t<T> val = static_cast<compute_t<T>>(h_ptr[0]);
510546
compute_t<Tw> weight = static_cast<compute_t<Tw>>(h_wptr[0]);
547+
compute_t<T> correction =
548+
common::Binary<compute_t<T>, af_add_t>::init();
511549
for (int i = 1; i < in_elements; i++) {
512-
stable_mean(&val, &weight, compute_t<T>(h_ptr[i]),
550+
stable_mean(&correction, &val, &weight, compute_t<T>(h_ptr[i]),
513551
compute_t<Tw>(h_wptr[i]));
514552
}
515553

@@ -566,9 +604,11 @@ To mean_all(CParam<Ti> in) {
566604

567605
compute_t<To> val = static_cast<compute_t<To>>(h_ptr[0]);
568606
compute_t<Tw> weight = static_cast<compute_t<Tw>>(h_cptr[0]);
607+
compute_t<To> correction =
608+
common::Binary<compute_t<To>, af_add_t>::init();
569609

570610
for (int i = 1; i < tmp_elements; i++) {
571-
stable_mean(&val, &weight, compute_t<To>(h_ptr[i]),
611+
stable_mean(&correction, &val, &weight, compute_t<To>(h_ptr[i]),
572612
compute_t<Tw>(h_cptr[i]));
573613
}
574614

@@ -583,11 +623,13 @@ To mean_all(CParam<Ti> in) {
583623

584624
common::Transform<Ti, compute_t<To>, af_add_t> transform;
585625
compute_t<Tw> count = static_cast<compute_t<Tw>>(1);
626+
compute_t<To> correction =
627+
common::Binary<compute_t<To>, af_add_t>::init();
586628

587629
compute_t<To> val = transform(h_ptr[0]);
588630
compute_t<Tw> weight = count;
589631
for (int i = 1; i < in_elements; i++) {
590-
stable_mean(&val, &weight, transform(h_ptr[i]), count);
632+
stable_mean(&correction, &val, &weight, transform(h_ptr[i]), count);
591633
}
592634

593635
return static_cast<To>(val);

0 commit comments

Comments
 (0)