From b8ab9eea7b57a3d09248d1ba111dac86fffb8b06 Mon Sep 17 00:00:00 2001 From: Abduragim Date: Tue, 2 Jul 2024 18:56:24 +0300 Subject: [PATCH] Fix for importing attention layer from Torch via ONNX --- modules/dnn/src/layers/concat_layer.cpp | 2 ++ modules/dnn/src/layers/slice_layer.cpp | 14 +++++++++----- modules/dnn/test/test_onnx_importer.cpp | 7 +++++++ 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/modules/dnn/src/layers/concat_layer.cpp b/modules/dnn/src/layers/concat_layer.cpp index a0453842e42e..4dc522d486b5 100644 --- a/modules/dnn/src/layers/concat_layer.cpp +++ b/modules/dnn/src/layers/concat_layer.cpp @@ -316,6 +316,8 @@ class ConcatLayerImpl CV_FINAL : public ConcatLayer ranges[cAxis].start = 0; for (size_t i = 0; i < inputs.size(); i++) { + if (inputs[i].empty()) + continue; ranges[cAxis].end = ranges[cAxis].start + inputs[i].size[cAxis]; for (int j = 0; j < outMat.dims; ++j) { diff --git a/modules/dnn/src/layers/slice_layer.cpp b/modules/dnn/src/layers/slice_layer.cpp index 829a5743f09c..0244dd4d2be2 100644 --- a/modules/dnn/src/layers/slice_layer.cpp +++ b/modules/dnn/src/layers/slice_layer.cpp @@ -69,10 +69,12 @@ Range normalizeRange(const Range& input_range, int n) { Range range = input_range; - range.start = std::min(std::max(range.start, -n), n - 1); - if (range.start < 0) - { - range.start += n; + if (!(range.start == n)){ + range.start = std::min(std::max(range.start, -n), n - 1); + if (range.start < 0) + { + range.start += n; + } } range.end = std::min(std::max(range.end, -n), n); @@ -632,7 +634,9 @@ class SliceLayerImpl : public SliceLayer { for (size_t i = 0; i < outputs.size(); i++) { - inpMat(finalSliceRanges[i]).copyTo(outputs[i]); + if (finalSliceRanges[i][0].start != finalSliceRanges[i][0].end){ + inpMat(finalSliceRanges[i]).copyTo(outputs[i]); + } } } else diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index 5e850a21917f..049c72530fbd 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -3099,6 +3099,13 @@ TEST_P(Test_ONNX_layers, Attention) { TEST_P(Test_ONNX_layers, AttentionSingleHead) { testONNXModels("attention_single_head"); } +TEST_P(Test_ONNX_layers, TorchAttentionSingleHead){ + testONNXModels("torch_attention_single_head"); +} + +TEST_P(Test_ONNX_layers, TorchUnflatten){ + testONNXModels("unflatten"); +} TEST_P(Test_ONNX_nets, ViT_B_32) { applyTestTag(CV_TEST_TAG_LONG, CV_TEST_TAG_DEBUG_LONG);