Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit bfa4cb6

Browse files
committed
tranpose op: perm as input tensor
1 parent 45c1509 commit bfa4cb6

File tree

3 files changed

+88
-68
lines changed

3 files changed

+88
-68
lines changed
Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,18 @@
11
#ifndef _TRANSPOSE_TEST_H
22
#define _TRANSPOSE_TEST_H
33

4-
static const unsigned short transpose_axes_arr[3] = { 2,1,0 };
5-
static const float random_input_arr[15] = { 3.484638214111328, 2.033799886703491, 3.2437448501586914, 4.783249855041504, 3.497023582458496, 3.511240005493164, 1.558927297592163, 3.7084484100341797, 2.570117712020874, 0.2405869960784912, 1.8713605403900146, 4.19132661819458, 0.6596618890762329, 0.9029078483581543, 0.2223271131515503 };
6-
static const float ref_output_arr[15] = { 3.484638214111328, 3.511240005493164, 1.8713605403900146, 2.033799886703491, 1.558927297592163, 4.19132661819458, 3.2437448501586914, 3.7084484100341797, 0.6596618890762329, 4.783249855041504, 2.570117712020874, 0.9029078483581543, 3.497023582458496, 0.2405869960784912, 0.2223271131515503 };
4+
static const int32_t transpose_perm_arr[4] = {2, 1, 0, 3};
5+
static const float random_input_arr[15] = {
6+
3.484638214111328, 2.033799886703491, 3.2437448501586914,
7+
4.783249855041504, 3.497023582458496, 3.511240005493164,
8+
1.558927297592163, 3.7084484100341797, 2.570117712020874,
9+
0.2405869960784912, 1.8713605403900146, 4.19132661819458,
10+
0.6596618890762329, 0.9029078483581543, 0.2223271131515503};
11+
static const float ref_output_arr[15] = {
12+
3.484638214111328, 3.511240005493164, 1.8713605403900146,
13+
2.033799886703491, 1.558927297592163, 4.19132661819458,
14+
3.2437448501586914, 3.7084484100341797, 0.6596618890762329,
15+
4.783249855041504, 2.570117712020874, 0.9029078483581543,
16+
3.497023582458496, 0.2405869960784912, 0.2223271131515503};
717

8-
#endif // _TRANSPOSE
18+
#endif // _TRANSPOSE_TEST_H

TESTS/operators/test_transpose.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
#include <iostream>
33

