Thanks to visit codestin.com
Credit goes to github.com

Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1067,21 +1067,6 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym,



/*!
* \brief infer storage type of unknown input types given the known one.
*/
MXNET_DLL int MXSymbolInferStorageType(SymbolHandle sym,
mx_uint num_args,
const char** keys,
const int *arg_storage_type_data,
mx_uint *in_storage_type_size,
const int **in_storage_type_data,
mx_uint *out_storage_type_size,
const int **out_storage_type_data,
mx_uint *aux_storage_type_size,
const int **aux_storage_type_data,
int *complete);

//--------------------------------------------
// Part 4: Executor interface
//--------------------------------------------
Expand Down
6 changes: 6 additions & 0 deletions include/mxnet/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ using FComputeEx = std::function<void (const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs)>;

using FInferStorageType = std::function<bool (const NodeAttrs& attrs,
const Context& ctx,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs)>;

} // namespace mxnet

#endif // MXNET_OP_ATTR_TYPES_H_
2 changes: 1 addition & 1 deletion nnvm
85 changes: 1 addition & 84 deletions python/mxnet/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .context import Context, cpu
from .ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP
from .name import NameManager # pylint: disable=unused-import
from .ndarray import _STORAGE_TYPE_ID_TO_STR, _STORAGE_TYPE_STR_TO_ID
from .ndarray import _STORAGE_TYPE_STR_TO_ID
from .sparse_ndarray import _ndarray_cls
from .executor import Executor
from . import _symbol_internal as _internal
Expand Down Expand Up @@ -723,89 +723,6 @@ def list_auxiliary_states(self):
self.handle, ctypes.byref(size), ctypes.byref(sarr)))
return [py_str(sarr[i]) for i in range(size.value)]

def infer_storage_type(self, *args, **kwargs):
"""Infer the storage type of outputs and arguments of given known types of arguments.

User can either pass in the known types in positional way or keyword argument way.
Tuple of Nones is returned if there is not enough information passed in.
An error will be raised if there is inconsistency found in the known types passed in.

Parameters
----------
*args :
Provide type of arguments in a positional way.
Unknown type can be marked as None

**kwargs :
Provide keyword arguments of known types.

Returns
-------
arg_storage_types : list of numpy.dtype or None
List of types of arguments.
The order is in the same order as list_arguments()
out_storage_types : list of numpy.dtype or None
List of types of outputs.
The order is in the same order as list_outputs()
aux_storage_types : list of numpy.dtype or None
List of types of outputs.
The order is in the same order as list_auxiliary_states()
"""
# pylint: disable=too-many-locals
if len(args) != 0 and len(kwargs) != 0:
raise ValueError('Can only specify known argument \
types either by positional or kwargs way.')
sdata = []
if len(args) != 0:
keys = None
for s in args:
if s is not None:
if s not in _STORAGE_TYPE_STR_TO_ID or not isinstance(s, basestring):
raise TypeError('Argument need to be one of '+str(_STORAGE_TYPE_STR_TO_ID))
sdata.append(_STORAGE_TYPE_STR_TO_ID[s])
else:
sdata.append(_STORAGE_TYPE_STR_TO_ID['undefined'])
else:
keys = []
for k, v in kwargs.items():
if v in _STORAGE_TYPE_STR_TO_ID:
keys.append(c_str(k))
sdata.append(_STORAGE_TYPE_STR_TO_ID[v])
arg_storage_type_size = mx_uint()
arg_storage_type_data = ctypes.POINTER(ctypes.c_int)()
out_storage_type_size = mx_uint()
out_storage_type_data = ctypes.POINTER(ctypes.c_int)()
aux_storage_type_size = mx_uint()
aux_storage_type_data = ctypes.POINTER(ctypes.c_int)()
complete = ctypes.c_int()
check_call(_LIB.MXSymbolInferStorageType(
self.handle,
mx_uint(len(sdata)),
c_array(ctypes.c_char_p, keys),
c_array(ctypes.c_int, sdata),
ctypes.byref(arg_storage_type_size),
ctypes.byref(arg_storage_type_data),
ctypes.byref(out_storage_type_size),
ctypes.byref(out_storage_type_data),
ctypes.byref(aux_storage_type_size),
ctypes.byref(aux_storage_type_data),
ctypes.byref(complete)))
if complete.value != 0:
arg_storage_types = [
_STORAGE_TYPE_ID_TO_STR[arg_storage_type_data[i]] \
for i in range(arg_storage_type_size.value)]
out_storage_types = [
_STORAGE_TYPE_ID_TO_STR[out_storage_type_data[i]] \
for i in range(out_storage_type_size.value)]
aux_storage_types = [
_STORAGE_TYPE_ID_TO_STR[aux_storage_type_data[i]] \
for i in range(aux_storage_type_size.value)]
return (arg_storage_types, out_storage_types, aux_storage_types)
else:
return (None, None, None)
# pylint: enable=too-many-locals


