|
| 1 | +#ifndef UTENSOR_TRANSPOSE_H |
| 2 | +#define UTENSOR_TRANSPOSE_H |
| 3 | + |
| 4 | +#include <cstring> |
| 5 | + |
| 6 | +#include "context.hpp" |
| 7 | +#include "operatorBase.hpp" |
| 8 | +#include "tensor.hpp" |
| 9 | +#include "types.hpp" |
| 10 | +#include "uTensor_util.hpp" |
| 11 | + |
| 12 | +namespace uTensor { |
| 13 | +namespace ReferenceOperators { |
| 14 | + |
| 15 | +// Transpose (Swap Axes) as a port from Numpy |
| 16 | +// using stride interation in the order of transpose axes |
| 17 | +template <typename Tin> |
| 18 | +class TransposeOperator : public OperatorInterface<2, 1> { |
| 19 | + /* reshape input as the shape of output*/ |
| 20 | + public: |
| 21 | + enum names_in : uint8_t { input, perm }; |
| 22 | + enum names_out : uint8_t { output }; |
| 23 | + |
| 24 | + virtual void compute() { |
| 25 | + const Tensor& perm_tensor = inputs[perm].tensor(); |
| 26 | + if (perm_tensor.get_shape().num_dims() > 1) { |
| 27 | + uTensor_printf( |
| 28 | + "the input tensor perm should be a vector (dimension should be 1)\n"); |
| 29 | + Context::get_default_context()->throwError(new InvalidTensorInputError); |
| 30 | + } |
| 31 | + if (perm_tensor->get_type() != i32) { |
| 32 | + uTensor_printf("expecting perm tensor of element type int32_t\n"); |
| 33 | + Context::get_default_context()->throwError( |
| 34 | + new InvalidTensorDataTypeError); |
| 35 | + } |
| 36 | + Tensor& input_tensor = inputs[input].tensor(); |
| 37 | + TensorShape& input_shape = input_tensor.get_shape(); |
| 38 | + input_shape.update_dims(); |
| 39 | + |
| 40 | + // Strides are used to iterate over the dataset, and transfer |
| 41 | + // the input tensor data, into the output tensor |
| 42 | + TensorStrides input_strides = TensorStrides(input_shape); |
| 43 | + |
| 44 | + Tensor& output_tensor = outputs[output].tensor(); |
| 45 | + |
| 46 | + // Create a placeholder to calculate the output shape |
| 47 | + // Normally this would reference output shape, but since this could (usually |
| 48 | + // would) be referencing the input, let's keep a dedicated value |
| 49 | + TensorShape output_shape = TensorShape(1, 1, 1, 1); |
| 50 | + TensorStrides output_strides = TensorStrides(output_shape); |
| 51 | + TensorShape offsets = TensorShape(input_shape.num_dims()); |
| 52 | + |
| 53 | + for (size_t i = 0; i < 4; ++i) { |
| 54 | + output_shape[i] = 0; |
| 55 | + output_strides[i] = 0; |
| 56 | + |
| 57 | + // Offsets are used to avoid multiple for loops |
| 58 | + offsets[i] = 0; |
| 59 | + } |
| 60 | + |
| 61 | + for (size_t i = 0; i < (size_t)input_shape.num_dims(); ++i) { |
| 62 | + int32_t axis = static_cast<int32_t>(perm_tensor(i)); |
| 63 | + output_shape[axis] = input_shape[i]; |
| 64 | + |
| 65 | + // output_strides(i) is derived from axes and input_strides |
| 66 | + output_strides[axis] = input_strides[i]; |
| 67 | + } |
| 68 | + |
| 69 | + // Output shape can be asserted once the transform |
| 70 | + // effect has been determined |
| 71 | + output_shape.update_dims(); |
| 72 | + output_tensor->resize(output_shape); |
| 73 | + |
| 74 | + // Perform some basic checks |
| 75 | + if (input_tensor->num_elems() != output_tensor->num_elems()) { |
| 76 | + uTensor_printf("inconsistent input and output shape for reshape\n"); |
| 77 | + Context::get_default_context()->throwError(new InvalidReshapeError); |
| 78 | + return; |
| 79 | + } |
| 80 | + if (input_tensor->get_type() != output_tensor->get_type()) { |
| 81 | + uTensor_printf("inconsistent input and output data type for reshape\n"); |
| 82 | + Context::get_default_context()->throwError( |
| 83 | + new InvalidTensorDataTypeError); |
| 84 | + return; |
| 85 | + } |
| 86 | + if (!_check_input_shape()) { |
| 87 | + Context::get_default_context()->throwError( |
| 88 | + new InvalidTensorDataTypeError); |
| 89 | + return; |
| 90 | + } |
| 91 | + |
| 92 | + // copy data |
| 93 | + for (uint32_t i = 0; i < input_tensor->num_elems(); ++i) { |
| 94 | + // Index of the source value, must be calculated |
| 95 | + // using the output strides and output shape |
| 96 | + uint32_t idx = 0; |
| 97 | + for (uint32_t j = 0; j < output_shape.num_dims(); j++) { |
| 98 | + idx += offsets[j] * output_strides[j]; |
| 99 | + } |
| 100 | + |
| 101 | + // this is not copy: `output_tensor(i) = input_tensor(i);` |
| 102 | + output_tensor(i) = static_cast<Tin>(input_tensor(idx)); |
| 103 | + |
| 104 | + // Update offsets, to iterate sequentially along strides |
| 105 | + // in the order of axes |
| 106 | + for (int32_t j = output_shape.num_dims() - 1; j >= 0; j--) { |
| 107 | + offsets[j] = (offsets[j] + 1) % (output_shape[j]); |
| 108 | + if (offsets[j] > 0) { |
| 109 | + break; |
| 110 | + } |
| 111 | + } |
| 112 | + } |
| 113 | + } |
| 114 | + |
| 115 | + private: |
| 116 | + bool _check_input_shape() { |
| 117 | + const Tensor& input_tensor = inputs[input].tensor(); |
| 118 | + const TensorShape& shape = input_tensor->get_shape(); |
| 119 | + uint8_t num_dims = shape.num_dims(); |
| 120 | + for (int i = 0; i < num_dims; ++i) { |
| 121 | + if (shape[i] < 0) { |
| 122 | + uTensor_printf("the output shape must be all positive\n"); |
| 123 | + return false; |
| 124 | + } |
| 125 | + } |
| 126 | + return true; |
| 127 | + } |
| 128 | +}; |
| 129 | + |
| 130 | +} // namespace ReferenceOperators |
| 131 | +} // namespace uTensor |
| 132 | + |
| 133 | +#endif // UTENSOR_TRANSPOSE_H |
0 commit comments