Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/source/onnx_torchscript.rst
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,6 @@ Functions
^^^^^^^^^

.. autofunction:: export
.. autofunction:: export_to_pretty_string
.. autofunction:: register_custom_op_symbolic
.. autofunction:: unregister_custom_op_symbolic
.. autofunction:: select_model_mode_for_export
Expand Down
85 changes: 48 additions & 37 deletions test/onnx/test_pytorch_onnx_no_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def forward(self, x):

x = torch.ones(3, 3)
f = io.BytesIO()
torch.onnx.export(AddmmModel(), x, f, verbose=False)
torch.onnx.export(AddmmModel(), x, f)

def test_onnx_transpose_incomplete_tensor_type(self):
# Smoke test to get us into the state where we are attempting to export
Expand Down Expand Up @@ -115,7 +115,8 @@ def foo(x):

traced = torch.jit.trace(foo, (torch.rand([2])))

torch.onnx.export_to_pretty_string(traced, (torch.rand([2]),))
f = io.BytesIO()
torch.onnx.export(traced, (torch.rand([2]),), f)

def test_onnx_export_script_module(self):
class ModuleToExport(torch.jit.ScriptModule):
Expand All @@ -125,7 +126,8 @@ def forward(self, x):
return x + x

mte = ModuleToExport()
torch.onnx.export_to_pretty_string(mte, (torch.zeros(1, 2, 3),), verbose=False)
f = io.BytesIO()
torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f)

@common_utils.suppress_warnings
def test_onnx_export_func_with_warnings(self):
Expand All @@ -138,9 +140,8 @@ def forward(self, x):
return func_with_warning(x)

# no exception
torch.onnx.export_to_pretty_string(
WarningTest(), torch.randn(42), verbose=False
)
f = io.BytesIO()
torch.onnx.export(WarningTest(), torch.randn(42), f)

def test_onnx_export_script_python_fail(self):
class PythonModule(torch.jit.ScriptModule):
Expand All @@ -161,7 +162,7 @@ def forward(self, x):
mte = ModuleToExport()
f = io.BytesIO()
with self.assertRaisesRegex(RuntimeError, "Couldn't export Python"):
torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f, verbose=False)
torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f)

def test_onnx_export_script_inline_trace(self):
class ModuleToInline(torch.nn.Module):
Expand All @@ -179,7 +180,8 @@ def forward(self, x):
return y + y

mte = ModuleToExport()
torch.onnx.export_to_pretty_string(mte, (torch.zeros(1, 2, 3),), verbose=False)
f = io.BytesIO()
torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f)

def test_onnx_export_script_inline_script(self):
class ModuleToInline(torch.jit.ScriptModule):
Expand All @@ -198,7 +200,8 @@ def forward(self, x):
return y + y

mte = ModuleToExport()
torch.onnx.export_to_pretty_string(mte, (torch.zeros(1, 2, 3),), verbose=False)
f = io.BytesIO()
torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f)

def test_onnx_export_script_module_loop(self):
class ModuleToExport(torch.jit.ScriptModule):
Expand All @@ -212,7 +215,8 @@ def forward(self, x):
return x

mte = ModuleToExport()
torch.onnx.export_to_pretty_string(mte, (torch.zeros(1, 2, 3),), verbose=False)
f = io.BytesIO()
torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f)

@common_utils.suppress_warnings
def test_onnx_export_script_truediv(self):
Expand All @@ -224,9 +228,8 @@ def forward(self, x):

mte = ModuleToExport()

torch.onnx.export_to_pretty_string(
mte, (torch.zeros(1, 2, 3, dtype=torch.float),), verbose=False
)
f = io.BytesIO()
torch.onnx.export(mte, (torch.zeros(1, 2, 3, dtype=torch.float),), f)

def test_onnx_export_script_non_alpha_add_sub(self):
class ModuleToExport(torch.jit.ScriptModule):
Expand All @@ -236,7 +239,8 @@ def forward(self, x):
return bs - 1

mte = ModuleToExport()
torch.onnx.export_to_pretty_string(mte, (torch.rand(3, 4),), verbose=False)
f = io.BytesIO()
torch.onnx.export(mte, (torch.rand(3, 4),), f)

def test_onnx_export_script_module_if(self):
class ModuleToExport(torch.jit.ScriptModule):
Expand All @@ -247,7 +251,8 @@ def forward(self, x):
return x

mte = ModuleToExport()
torch.onnx.export_to_pretty_string(mte, (torch.zeros(1, 2, 3),), verbose=False)
f = io.BytesIO()
torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f)

def test_onnx_export_script_inline_params(self):
class ModuleToInline(torch.jit.ScriptModule):
Expand Down Expand Up @@ -277,7 +282,8 @@ def forward(self, x):
torch.mm(torch.zeros(2, 3), torch.ones(3, 3)), torch.ones(3, 4)
)
self.assertEqual(result, reference)
torch.onnx.export_to_pretty_string(mte, (torch.ones(2, 3),), verbose=False)
f = io.BytesIO()
torch.onnx.export(mte, (torch.ones(2, 3),), f)

