-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Expand file tree
/
Copy pathnormalizePlugin.h
More file actions
134 lines (96 loc) · 4.22 KB
/
normalizePlugin.h
File metadata and controls
134 lines (96 loc) · 4.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
/*
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef TRT_NORMALIZE_PLUGIN_H
#define TRT_NORMALIZE_PLUGIN_H
#include "cudnn.h"
#include "kernel.h"
#include "plugin.h"
#include <cublas_v2.h>
#include <string>
#include <vector>
namespace nvinfer1
{
namespace plugin
{
class Normalize : public IPluginV2Ext
{
public:
Normalize(const Weights* weights, int nbWeights, bool acrossSpatial, bool channelShared, float eps);
Normalize(
const Weights* weights, int nbWeights, bool acrossSpatial, bool channelShared, float eps, int C, int H, int W);
Normalize(const void* buffer, size_t length);
~Normalize() override = default;
int getNbOutputs() const noexcept override;
Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) noexcept override;
int initialize() noexcept override;
void terminate() noexcept override;
size_t getWorkspaceSize(int maxBatchSize) const noexcept override;
int enqueue(int batchSize, const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) noexcept override;
size_t getSerializationSize() const noexcept override;
void serialize(void* buffer) const noexcept override;
bool supportsFormat(DataType type, PluginFormat format) const noexcept override;
const char* getPluginType() const noexcept override;
const char* getPluginVersion() const noexcept override;
void destroy() noexcept override;
IPluginV2Ext* clone() const noexcept override;
void setPluginNamespace(const char* pluginNamespace) noexcept override;
const char* getPluginNamespace() const noexcept override;
DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept override;
bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const noexcept override;
bool canBroadcastInputAcrossBatch(int inputIndex) const noexcept override;
void attachToContext(
cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) noexcept override;
void configurePlugin(const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs,
const DataType* inputTypes, const DataType* outputTypes, const bool* inputIsBroadcast,
const bool* outputIsBroadcast, PluginFormat floatFormat, int maxBatchSize) noexcept override;
void detachFromContext() noexcept override;
private:
Weights copyToDevice(const void* hostData, size_t count);
void serializeFromDevice(char*& hostBuffer, Weights deviceWeights) const;
Weights deserializeToDevice(const char*& hostBuffer, size_t count);
cublasHandle_t mCublas;
Weights mWeights{};
int mNbWeights{};
bool acrossSpatial{};
bool channelShared{};
float eps{};
int C{};
int H{};
int W{};
std::string mPluginNamespace;
};
class NormalizePluginCreator : public BaseCreator
{
public:
NormalizePluginCreator();
~NormalizePluginCreator() override = default;
const char* getPluginName() const noexcept override;
const char* getPluginVersion() const noexcept override;
const PluginFieldCollection* getFieldNames() noexcept override;
IPluginV2Ext* createPlugin(const char* name, const PluginFieldCollection* fc) noexcept override;
IPluginV2Ext* deserializePlugin(const char* name, const void* serialData, size_t serialLength) noexcept override;
private:
static PluginFieldCollection mFC;
bool mAcrossSpatial{};
bool mChannelShared{};
float mEps{};
int mNbWeights{};
static std::vector<PluginField> mPluginAttributes;
};
} // namespace plugin
} // namespace nvinfer1
#endif // TRT_NORMALIZE_PLUGIN_H