Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit ec2f876

Browse files
Yangzihao Wangtensorflower-gardener
authored andcommitted
Improved transpose operator's performance
Use specialized GPU kernels on tensors when the permutation can be reduced to {0,2,1}, {2,1,0} or {1,0}. Change: 151147354
1 parent 4f20605 commit ec2f876

10 files changed

Lines changed: 492 additions & 53 deletions

File tree

tensorflow/core/kernels/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,6 +1105,7 @@ tf_kernel_library(
11051105
visibility = ["//visibility:private"],
11061106
deps = [
11071107
"//tensorflow/core:framework",
1108+
"//tensorflow/core/kernels:conv_ops",
11081109
"//third_party/eigen3",
11091110
],
11101111
alwayslink = 0,

tensorflow/core/kernels/conv_2d.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,26 @@ struct NCHWToNHWC {
256256
typename TTypes<T, NDIMS>::Tensor out);
257257
};
258258

259+
// Converts a tensor from:
260+
// [dim0, dim1, dim2]
261+
// to:
262+
// [dim0, dim2, dim1]
263+
template <typename Device, typename T>
264+
struct SwapDimension1And2InTensor3 {
265+
void operator()(const Device& d, const T* in,
266+
const gtl::ArraySlice<int64>& input_dims, T* out);
267+
};
268+
269+
// Converts a tensor from:
270+
// [dim0, dim1, dim2]
271+
// to:
272+
// [dim2, dim1, dim0]
273+
template <typename Device, typename T>
274+
struct SwapDimension0And2InTensor3 {
275+
void operator()(const Device& d, const T* in,
276+
const gtl::ArraySlice<int64>& input_dims, T* out);
277+
};
278+
259279
// Reverses the effect of TransformFilter above.
260280
template <typename Device, typename T, int NDIMS>
261281
struct ReverseTransformFilter {

tensorflow/core/kernels/conv_ops_gpu_3.cu.cc

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,9 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index<IndexCount> FlatToTensorIndex(
126126

127127
// A Cuda custom kernel that swaps dimension-0 and dimension-2 of a 3D tensor.
128128
template <typename T>
129-
__global__ void SwapDimension0And2InTensor3(int nthreads, const T* input,
130-
Dimension<3> input_dims,
131-
T* output) {
129+
__global__ void SwapDimension0And2InTensor3Simple(int nthreads, const T* input,
130+
Dimension<3> input_dims,
131+
T* output) {
132132
Dimension<3> output_dims;
133133
output_dims[0] = input_dims[2];
134134
output_dims[1] = input_dims[1];
@@ -152,9 +152,9 @@ __global__ void SwapDimension0And2InTensor3(int nthreads, const T* input,
152152

153153
// A Cuda custom kernel that swaps dimension-1 and dimension-2 of a 3D tensor.
154154
template <typename T>
155-
__global__ void SwapDimension1And2InTensor3(int nthreads, const T* input,
156-
Dimension<3> input_dims,
157-
T* output) {
155+
__global__ void SwapDimension1And2InTensor3Simple(int nthreads, const T* input,
156+
Dimension<3> input_dims,
157+
T* output) {
158158
Dimension<3> output_dims;
159159
output_dims[0] = input_dims[0];
160160
output_dims[1] = input_dims[2];
@@ -348,9 +348,9 @@ struct TransformFilter<GPUDevice, T, int, NDIMS> {
348348
combined_dims[1] = in.dimension(NDIMS - 2); // input filters
349349
combined_dims[2] = in.dimension(NDIMS - 1); // output filters
350350
CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
351-
SwapDimension0And2InTensor3<
352-
T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
353-
config.virtual_thread_count, in.data(), combined_dims, out.data());
351+
SwapDimension0And2InTensor3Simple<T>
352+
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
353+
config.virtual_thread_count, in.data(), combined_dims, out.data());
354354
}
355355
};
356356

@@ -368,9 +368,9 @@ struct ReverseTransformFilter<GPUDevice, T, NDIMS> {
368368
combined_dims[2] *= in.dimension(i);
369369
}
370370
CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
371-
SwapDimension0And2InTensor3<
372-
T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
373-
config.virtual_thread_count, in.data(), combined_dims, out.data());
371+
SwapDimension0And2InTensor3Simple<T>
372+
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
373+
config.virtual_thread_count, in.data(), combined_dims, out.data());
374374
}
375375
};
376376

@@ -442,12 +442,44 @@ void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input,
442442
} else {
443443
int total_element_count = input_dims[0] * input_dims[1] * input_dims[2];
444444
CudaLaunchConfig config = GetCudaLaunchConfig(total_element_count, d);
445-
SwapDimension1And2InTensor3<
446-
T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
447-
config.virtual_thread_count, input, input_dims, output);
445+
SwapDimension1And2InTensor3Simple<T>
446+
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
447+
config.virtual_thread_count, input, input_dims, output);
448448
}
449449
}
450450

