|
11 | 11 | #include <gtest/gtest.h> |
12 | 12 | #include <xnnpack.h> |
13 | 13 |
|
| 14 | +using executorch::aten::Tensor; |
14 | 15 | using executorch::backends::xnnpack::delegate::XNNExecutor; |
15 | 16 | using executorch::runtime::Error; |
16 | 17 | using executorch::runtime::EValue; |
@@ -95,3 +96,96 @@ TEST(XNNExecutorTest, ArgumentWithTooManyDimensions) { |
95 | 96 | // Check for invalid number of dimensions should fail without stack overflow. |
96 | 97 | EXPECT_EQ(executor.prepare_args(stack_args), Error::InvalidArgument); |
97 | 98 | } |
| 99 | + |
| 100 | +// Tests that resize_outputs correctly converts int32 indices to int64. |
| 101 | +TEST(XNNExecutorTest, ResizeOutputsWithLongTensorConvertsInt32ToInt64) { |
| 102 | + XNNExecutor executor({}); |
| 103 | + xnn_runtime_t rt = nullptr; |
| 104 | + et_pal_init(); |
| 105 | + ASSERT_EQ(xnn_initialize(nullptr), xnn_status_success); |
| 106 | + |
| 107 | + xnn_subgraph_t subgraph = nullptr; |
| 108 | + ASSERT_EQ(xnn_create_subgraph(3, 0, &subgraph), xnn_status_success); |
| 109 | + std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph( |
| 110 | + subgraph, xnn_delete_subgraph); |
| 111 | + |
| 112 | + std::vector<size_t> in_dims = {1, 4, 4, 1}, out_dims = {1, 2, 2, 1}; |
| 113 | + uint32_t input_id = XNN_INVALID_VALUE_ID; |
| 114 | + uint32_t value_id = XNN_INVALID_VALUE_ID; |
| 115 | + uint32_t index_id = XNN_INVALID_VALUE_ID; |
| 116 | + |
| 117 | + ASSERT_EQ( |
| 118 | + xnn_status_success, |
| 119 | + xnn_define_tensor_value( |
| 120 | + subgraph, |
| 121 | + xnn_datatype_fp32, |
| 122 | + in_dims.size(), |
| 123 | + in_dims.data(), |
| 124 | + nullptr, |
| 125 | + 0, |
| 126 | + XNN_VALUE_FLAG_EXTERNAL_INPUT, |
| 127 | + &input_id)); |
| 128 | + ASSERT_EQ( |
| 129 | + xnn_status_success, |
| 130 | + xnn_define_tensor_value( |
| 131 | + subgraph, |
| 132 | + xnn_datatype_fp32, |
| 133 | + out_dims.size(), |
| 134 | + out_dims.data(), |
| 135 | + nullptr, |
| 136 | + 1, |
| 137 | + XNN_VALUE_FLAG_EXTERNAL_OUTPUT, |
| 138 | + &value_id)); |
| 139 | + ASSERT_EQ( |
| 140 | + xnn_status_success, |
| 141 | + xnn_define_tensor_value( |
| 142 | + subgraph, |
| 143 | + xnn_datatype_int32, |
| 144 | + out_dims.size(), |
| 145 | + out_dims.data(), |
| 146 | + nullptr, |
| 147 | + 2, |
| 148 | + XNN_VALUE_FLAG_EXTERNAL_OUTPUT, |
| 149 | + &index_id)); |
| 150 | + ASSERT_EQ( |
| 151 | + xnn_status_success, |
| 152 | + xnn_define_argmax_pooling_2d( |
| 153 | + subgraph, 0, 0, 0, 0, 2, 2, input_id, value_id, index_id, 0)); |
| 154 | + |
| 155 | + ASSERT_EQ(xnn_create_runtime(subgraph, &rt), xnn_status_success); |
| 156 | + ASSERT_EQ(executor.initialize(rt, {0}, {1, 2}, {}), Error::Ok); |
| 157 | + |
| 158 | + TensorFactory<executorch::aten::ScalarType::Float> tf_float; |
| 159 | + TensorFactory<executorch::aten::ScalarType::Long> tf_long; |
| 160 | + |
| 161 | + auto input = tf_float.make( |
| 162 | + {1, 4, 4, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); |
| 163 | + auto out_value = tf_float.make({1, 2, 2, 1}, {0, 0, 0, 0}); |
| 164 | + auto out_index = tf_long.make({1, 2, 2, 1}, {0, 0, 0, 0}); |
| 165 | + |
| 166 | + EValue ev_in(input), ev_val(out_value), ev_idx(out_index); |
| 167 | + std::array<EValue*, 3> args = {&ev_in, &ev_val, &ev_idx}; |
| 168 | + Span<EValue*> span(args.data(), 3); |
| 169 | + |
| 170 | + ASSERT_EQ(executor.prepare_args(span), Error::Ok); |
| 171 | + executorch::ET_RUNTIME_NAMESPACE::BackendExecutionContext context; |
| 172 | + ASSERT_EQ(executor.forward(context), Error::Ok); |
| 173 | + ASSERT_EQ(executor.resize_outputs(span), Error::Ok); |
| 174 | + |
| 175 | + Tensor& result = args[2]->toTensor(); |
| 176 | + ASSERT_EQ(result.scalar_type(), executorch::aten::ScalarType::Long); |
| 177 | + |
| 178 | + /* |
| 179 | + Input 4x4: Output values: Output indices: |
| 180 | + 1 2 | 3 4 6 | 8 3 | 3 |
| 181 | + 5 6 | 7 8 14 | 16 3 | 3 |
| 182 | + ------|----- |
| 183 | + 9 10 |11 12 |
| 184 | + 13 14 |15 16 |
| 185 | +
|
| 186 | + Each 2x2 quadrant → max value + index of max (3 = bottom-right). |
| 187 | + */ |
| 188 | + for (ssize_t i = 0; i < result.numel(); ++i) { |
| 189 | + EXPECT_EQ(result.const_data_ptr<int64_t>()[i], 3); |
| 190 | + } |
| 191 | +} |
0 commit comments