11#include " ctranslate2/ops/median_filter.h"
22
3- #include < cuda_fp16.h>
4- #ifdef CUDA_BF16_AVAILABLE
5- #include < cuda_bf16.h>
6- #endif
7-
8- #include " type_dispatch.h"
93#include " cuda/helpers.h"
10- #include < type_traits>
114
125namespace ctranslate2 {
136 namespace ops {
147
158 constexpr dim_t num_threads = 256 ;
16-
17- // Conversion helpers
18- __device__ __forceinline__ float to_float (float v) { return v; }
19- __device__ __forceinline__ float to_float (const half v) { return __half2float (v); }
20- #ifdef CUDA_BF16_AVAILABLE
21- __device__ __forceinline__ float to_float (const __nv_bfloat16 v) { return __bfloat162float (v); }
22- #endif
23-
24- __device__ __forceinline__ float from_float (float v) { return v; }
25- __device__ __forceinline__ half from_float_half (float v) { return __float2half (v); }
26- #ifdef CUDA_BF16_AVAILABLE
27- __device__ __forceinline__ __nv_bfloat16 from_float_bf16 (float v) { return __float2bfloat16 (v); }
28- #endif
29-
30- namespace {
31- constexpr int kMaxWindow = 129 ; // supports window widths up to 129 (rank 64)
32- }
9+ constexpr int kMaxWindow = 129 ; // supports window widths up to 129 (rank 64)
3310
3411 template <typename DeviceT, int kMax >
3512 __global__ void sliding_median_lastdim_kernel (const DeviceT* input,
@@ -45,15 +22,6 @@ namespace ctranslate2 {
4522 int col = tid % depth;
4623 const int rank = width / 2 ;
4724
48- if (depth <= rank) {
49- output[tid] = input[tid];
50- return ;
51- }
52- if (width > kMax ) {
53- output[tid] = input[tid];
54- return ;
55- }
56-
5725 float window[kMax ];
5826
5927 const int row_offset = row * depth;
@@ -62,7 +30,7 @@ namespace ctranslate2 {
6230 int read = col + k;
6331 if (read < 0 ) read = -read;
6432 if (read >= depth) read = 2 * depth - read - 2 ;
65- window[k + rank] = to_float (input[row_offset + read]);
33+ window[k + rank] = float (input[row_offset + read]);
6634 }
6735
6836 // Insertion sort (width is small: <= kMax, typically < 129).
@@ -75,24 +43,13 @@ namespace ctranslate2 {
7543 }
7644 window[j + 1 ] = key;
7745 }
78- float median = window[rank];
79-
80- if constexpr (std::is_same<DeviceT, float >::value) {
81- output[tid] = median;
82- } else if constexpr (std::is_same<DeviceT, half>::value) {
83- output[tid] = from_float_half (median);
84- #ifdef CUDA_BF16_AVAILABLE
85- } else if constexpr (std::is_same<DeviceT, __nv_bfloat16>::value) {
86- output[tid] = from_float_bf16 (median);
87- #endif
88- }
46+ output[tid] = DeviceT (window[rank]);
8947 }
9048
9149 template <Device D, typename T>
9250 void MedianFilter::compute (const StorageView& input,
9351 const dim_t axis_size,
9452 StorageView& output) const {
95- output.resize_as (input);
9653 const int depth = static_cast <int >(axis_size);
9754 const int rows = static_cast <int >(input.size () / depth);
9855 const int width = static_cast <int >(_width);
@@ -130,12 +87,10 @@ namespace ctranslate2 {
13087 rows,
13188 depth,
13289 width);
133- CUDA_CHECK (cudaGetLastError ());
134- CUDA_CHECK (cudaDeviceSynchronize ());
13590 }
13691
137- #define DECLARE_IMPL (T ) \
138- template void \
92+ #define DECLARE_IMPL (T ) \
93+ template void \
13994 MedianFilter::compute<Device::CUDA , T>(const StorageView& input, \
14095 const dim_t axis_size, \
14196 StorageView& output) const ;
0 commit comments