Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d09369d

Browse files
authored
Fix CUDA bf16 median filter (OpenNMT#1972)
1 parent f737c90 commit d09369d

4 files changed

Lines changed: 15 additions & 57 deletions

File tree

include/ctranslate2/ops/median_filter.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
#pragma once
2+
23
#include "op.h"
34

45
namespace ctranslate2 {
56
namespace ops {
67

78
class MedianFilter : public Op {
89
public:
9-
explicit MedianFilter(dim_t width);
10+
MedianFilter(const dim_t width);
1011
void operator()(const StorageView& input, StorageView& output) const;
1112

1213
private:

src/ops/median_filter.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
namespace ctranslate2 {
66
namespace ops {
77

8-
MedianFilter::MedianFilter(dim_t width)
8+
MedianFilter::MedianFilter(const dim_t width)
99
: _width(width)
1010
{
1111
}

src/ops/median_filter_gpu.cu

Lines changed: 5 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,12 @@
11
#include "ctranslate2/ops/median_filter.h"
22

3-
#include <cuda_fp16.h>
4-
#ifdef CUDA_BF16_AVAILABLE
5-
#include <cuda_bf16.h>
6-
#endif
7-
8-
#include "type_dispatch.h"
93
#include "cuda/helpers.h"
10-
#include <type_traits>
114

125
namespace ctranslate2 {
136
namespace ops {
147

158
constexpr dim_t num_threads = 256;
16-
17-
// Conversion helpers
18-
__device__ __forceinline__ float to_float(float v) { return v; }
19-
__device__ __forceinline__ float to_float(const half v) { return __half2float(v); }
20-
#ifdef CUDA_BF16_AVAILABLE
21-
__device__ __forceinline__ float to_float(const __nv_bfloat16 v) { return __bfloat162float(v); }
22-
#endif
23-
24-
__device__ __forceinline__ float from_float(float v) { return v; }
25-
__device__ __forceinline__ half from_float_half(float v) { return __float2half(v); }
26-
#ifdef CUDA_BF16_AVAILABLE
27-
__device__ __forceinline__ __nv_bfloat16 from_float_bf16(float v) { return __float2bfloat16(v); }
28-
#endif
29-
30-
namespace {
31-
constexpr int kMaxWindow = 129; // supports window widths up to 129 (rank 64)
32-
}
9+
constexpr int kMaxWindow = 129; // supports window widths up to 129 (rank 64)
3310

3411
template <typename DeviceT, int kMax>
3512
__global__ void sliding_median_lastdim_kernel(const DeviceT* input,
@@ -45,15 +22,6 @@ namespace ctranslate2 {
4522
int col = tid % depth;
4623
const int rank = width / 2;
4724

48-
if (depth <= rank) {
49-
output[tid] = input[tid];
50-
return;
51-
}
52-
if (width > kMax) {
53-
output[tid] = input[tid];
54-
return;
55-
}
56-
5725
float window[kMax];
5826

5927
const int row_offset = row * depth;
@@ -62,7 +30,7 @@ namespace ctranslate2 {
6230
int read = col + k;
6331
if (read < 0) read = -read;
6432
if (read >= depth) read = 2 * depth - read - 2;
65-
window[k + rank] = to_float(input[row_offset + read]);
33+
window[k + rank] = float(input[row_offset + read]);
6634
}
6735

6836
// Insertion sort (width is small: <= kMax, typically < 129).
@@ -75,24 +43,13 @@ namespace ctranslate2 {
7543
}
7644
window[j + 1] = key;
7745
}
78-
float median = window[rank];
79-
80-
if constexpr (std::is_same<DeviceT, float>::value) {
81-
output[tid] = median;
82-
} else if constexpr (std::is_same<DeviceT, half>::value) {
83-
output[tid] = from_float_half(median);
84-
#ifdef CUDA_BF16_AVAILABLE
85-
} else if constexpr (std::is_same<DeviceT, __nv_bfloat16>::value) {
86-
output[tid] = from_float_bf16(median);
87-
#endif
88-
}
46+
output[tid] = DeviceT(window[rank]);
8947
}
9048

9149
template <Device D, typename T>
9250
void MedianFilter::compute(const StorageView& input,
9351
const dim_t axis_size,
9452
StorageView& output) const {
95-
output.resize_as(input);
9653
const int depth = static_cast<int>(axis_size);
9754
const int rows = static_cast<int>(input.size() / depth);
9855
const int width = static_cast<int>(_width);
@@ -130,12 +87,10 @@ namespace ctranslate2 {
13087
rows,
13188
depth,
13289
width);
133-
CUDA_CHECK(cudaGetLastError());
134-
CUDA_CHECK(cudaDeviceSynchronize());
13590
}
13691

137-
#define DECLARE_IMPL(T) \
138-
template void \
92+
#define DECLARE_IMPL(T) \
93+
template void \
13994
MedianFilter::compute<Device::CUDA, T>(const StorageView& input, \
14095
const dim_t axis_size, \
14196
StorageView& output) const;

tests/ops_test.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,10 @@ class OpDeviceFPTest : public ::testing::TestWithParam<FloatType> {
125125
};
126126

127127

128-
TEST_P(OpDeviceTest, MedianFilter) {
129-
Device device = GetParam();
128+
TEST_P(OpDeviceFPTest, MedianFilter) {
129+
Device device = GetParam().device;
130+
const DataType dtype = GetParam().dtype;
131+
const float error = GetParam().error;
130132
StorageView x({2, 8}, std::vector<float>{
131133
0.2556743323802948, 0.8028775453567505, 0.3514494299888611, 0.3542254865169525,
132134
0.5881291031837463, 0.1458204835653305, 0.6845740675926208, 0.543143630027771,
@@ -139,9 +141,9 @@ TEST_P(OpDeviceTest, MedianFilter) {
139141
0.9039326310157776, 0.4063926637172699, 0.7943458557128906, 0.4063926637172699,
140142
0.7943458557128906, 0.4063926637172699, 0.7943458557128906, 0.289182186126709},
141143
device);
142-
StorageView y(device);
143-
ops::MedianFilter(5)(x, y);
144-
expect_storage_eq(y, expected);
144+
StorageView y(dtype, device);
145+
ops::MedianFilter(5)(x.to(dtype), y);
146+
expect_storage_eq(y.to_float32(), expected, error);
145147
}
146148

147149
TEST_P(OpDeviceTest, Add) {

0 commit comments

Comments
 (0)