Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 347d673

Browse files
authored
Merge pull request #23279 from fengyuentau:add_topk
dnn: add ONNX TopK #23279 Merge with opencv/opencv_extra#1200 Partially fixes #22890 and #20258 To-do: - [x] TopK forward impl - [x] add tests - [x] support Opset 1 & 10 if possible - [ ] ~Support other backends~ (TopK has two outputs, which is not supported by other backends, such as openvino) Perf: M1 (time in millisecond) | input shape | axis | dnn | ort | | --------------- | ---- | ---- | ---- | | (1000, 100) | 0 | 1.68 | 4.07 | | (1000, 100) K5 | 0 | 1.13 | 0.12 | | (1000, 100) | 1 | 0.96 | 0.77 | | (100, 100, 100) | 0 | 10.00 | 31.13 | | (100, 100, 100) | 1 | 7.33 | 9.17 | | (100, 100, 100) | 2 | 7.52 | 9.48 | M2 (time in milisecond) | input shape | axis | dnn | ort | | --------------- | ---- | ---- | ---- | | (1000, 100) | 0 | 0.76 | 2.44 | | (1000, 100) K5 | 0 | 0.68 | 0.07 | | (1000, 100) | 1 | 0.41 | 0.50 | | (100, 100, 100) | 0 | 4.83 | 17.52| | (100, 100, 100) | 1 | 3.60 | 5.08 | | (100, 100, 100) | 2 | 3.73 | 5.10 | ONNXRuntime performance testing script: https://gist.github.com/fengyuentau/a119f94fd16721ec9974b8c7b0a45d4c ### Pull Request Readiness Checklist See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake
1 parent 7cf075c commit 347d673

File tree

6 files changed

+346
-0
lines changed

6 files changed

+346
-0
lines changed

modules/dnn/include/opencv2/dnn/all_layers.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,6 +1198,12 @@ CV__DNN_INLINE_NS_BEGIN
11981198
static Ptr<SpaceToDepthLayer> create(const LayerParams &params);
11991199
};
12001200

1201+
class CV_EXPORTS TopKLayer : public Layer
1202+
{
1203+
public:
1204+
static Ptr<TopKLayer> create(const LayerParams& params);
1205+
};
1206+
12011207
//! @}
12021208
//! @}
12031209
CV__DNN_INLINE_NS_END

modules/dnn/perf/perf_layer.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,4 +1043,67 @@ INSTANTIATE_TEST_CASE_P(/**/, Layer_Elementwise,
10431043
/* withWebnn= */ false,
10441044
/* withCann= */ false));
10451045

