Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 83 additions & 3 deletions apps/model-diagnostics/model_diagnostics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,36 @@ static std::string checkFileExists(const std::string& fileName)
"Please, specify a full path to the file.");
}

static std::vector<int> parseShape(const std::string &shape_str) {
std::stringstream ss(shape_str);
std::string item;
std::vector<std::string> items;

while (std::getline(ss, item, ',')) {
items.push_back(item);
}

std::vector<int> shape;
for (size_t i = 0; i < items.size(); i++) {
shape.push_back(std::stoi(items[i]));
}
return shape;
}

std::string diagnosticKeys =
"{ model m | | Path to the model file. }"
"{ config c | | Path to the model configuration file. }"
"{ framework f | | [Optional] Name of the model framework. }";


"{ framework f | | [Optional] Name of the model framework. }"
"{ input0_name | | [Optional] Name of input0. Use with input0_shape}"
"{ input0_shape | | [Optional] Shape of input0. Use with input0_name}"
"{ input1_name | | [Optional] Name of input1. Use with input1_shape}"
"{ input1_shape | | [Optional] Shape of input1. Use with input1_name}"
"{ input2_name | | [Optional] Name of input2. Use with input2_shape}"
"{ input2_shape | | [Optional] Shape of input2. Use with input2_name}"
"{ input3_name | | [Optional] Name of input3. Use with input3_shape}"
"{ input3_shape | | [Optional] Shape of input3. Use with input3_name}"
"{ input4_name | | [Optional] Name of input4. Use with input4_shape}"
"{ input4_shape | | [Optional] Shape of input4. Use with input4_name}";

int main( int argc, const char** argv )
{
Expand All @@ -55,6 +79,17 @@ int main( int argc, const char** argv )
std::string config = checkFileExists(argParser.get<std::string>("config"));
std::string frameworkId = argParser.get<std::string>("framework");

std::string input0_name = argParser.get<std::string>("input0_name");
std::string input0_shape = argParser.get<std::string>("input0_shape");
std::string input1_name = argParser.get<std::string>("input1_name");
std::string input1_shape = argParser.get<std::string>("input1_shape");
std::string input2_name = argParser.get<std::string>("input2_name");
std::string input2_shape = argParser.get<std::string>("input2_shape");
std::string input3_name = argParser.get<std::string>("input3_name");
std::string input3_shape = argParser.get<std::string>("input3_shape");
std::string input4_name = argParser.get<std::string>("input4_name");
std::string input4_shape = argParser.get<std::string>("input4_shape");

CV_Assert(!model.empty());

enableModelDiagnostics(true);
Expand All @@ -63,5 +98,50 @@ int main( int argc, const char** argv )

Net ocvNet = readNet(model, config, frameworkId);

std::vector<std::string> input_names;
std::vector<std::vector<int>> input_shapes;
if (!input0_name.empty() || !input0_shape.empty()) {
CV_CheckFalse(input0_name.empty(), "input0_name cannot be empty");
CV_CheckFalse(input0_shape.empty(), "input0_shape cannot be empty");
input_names.push_back(input0_name);
input_shapes.push_back(parseShape(input0_shape));
}
if (!input1_name.empty() || !input1_shape.empty()) {
CV_CheckFalse(input1_name.empty(), "input1_name cannot be empty");
CV_CheckFalse(input1_shape.empty(), "input1_shape cannot be empty");
input_names.push_back(input1_name);
input_shapes.push_back(parseShape(input1_shape));
}
if (!input2_name.empty() || !input2_shape.empty()) {
CV_CheckFalse(input2_name.empty(), "input2_name cannot be empty");
CV_CheckFalse(input2_shape.empty(), "input2_shape cannot be empty");
input_names.push_back(input2_name);
input_shapes.push_back(parseShape(input2_shape));
}
if (!input3_name.empty() || !input3_shape.empty()) {
CV_CheckFalse(input3_name.empty(), "input3_name cannot be empty");
CV_CheckFalse(input3_shape.empty(), "input3_shape cannot be empty");
input_names.push_back(input3_name);
input_shapes.push_back(parseShape(input3_shape));
}
if (!input4_name.empty() || !input4_shape.empty()) {
CV_CheckFalse(input4_name.empty(), "input4_name cannot be empty");
CV_CheckFalse(input4_shape.empty(), "input4_shape cannot be empty");
input_names.push_back(input4_name);
input_shapes.push_back(parseShape(input4_shape));
}

if (!input_names.empty() && !input_shapes.empty() && input_names.size() == input_shapes.size()) {
ocvNet.setInputsNames(input_names);
for (size_t i = 0; i < input_names.size(); i++) {
Mat input(input_shapes[i], CV_32F);
ocvNet.setInput(input, input_names[i]);
}

size_t dot_index = model.rfind('.');
std::string graph_filename = model.substr(0, dot_index) + ".pbtxt";
ocvNet.dumpToPbtxt(graph_filename);
}

return 0;
}
8 changes: 8 additions & 0 deletions modules/dnn/include/opencv2/dnn/dnn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,14 @@ CV__DNN_INLINE_NS_BEGIN
* @see dump()
*/
CV_WRAP void dumpToFile(CV_WRAP_FILE_PATH const String& path);
/** @brief Dump net structure, hyperparameters, backend, target and fusion to pbtxt file
* @param path path to output file with .pbtxt extension
*
* Use Netron (https://netron.app) to open the target file to visualize the model.
* Call method after setInput(). To see correct backend, target and fusion run after forward().
*/
CV_WRAP void dumpToPbtxt(CV_WRAP_FILE_PATH const String& path);

/** @brief Adds new layer to the net.
* @param name unique name of the adding layer.
* @param type typename of the adding layer (type must be registered in LayerRegister).
Expand Down
10 changes: 10 additions & 0 deletions modules/dnn/src/net.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,16 @@ void Net::dumpToFile(const String& path)
file.close();
}

void Net::dumpToPbtxt(const String& path)
{
CV_TRACE_FUNCTION();
CV_Assert(impl);
CV_Assert(!empty());
std::ofstream file(path.c_str());
file << impl->dumpToPbtxt(true);
file.close();
}

Ptr<Layer> Net::getLayer(int layerId) const
{
CV_Assert(impl);
Expand Down
Loading