def infer_type(self, *args, **kwargs):
"""Infers the type of all arguments and all outputs, given the known types
for some arguments.
Expand Down
4 changes: 2 additions & 2 deletions src/c_api/c_api_ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ void SetShapeType(const nnvm::Op* op,
std::vector<NDArray>& ndoutputs = *p_ndoutputs;
static auto& infershape = nnvm::Op::GetAttr<nnvm::FInferShape>("FInferShape");
static auto& infertype = nnvm::Op::GetAttr<nnvm::FInferType>("FInferType");
static auto& inferstorage = nnvm::Op::GetAttr<nnvm::FInferStorageType>("FInferStorageType");
static auto& inferstorage = nnvm::Op::GetAttr<FInferStorageType>("FInferStorageType");
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
// infer shape
std::vector<TShape>& in_shapes = ret->arg_shapes;
Expand Down Expand Up @@ -184,7 +184,7 @@ void SetShapeType(const nnvm::Op* op,
out_storage_types.push_back(i.storage_type());
}
if (inferstorage.count(op)) {
CHECK(inferstorage[op](attrs, &in_storage_types, &out_storage_types));
CHECK(inferstorage[op](attrs, ctx, &in_storage_types, &out_storage_types));
CHECK_EQ(out_storage_types.size(), static_cast<size_t>(infered_num_outputs));
} else {
#if IMPERATIVE_EXEC_DEBUG
Expand Down
57 changes: 3 additions & 54 deletions src/c_api/c_api_symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <nnvm/symbolic.h>
#include "./c_api_common.h"
#include "../operator/operator_common.h"
#include "../executor/exec_pass.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -457,7 +458,7 @@ int MXSymbolInferShape(SymbolHandle sym,
}

try {
g = nnvm::pass::InferShape(std::move(g), arg_shapes, "__shape__");
g = mxnet::exec::InferShape(std::move(g), arg_shapes, "__shape__");
} catch (const mxnet::op::InferShapeError &err) {
throw dmlc::Error(err.msg);
}
Expand Down Expand Up @@ -512,58 +513,6 @@ int MXSymbolInferShapePartial(SymbolHandle sym,
&succ);
}

// TODO(haibin) refactor with infer_type
int MXSymbolInferStorageType(SymbolHandle sym,
mx_uint num_args,
const char** keys,
const int *arg_storage_type_data,
mx_uint *in_storage_type_size,
const int **in_storage_type_data,
mx_uint *out_storage_type_size,
const int **out_storage_type_data,
mx_uint *aux_storage_type_size,
const int **aux_storage_type_data,
int *complete) {
nnvm::Symbol *s = static_cast<nnvm::Symbol*>(sym);
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
API_BEGIN();
nnvm::Graph g = Symbol2Graph(*s);
nnvm::StorageTypeVector arg_storage_types(g.indexed_graph().input_nodes().size(),
kUndefinedStorage);
if (keys == nullptr && num_args != 0) {
std::vector<uint32_t> read_only_args = mxnet::ReadOnlyArgIndices(g.indexed_graph());
CHECK_LE(num_args, read_only_args.size());
for (mx_uint i = 0; i < num_args; ++i) {
arg_storage_types[read_only_args[i]] = arg_storage_type_data[i];
}
} else {
std::unordered_map<std::string, int> kwargs;
for (mx_uint i = 0; i < num_args; ++i) {
kwargs[keys[i]] = arg_storage_type_data[i];
}
mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_storage_types, "InferStorageType");
}

g = nnvm::pass::InferStorageType(std::move(g), arg_storage_types, "__storage_type__");
// copy back
CopyAttr(g.indexed_graph(), g.GetAttr<nnvm::StorageTypeVector>("storage_type"),
&(ret->arg_storage_types), &(ret->out_storage_types), &(ret->aux_storage_types));

*in_storage_type_size = static_cast<mx_uint>(ret->arg_storage_types.size());
*in_storage_type_data = dmlc::BeginPtr(ret->arg_storage_types);
*out_storage_type_size = static_cast<mx_uint>(ret->out_storage_types.size());
*out_storage_type_data = dmlc::BeginPtr(ret->out_storage_types);
*in_storage_type_size = static_cast<mx_uint>(ret->arg_storage_types.size());
*in_storage_type_data = dmlc::BeginPtr(ret->arg_storage_types);
*out_storage_type_size = static_cast<mx_uint>(ret->out_storage_types.size());
*out_storage_type_data = dmlc::BeginPtr(ret->out_storage_types);
*aux_storage_type_size = static_cast<mx_uint>(ret->aux_storage_types.size());
*aux_storage_type_data = dmlc::BeginPtr(ret->aux_storage_types);
*complete = (g.GetAttr<size_t>("storage_type_num_unknown_nodes") == 0);
API_END();
}


int MXSymbolInferType(SymbolHandle sym,
mx_uint num_args,
const char** keys,
Expand Down Expand Up @@ -594,7 +543,7 @@ int MXSymbolInferType(SymbolHandle sym,
mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_types, "InferType");
}

g = nnvm::pass::InferType(std::move(g), arg_types, "__dtype__");
g = mxnet::exec::InferType(std::move(g), arg_types, "__dtype__");
// copy back
CopyAttr(g.indexed_graph(), g.GetAttr<nnvm::DTypeVector>("dtype"),
&(ret->arg_types), &(ret->out_types), &(ret->aux_types));
Expand Down
3 changes: 2 additions & 1 deletion src/c_api/c_predict_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <unordered_map>
#include "./c_api_common.h"
#include "../operator/operator_common.h"
#include "../executor/exec_pass.h"

using namespace mxnet;

Expand Down Expand Up @@ -176,7 +177,7 @@ int MXPredCreatePartialOut(const char* symbol_json_str,
}
}
nnvm::Graph g; g.outputs = sym.outputs;
g = nnvm::pass::InferShape(std::move(g), in_shapes, "__shape__");
g = mxnet::exec::InferShape(std::move(g), in_shapes, "__shape__");
bool infer_complete = (g.GetAttr<size_t>("shape_num_unknown_nodes") == 0);
CHECK(infer_complete)
<< "The shape information of is not enough to get the shapes";
Expand Down
41 changes: 41 additions & 0 deletions src/executor/exec_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
#include <mxnet/ndarray.h>
#include <mxnet/operator.h>
#include <nnvm/graph.h>
#include <nnvm/graph_attr_types.h>
#include <vector>
#include <memory>
#include <string>

namespace mxnet {
namespace exec {
Expand Down Expand Up @@ -107,6 +109,45 @@ Graph AttachOpResources(Graph g);
*/
Graph DetectInplaceAddTo(Graph g);

/*!
* \brief Infer shapes in the graph given the information.
* \param graph The input graph.
* \param shape_inputs The shapes of input symbols to the graph.
* \param shape_attr_key The key to the node attribute that can indicate shape. This is
* the place where manual hint for shapes could be injected.
* \return A graph with new attribute "shape" containing inferred shape of each NodeEntry.
* The index of ShapeVector is given by graph.indexed_graph().entry_id.
*/
Graph InferShape(Graph graph,
nnvm::ShapeVector shape_inputs,
const std::string& shape_attr_key = "");

/*!
* \brief Infer types in the graph given the information.
* \param graph The input graph.
* \param dtype_inputs The types of input symbols to the graph.
* \param dtype_attr_key The key to the node attribute that can indicate types. This is
* the place where manual hint for types could be injected.
* \return A graph with new attribute "dtype" containing inferred type of each NodeEntry.
* The index of ShapeVector is given by graph.indexed_graph().entry_id.
*/
Graph InferType(Graph graph,
nnvm::DTypeVector dtype_inputs,
const std::string& dtype_attr_key = "");

/*!
* \brief Infer storage types in the graph given the information.
* \param graph The input graph.
* \param storage_type_inputs The storage types of input symbols to the graph.
* \param storage_type_attr_key The key to the node attribute that can indicate storage types.
This is the place where manual hint for types could be injected.
* \return A graph with new attribute "storage_type" containing inferred type of each NodeEntry.
* The index of StorageTypeVector is given by graph.indexed_graph().entry_id.
*/
Graph InferStorageType(Graph graph,
nnvm::StorageTypeVector storage_type_inputs,
const std::string& storage_type_attr_key = "");

} // namespace exec
} // namespace mxnet

Expand Down
Loading