1046+
struct Layer_TopK : public TestBaseWithParam<tuple<Backend, Target>> {
1047+
void test_layer(const std::vector<int> &input_shape, const int K, const int axis) {
1048+
int backend_id = get<0>(GetParam());
1049+
int target_id = get<1>(GetParam());
1050+
1051+
Mat input_data(input_shape, CV_32F);
1052+
randn(input_data, -1.f, 1.f);
1053+
1054+
Net net;
1055+
LayerParams lp;
1056+
lp.type = "TopK";
1057+
lp.name = "testLayer";
1058+
lp.set("k", K);
1059+
lp.set("axis", axis);
1060+
net.addLayerToPrev(lp.name, lp.type, lp);
1061+
1062+
// Warmup
1063+
{
1064+
net.setInput(input_data);
1065+
net.setPreferableBackend(backend_id);
1066+
net.setPreferableTarget(target_id);
1067+
net.forward();
1068+
}
1069+
1070+
TEST_CYCLE() {
1071+
net.forward();
1072+
}
1073+
1074+
SANITY_CHECK_NOTHING();
1075+
}
1076+
1077+
std::vector<int> input_shape_2d{1000, 100};
1078+
std::vector<int> input_shape_3d{100, 100, 100};
1079+
};
1080+
1081+
PERF_TEST_P_(Layer_TopK, TopK_2D_Axis0) {
1082+
test_layer(input_shape_2d, input_shape_2d[0] / 2, 0);
1083+
}
1084+
PERF_TEST_P_(Layer_TopK, TopK_2D_Axis0_K5) {
1085+
test_layer(input_shape_2d, 5, 0);
1086+
}
1087+
PERF_TEST_P_(Layer_TopK, TopK_2D_Axis1) {
1088+
test_layer(input_shape_2d, input_shape_2d[1] / 2, 1);
1089+
}
1090+
PERF_TEST_P_(Layer_TopK, TopK_3D_Axis0) {
1091+
test_layer(input_shape_3d, input_shape_3d[0] / 2, 0);
1092+
}
1093+
PERF_TEST_P_(Layer_TopK, TopK_3D_Axis1) {
1094+
test_layer(input_shape_3d, input_shape_3d[1] / 2, 1);
1095+
}
1096+
PERF_TEST_P_(Layer_TopK, TopK_3D_Axis2) {
1097+
test_layer(input_shape_3d, input_shape_3d[2] / 2, 2);
1098+
}
1099+
INSTANTIATE_TEST_CASE_P(/**/, Layer_TopK,
1100+
dnnBackendsAndTargets(/* withInferenceEngine= */ false,
1101+
/* withHalide= */ false,
1102+
/* withCpuOCV= */ true,
1103+
/* withVkCom= */ false,
1104+
/* withCUDA= */ false,
1105+
/* withNgraph= */ false,
1106+
/* withWebnn= */ false,
1107+
/* withCann= */ false));
1108+
10461109
} // namespace

modules/dnn/src/init.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ void initializeLayerFactory()
199199
CV_DNN_REGISTER_LAYER_CLASS(Scatter, ScatterLayer);
200200
CV_DNN_REGISTER_LAYER_CLASS(ScatterND, ScatterNDLayer);
201201
CV_DNN_REGISTER_LAYER_CLASS(Tile, TileLayer);
202+
CV_DNN_REGISTER_LAYER_CLASS(TopK, TopKLayer);
202203