44
#include "RamTensor.hpp"
5-
#include "Transpose.hpp"
65
#include "RomTensor.hpp"
6+
#include "Transpose.hpp"
77
#include "arenaAllocator.hpp"
88
#include "constants_transpose.hpp"
99
#include "context.hpp"
@@ -19,19 +19,19 @@ TEST(Transpose, transpose_test) {
1919
localCircularArenaAllocator<15 * 2 * sizeof(float), uint32_t> ram_allocator;
2020
Context::get_default_context()->set_metadata_allocator(&meta_allocator);
2121
Context::get_default_context()->set_ram_data_allocator(&ram_allocator);
22-
22+
2323
Tensor input_tensor = new RomTensor({3, 1, 5, 1}, flt, random_input_arr);
24+
Tensor perm_tensor = new RomTensor({4}, i32, transpose_perm_arr);
2425

2526
TensorShape input_target_shape(3, 1, 5, 1);
2627
TensorShape input_shape = input_tensor->get_shape();
2728
EXPECT_TRUE(input_target_shape == input_shape);
2829

29-
Tensor transpose_axes = new RomTensor({4}, u8, transpose_axes_arr);
3030
Tensor output_tensor = new RamTensor(flt);
31-
TransposeOperator<float> op({2,1,0,3});
32-
31+
TransposeOperator<float> op;
3332

34-
op.set_inputs({{TransposeOperator<float>::input, input_tensor}})
33+
op.set_inputs({{TransposeOperator<float>::input, input_tensor},
34+
{TransposeOperator<float>::perm, perm_tensor}})
3535
.set_outputs({{TransposeOperator<float>::output, output_tensor}})
3636
.eval();
3737

src/uTensor/ops/Transpose.hpp

Lines changed: 68 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,38 @@
11
#ifndef UTENSOR_TRANSPOSE_H
22
#define UTENSOR_TRANSPOSE_H
33

4+
#include <cstring>
5+
46
#include "context.hpp"
5-
#include "types.hpp"
7+
#include "operatorBase.hpp"
68
#include "tensor.hpp"
9+
#include "types.hpp"
710
#include "uTensor_util.hpp"
8-
#include "operatorBase.hpp"
9-
10-
#include <cstring>
1111

1212
namespace uTensor {
1313
namespace ReferenceOperators {
1414

1515
// Transpose (Swap Axes) as a port from Numpy
1616
// using stride interation in the order of transpose axes
1717
template <typename Tin>
18-
class TransposeOperator : public OperatorInterface<1, 1> {
19-
/* reshape input as the shape of output*/
20-
public:
21-
TransposeOperator(const TensorShape&& axes) : _axes(axes) {}
22-
TransposeOperator(const TensorShape& axes) : _axes(axes) {}
23-
24-
enum names_in : uint8_t { input };
18+
class TransposeOperator : public OperatorInterface<2, 1> {
19+
/* reshape input as the shape of output*/
20+
public:
21+
enum names_in : uint8_t { input, perm };
2522
enum names_out : uint8_t { output };
2623

27-
virtual void compute(){
24+
virtual void compute() {
25+
Tensor& perm_tensor = inputs[perm].tensor();
26+
if (perm_tensor.get_shape().num_dims() > 1) {
27+
uTensor_printf(
28+
"the input tensor perm should be a vector (dimension should be 1)\n");
29+
Context::get_default_context()->throwError(new InvalidTensorInputError);
30+
}
31+
if (perm_tensor->get_type() != i32) {
32+
uTensor_printf("expecting perm tensor of element type int32_t\n");
33+
Context::get_default_context()->throwError(
34+
new InvalidTensorDataTypeError);
35+
}
2836
Tensor& input_tensor = inputs[input].tensor();
2937
TensorShape& input_shape = input_tensor.get_shape();
3038
input_shape.update_dims();
@@ -36,78 +44,80 @@ class TransposeOperator : public OperatorInterface<1, 1> {
3644
Tensor& output_tensor = outputs[output].tensor();
3745

3846
// Create a placeholder to calculate the output shape
39-
// Normally this would reference output shape, but since this could (usually would) be referencing the input, let's keep a dedicated value
40-
TensorShape output_shape = TensorShape(1,1,1,1);
47+
// Normally this would reference output shape, but since this could (usually
48+
// would) be referencing the input, let's keep a dedicated value
49+
TensorShape output_shape = TensorShape(1, 1, 1, 1);
4150
TensorStrides output_strides = TensorStrides(output_shape);
4251
TensorShape offsets = TensorShape(input_shape.num_dims());
4352

44-
for (size_t i = 0; i < 4; ++i) {
53+
for (size_t i = 0; i < 4; ++i) {
4554
output_shape[i] = 0;
4655
output_strides[i] = 0;
4756

4857
// Offsets are used to avoid multiple for loops
4958
offsets[i] = 0;
5059
}
5160

52-
for (size_t i = 0; i < (size_t) input_shape.num_dims(); ++i) {
53-
output_shape[_axes[i]] = input_shape[i];
61+
for (size_t i = 0; i < (size_t)input_shape.num_dims(); ++i) {
62+
int32_t axis = static_cast<int32_t>(perm_tensor(i));
63+
output_shape[axis] = input_shape[i];
5464

5565
// output_strides(i) is derived from axes and input_strides
56-
output_strides[_axes[i]] = input_strides[i];
66+
output_strides[axis] = input_strides[i];
5767
}
58-
59-
// Output shape can be asserted once the transform
68+
69+
// Output shape can be asserted once the transform
6070
// effect has been determined
6171
output_shape.update_dims();
6272
output_tensor->resize(output_shape);
6373

6474
// Perform some basic checks
65-
if (input_tensor->num_elems() != output_tensor->num_elems()){
66-
uTensor_printf("inconsistent input and output shape for reshape\n");
67-
Context::get_default_context()->throwError(new InvalidReshapeError);
68-
return;
69-
}
70-
if (input_tensor->get_type() != output_tensor->get_type()){
71-
uTensor_printf("inconsistent input and output data type for reshape\n");
72-
Context::get_default_context()->throwError(new InvalidTensorDataTypeError);
73-
return;
75+
if (input_tensor->num_elems() != output_tensor->num_elems()) {
76+
uTensor_printf("inconsistent input and output shape for reshape\n");
77+
Context::get_default_context()->throwError(new InvalidReshapeError);
78+
return;
79+
}
80+
if (input_tensor->get_type() != output_tensor->get_type()) {
81+
uTensor_printf("inconsistent input and output data type for reshape\n");
82+
Context::get_default_context()->throwError(
83+
new InvalidTensorDataTypeError);
84+
return;
7485
}
75-
if (!_check_input_shape()){
76-
Context::get_default_context()->throwError(new InvalidTensorDataTypeError);
77-
return;
86+
if (!_check_input_shape()) {
87+
Context::get_default_context()->throwError(
88+
new InvalidTensorDataTypeError);
89+
return;
7890
}
7991

8092
// copy data
81-
for (uint32_t i = 0; i < input_tensor->num_elems(); ++i) {
82-
// Index of the source value, must be calculated
83-
// using the output strides and output shape
84-
uint32_t idx = 0;
85-
for (uint32_t j = 0; j < output_shape.num_dims(); j++) {
86-
idx += offsets[j] * output_strides[j];
87-
}
88-
89-
// this is not copy: `output_tensor(i) = input_tensor(i);`
90-
output_tensor(i) = static_cast<Tin>(input_tensor(idx));
93+
for (uint32_t i = 0; i < input_tensor->num_elems(); ++i) {
94+
// Index of the source value, must be calculated
95+
// using the output strides and output shape
96+
uint32_t idx = 0;
97+
for (uint32_t j = 0; j < output_shape.num_dims(); j++) {
98+
idx += offsets[j] * output_strides[j];
99+
}
91100

92-
// Update offsets, to iterate sequentially along strides
93-
// in the order of axes
94-
for (int32_t j = output_shape.num_dims() - 1; j >= 0; j--) {
95-
offsets[j] = (offsets[j] + 1) % (output_shape[j]);
96-
if( offsets[j] > 0 ) {
97-
break;
98-
}
99-
}
100-
}
101+
// this is not copy: `output_tensor(i) = input_tensor(i);`
102+
output_tensor(i) = static_cast<Tin>(input_tensor(idx));
101103

104+
// Update offsets, to iterate sequentially along strides
105+
// in the order of axes
106+
for (int32_t j = output_shape.num_dims() - 1; j >= 0; j--) {
107+
offsets[j] = (offsets[j] + 1) % (output_shape[j]);
108+
if (offsets[j] > 0) {
109+
break;
110+
}
111+
}
112+
}
102113
}
103-
private:
104-
TensorShape _axes;
105114

106-
bool _check_input_shape(){
115+
private:
116+
bool _check_input_shape() {
107117
const Tensor& input_tensor = inputs[input].tensor();
108118
const TensorShape& shape = input_tensor->get_shape();
109119
uint8_t num_dims = shape.num_dims();
110-
for (int i = 0; i < num_dims; ++i){
120+
for (int i = 0; i < num_dims; ++i) {
111121
if (shape[i] < 0) {
112122
uTensor_printf("the output shape must be all positive\n");
113123
return false;
@@ -117,7 +127,7 @@ class TransposeOperator : public OperatorInterface<1, 1> {
117127
}
118128
};
119129

120-
}
121-
}
130+
} // namespace ReferenceOperators
131+
} // namespace uTensor
122132

123-
#endif // UTENSOR_TRANSPOSE_H
133+
#endif // UTENSOR_TRANSPOSE_H

0 commit comments

Comments
 (0)