Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 70d57df

Browse files
tjingrantprasanthpul
authored andcommitted
ONNX-Tensorflow Frontend Tutorial (#27)
* Create OnnxTensorflowExport.ipynb * frontend tutorial * add inference asset * add link * update tutorial * more spaces for indentation
1 parent 6be9576 commit 70d57df

4 files changed

Lines changed: 369 additions & 1 deletion

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
| [Cognitive Toolkit (CNTK)](https://www.microsoft.com/en-us/cognitive-toolkit/) | [built-in](https://docs.microsoft.com/en-us/cognitive-toolkit/setup-cntk-on-your-machine) | [Exporting](tutorials/CntkOnnxExport.ipynb) | [Importing](tutorials/OnnxCntkImport.ipynb) |
1010
| [Apache MXNet](http://mxnet.incubator.apache.org/) | [onnx/onnx-mxnet](https://github.com/onnx/onnx-mxnet) | coming soon | [Importing](tutorials/OnnxMxnetImport.ipynb) [experimental] |
1111
| [Chainer](https://chainer.org/) | [chainer/onnx-chainer](https://github.com/chainer/onnx-chainer) | [Exporting](tutorials/ChainerOnnxExport.ipynb) | coming soon |
12-
| [TensorFlow](https://www.tensorflow.org/) | [onnx/onnx-tensorflow](https://github.com/onnx/onnx-tensorflow) | coming soon | [Importing](tutorials/OnnxTensorflowImport.ipynb) [experimental] |
12+
| [TensorFlow](https://www.tensorflow.org/) | [onnx/onnx-tensorflow](https://github.com/onnx/onnx-tensorflow) | [Exporting](tutorials/OnnxTensorflowExport.ipynb) | [Importing](tutorials/OnnxTensorflowImport.ipynb) [experimental] |
1313
| [Apple CoreML](https://developer.apple.com/documentation/coreml) | [onnx/onnx-coreml](https://github.com/onnx/onnx-coreml) and [onnx/onnxmltools](https://github.com/onnx/onnxmltools) | [Exporting](https://github.com/onnx/onnxmltools) | [Importing](tutorials/OnnxCoremlImport.ipynb) |
1414
| [SciKit-Learn](http://scikit-learn.org/) | [onnx/onnxmltools](https://github.com/onnx/onnxmltools) | [Exporting](https://github.com/onnx/onnxmltools) | n/a |
1515

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {
6+
"deletable": true,
7+
"editable": true
8+
},
9+
"source": [
10+
"# Train in Tensorflow, Export to ONNX\n",
11+
"In this tutorial, we will demonstrate the complete process of training a MNIST model in Tensorflow and exporting the trained model to ONNX.\n",
12+
"\n",
13+
"### Training\n",
14+
"\n",
15+
"Firstly, we can initiate the [training script](./assets/tf-train-mnist.py) by issuing the command `python tf-train-mnist.py` on your terminal. Shortly, we should obtain a trained MNIST model. The training process needs no special instrumentation. However, to successfully convert the trained model, onnx-tensorflow requires three pieces of information, all of which can be obtained after training is complete:\n",
16+
"\n",
17+
" - *Graph definition*: You need to obtain information about the graph definition in the form of GraphProto. The easiest way to achieve this is to use the following snippet of code as shown in the example training script:\n",
18+
"```\n",
19+
" with open(\"graph.proto\", \"wb\") as file:\n",
20+
" graph = tf.get_default_graph().as_graph_def(add_shapes=True)\n",
21+
" file.write(graph.SerializeToString())\n",
22+
"```\n",
23+
" - *Shape information*: By default, `as_graph_def` does not serialize any information about the shapes of the intermediate tensor and such information is required by onnx-tensorflow. Thus we request Tensorflow to serialize the shape information by adding the keyword argument `add_shapes=True` as demonstrated above.\n",
24+
" - *Checkpoint*: Tensorflow checkpoint files contain information about the obtained weight; thus they are needed to convert the trained model to ONNX format.\n",
25+
"\n",
26+
"### Graph Freezing\n",
27+
"\n",
28+
"Secondly, we freeze the graph. Here, we include quotes from Tensorflow documentation about what graph freezing is:\n",
29+
"> One confusing part about this is that the weights usually aren't stored inside the file format during training. Instead, they're held in separate checkpoint files, and there are Variable ops in the graph that load the latest values when they're initialized. It's often not very convenient to have separate files when you're deploying to production, so there's the freeze_graph.py script that takes a graph definition and a set of checkpoints and freezes them together into a single file.\n",
30+
"\n",
31+
"Thus here we build the free_graph tool in Tensorflow source folder and execute it with the information about where the GraphProto is, where the checkpoint file is and where to put the freozen graph. One caveat is that you need to supply the name of the output node to this utility. If you are having trouble finding the name of the output node, please refer to [this article](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/graph_transforms/README.md#inspecting-graphs) for help.\n",
32+
"```\n",
33+
"bazel build tensorflow/python/tools:freeze_graph\n",
34+
"bazel-bin/tensorflow/python/tools/freeze_graph \\\n",
35+
" --input_graph=/home/mnist-tf/graph.proto \\\n",
36+
" --input_checkpoint=/home/mnist-tf/ckpt/model.ckpt \\\n",
37+
" --output_graph=/tmp/frozen_graph.pb \\\n",
38+
" --output_node_names=fc2/add \\\n",
39+
" --input_binary=True\n",
40+
"```\n",
41+
"\n",
42+
"Note that now we have obtained the `frozen_graph.pb` with graph definition as well as weight information in one file.\n",
43+
"\n",
44+
"### Model Conversion\n",
45+
"\n",
46+
"Thirdly, we convert the model to ONNX format using onnx-tensorflow. Using `tensorflow_graph_to_onnx_model` from onnx-tensorflow API (documentation available at https://github.com/onnx/onnx-tensorflow/blob/master/onnx_tf/doc/API.md)."
47+
]
48+
},
49+
{
50+
"cell_type": "code",
51+
"execution_count": 4,
52+
"metadata": {
53+
"collapsed": false,
54+
"deletable": true,
55+
"editable": true
56+
},
57+
"outputs": [],
58+
"source": [
59+
"import tensorflow as tf\n",
60+
"from onnx_tf.frontend import tensorflow_graph_to_onnx_model\n",
61+
"\n",
62+
"with tf.gfile.GFile(\"frozen_graph.pb\", \"rb\") as f:\n",
63+
" graph_def = tf.GraphDef()\n",
64+
" graph_def.ParseFromString(f.read())\n",
65+
" onnx_model = tensorflow_graph_to_onnx_model(graph_def,\n",
66+
" \"fc2/add\",\n",
67+
" opset=6)\n",
68+
"\n",
69+
" file = open(\"mnist.onnx\", \"wb\")\n",
70+
" file.write(onnx_model.SerializeToString())\n",
71+
" file.close()"
72+
]
73+
},
74+
{
75+
"cell_type": "markdown",
76+
"metadata": {
77+
"deletable": true,
78+
"editable": true
79+
},
80+
"source": [
81+
"Performing a simple sanity check to ensure that we have obtained the correct model, we print out the first node of the ONNX model graph converted, which corresponds to the reshape operation performed to convert the 1D serial input to a 2D image tensor:"
82+
]
83+
},
84+
{
85+
"cell_type": "code",
86+
"execution_count": 11,
87+
"metadata": {
88+
"collapsed": false,
89+
"deletable": true,
90+
"editable": true
91+
},
92+
"outputs": [
93+
{
94+
"name": "stdout",
95+
"output_type": "stream",
96+
"text": [
97+
"input: \"Placeholder\"\n",
98+
"input: \"reshape/Reshape/shape\"\n",
99+
"output: \"reshape/Reshape\"\n",
100+
"op_type: \"Reshape\"\n",
101+
"\n"
102+
]
103+
}
104+
],
105+
"source": [
106+
"print(onnx_model.graph.node[0])"
107+
]
108+
},
109+
{
110+
"cell_type": "markdown",
111+
"metadata": {
112+
"deletable": true,
113+
"editable": true
114+
},
115+
"source": [
116+
"### Inference using Backend\n",
117+
"\n",
118+
"In this tutorial, we continue our demonstration by performing inference using this obtained ONNX model. Here, we exported an image representing a handwritten 7 and stored the numpy array as image.npz. Using our backend, we will classify this image using the converted ONNX model."
119+
]
120+
},
121+
{
122+
"cell_type": "code",
123+
"execution_count": 5,
124+
"metadata": {
125+
"collapsed": false,
126+
"deletable": true,
127+
"editable": true
128+
},
129+
"outputs": [
130+
{
131+
"name": "stdout",
132+
"output_type": "stream",
133+
"text": [
134+
"The digit is classified as 7\n"
135+
]
136+
}
137+
],
138+
"source": [
139+
"import onnx\n",
140+
"import numpy as np\n",
141+
"from onnx_tf.backend import prepare\n",
142+
"\n",
143+
"model = onnx.load('mnist.onnx')\n",
144+
"tf_rep = prepare(model)\n",
145+
"\n",
146+
"img = np.load(\"./assets/image.npz\")\n",
147+
"output = tf_rep.run(img.reshape([1, 784]))\n",
148+
"print \"The digit is classified as \", np.argmax(output)\n"
149+
]
150+
},
151+
{
152+
"cell_type": "code",
153+
"execution_count": null,
154+
"metadata": {
155+
"collapsed": true,
156+
"deletable": true,
157+
"editable": true
158+
},
159+
"outputs": [],
160+
"source": []
161+
}
162+
],
163+
"metadata": {
164+
"kernelspec": {
165+
"display_name": "Python 2",
166+
"language": "python",
167+
"name": "python2"
168+
},
169+
"language_info": {
170+
"codemirror_mode": {
171+
"name": "ipython",
172+
"version": 2
173+
},
174+
"file_extension": ".py",
175+
"mimetype": "text/x-python",
176+
"name": "python",
177+
"nbconvert_exporter": "python",
178+
"pygments_lexer": "ipython2",
179+
"version": "2.7.5"
180+
}
181+
},
182+
"nbformat": 4,
183+
"nbformat_minor": 2
184+
}

tutorials/assets/image.npz

3.19 KB
Binary file not shown.

tutorials/assets/tf-train-mnist.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""A deep MNIST classifier using convolutional layers.
17+
18+
See extensive documentation at
19+
https://www.tensorflow.org/get_started/mnist/pros
20+
"""
21+
# Disable linter warnings to maintain consistency with tutorial.
22+
# pylint: disable=invalid-name
23+
# pylint: disable=g-bad-import-order
24+
25+
from __future__ import absolute_import
26+
from __future__ import division
27+
from __future__ import print_function
28+
29+
import argparse
30+
import sys
31+
import tempfile
32+
33+
from tensorflow.examples.tutorials.mnist import input_data
34+
35+
import tensorflow as tf
36+
37+
FLAGS = None
38+
39+
def add(x, y):
40+
return tf.nn.bias_add(x, y, data_format="NCHW")
41+
42+
def deepnn(x):
43+
"""deepnn builds the graph for a deep net for classifying digits.
44+
45+
Args:
46+
x: an input tensor with the dimensions (N_examples, 784), where 784 is the
47+
number of pixels in a standard MNIST image.
48+
49+
Returns:
50+
A tuple (y, keep_prob). y is a tensor of shape (N_examples, 10), with values
51+
equal to the logits of classifying the digit into one of 10 classes (the
52+
digits 0-9). keep_prob is a scalar placeholder for the probability of
53+
dropout.
54+
"""
55+
# Reshape to use within a convolutional neural net.
56+
# Last dimension is for "features" - there is only one here, since images are
57+
# grayscale -- it would be 3 for an RGB image, 4 for RGBA, etc.
58+
with tf.name_scope('reshape'):
59+
x_image = tf.reshape(x, [-1, 1, 28, 28])
60+
61+
# First convolutional layer - maps one grayscale image to 32 feature maps.
62+
with tf.name_scope('conv1'):
63+
W_conv1 = weight_variable([5, 5, 1, 32])
64+
b_conv1 = bias_variable([32])
65+
h_conv1 = tf.nn.relu(add(conv2d(x_image, W_conv1), b_conv1))
66+
67+
# Pooling layer - downsamples by 2X.
68+
with tf.name_scope('pool1'):
69+
h_pool1 = max_pool_2x2(h_conv1)
70+
71+
# Second convolutional layer -- maps 32 feature maps to 64.
72+
with tf.name_scope('conv2'):
73+
W_conv2 = weight_variable([5, 5, 32, 64])
74+
b_conv2 = bias_variable([64])
75+
h_conv2 = tf.nn.relu(add(conv2d(h_pool1, W_conv2), b_conv2))
76+
77+
# Second pooling layer.
78+
with tf.name_scope('pool2'):
79+
h_pool2 = max_pool_2x2(h_conv2)
80+
81+
# Fully connected layer 1 -- after 2 round of downsampling, our 28x28 image
82+
# is down to 7x7x64 feature maps -- maps this to 1024 features.
83+
with tf.name_scope('fc1'):
84+
W_fc1 = weight_variable([7 * 7 * 64, 1024])
85+
b_fc1 = bias_variable([1024])
86+
87+
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
88+
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
89+
90+
# Map the 1024 features to 10 classes, one for each digit
91+
with tf.name_scope('fc2'):
92+
W_fc2 = weight_variable([1024, 10])
93+
b_fc2 = bias_variable([10])
94+
95+
y_conv = tf.matmul(h_fc1, W_fc2) + b_fc2
96+
97+
return y_conv
98+
99+
100+
def conv2d(x, W):
101+
"""conv2d returns a 2d convolution layer with full stride."""
102+
return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME', data_format="NCHW")
103+
104+
105+
def max_pool_2x2(x):
106+
"""max_pool_2x2 downsamples a feature map by 2X."""
107+
return tf.nn.max_pool(x, ksize=[1, 1, 2, 2],
108+
strides=[1, 1, 2, 2], padding='SAME', data_format="NCHW")
109+
110+
111+
def weight_variable(shape):
112+
"""weight_variable generates a weight variable of a given shape."""
113+
initial = tf.truncated_normal(shape, stddev=0.1)
114+
return tf.Variable(initial)
115+
116+
117+
def bias_variable(shape):
118+
"""bias_variable generates a bias variable of a given shape."""
119+
initial = tf.constant(0.1, shape=shape)
120+
return tf.Variable(initial)
121+
122+
123+
def main(_):
124+
# Import data
125+
mnist = input_data.read_data_sets(FLAGS.data_dir)
126+
127+
# Create the model
128+
x = tf.placeholder(tf.float32, [None, 784])
129+
130+
# Build the graph for the deep net
131+
y_conv = deepnn(x)
132+
133+
with open("graph.proto", "wb") as file:
134+
graph = tf.get_default_graph().as_graph_def(add_shapes=True)
135+
file.write(graph.SerializeToString())
136+
137+
# Define loss and optimizer
138+
y_ = tf.placeholder(tf.int64, [None])
139+
140+
with tf.name_scope('loss'):
141+
cross_entropy = tf.losses.sparse_softmax_cross_entropy(
142+
labels=y_, logits=y_conv)
143+
cross_entropy = tf.reduce_mean(cross_entropy)
144+
145+
with tf.name_scope('adam_optimizer'):
146+
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
147+
148+
with tf.name_scope('accuracy'):
149+
correct_prediction = tf.equal(tf.argmax(y_conv, 1), y_)
150+
correct_prediction = tf.cast(correct_prediction, tf.float32)
151+
accuracy = tf.reduce_mean(correct_prediction)
152+
153+
graph_location = tempfile.mkdtemp()
154+
print('Saving graph to: %s' % graph_location)
155+
train_writer = tf.summary.FileWriter(graph_location)
156+
train_writer.add_graph(tf.get_default_graph())
157+
158+
saver = tf.train.Saver()
159+
160+
with tf.Session() as sess:
161+
sess.run(tf.global_variables_initializer())
162+
for i in range(20000):
163+
batch = mnist.train.next_batch(50)
164+
165+
if i % 1000 == 0:
166+
train_accuracy = accuracy.eval(feed_dict={
167+
x: batch[0], y_: batch[1]})
168+
print('step %d, training accuracy %g' % (i, train_accuracy))
169+
170+
save_path = saver.save(sess, "./ckpt/model.ckpt")
171+
print("Model saved in path: %s" % save_path)
172+
train_step.run(feed_dict={x: batch[0], y_: batch[1]})
173+
174+
print('test accuracy %g' % accuracy.eval(feed_dict={
175+
x: mnist.test.images, y_: mnist.test.labels}))
176+
177+
if __name__ == '__main__':
178+
parser = argparse.ArgumentParser()
179+
parser.add_argument('--data_dir', type=str,
180+
default='/tmp/tensorflow/mnist/input_data',
181+
help='Directory for storing input data')
182+
FLAGS, unparsed = parser.parse_known_args()
183+
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
184+

0 commit comments

Comments
 (0)