Thanks to visit codestin.com
Credit goes to github.com

Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/mxnet/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#define MXNET_OPERATOR_H_

#include <dmlc/base.h>
#include <dmlc/json.h>
#include <dmlc/logging.h>
#include <dmlc/registry.h>
#include <vector>
Expand Down Expand Up @@ -389,6 +390,9 @@ class OperatorProperty {
* \return a new constructed OperatorProperty
*/
static OperatorProperty *Create(const char* type_name);

virtual void Save(dmlc::JSONWriter *writer) const = 0;
virtual void Load(dmlc::JSONReader *reader) = 0;
};

/*! \brief typedef the factory function of operator property */
Expand Down
36 changes: 36 additions & 0 deletions include/mxnet/symbolic.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#define MXNET_SYMBOLIC_H_

#include <dmlc/base.h>
#include <dmlc/json.h>
#include <algorithm>
#include <vector>
#include <memory>
#include <string>
Expand Down Expand Up @@ -64,6 +66,11 @@ class StaticGraph {
if (source_id == other.source_id) return index < other.index;
return source_id < other.source_id;
}

/*! \brief interface for json serialization */
void Save(dmlc::JSONWriter *writer) const;
/*! \brief interface for json serialization */
void Load(dmlc::JSONReader *reader);
};
/*!
* \brief Operation Node in static graphs.
Expand Down Expand Up @@ -95,6 +102,23 @@ class StaticGraph {
int32_t backward_source_id;
/*! \brief default constructor */
Node() : backward_source_id(-1) {}

friend void swap(Node& lhs, Node& rhs) {
std::swap(lhs.op, rhs.op);
std::swap(lhs.name, rhs.name);
std::swap(lhs.inputs, rhs.inputs);
std::swap(lhs.backward_source_id, rhs.backward_source_id);
}
/*! \brief copy constructor in favor of serialization. */
Node(const Node& another) : op(another.op.get() ? another.op.get()->Copy() : nullptr),
name(another.name),
inputs(another.inputs),
backward_source_id(another.backward_source_id) {}