203204
CV_DNN_REGISTER_LAYER_CLASS(Quantize, QuantizeLayer);
204205
CV_DNN_REGISTER_LAYER_CLASS(Dequantize, DequantizeLayer);
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
// This file is part of OpenCV project.
2+
// It is subject to the license terms in the LICENSE file found in the top-level directory
3+
// of this distribution and at http://opencv.org/license.html.
4+
5+
#include "../precomp.hpp"
6+
#include "layers_common.hpp"
7+
8+
#include <opencv2/dnn/shape_utils.hpp>
9+
10+
namespace cv { namespace dnn {
11+
12+
namespace {
13+
14+
template<typename T>
15+
class ComparatorGreater {
16+
public:
17+
ComparatorGreater(const T* data, size_t step)
18+
: data_(data), step_(step) {}
19+
20+
void addOffset(size_t offset) {
21+
data_ += offset;
22+
}
23+
24+
void minusOffset(size_t offset) {
25+
data_ -= offset;
26+
}
27+
28+
bool operator()(const size_t lhs_idx, const size_t rhs_idx) {
29+
T lhs = *(data_ + lhs_idx * step_),
30+
rhs = *(data_ + rhs_idx * step_);
31+
return (lhs > rhs || (lhs == rhs && lhs_idx < rhs_idx));
32+
}
33+
34+
private:
35+
const T* data_;
36+
size_t step_;
37+
};
38+
39+
template<typename T>
40+
class ComparatorLess {
41+
public:
42+
ComparatorLess(const T* data, size_t step)
43+
: data_(data), step_(step) {}
44+
45+
void addOffset(size_t offset) {
46+
data_ += offset;
47+
}
48+
49+
void minusOffset(size_t offset) {
50+
data_ -= offset;
51+
}
52+
53+
bool operator()(const size_t lhs_idx, const size_t rhs_idx) {
54+
T lhs = *(data_ + lhs_idx * step_),
55+
rhs = *(data_ + rhs_idx * step_);
56+
return (lhs < rhs || (lhs == rhs && lhs_idx < rhs_idx));
57+
}
58+
59+
private:
60+
const T* data_;
61+
size_t step_;
62+
};
63+
}
64+
65+
class TopKLayerImpl CV_FINAL : public TopKLayer
66+
{
67+
public:
68+
TopKLayerImpl(const LayerParams& params)
69+
{
70+
setParamsFrom(params);
71+
72+
axis = params.get<int>("axis", -1);
73+
largest = params.get<int>("largest", 1) == 1;
74+
sorted = params.get<int>("sorted", 1) == 1;
75+
CV_CheckTrue(sorted, "TopK: sorted == false is not supported"); // TODO: support sorted
76+
77+
CV_CheckTrue(params.has("k"), "TopK: parameter k is required but missing");
78+
K = params.get<int>("k");
79+
}
80+
81+
virtual bool supportBackend(int backendId) CV_OVERRIDE
82+
{
83+
return backendId == DNN_BACKEND_OPENCV;
84+
}
85+
86+
virtual bool getMemoryShapes(const std::vector<MatShape> &inputs,
87+
const int requiredOutputs,
88+
std::vector<MatShape> &outputs,
89+
std::vector<MatShape> &internals) const CV_OVERRIDE
90+
{
91+
const auto &input_shape = inputs.front();
92+
int input_dims = input_shape.size();
93+
94+
// Check if axis is valid
95+
CV_CheckGE(axis, -input_dims, "TopK: axis is out of range");
96+
CV_CheckLT(axis, input_dims, "TopK: axis is out of range");
97+
// Normalize axis
98+
int axis_normalized = normalize_axis(axis, input_shape.size());
99+
100+
// Check if K is in range (0, input_shape[axis])
101+
CV_CheckGT(K, 0, "TopK: K needs to be a positive integer");
102+
CV_CheckLT(K, input_shape[axis_normalized], "TopK: K is out of range");
103+
104+
// Assign output shape
105+
auto output_shape = input_shape;
106+
output_shape[axis_normalized] = K;
107+
outputs.assign(1, output_shape);
108+
outputs.assign(2, output_shape); // TODO: support indices of type CV_32S on 5.x
109+
110+
return false;
111+
}
112+
113+
virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE {
114+
std::vector<Mat> inputs;
115+
inputs_arr.getMatVector(inputs);
116+
117+
// Normalize axis
118+
auto input_shape = shape(inputs.front());
119+
axis = normalize_axis(axis, input_shape.size());
120+
}
121+
122+
template<class Comparator>
123+
void FindTopK(const Mat &input, Mat &output_value, Mat &output_index) {
124+
const auto input_shape = shape(input);
125+
size_t loops = std::accumulate(input_shape.begin(), input_shape.begin() + axis, 1, std::multiplies<int>());
126+
size_t step = std::accumulate(input_shape.begin() + axis + 1, input_shape.end(), 1, std::multiplies<int>());
127+
int dim_axis = input_shape[axis];
128+
if (loops == 1) {
129+
auto worker = [&](const Range &r) {
130+
const auto *input_ptr = input.ptr<const float>(); // TODO: support other input type
131+
auto *output_value_ptr = output_value.ptr<float>();
132+
auto *output_index_ptr = output_index.ptr<float>(); // TODO: use CV_32S on 5.x
133+
134+
Comparator cmp(input_ptr, step);
135+
136+
AutoBuffer<int> buffer_index(dim_axis);
137+
auto *buffer_index_ptr = buffer_index.data();
138+
for (int offset = r.start; offset < r.end; offset++) {
139+
const auto *input_offset_ptr = input_ptr + offset;
140+
cmp.addOffset(offset);
141+
142+
std::iota(buffer_index_ptr, buffer_index_ptr + dim_axis, 0);
143+
std::stable_sort(buffer_index_ptr, buffer_index_ptr + dim_axis, cmp);
144+
145+
auto *output_value_offset_ptr = output_value_ptr + offset;
146+
auto *output_index_offset_ptr = output_index_ptr + offset;
147+
for (int i = 0; i < K; i++) {
148+
int source_index = buffer_index_ptr[i];
149+
output_value_offset_ptr[i * step] = *(input_offset_ptr + source_index * step);
150+
output_index_offset_ptr[i * step] = source_index;
151+
}
152+
cmp.minusOffset(offset);
153+
}
154+
};
155+
parallel_for_(Range(0, step), worker);
156+
} else {
157+
auto worker = [&](const Range &r) {
158+
const auto *input_ptr = input.ptr<const float>();
159+
auto *output_value_ptr = output_value.ptr<float>();
160+
auto *output_index_ptr = output_index.ptr<float>();
161+
162+
Comparator cmp(input_ptr, step);
163+
164+
AutoBuffer<int> buffer_index(dim_axis);
165+
auto *buffer_index_ptr = buffer_index.data();
166+
for (int batch_index = r.start; batch_index < r.end; batch_index++) {
167+
for (size_t offset = 0; offset < step; offset++) {
168+
const auto *input_offset_ptr = input_ptr + batch_index * dim_axis * step + offset;
169+
cmp.addOffset(batch_index * dim_axis * step + offset);
170+
171+
std::iota(buffer_index_ptr, buffer_index_ptr + dim_axis, 0);
172+
std::stable_sort(buffer_index_ptr, buffer_index_ptr + dim_axis, cmp);
173+
174+
auto *output_value_offset_ptr = output_value_ptr + batch_index * K * step + offset;
175+
auto *output_index_offset_ptr = output_index_ptr + batch_index * K * step + offset;
176+
for (int i = 0; i < K; i++) {
177+
int source_index = buffer_index_ptr[i];
178+
output_value_offset_ptr[i * step] = *(input_offset_ptr + source_index * step);
179+
output_index_offset_ptr[i * step] = source_index;
180+
}
181+
cmp.minusOffset(batch_index * dim_axis * step + offset);
182+
}
183+
}
184+
};
185+
parallel_for_(Range(0, loops), worker);
186+
}
187+
}
188+
189+
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
190+
{
191+
CV_TRACE_FUNCTION();
192+
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
193+
194+
if (inputs_arr.depth() == CV_16F)
195+
{
196+
forward_fallback(inputs_arr, outputs_arr, internals_arr);
197+
return;
198+
}
199+
200+
std::vector<Mat> inputs, outputs;
201+
inputs_arr.getMatVector(inputs);
202+
outputs_arr.getMatVector(outputs);
203+
204+
const auto &input = inputs.front();
205+
auto &output_value = outputs.front();
206+
auto &output_index = outputs.back();
207+
208+
if (largest) {
209+
FindTopK<ComparatorGreater<float>>(input, output_value, output_index);
210+
} else {
211+
FindTopK<ComparatorLess<float>>(input, output_value, output_index);
212+
}
213+
}
214+
215+
private:
216+
int axis;
217+
bool largest;
218+
bool sorted;
219+
220+
int K; // FIXIT: make it layer input once dynamic shape is supported
221+
};
222+
223+
Ptr<TopKLayer> TopKLayer::create(const LayerParams& params)
224+
{
225+
return makePtr<TopKLayerImpl>(params);
226+
}
227+
228+
}} // namespace cv::dnn