451+
// A GPU helper functor that does general dimension 1 and 2 switch for 3D
452+
// tensor.
453+
template <typename T>
454+
struct SwapDimension1And2InTensor3<GPUDevice, T> {
455+
typedef GPUDevice Device;
456+
void operator()(const Device& d, const T* in,
457+
const gtl::ArraySlice<int64>& combined_dims, T* out) {
458+
Dimension<3> input_dims = {static_cast<int>(combined_dims[0]),
459+
static_cast<int>(combined_dims[1]),
460+
static_cast<int>(combined_dims[2])};
461+
RunSwapDimension1And2InTensor3(d, in, input_dims, out);
462+
}
463+
};
464+
465+
// A GPU helper functor that does general dimension 0 and 2 switch for 3D
466+
// tensor.
467+
template <typename T>
468+
struct SwapDimension0And2InTensor3<GPUDevice, T> {
469+
typedef GPUDevice Device;
470+
void operator()(const Device& d, const T* in,
471+
const gtl::ArraySlice<int64>& combined_dims, T* out) {
472+
Dimension<3> input_dims = {static_cast<int>(combined_dims[0]),
473+
static_cast<int>(combined_dims[1]),
474+
static_cast<int>(combined_dims[2])};
475+
size_t total_size = combined_dims[0] * combined_dims[1] * combined_dims[2];
476+
CudaLaunchConfig config = GetCudaLaunchConfig(total_size, d);
477+
SwapDimension0And2InTensor3Simple<T>
478+
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
479+
config.virtual_thread_count, in, input_dims, out);
480+
}
481+
};
482+
451483
// A GPU helper functor that converts NHWC TensorFlow data format to
452484
// NCHW format that is accepted by Cudnn.
453485
template <typename T, int NDIMS>
@@ -497,6 +529,18 @@ template struct functor::ShuffleAndReverse<GPUDevice, Eigen::half, 4,
497529
template struct functor::TransformDepth<GPUDevice, float, int>;
498530
template struct functor::TransformDepth<GPUDevice, Eigen::half, int>;
499531

532+
template struct functor::SwapDimension1And2InTensor3<GPUDevice, uint8>;
533+
template struct functor::SwapDimension1And2InTensor3<GPUDevice, uint16>;
534+
template struct functor::SwapDimension1And2InTensor3<GPUDevice, uint32>;
535+
template struct functor::SwapDimension1And2InTensor3<GPUDevice, uint64>;
536+
template struct functor::SwapDimension1And2InTensor3<GPUDevice, float4>;
537+
538+
template struct functor::SwapDimension0And2InTensor3<GPUDevice, uint8>;
539+
template struct functor::SwapDimension0And2InTensor3<GPUDevice, uint16>;
540+
template struct functor::SwapDimension0And2InTensor3<GPUDevice, uint32>;
541+
template struct functor::SwapDimension0And2InTensor3<GPUDevice, uint64>;
542+
template struct functor::SwapDimension0And2InTensor3<GPUDevice, float4>;
543+
500544
// For 2d ops.
501545
template struct functor::TransformFilter<GPUDevice, float, int, 4>;
502546
template struct functor::TransformFilter<GPUDevice, Eigen::half, int, 4>;

tensorflow/core/kernels/transpose_functor.h

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -111,25 +111,14 @@ template <typename Device, typename T, int NDIMS>
111111
void TransposeUsingEigen(const Device& d, const Tensor& in,
112112
const gtl::ArraySlice<int32> perm, Tensor* out);
113113

114-
template <typename Device, typename T>
115-
void Transpose(const Device& d, const Tensor& in,
116-
const gtl::ArraySlice<int32> perm, Tensor* out) {
117-
switch (in.dims()) {
118-
case 2:
119-
TransposeUsingEigen<Device, T, 2>(d, in, perm, out);
120-
break;
121-
case 3:
122-
TransposeUsingEigen<Device, T, 3>(d, in, perm, out);
123-
break;
124-
case 4:
125-
TransposeUsingEigen<Device, T, 4>(d, in, perm, out);
126-
break;
127-
default:
128-
TransposeSimple<Device, T>(d, in, perm, out);
129-
break;
130-
}
131-
}
132114
} // namespace internal
115+
116+
template <typename Device, typename T>
117+
struct Transpose {
118+
static void run(const Device& d, const Tensor& in,
119+
const gtl::ArraySlice<int32> perm, Tensor* out);
120+
};
121+
133122
} // namespace tensorflow
134123

135124
#endif // TENSORFLOW_CORE_KERNELS_TRANSPOSE_FUNCTOR_H_

tensorflow/core/kernels/transpose_functor_cpu.cc

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,38 @@ void TransposeUsingEigen(const Device& d, const Tensor& in,
6161

6262
} // end namespace internal
6363

