Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnx/defs/shape_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ void propagateElemTypeWithValidation(const TypeProto* input_type, TypeProto* out
}
}

TensorShapeProto getShapeInput(InferenceContext& ctx, size_t input_index, bool& found) {
TensorShapeProto getShapeInput(const InferenceContext& ctx, size_t input_index, bool& found) {
TensorShapeProto shape_input;

// First, check initializer.
Expand Down
8 changes: 4 additions & 4 deletions onnx/defs/shape_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -291,12 +291,12 @@ inline bool hasShape(const TypeProto& type) {
}

template <typename Context>
inline bool hasInputShape(Context& ctx, size_t n) {
inline bool hasInputShape(const Context& ctx, size_t n) {
return ctx.getNumInputs() > static_cast<size_t>(n) && ctx.getInputType(n) && hasShape(*ctx.getInputType(n));
}

template <typename Context>
inline bool hasNInputShapes(Context& ctx, size_t n) {
inline bool hasNInputShapes(const Context& ctx, size_t n) {
for (size_t i = 0; i < n; i++) {
if (!hasInputShape(ctx, i)) {
return false;
Expand All @@ -305,7 +305,7 @@ inline bool hasNInputShapes(Context& ctx, size_t n) {
return true;
}

inline const TensorShapeProto& getInputShape(InferenceContext& ctx, size_t n) {
inline const TensorShapeProto& getInputShape(const InferenceContext& ctx, size_t n) {
const auto* input_type = ctx.getInputType(n);
const auto value_case = input_type->value_case();
if (value_case != TypeProto::kTensorType && value_case != TypeProto::kSparseTensorType) {
Expand Down Expand Up @@ -532,7 +532,7 @@ inline void updateOutputShape(
// If neither is available, try rank inference.
// When one of above succeeds, `true` is stored in `found`.
// Otherwise, `false` is stored, which means that returned TensorShapeProto does not make sense.
TensorShapeProto getShapeInput(InferenceContext& ctx, size_t input_index, bool& found);
TensorShapeProto getShapeInput(const InferenceContext& ctx, size_t input_index, bool& found);

// Infer shape of an output from the value of a specified attribute, which is
// expected to be a list of integers specifying a valid shape.
Expand Down
17 changes: 3 additions & 14 deletions onnx/defs/tensor/defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,20 +307,9 @@ ONNX_OPERATOR_SET_SCHEMA(
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
// Type inference
propagateElemTypeFromInputToOutput(ctx, 0, 0);
// Shape Inference if 2nd input data (the target shape) is available
// or the target shape is generated via partial data propagation
const TensorProto* targetShapeInitializer = ctx.getInputData(1);
const auto* shapeInput = ctx.getSymbolicInput(1);
// The targetShapeProto represents the specified shape for output.
TensorShapeProto targetShapeProto;
if (targetShapeInitializer) {
auto targetShape = ParseData<int64_t>(targetShapeInitializer);
for (auto val : targetShape) {
targetShapeProto.add_dim()->set_dim_value(val);
}
} else if (shapeInput) {
targetShapeProto.CopyFrom(*shapeInput);
} else {
bool found;
TensorShapeProto targetShapeProto = getShapeInput(ctx, 1, found);
if (!found) {
return;
}

Expand Down
17 changes: 3 additions & 14 deletions onnx/defs/tensor/old.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5032,20 +5032,9 @@ ONNX_OPERATOR_SET_SCHEMA(
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
// Type inference
propagateElemTypeFromInputToOutput(ctx, 0, 0);
// Shape Inference if 2nd input data (the target shape) is available
// or the target shape is generated via partial data propagation
const TensorProto* targetShapeInitializer = ctx.getInputData(1);
const auto* shapeInput = ctx.getSymbolicInput(1);
// The targetShapeProto represents the specified shape for output.
TensorShapeProto targetShapeProto;
if (targetShapeInitializer) {
auto targetShape = ParseData<int64_t>(targetShapeInitializer);
for (auto val : targetShape) {
targetShapeProto.add_dim()->set_dim_value(val);
}
} else if (shapeInput) {
targetShapeProto.CopyFrom(*shapeInput);
} else {
bool found;
TensorShapeProto targetShapeProto = getShapeInput(ctx, 1, found);
if (!found) {
return;
}

Expand Down
35 changes: 31 additions & 4 deletions onnx/test/shape_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _make_graph(
)
input_value_infos.append(
make_tensor_value_info(
"UNKNOWN_SHAPE_" + seed_name, TensorProto.INT64, ()
"UNKNOWN_SHAPE_" + seed_name, TensorProto.INT64, (None,)
)
)
nodes[:0] = [
Expand Down Expand Up @@ -157,7 +157,7 @@ def _assert_inferred(
vis = [x for x in graph.value_info if x.name not in names_in_vis] + vis
inferred_model = self._inferred(graph_or_model, **kwargs)
inferred_vis = list(inferred_model.graph.value_info)
vis = sorted(vis, key=lambda x: x.name)
vis = sorted(vis, key=lambda x: x.name) # type: ignore[no-any-return]
inferred_vis = sorted(inferred_vis, key=lambda x: x.name) # type: ignore
assert len(vis) == len(inferred_vis)
for v, inferred_v in zip(vis, inferred_vis):
Expand Down Expand Up @@ -604,12 +604,39 @@ def test_concat_param_single_input(self, _, version) -> None:
)

@parameterized.expand(all_versions_for("Reshape"))
def test_reshape_dynamic_shape(self, _, version) -> None:
def test_reshape_dynamic_shape_known_rank(self, _, version) -> None:
self.skipIf(version < 14, "Rank inference is added from Version 14")
graph = self._make_graph(
[("x", TensorProto.UINT8, (2, 4, 3)), ("shape", TensorProto.INT64, (2,))],
[make_node("Reshape", ["x", "shape"], ["y"])],
[],
)
self._assert_inferred(
graph,
[make_tensor_value_info("y", TensorProto.UINT8, (None, None))],
opset_imports=[helper.make_opsetid(ONNX_DOMAIN, version)],
)

@parameterized.expand(all_versions_for("Reshape"))
def test_reshape_dynamic_shape_symbolic(self, _, version) -> None:
graph = self._make_graph(
[("x", TensorProto.UINT8, (2, 4, 3)), ("shape", TensorProto.INT64, ("M",))],
[make_node("Reshape", ["x", "shape"], ["y"])],
[],
)
self._assert_inferred(
graph,
[make_tensor_value_info("y", TensorProto.UINT8, None)],
opset_imports=[helper.make_opsetid(ONNX_DOMAIN, version)],
)

@parameterized.expand(all_versions_for("Reshape"))
def test_reshape_dynamic_unknown_shape(self, _, version) -> None:
graph = self._make_graph(
[("x", TensorProto.UINT8, (2, 4, 3)), ("shape", TensorProto.INT64, None)],
[make_node("Reshape", ["x", "shape"], ["y"])],
[],
)
self._assert_inferred(
graph,
[make_tensor_value_info("y", TensorProto.UINT8, None)],
Expand Down Expand Up @@ -1515,7 +1542,7 @@ def test_scatternd_noshape(self, _, version) -> None:
("x", TensorProto.FLOAT, (4, 5, 6)),
("indices", TensorProto.INT64, (3, 3, 2)),
("updates", TensorProto.FLOAT, (3, 3, 6)),
("shape", TensorProto.INT64, (2,)),
("shape", TensorProto.INT64, ("M",)),
],
[
make_node("Reshape", ["x", "shape"], ["x_reshaped"]),
Expand Down