modules/dnn/src/onnx/onnx_importer.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ class ONNXImporter
194194
void parseScatter (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
195195
void parseTile (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
196196
void parseLayerNorm (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
197+
void parseTopK (LayerParams& LayerParams, const opencv_onnx::NodeProto& node_proto);
197198
void parseSimpleLayers (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
198199
void parseEinsum (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
199200

@@ -3121,6 +3122,21 @@ void ONNXImporter::parseLayerNorm(LayerParams& layerParams, const opencv_onnx::N
31213122
}
31223123
}
31233124

3125+
void ONNXImporter::parseTopK(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
3126+
{
3127+
// K needs to be constant in case of being input (since opset 10)
3128+
if (node_proto.input_size() == 2) {
3129+
bool K_const = constBlobs.find(node_proto.input(1)) != constBlobs.end();
3130+
CV_CheckTrue(K_const, "OnnxImporter/TopK: K being non-constant is not supported");
3131+
3132+
Mat input_K = getBlob(node_proto, 1);
3133+
int K = input_K.at<int>(0);
3134+
layerParams.set("k", K);
3135+
}
3136+
3137+
addLayer(layerParams, node_proto);
3138+
}
3139+
31243140
void ONNXImporter::parseSimpleLayers(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
31253141
{
31263142
bool is_all_input_const = true;
@@ -3931,6 +3947,7 @@ void ONNXImporter::buildDispatchMap_ONNX_AI(int opset_version)
39313947
dispatch["Tile"] = &ONNXImporter::parseTile;
39323948
dispatch["LayerNormalization"] = &ONNXImporter::parseLayerNorm;
39333949
dispatch["GroupNormalization"] = &ONNXImporter::parseInstanceNormalization;
3950+
dispatch["TopK"] = &ONNXImporter::parseTopK;
39343951

39353952
dispatch["Equal"] = dispatch["Greater"] = dispatch["Less"] = dispatch["Pow"] = dispatch["Add"] =
39363953
dispatch["Sub"] = dispatch["Mul"] = dispatch["Div"] = dispatch["GreaterOrEqual"] =

modules/dnn/test/test_onnx_importer.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3202,6 +3202,37 @@ TEST_P(Test_ONNX_layers, ClipDivSharedConstant) {
32023202
testONNXModels("clip_div_shared_constant");
32033203
}
32043204

3205+
TEST_P(Test_ONNX_layers, TopK) {
3206+
auto test = [&](const std::string &basename, double l1 = 0, double lInf = 0) {
3207+
std::string onnxmodel = _tf("models/" + basename + ".onnx", true);
3208+
Mat input = readTensorFromONNX(_tf("data/input_" + basename + ".pb"));
3209+
Mat output_ref_val = readTensorFromONNX(_tf("data/output_" + basename + "_0.pb")),
3210+
output_ref_ind = readTensorFromONNX(_tf("data/output_" + basename + "_1.pb"));
3211+
3212+
checkBackend(&input, &output_ref_val);
3213+
checkBackend(&input, &output_ref_ind);
3214+
Net net = readNetFromONNX(onnxmodel);
3215+
net.setPreferableBackend(backend);
3216+
net.setPreferableTarget(target);
3217+
3218+
net.setInput(input);
3219+
std::vector<Mat> outputs;
3220+
net.forward(outputs, std::vector<std::string>{"values", "indices"});
3221+
3222+
Mat output_res_val = outputs.front(),
3223+
output_res_ind = outputs.back();
3224+
output_res_ind.convertTo(output_res_ind, CV_32S); // TODO: remove this conversion on 5.x
3225+
3226+
normAssert(output_ref_val, output_res_val, (basename + " values").c_str(), l1 ? l1 : default_l1, lInf ? lInf : default_lInf);
3227+
normAssert(output_ref_ind, output_res_ind, (basename + " indices").c_str(), l1 ? l1 : default_l1, lInf ? lInf : default_lInf);
3228+
expectNoFallbacksFromIE(net);
3229+
};
3230+
3231+
test("top_k");
3232+
test("top_k_negative_axis");
3233+
test("top_k_smallest");
3234+
}
3235+
32053236
INSTANTIATE_TEST_CASE_P(/**/, Test_ONNX_nets, dnnBackendsAndTargets());
32063237

32073238
}} // namespace

0 commit comments

Comments
 (0)