def test_onnx_export_speculate(self):
class Foo(torch.jit.ScriptModule):
Expand Down Expand Up @@ -312,8 +318,10 @@ def transpose(x):
f1 = Foo(transpose)
f2 = Foo(linear)

torch.onnx.export_to_pretty_string(f1, (torch.ones(1, 10, dtype=torch.float),))
torch.onnx.export_to_pretty_string(f2, (torch.ones(1, 10, dtype=torch.float),))
f = io.BytesIO()
torch.onnx.export(f1, (torch.ones(1, 10, dtype=torch.float),), f)
f = io.BytesIO()
torch.onnx.export(f2, (torch.ones(1, 10, dtype=torch.float),), f)

def test_onnx_export_shape_reshape(self):
class Foo(torch.nn.Module):
Expand All @@ -326,17 +334,20 @@ def forward(self, x):
return reshaped

foo = torch.jit.trace(Foo(), torch.zeros(1, 2, 3))
torch.onnx.export_to_pretty_string(foo, (torch.zeros(1, 2, 3)))
f = io.BytesIO()
torch.onnx.export(foo, (torch.zeros(1, 2, 3)), f)

def test_listconstruct_erasure(self):
class FooMod(torch.nn.Module):
def forward(self, x):
mask = x < 0.0
return x[mask]

