@@ -126,9 +126,9 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index<IndexCount> FlatToTensorIndex(
126126
127127// A Cuda custom kernel that swaps dimension-0 and dimension-2 of a 3D tensor.
128128template <typename T>
129- __global__ void SwapDimension0And2InTensor3 (int nthreads, const T* input,
130- Dimension<3 > input_dims,
131- T* output) {
129+ __global__ void SwapDimension0And2InTensor3Simple (int nthreads, const T* input,
130+ Dimension<3 > input_dims,
131+ T* output) {
132132 Dimension<3 > output_dims;
133133 output_dims[0 ] = input_dims[2 ];
134134 output_dims[1 ] = input_dims[1 ];
@@ -152,9 +152,9 @@ __global__ void SwapDimension0And2InTensor3(int nthreads, const T* input,
152152
153153// A Cuda custom kernel that swaps dimension-1 and dimension-2 of a 3D tensor.
154154template <typename T>
155- __global__ void SwapDimension1And2InTensor3 (int nthreads, const T* input,
156- Dimension<3 > input_dims,
157- T* output) {
155+ __global__ void SwapDimension1And2InTensor3Simple (int nthreads, const T* input,
156+ Dimension<3 > input_dims,
157+ T* output) {
158158 Dimension<3 > output_dims;
159159 output_dims[0 ] = input_dims[0 ];
160160 output_dims[1 ] = input_dims[2 ];
@@ -348,9 +348,9 @@ struct TransformFilter<GPUDevice, T, int, NDIMS> {
348348 combined_dims[1 ] = in.dimension (NDIMS - 2 ); // input filters
349349 combined_dims[2 ] = in.dimension (NDIMS - 1 ); // output filters
350350 CudaLaunchConfig config = GetCudaLaunchConfig (out.size (), d);
351- SwapDimension0And2InTensor3<
352- T> <<<config.block_count , config.thread_per_block , 0 , d.stream ()>>>(
353- config.virtual_thread_count , in.data (), combined_dims, out.data ());
351+ SwapDimension0And2InTensor3Simple<T>
352+ <<<config.block_count , config.thread_per_block , 0 , d.stream ()>>>(
353+ config.virtual_thread_count , in.data (), combined_dims, out.data ());
354354 }
355355};
356356
@@ -368,9 +368,9 @@ struct ReverseTransformFilter<GPUDevice, T, NDIMS> {
368368 combined_dims[2 ] *= in.dimension (i);
369369 }
370370 CudaLaunchConfig config = GetCudaLaunchConfig (out.size (), d);
371- SwapDimension0And2InTensor3<
372- T> <<<config.block_count , config.thread_per_block , 0 , d.stream ()>>>(
373- config.virtual_thread_count , in.data (), combined_dims, out.data ());
371+ SwapDimension0And2InTensor3Simple<T>
372+ <<<config.block_count , config.thread_per_block , 0 , d.stream ()>>>(
373+ config.virtual_thread_count , in.data (), combined_dims, out.data ());
374374 }
375375};
376376
@@ -442,12 +442,44 @@ void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input,
442442 } else {
443443 int total_element_count = input_dims[0 ] * input_dims[1 ] * input_dims[2 ];
444444 CudaLaunchConfig config = GetCudaLaunchConfig (total_element_count, d);
445- SwapDimension1And2InTensor3<
446- T> <<<config.block_count , config.thread_per_block , 0 , d.stream ()>>>(
447- config.virtual_thread_count , input, input_dims, output);
445+ SwapDimension1And2InTensor3Simple<T>
446+ <<<config.block_count , config.thread_per_block , 0 , d.stream ()>>>(
447+ config.virtual_thread_count , input, input_dims, output);
448448 }
449449}
450450
451+ // A GPU helper functor that does general dimension 1 and 2 switch for 3D
452+ // tensor.
453+ template <typename T>
454+ struct SwapDimension1And2InTensor3 <GPUDevice, T> {
455+ typedef GPUDevice Device;
456+ void operator ()(const Device& d, const T* in,
457+ const gtl::ArraySlice<int64>& combined_dims, T* out) {
458+ Dimension<3 > input_dims = {static_cast <int >(combined_dims[0 ]),
459+ static_cast <int >(combined_dims[1 ]),
460+ static_cast <int >(combined_dims[2 ])};
461+ RunSwapDimension1And2InTensor3 (d, in, input_dims, out);
462+ }
463+ };
464+
465+ // A GPU helper functor that does general dimension 0 and 2 switch for 3D
466+ // tensor.
467+ template <typename T>
468+ struct SwapDimension0And2InTensor3 <GPUDevice, T> {
469+ typedef GPUDevice Device;
470+ void operator ()(const Device& d, const T* in,
471+ const gtl::ArraySlice<int64>& combined_dims, T* out) {
472+ Dimension<3 > input_dims = {static_cast <int >(combined_dims[0 ]),
473+ static_cast <int >(combined_dims[1 ]),
474+ static_cast <int >(combined_dims[2 ])};
475+ size_t total_size = combined_dims[0 ] * combined_dims[1 ] * combined_dims[2 ];
476+ CudaLaunchConfig config = GetCudaLaunchConfig (total_size, d);
477+ SwapDimension0And2InTensor3Simple<T>
478+ <<<config.block_count , config.thread_per_block , 0 , d.stream ()>>>(
479+ config.virtual_thread_count , in, input_dims, out);
480+ }
481+ };
482+
451483// A GPU helper functor that converts NHWC TensorFlow data format to
452484// NCHW format that is accepted by Cudnn.
453485template <typename T, int NDIMS>
@@ -497,6 +529,18 @@ template struct functor::ShuffleAndReverse<GPUDevice, Eigen::half, 4,
497529template struct functor ::TransformDepth<GPUDevice, float , int >;
498530template struct functor ::TransformDepth<GPUDevice, Eigen::half, int >;
499531
532+ template struct functor ::SwapDimension1And2InTensor3<GPUDevice, uint8>;
533+ template struct functor ::SwapDimension1And2InTensor3<GPUDevice, uint16>;
534+ template struct functor ::SwapDimension1And2InTensor3<GPUDevice, uint32>;
535+ template struct functor ::SwapDimension1And2InTensor3<GPUDevice, uint64>;
536+ template struct functor ::SwapDimension1And2InTensor3<GPUDevice, float4>;
537+
538+ template struct functor ::SwapDimension0And2InTensor3<GPUDevice, uint8>;
539+ template struct functor ::SwapDimension0And2InTensor3<GPUDevice, uint16>;
540+ template struct functor ::SwapDimension0And2InTensor3<GPUDevice, uint32>;
541+ template struct functor ::SwapDimension0And2InTensor3<GPUDevice, uint64>;
542+ template struct functor ::SwapDimension0And2InTensor3<GPUDevice, float4>;
543+
500544// For 2d ops.
501545template struct functor ::TransformFilter<GPUDevice, float , int , 4 >;
502546template struct functor ::TransformFilter<GPUDevice, Eigen::half, int , 4 >;
0 commit comments