forked from triton-inference-server/server
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathclassification.cc
More file actions
129 lines (111 loc) · 4.93 KB
/
classification.cc
File metadata and controls
129 lines (111 loc) · 4.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "classification.h"
#include <algorithm>
#include <numeric>
#include "common.h"
namespace triton { namespace server {
namespace {
template <typename T>
TRITONSERVER_Error*
AddClassResults(
TRITONSERVER_InferenceResponse* response, const uint32_t output_idx,
const char* base, const size_t element_cnt, const uint32_t req_class_cnt,
std::vector<std::string>* class_strs)
{
const T* probs = reinterpret_cast<const T*>(base);
std::vector<size_t> idx(element_cnt);
iota(idx.begin(), idx.end(), 0);
sort(idx.begin(), idx.end(), [&probs](size_t i1, size_t i2) {
return probs[i1] > probs[i2];
});
const size_t class_cnt = std::min(element_cnt, (size_t)req_class_cnt);
for (size_t k = 0; k < class_cnt; ++k) {
class_strs->push_back(
std::to_string(probs[idx[k]]) + ":" + std::to_string(idx[k]));
const char* label;
RETURN_IF_ERR(TRITONSERVER_InferenceResponseOutputClassificationLabel(
response, output_idx, idx[k], &label));
if (label != nullptr) {
class_strs->back() += ":";
class_strs->back().append(label);
}
}
return nullptr; // success
}
} // namespace
TRITONSERVER_Error*
TopkClassifications(
TRITONSERVER_InferenceResponse* response, const uint32_t output_idx,
const char* base, const size_t byte_size,
const TRITONSERVER_DataType datatype, const uint32_t req_class_count,
std::vector<std::string>* class_strs)
{
const size_t element_cnt =
byte_size / TRITONSERVER_DataTypeByteSize(datatype);
switch (datatype) {
case TRITONSERVER_TYPE_UINT8:
return AddClassResults<uint8_t>(
response, output_idx, base, element_cnt, req_class_count, class_strs);
case TRITONSERVER_TYPE_UINT16:
return AddClassResults<uint16_t>(
response, output_idx, base, element_cnt, req_class_count, class_strs);
case TRITONSERVER_TYPE_UINT32:
return AddClassResults<uint32_t>(
response, output_idx, base, element_cnt, req_class_count, class_strs);
case TRITONSERVER_TYPE_UINT64:
return AddClassResults<uint64_t>(
response, output_idx, base, element_cnt, req_class_count, class_strs);
case TRITONSERVER_TYPE_INT8:
return AddClassResults<int8_t>(
response, output_idx, base, element_cnt, req_class_count, class_strs);
case TRITONSERVER_TYPE_INT16:
return AddClassResults<int16_t>(
response, output_idx, base, element_cnt, req_class_count, class_strs);
case TRITONSERVER_TYPE_INT32:
return AddClassResults<int32_t>(
response, output_idx, base, element_cnt, req_class_count, class_strs);
case TRITONSERVER_TYPE_INT64:
return AddClassResults<int64_t>(
response, output_idx, base, element_cnt, req_class_count, class_strs);
case TRITONSERVER_TYPE_FP32:
return AddClassResults<float>(
response, output_idx, base, element_cnt, req_class_count, class_strs);
case TRITONSERVER_TYPE_FP64:
return AddClassResults<double>(
response, output_idx, base, element_cnt, req_class_count, class_strs);
default:
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
std::string(
std::string("class result not available for output due to "
"unsupported type '") +
std::string(TRITONSERVER_DataTypeString(datatype)) + "'")
.c_str());
}
return nullptr; // success
}
}} // namespace triton::server