torch.onnx.export_to_pretty_string(
f = io.BytesIO()
torch.onnx.export(
FooMod(),
(torch.rand(3, 4),),
f,
add_node_names=False,
do_constant_folding=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
Expand All @@ -351,13 +362,10 @@ def forward(self, x):
retval += torch.sum(x[0:i], dim=0)
return retval

mod = DynamicSliceExportMod()

input = torch.rand(3, 4, 5)

torch.onnx.export_to_pretty_string(
DynamicSliceExportMod(), (input,), opset_version=10
)
f = io.BytesIO()
torch.onnx.export(DynamicSliceExportMod(), (input,), f, opset_version=10)

def test_export_dict(self):
class DictModule(torch.nn.Module):
Expand All @@ -368,10 +376,12 @@ def forward(self, x_in: torch.Tensor) -> Dict[str, torch.Tensor]:
mod = DictModule()
mod.train(False)

torch.onnx.export_to_pretty_string(mod, (x_in,))
f = io.BytesIO()
torch.onnx.export(mod, (x_in,), f)

with self.assertRaisesRegex(RuntimeError, r"DictConstruct.+is not supported."):
torch.onnx.export_to_pretty_string(torch.jit.script(mod), (x_in,))
f = io.BytesIO()
torch.onnx.export(torch.jit.script(mod), (x_in,), f)

def test_source_range_propagation(self):
class ExpandingModule(torch.nn.Module):
Expand Down Expand Up @@ -497,11 +507,11 @@ def forward(self, box_regression: Tensor, proposals: List[Tensor]):
proposal = [torch.randn(2, 4), torch.randn(2, 4)]

with self.assertRaises(RuntimeError) as cm:
onnx_model = io.BytesIO()
f = io.BytesIO()
torch.onnx.export(
model,
(box_regression, proposal),
onnx_model,
f,
)

def test_initializer_sequence(self):
Expand Down Expand Up @@ -637,7 +647,7 @@ def forward(self, x):

x = torch.randn(1, 2, 3, requires_grad=True)
f = io.BytesIO()
torch.onnx.export(Model(), x, f)
torch.onnx.export(Model(), (x,), f)
model = onnx.load(f)
model.ir_version = 0

Expand Down Expand Up @@ -744,7 +754,7 @@ def forward(self, x):

f = io.BytesIO()
with warnings.catch_warnings(record=True):
torch.onnx.export(MyDrop(), (eg,), f, verbose=False)
torch.onnx.export(MyDrop(), (eg,), f)

def test_pack_padded_pad_packed_trace(self):
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
Expand Down Expand Up @@ -791,7 +801,7 @@ def forward(self, x, seq_lens):
self.assertEqual(grad, grad_traced)

f = io.BytesIO()
torch.onnx.export(m, (x, seq_lens), f, verbose=False)
torch.onnx.export(m, (x, seq_lens), f)

# Suppression: ONNX warns when exporting RNNs because of potential batch size mismatch.
@common_utils.suppress_warnings
Expand Down Expand Up @@ -851,7 +861,7 @@ def forward(self, x, seq_lens):
self.assertEqual(grad, grad_traced)

f = io.BytesIO()
torch.onnx.export(m, (x, seq_lens), f, verbose=False)
torch.onnx.export(m, (x, seq_lens), f)

def test_pushpackingpastrnn_in_peephole_create_own_gather_input(self):
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
Expand Down Expand Up @@ -931,7 +941,8 @@ class Mod(torch.nn.Module):
def forward(self, x, w):
return torch.matmul(x, w).detach()

torch.onnx.export_to_pretty_string(Mod(), (torch.rand(3, 4), torch.rand(4, 5)))
f = io.BytesIO()
torch.onnx.export(Mod(), (torch.rand(3, 4), torch.rand(4, 5)), f)

def test_aten_fallback_must_fallback(self):
class ModelWithAtenNotONNXOp(torch.nn.Module):
Expand Down Expand Up @@ -1088,12 +1099,12 @@ def sym_scatter_max(g, src, index, dim, out, dim_size):
torch.onnx.register_custom_op_symbolic(
"torch_scatter::scatter_max", sym_scatter_max, 1
)
f = io.BytesIO()
with torch.no_grad():
torch.onnx.export(
m,
(src, idx),
"mymodel.onnx",
verbose=False,
f,
opset_version=13,
custom_opsets={"torch_scatter": 1},
do_constant_folding=True,
Expand Down Expand Up @@ -1176,7 +1187,7 @@ def forward(self, x):
model = Net(C).cuda().half()
x = torch.randn(N, C).cuda().half()
f = io.BytesIO()
torch.onnx.export(model, x, f, opset_version=14)
torch.onnx.export(model, (x,), f, opset_version=14)
onnx_model = onnx.load_from_string(f.getvalue())
const_node = [n for n in onnx_model.graph.node if n.op_type == "Constant"]
self.assertNotEqual(len(const_node), 0)
Expand Down
2 changes: 0 additions & 2 deletions torch/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
"JitScalarType",
# Public functions
"export",
"export_to_pretty_string",
"is_in_onnx_export",
"select_model_mode_for_export",
"register_custom_op_symbolic",
Expand Down Expand Up @@ -68,7 +67,6 @@
from .utils import (
_run_symbolic_function,
_run_symbolic_method,
export_to_pretty_string,
is_in_onnx_export,
register_custom_op_symbolic,
select_model_mode_for_export,
Expand Down
79 changes: 0 additions & 79 deletions torch/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
"model_signature",
"warn_on_static_input_change",
"unpack_quantized_tensor",
"export_to_pretty_string",
"unconvertible_ops",
"register_custom_op_symbolic",
"unregister_custom_op_symbolic",
Expand Down Expand Up @@ -1140,84 +1139,6 @@ def _model_to_graph(
return graph, params_dict, torch_out


@torch._disable_dynamo
@_deprecation.deprecated("2.5", "the future", "use onnx.printer.to_text() instead")
Comment thread
titaiwangms marked this conversation as resolved.
def export_to_pretty_string(
model,
args,
export_params=True,
verbose=False,
training=_C_onnx.TrainingMode.EVAL,
input_names=None,
output_names=None,
operator_export_type=_C_onnx.OperatorExportTypes.ONNX,
export_type=None,
google_printer=False,
opset_version=None,
keep_initializers_as_inputs=None,
custom_opsets=None,
add_node_names=True,
do_constant_folding=True,
dynamic_axes=None,
):
"""Similar to :func:`export`, but returns a text representation of the ONNX model.

Only differences in args listed below. All other args are the same
as :func:`export`.

Args:
add_node_names (bool, default True): Whether or not to set
NodeProto.name. This makes no difference unless
``google_printer=True``.
google_printer (bool, default False): If False, will return a custom,
compact representation of the model. If True will return the
protobuf's `Message::DebugString()`, which is more verbose.

Returns:
A UTF-8 str containing a human-readable representation of the ONNX model.
"""
if opset_version is None:
opset_version = _constants.ONNX_DEFAULT_OPSET
if custom_opsets is None:
custom_opsets = {}
GLOBALS.export_onnx_opset_version = opset_version
GLOBALS.operator_export_type = operator_export_type

with exporter_context(model, training, verbose):
val_keep_init_as_ip = _decide_keep_init_as_input(
keep_initializers_as_inputs, operator_export_type, opset_version
)
val_add_node_names = _decide_add_node_names(
add_node_names, operator_export_type
)
val_do_constant_folding = _decide_constant_folding(
do_constant_folding, operator_export_type, training
)
args = _decide_input_format(model, args)
graph, params_dict, torch_out = _model_to_graph(
model,
args,
verbose,
input_names,
output_names,
operator_export_type,
val_do_constant_folding,
training=training,
dynamic_axes=dynamic_axes,
)

return graph._pretty_print_onnx( # type: ignore[attr-defined]
params_dict,
opset_version,
False,
operator_export_type,
google_printer,
val_keep_init_as_ip,
custom_opsets,
val_add_node_names,
)


@_deprecation.deprecated("2.5", "the future", "avoid using this function")
def unconvertible_ops(
model,
Expand Down