@@ -39,16 +39,43 @@ namespace kernel {
39
39
40
40
template <typename To, typename Tw>
41
41
__device__ __host__ void stable_mean (To *lhs, Tw *l_wt, To rhs, Tw r_wt) {
42
+ Tw l_scale = (*l_wt);
43
+ (*l_wt) += r_wt;
42
44
if (((*l_wt) != (Tw)0 ) || (r_wt != (Tw)0 )) {
43
- Tw l_scale = (*l_wt);
44
- (*l_wt) += r_wt;
45
45
l_scale = l_scale / (*l_wt);
46
46
47
47
Tw r_scale = r_wt / (*l_wt);
48
48
(*lhs) = (l_scale * *lhs) + (r_scale * rhs);
49
49
}
50
50
}
51
51
52
+ template <typename To, typename Tw>
53
+ __device__ __host__ void stable_mean (To *c, To *lhs, Tw *l_wt, To rhs,
54
+ Tw r_wt) {
55
+ (*l_wt) += r_wt;
56
+ if (((*l_wt) != (Tw)0 ) || (r_wt != (Tw)0 )) {
57
+ // Since only 1 pass is used, the rounding errors will become important
58
+ // because the longer the serie the larger the difference between the 2
59
+ // numbers to sum See:
60
+ // https://en.wikipedia.org/wiki/Kahan_summation_algorithm to reduce
61
+ // this error
62
+ // (*lhs) = (*lhs) + (rhs - (*lhs)) * r_wt / (*l_wt);
63
+
64
+ // *c is zero for the first time around
65
+ To y = (rhs - (*lhs)) * r_wt / (*l_wt) - (*c);
66
+ // Alas, (*lhs) is big, y small, so low-order digits of y are lost
67
+ To t = (*lhs) + y;
68
+ // (t - (*lhs)) cancels the high-order part of y
69
+ // subtracting y recovers negative (low part of y)
70
+ (*c) = (t - (*lhs)) - y;
71
+ // Algebraically, *c should always be zero. Beware overly-agressive
72
+ // optimizing compilers!
73
+ (*lhs) = t;
74
+ // Next time around, the lost low part will be added to y in a fresh
75
+ // attempt
76
+ }
77
+ }
78
+
52
79
template <typename Ti, typename Tw, typename To, uint dim, uint DIMY>
53
80
__global__ static void mean_dim_kernel (Param<To> out, Param<Tw> owt,
54
81
CParam<Ti> in, CParam<Tw> iwt,
@@ -74,24 +101,32 @@ __global__ static void mean_dim_kernel(Param<To> out, Param<Tw> owt,
74
101
75
102
int ooffset = ids[3 ] * out.strides [3 ] + ids[2 ] * out.strides [2 ] +
76
103
ids[1 ] * out.strides [1 ] + ids[0 ];
104
+ optr += ooffset;
105
+ if (owptr != NULL ) {
106
+ int owoffset = ids[3 ] * owt.strides [3 ] + ids[2 ] * owt.strides [2 ] +
107
+ ids[1 ] * owt.strides [1 ] + ids[0 ];
108
+ owptr += owoffset;
109
+ }
110
+
77
111
// There is only one element per block for out
78
112
// There are blockDim.y elements per block for in
79
113
// Hence increment ids[dim] just after offseting out and before offsetting
80
114
// in
81
- optr += ooffset;
82
- if (owptr != NULL ) owptr += ooffset;
83
-
84
115
const uint blockIdx_dim = ids[dim];
85
-
86
- ids[dim] = ids[dim] * blockDim.y + tidy;
116
+ ids[dim] = ids[dim] * blockDim.y + tidy;
87
117
88
118
int ioffset = ids[3 ] * in.strides [3 ] + ids[2 ] * in.strides [2 ] +
89
119
ids[1 ] * in.strides [1 ] + ids[0 ];
90
120
iptr += ioffset;
91
- if (iwptr != NULL ) iwptr += ioffset;
121
+ if (iwptr != NULL ) {
122
+ int iwoffset = ids[3 ] * iwt.strides [3 ] + ids[2 ] * iwt.strides [2 ] +
123
+ ids[1 ] * iwt.strides [1 ] + ids[0 ];
124
+ iwptr += iwoffset;
125
+ }
92
126
93
- const uint id_dim_in = ids[dim];
94
- const uint istride_dim = in.strides [dim];
127
+ const uint id_dim_in = ids[dim];
128
+ const uint istride_dim = in.strides [dim];
129
+ const uint iwstride_dim = iwt.strides [dim];
95
130
96
131
bool is_valid = (ids[0 ] < in.dims [0 ]) && (ids[1 ] < in.dims [1 ]) &&
97
132
(ids[2 ] < in.dims [2 ]) && (ids[3 ] < in.dims [3 ]);
@@ -106,7 +141,7 @@ __global__ static void mean_dim_kernel(Param<To> out, Param<Tw> owt,
106
141
if (iwptr != NULL ) {
107
142
weight = *iwptr;
108
143
} else {
109
- weight = (Tw) 1 ;
144
+ weight = compute_t <Tw>( 1 ) ;
110
145
}
111
146
}
112
147
@@ -115,16 +150,17 @@ __global__ static void mean_dim_kernel(Param<To> out, Param<Tw> owt,
115
150
__shared__ compute_t <To> s_val[THREADS_X * DIMY];
116
151
__shared__ compute_t <Tw> s_idx[THREADS_X * DIMY];
117
152
153
+ compute_t <To> correction = scalar<compute_t <To>>(0 );
118
154
for (int id = id_dim_in_start; is_valid && (id < in.dims [dim]);
119
155
id += offset_dim * blockDim.y ) {
120
156
iptr = iptr + offset_dim * blockDim.y * istride_dim;
121
157
if (iwptr != NULL ) {
122
- iwptr = iwptr + offset_dim * blockDim.y * istride_dim;
123
- stable_mean (&val, &weight, transform (*iptr), compute_t <Tw>(*iwptr));
158
+ iwptr = iwptr + offset_dim * blockDim.y * iwstride_dim;
159
+ stable_mean (&correction, &val, &weight, transform (*iptr),
160
+ compute_t <Tw>(*iwptr));
124
161
} else {
125
- // Faster version of stable_mean when iwptr is NULL
126
- val = val + (transform (*iptr) - val) / (weight + (Tw)1 );
127
- weight = weight + (Tw)1 ;
162
+ stable_mean (&correction, &val, &weight, transform (*iptr),
163
+ compute_t <Tw>(1 ));
128
164
}
129
165
}
130
166
@@ -299,16 +335,16 @@ __global__ static void mean_first_kernel(Param<To> out, Param<Tw> owt,
299
335
__shared__ compute_t <To> s_val[THREADS_PER_BLOCK];
300
336
__shared__ compute_t <Tw> s_idx[THREADS_PER_BLOCK];
301
337
338
+ compute_t <To> correction = scalar<compute_t <To>>(0 );
302
339
if (iwptr != NULL ) {
303
340
for (int id = xid + DIMX; id < lim; id += DIMX) {
304
- stable_mean (&val, &weight, transform (iptr[id]),
341
+ stable_mean (&correction, & val, &weight, transform (iptr[id]),
305
342
compute_t <Tw>(iwptr[id]));
306
343
}
307
344
} else {
308
345
for (int id = xid + DIMX; id < lim; id += DIMX) {
309
- // Faster version of stable_mean when iwptr is NULL
310
- val = val + (transform (iptr[id]) - val) / (weight + (Tw)1 );
311
- weight = weight + (Tw)1 ;
346
+ stable_mean (&correction, &val, &weight, transform (iptr[id]),
347
+ compute_t <Tw>(1 ));
312
348
}
313
349
}
314
350
@@ -406,8 +442,7 @@ void mean_first(Param<To> out, CParam<Ti> in, CParam<Tw> iwt) {
406
442
threads_x);
407
443
408
444
if (blocks_x > 1 ) {
409
- Param<Tw> owt;
410
- owt.ptr = NULL ;
445
+ Array<Tw> owt = createEmptyArray<Tw>(dim4 ());
411
446
mean_first_launcher<To, Tw, To>(out, owt, tmpOut, tmpWt, 1 , blocks_y,
412
447
threads_x);
413
448
}
@@ -425,25 +460,24 @@ void mean_weighted(Param<To> out, CParam<Ti> in, CParam<Tw> iwt, int dim) {
425
460
426
461
template <typename Ti, typename Tw, typename To>
427
462
void mean (Param<To> out, CParam<Ti> in, int dim) {
428
- Param <Tw> dummy_weight;
463
+ Array <Tw> dummy_weight = createEmptyArray<Tw>( dim4 ()) ;
429
464
mean_weighted<Ti, Tw, To>(out, in, dummy_weight, dim);
430
465
}
431
466
432
467
template <typename T, typename Tw>
433
468
T mean_all_weighted (CParam<T> in, CParam<Tw> iwt) {
434
469
int in_elements = in.dims [0 ] * in.dims [1 ] * in.dims [2 ] * in.dims [3 ];
435
470
436
- // FIXME: Use better heuristics to get to the optimum number
437
- if (in_elements > 4096 ) {
438
- bool in_is_linear = (in.strides [0 ] == 1 );
439
- bool wt_is_linear = (iwt.strides [0 ] == 1 );
440
- for (int k = 1 ; k < 4 ; k++) {
441
- in_is_linear &=
442
- (in.strides [k] == (in.strides [k - 1 ] * in.dims [k - 1 ]));
443
- wt_is_linear &=
444
- (iwt.strides [k] == (iwt.strides [k - 1 ] * iwt.dims [k - 1 ]));
445
- }
471
+ bool in_is_linear = (in.strides [0 ] == 1 );
472
+ bool wt_is_linear = (iwt.strides [0 ] == 1 );
473
+ for (int k = 1 ; k < 4 ; k++) {
474
+ in_is_linear &= (in.strides [k] == (in.strides [k - 1 ] * in.dims [k - 1 ]));
475
+ wt_is_linear &=
476
+ (iwt.strides [k] == (iwt.strides [k - 1 ] * iwt.dims [k - 1 ]));
477
+ }
446
478
479
+ // FIXME: Use better heuristics to get to the optimum number
480
+ if (in_elements > 4096 || !in_is_linear || !wt_is_linear) {
447
481
if (in_is_linear && wt_is_linear) {
448
482
in.dims [0 ] = in_elements;
449
483
for (int k = 1 ; k < 4 ; k++) {
@@ -487,9 +521,11 @@ T mean_all_weighted(CParam<T> in, CParam<Tw> iwt) {
487
521
488
522
compute_t <T> val = static_cast <compute_t <T>>(h_ptr[0 ]);
489
523
compute_t <Tw> weight = static_cast <compute_t <Tw>>(h_wptr[0 ]);
524
+ compute_t <T> correction =
525
+ common::Binary<compute_t <T>, af_add_t >::init ();
490
526
491
527
for (int i = 1 ; i < tmp_elements; i++) {
492
- stable_mean (&val, &weight, compute_t <T>(h_ptr[i]),
528
+ stable_mean (&correction, & val, &weight, compute_t <T>(h_ptr[i]),
493
529
compute_t <Tw>(h_wptr[i]));
494
530
}
495
531
@@ -508,8 +544,10 @@ T mean_all_weighted(CParam<T> in, CParam<Tw> iwt) {
508
544
509
545
compute_t <T> val = static_cast <compute_t <T>>(h_ptr[0 ]);
510
546
compute_t <Tw> weight = static_cast <compute_t <Tw>>(h_wptr[0 ]);
547
+ compute_t <T> correction =
548
+ common::Binary<compute_t <T>, af_add_t >::init ();
511
549
for (int i = 1 ; i < in_elements; i++) {
512
- stable_mean (&val, &weight, compute_t <T>(h_ptr[i]),
550
+ stable_mean (&correction, & val, &weight, compute_t <T>(h_ptr[i]),
513
551
compute_t <Tw>(h_wptr[i]));
514
552
}
515
553
@@ -566,9 +604,11 @@ To mean_all(CParam<Ti> in) {
566
604
567
605
compute_t <To> val = static_cast <compute_t <To>>(h_ptr[0 ]);
568
606
compute_t <Tw> weight = static_cast <compute_t <Tw>>(h_cptr[0 ]);
607
+ compute_t <To> correction =
608
+ common::Binary<compute_t <To>, af_add_t >::init ();
569
609
570
610
for (int i = 1 ; i < tmp_elements; i++) {
571
- stable_mean (&val, &weight, compute_t <To>(h_ptr[i]),
611
+ stable_mean (&correction, & val, &weight, compute_t <To>(h_ptr[i]),
572
612
compute_t <Tw>(h_cptr[i]));
573
613
}
574
614
@@ -583,11 +623,13 @@ To mean_all(CParam<Ti> in) {
583
623
584
624
common::Transform<Ti, compute_t <To>, af_add_t > transform;
585
625
compute_t <Tw> count = static_cast <compute_t <Tw>>(1 );
626
+ compute_t <To> correction =
627
+ common::Binary<compute_t <To>, af_add_t >::init ();
586
628
587
629
compute_t <To> val = transform (h_ptr[0 ]);
588
630
compute_t <Tw> weight = count;
589
631
for (int i = 1 ; i < in_elements; i++) {
590
- stable_mean (&val, &weight, transform (h_ptr[i]), count);
632
+ stable_mean (&correction, & val, &weight, transform (h_ptr[i]), count);
591
633
}
592
634
593
635
return static_cast <To>(val);
0 commit comments