64-
typedef Eigen::ThreadPoolDevice Device;
64+
typedef Eigen::ThreadPoolDevice CPUDevice;
65+
66+
template <typename T>
67+
struct Transpose<CPUDevice, T> {
68+
static void run(const CPUDevice& d, const Tensor& in,
69+
const gtl::ArraySlice<int32> perm, Tensor* out) {
70+
switch (in.dims()) {
71+
case 2:
72+
internal::TransposeUsingEigen<CPUDevice, T, 2>(d, in, perm, out);
73+
break;
74+
case 3:
75+
internal::TransposeUsingEigen<CPUDevice, T, 3>(d, in, perm, out);
76+
break;
77+
case 4:
78+
internal::TransposeUsingEigen<CPUDevice, T, 4>(d, in, perm, out);
79+
break;
80+
case 5:
81+
internal::TransposeUsingEigen<CPUDevice, T, 5>(d, in, perm, out);
82+
break;
83+
default:
84+
internal::TransposeSimple<CPUDevice, T>(d, in, perm, out);
85+
break;
86+
}
87+
}
88+
};
6589

90+
// TODO(yangzihao): Merge this code with its GPU counterpart to reduce code
91+
// duplication.
6692
template <>
67-
Status DoTranspose<Device>(const Device& d, const Tensor& in,
68-
const gtl::ArraySlice<int32> perm, Tensor* out) {
93+
Status DoTranspose<CPUDevice>(const CPUDevice& d, const Tensor& in,
94+
const gtl::ArraySlice<int32> perm, Tensor* out) {
95+
typedef CPUDevice Device;
6996
CHECK_GE(in.dims(), 2);
7097
CHECK_EQ(in.dims(), out->dims());
7198
CHECK_EQ(in.dims(), perm.size());
@@ -76,7 +103,7 @@ Status DoTranspose<Device>(const Device& d, const Tensor& in,
76103
case DT_QINT8:
77104
case DT_QUINT8:
78105
case DT_UINT8:
79-
internal::Transpose<Device, uint8>(d, in, perm, out);
106+
Transpose<Device, uint8>::run(d, in, perm, out);
80107
break;
81108

82109
case DT_BFLOAT16:
@@ -85,27 +112,27 @@ Status DoTranspose<Device>(const Device& d, const Tensor& in,
85112
case DT_QINT16:
86113
case DT_QUINT16:
87114
case DT_UINT16:
88-
internal::Transpose<Device, uint16>(d, in, perm, out);
115+
Transpose<Device, uint16>::run(d, in, perm, out);
89116
break;
90117

91118
case DT_FLOAT:
92119
case DT_INT32:
93120
case DT_QINT32:
94-
internal::Transpose<Device, uint32>(d, in, perm, out);
121+
Transpose<Device, uint32>::run(d, in, perm, out);
95122
break;
96123

97124
case DT_COMPLEX64:
98125
case DT_DOUBLE:
99126
case DT_INT64:
100-
internal::Transpose<Device, uint64>(d, in, perm, out);
127+
Transpose<Device, uint64>::run(d, in, perm, out);
101128
break;
102129

103130
case DT_COMPLEX128:
104-
internal::Transpose<Device, complex128>(d, in, perm, out);
131+
Transpose<Device, complex128>::run(d, in, perm, out);
105132
break;
106133

107134
case DT_STRING:
108-
internal::Transpose<Device, string>(d, in, perm, out);
135+
Transpose<Device, string>::run(d, in, perm, out);
109136
break;
110137

111138
default:
@@ -117,6 +144,14 @@ Status DoTranspose<Device>(const Device& d, const Tensor& in,
117144
#ifdef TENSORFLOW_USE_SYCL
118145
typedef Eigen::SyclDevice SYCLDevice;
119146

147+
template <typename T>
148+
struct internal::Transpose<SYCLDevice, T> {
149+
static void run(const SYCLDevice& d, const Tensor& in,
150+
const gtl::ArraySlice<int32> perm, Tensor* out) {
151+
// Should add a specialized implementation for SYCLDevice here.
152+
}
153+
};
154+
120155
template <>
121156
Status DoTranspose<SYCLDevice>(const SYCLDevice& d, const Tensor& in,
122157
const gtl::ArraySlice<int32> perm, Tensor* out) {
@@ -125,11 +160,10 @@ Status DoTranspose<SYCLDevice>(const SYCLDevice& d, const Tensor& in,
125160
CHECK_EQ(in.dims(), perm.size());
126161
CHECK_EQ(in.dtype(), out->dtype());
127162
switch (in.dtype()) {
128-
129163
case DT_FLOAT:
130164
case DT_DOUBLE:
131165
case DT_INT32:
132-
internal::Transpose<SYCLDevice, uint32>(d, in, perm, out);
166+
internal::Transpose<SYCLDevice, uint32>::run(d, in, perm, out);
133167
break;
134168

135169
default:

0 commit comments

Comments
 (0)