inline Node& operator=(Node another) {
swap(*this, another);
return *this;
}
/*! \return whether the node is forward op node */
inline bool is_forward() const {
return op != nullptr;
Expand All @@ -107,13 +131,25 @@ class StaticGraph {
inline bool is_variable() const {
return op == nullptr && !is_backward();
}
/*! \brief interface for json serialization */
void Save(dmlc::JSONWriter *writer) const;
/*! \brief interface for json serialization */
void Load(dmlc::JSONReader *reader);
};
/*! \brief all nodes in the graph */
std::vector<Node> nodes;
/*! \brief index of nodes that correspods to arguments */
std::vector<uint32_t> arg_nodes;
/*! \brief heads outputs of the graph */
std::vector<DataEntry> heads;
/*! \brief load static graph from json. TODO: a static creator's better */
void Load(const std::string& json);
/*! \brief save static graph to json */
void Save(std::string* json) const;
/*! \brief interface for json serialization */
void Save(dmlc::JSONWriter *writer) const;
/*! \brief interface for json serialization */
void Load(dmlc::JSONReader *reader);
// funtions to help inference in static graph
/*!
* \brief Perform a topological sort on the graph
Expand Down
6 changes: 1 addition & 5 deletions src/operator/activation-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ template<typename xpu>
Operator* CreateOp(ActivationParam type);

#if DMLC_USE_CXX11
class ActivationProp : public OperatorProperty {
class ActivationProp : public ParamOperatorProperty<ActivationParam> {
public:
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.Init(kwargs);
Expand Down Expand Up @@ -139,12 +139,8 @@ class ActivationProp : public OperatorProperty {
}

Operator* CreateOperator(Context ctx) const;

private:
ActivationParam param_;
};
#endif // DMLC_USE_CXX11
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_ACTIVATION_INL_H_

4 changes: 1 addition & 3 deletions src/operator/batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ Operator *CreateOp(BatchNormParam param);


#if DMLC_USE_CXX11
class BatchNormProp : public OperatorProperty {
class BatchNormProp : public ParamOperatorProperty<BatchNormParam> {
public:
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.Init(kwargs);
Expand Down Expand Up @@ -263,8 +263,6 @@ class BatchNormProp : public OperatorProperty {

Operator* CreateOperator(Context ctx) const;

private:
BatchNormParam param_;
}; // class BatchNormProp

#endif // DMLC_USE_CXX11
Expand Down
5 changes: 1 addition & 4 deletions src/operator/concat-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ template<typename xpu>
Operator *CreateOp(ConcatParam param);

#if DMLC_USE_CXX11
class ConcatProp : public OperatorProperty {
class ConcatProp : public ParamOperatorProperty<ConcatParam> {
public:
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.Init(kwargs);
Expand Down Expand Up @@ -223,9 +223,6 @@ class ConcatProp : public OperatorProperty {
}

Operator* CreateOperator(Context ctx) const;

private:
ConcatParam param_;
}; // class ConcatProp
#endif // DMLC_USE_CXX11
} // namespace op
Expand Down
4 changes: 1 addition & 3 deletions src/operator/convolution-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ template<typename xpu>
Operator* CreateOp(ConvolutionParam param);

#if DMLC_USE_CXX11
class ConvolutionProp : public OperatorProperty {
class ConvolutionProp : public ParamOperatorProperty<ConvolutionParam> {
public:
std::vector<std::string> ListArguments() const override {
if (!param_.no_bias) {
Expand Down Expand Up @@ -358,8 +358,6 @@ class ConvolutionProp : public OperatorProperty {

Operator* CreateOperator(Context ctx) const;

private:
ConvolutionParam param_;
}; // class ConvolutionProp
#endif // DMLC_USE_CXX11
} // namespace op
Expand Down
4 changes: 1 addition & 3 deletions src/operator/dropout-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ template<typename xpu>
Operator *CreateOp(DropoutParam param);

#if DMLC_USE_CXX11
class DropoutProp : public OperatorProperty {
class DropoutProp : public ParamOperatorProperty<DropoutParam> {
public:
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.Init(kwargs);
Expand Down Expand Up @@ -160,8 +160,6 @@ class DropoutProp : public OperatorProperty {

Operator* CreateOperator(Context ctx) const;

private:
DropoutParam param_;
}; // class DropoutProp
#endif // DMLC_USE_CXX11
} // namespace op
Expand Down
2 changes: 1 addition & 1 deletion src/operator/elementwise_binary_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ Operator* CreateElementWiseBinaryOp(ElementWiseBinaryOpType type);

#if DMLC_USE_CXX11
template<typename ForwardOp>
class ElementWiseBinaryOpProp : public OperatorProperty {
class ElementWiseBinaryOpProp : public NoParamOperatorProperty {
public:
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
CHECK_EQ(kwargs.size(), 0)
Expand Down
15 changes: 11 additions & 4 deletions src/operator/elementwise_sum-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,16 @@ class ElementWiseSumOp : public Operator {
Assign(igrad, req[i], F<mshadow_op::identity>(ograd));
}
}
inline void Save(dmlc::JSONWriter *writer) const {
writer->BeginObject();
writer->WriteObjectKeyValue("size_", size_);
writer->EndObject();
}
inline void Load(dmlc::JSONReader *reader) {
dmlc::JSONObjectReadHelper helper;
helper.DeclareField("size_", &size_);
helper.ReadAllFields(reader);
}

private:
int size_;
Expand All @@ -111,7 +121,7 @@ template<typename xpu>
Operator* CreateOp(ElementWiseSumParam param);

#if DMLC_USE_CXX11
class ElementWiseSumProp : public OperatorProperty {
class ElementWiseSumProp : public ParamOperatorProperty<ElementWiseSumParam> {
public:
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.Init(kwargs);
Expand Down Expand Up @@ -180,9 +190,6 @@ class ElementWiseSumProp : public OperatorProperty {
}

Operator* CreateOperator(Context ctx) const;

private:
ElementWiseSumParam param_;
}; // class ElementWiseSumProp

#endif // DMLC_USE_CXX11
Expand Down
5 changes: 1 addition & 4 deletions src/operator/fully_connected-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ template<typename xpu>
Operator* CreateOp(FullyConnectedParam param);

#if DMLC_USE_CXX11
class FullyConnectedProp : public OperatorProperty {
class FullyConnectedProp : public ParamOperatorProperty<FullyConnectedParam> {
public:
std::vector<std::string> ListArguments() const override {
if (!param_.no_bias) {
Expand Down Expand Up @@ -189,9 +189,6 @@ class FullyConnectedProp : public OperatorProperty {
}

Operator* CreateOperator(Context ctx) const;

private:
FullyConnectedParam param_;
}; // class FullyConnectedSymbol
#endif
} // namespace op
Expand Down
5 changes: 1 addition & 4 deletions src/operator/leaky_relu-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ template<typename xpu>
Operator* CreateOp(LeakyReLUParam type);

#if DMLC_USE_CXX11
class LeakyReLUProp : public OperatorProperty {
class LeakyReLUProp : public ParamOperatorProperty<LeakyReLUParam> {
public:
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.Init(kwargs);
Expand Down Expand Up @@ -298,9 +298,6 @@ class LeakyReLUProp : public OperatorProperty {
}

Operator* CreateOperator(Context ctx) const;

private:
LeakyReLUParam param_;
};
#endif // DMLC_USE_CXX11
} // namespace op
Expand Down
5 changes: 1 addition & 4 deletions src/operator/lrn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ template<typename xpu>
Operator *CreateOp(LRNParam param);

#if DMLC_USE_CXX11
class LocalResponseNormProp : public OperatorProperty {
class LocalResponseNormProp : public ParamOperatorProperty<LRNParam> {
public:
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.Init(kwargs);
Expand Down Expand Up @@ -173,9 +173,6 @@ class LocalResponseNormProp : public OperatorProperty {
}

Operator* CreateOperator(Context ctx) const;

private:
LRNParam param_;
}; // LocalResponseNormProp
#endif // DMLC_USE_CXX11
} // namespace op
Expand Down
40 changes: 40 additions & 0 deletions src/operator/operator_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
#ifndef MXNET_OPERATOR_OPERATOR_COMMON_H_
#define MXNET_OPERATOR_OPERATOR_COMMON_H_

#include <dmlc/json.h>
#include <dmlc/logging.h>
#include <mxnet/operator.h>
#include <mxnet/base.h>
#include <istream>
#include <ostream>
#include <string>

namespace mxnet {
Expand Down Expand Up @@ -93,6 +96,43 @@ struct InferShapeError {
}
#endif

#if DMLC_USE_CXX11
template<class Param>
class ParamOperatorProperty : public OperatorProperty {
public:
ParamOperatorProperty() {}
explicit ParamOperatorProperty(Param param) : param_(param) {}
inline void Save(dmlc::JSONWriter *writer) const {
writer->BeginObject();
std::string value = param_.PrintJson();
writer->WriteObjectKeyValue("param", value);
writer->EndObject();
}
inline void Load(dmlc::JSONReader *reader) {
dmlc::JSONObjectReadHelper helper;
std::string value;
helper.DeclareField("param", &value);
helper.ReadAllFields(reader);
param_.LoadJson(value);
}
inline bool operator==(const ParamOperatorProperty<Param>& other) const {
return param_ == other.param_;
}
protected:
Param param_;
};

class NoParamOperatorProperty : public OperatorProperty {
public:
inline void Save(dmlc::JSONWriter *writer) const {
}
inline void Load(dmlc::JSONReader *reader) {
}
inline bool operator==(const NoParamOperatorProperty& other) const {
return true;
}
};
#endif // DMLC_USE_CXX11
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_OPERATOR_COMMON_H_
1 change: 0 additions & 1 deletion src/operator/param.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,3 @@ struct Param {
} // namespace mxnet

#endif // MXNET_OPERATOR_PARAM_H_

5 changes: 1 addition & 4 deletions src/operator/pooling-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ Operator* CreateOp(PoolingParam param);


#if DMLC_USE_CXX11
class PoolingProp : public OperatorProperty {
class PoolingProp : public ParamOperatorProperty<PoolingParam> {
public:
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.Init(kwargs);
Expand Down Expand Up @@ -209,9 +209,6 @@ class PoolingProp : public OperatorProperty {
}

Operator* CreateOperator(Context ctx) const;

private:
PoolingParam param_;
}; // class PoolingProp
#endif // DMLC_USE_CXX11
} // namespace op
Expand Down
7 changes: 2 additions & 5 deletions src/operator/reshape-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,11 @@ template<typename xpu>
Operator* CreateOp();

#if DMLC_USE_CXX11
class ReshapeProp : public OperatorProperty {
class ReshapeProp : public ParamOperatorProperty<ReshapeParam> {
public:
ReshapeProp() {}

explicit ReshapeProp(ReshapeParam param) : param_(param) {}
explicit ReshapeProp(ReshapeParam param) : ParamOperatorProperty<ReshapeParam>(param) {}

void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.Init(kwargs);
Expand Down Expand Up @@ -140,9 +140,6 @@ class ReshapeProp : public OperatorProperty {
}

Operator* CreateOperator(Context ctx) const;

private:
ReshapeParam param_;
}; // class ReshapeProp

class FlattenProp : public ReshapeProp {
Expand Down
Loading