{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Convert Tensorflow model to ONNX\n", "The general procedures to convert a tensorflow to ONNX is covered [part 1](./TensorflowToOnnx-1.ipynb).\n", "\n", "In this tutorial, we will cover the following contents in order:\n", "1. convert tensorflow model with other formats\n", " - convert with frozen graph\n", " - convert with checkpoint\n", "2. convert in python script\n", "3. useful command line options of `tensorflow-onnx`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Convert with frozen graph\n", "Tensorflow has API to get model's frozen graph and `tensorflow-onnx` can accept it as an input.\n", "\n", "While besides the frozen graph, the input and output tensors' names are also needed. Those names typically end with \":0\", you could either get them by tensorflow [summarized tool](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms), or specify their names using tf.identity in your tensorflow script." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Extracting /tmp/tensorflow/mnist/input_data/train-images-idx3-ubyte.gz\n", "Extracting /tmp/tensorflow/mnist/input_data/train-labels-idx1-ubyte.gz\n", "Extracting /tmp/tensorflow/mnist/input_data/t10k-images-idx3-ubyte.gz\n", "Extracting /tmp/tensorflow/mnist/input_data/t10k-labels-idx1-ubyte.gz\n", "step 0, training accuracy 0.08\n", "step 1000, training accuracy 0.98\n", "step 2000, training accuracy 1\n", "step 3000, training accuracy 0.98\n", "step 4000, training accuracy 1\n", "test accuracy 0.984\n" ] } ], "source": [ "import tensorflow as tf\n", "from assets.tensorflow_to_onnx_example import create_and_train_mnist\n", "\n", "def save_model_to_frozen_proto(sess):\n", " frozen_graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, [output_tensor.name[:-2]])\n", " with open(\"./output/mnist_frozen.pb\", \"wb\") as file:\n", " file.write(frozen_graph.SerializeToString())\n", "\n", "sess_tf, saver, input_tensor, output_tensor = create_and_train_mnist() \n", "save_model_to_frozen_proto(sess_tf)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2019-06-17 07:19:39,492 - INFO - Using tensorflow=1.12.0, onnx=1.5.0, tf2onnx=1.5.1/0c735a\n", "2019-06-17 07:19:39,492 - INFO - Using opset \n", "2019-06-17 07:19:39,606 - INFO - \n", "2019-06-17 07:19:39,635 - INFO - Optimizing ONNX model\n", "2019-06-17 07:19:39,652 - INFO - After optimization: Add -2 (4->2), Identity -3 (3->0), Transpose -8 (9->1)\n", "2019-06-17 07:19:39,654 - INFO - \n", "2019-06-17 07:19:39,654 - INFO - Successfully converted TensorFlow model ./output/mnist_frozen.pb to ONNX\n", "2019-06-17 07:19:39,667 - INFO - ONNX model is saved at ./output/mnist2.onnx\n" ] } ], "source": [ "# generating mnist.onnx using frozen_graph\n", "!python -m tf2onnx.convert \\\n", " --input ./output/mnist_frozen.pb \\\n", " --inputs {input_tensor.name} \\\n", " --outputs {output_tensor.name} \\\n", " --output ./output/mnist2.onnx \\\n", " --opset 7" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Convert with checkpoint\n", "Same as frozen graph, you need to specify the path to checkpoint file and model's input and output names." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "def save_model_to_checkpoint(saver, sess):\n", " save_path = saver.save(sess, \"./output/ckpt/model.ckpt\")\n", "\n", "save_model_to_checkpoint(saver, sess_tf)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2019-06-17 07:19:42,533 - INFO - Using tensorflow=1.12.0, onnx=1.5.0, tf2onnx=1.5.1/0c735a\n", "2019-06-17 07:19:42,534 - INFO - Using opset \n", "2019-06-17 07:19:42,660 - INFO - \n", "2019-06-17 07:19:42,684 - INFO - Optimizing ONNX model\n", "2019-06-17 07:19:42,700 - INFO - After optimization: Add -2 (4->2), Identity -3 (3->0), Transpose -8 (9->1)\n", "2019-06-17 07:19:42,705 - INFO - \n", "2019-06-17 07:19:42,705 - INFO - Successfully converted TensorFlow model ./output/ckpt/model.ckpt.meta to ONNX\n", "2019-06-17 07:19:42,718 - INFO - ONNX model is saved at ./output/mnist3.onnx\n" ] } ], "source": [ "# generating mnist.onnx using checkpoint\n", "!python -m tf2onnx.convert \\\n", " --checkpoint ./output/ckpt/model.ckpt.meta \\\n", " --inputs {input_tensor.name}\\\n", " --outputs {output_tensor.name} \\\n", " --output ./output/mnist3.onnx \\\n", " --opset 7" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Convert in python script\n", "`tensorflow-onnx` exports conversion APIs so that users can convert tensorflow model into ONNX directly in their script, the following code is an example." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "generating mnist.onnx in python script\n", "ONNX model is saved at ./output/mnist4.onnx\n" ] } ], "source": [ "from tf2onnx.tfonnx import process_tf_graph, tf_optimize\n", "import tensorflow as tf\n", "from tensorflow.graph_util import convert_variables_to_constants as freeze_graph\n", "\n", "print(\"generating mnist.onnx in python script\")\n", "graph_def = freeze_graph(sess_tf, sess_tf.graph_def, [output_tensor.name[:-2]])\n", "with tf.Graph().as_default() as graph:\n", " tf.import_graph_def(graph_def, name='')\n", " onnx_graph = process_tf_graph(graph, opset=7, input_names=[input_tensor.name], output_names=[output_tensor.name])\n", "model_proto = onnx_graph.make_model(\"test\")\n", "print(\"ONNX model is saved at ./output/mnist4.onnx\")\n", "with open(\"./output/mnist4.onnx\", \"wb\") as f:\n", " f.write(model_proto.SerializeToString())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Useful command line options\n", "The first useful option is \"**opset**\" which has been covered in [part 1](./TensorflowToOnnx-1.ipynb).\n", "\n", "Then second option is \"**inputs-as-nchw**\". Tensorflow supports NCHW and NHWC while ONNX only supports NCHW for now, so if your model uses NHWC then the tool will insert extra transpose nodes to convert the model. And though \"tensroflow-onnx\" has optimizers to remove the transpose nodes as much as possible, it's suggested to use NCHW directly if possible. And if model with NCHW is impossible, this option will tell the tool that the real input format will be NCHW and it can remove more inserted transpose nodes now. For example --inputs input0:0,input1:0 --inputs-as-nchw input0:0 assumes that images are passed into input0:0 as nchw while the TensorFlow model given uses nhwc.\n", "\n", "As said in part 1, ONNX defines its own operations set to represent machine learning computation operations and the set is different with tensorflow's. And two main difference will make the conversion fail, unsupported input dtype and unsupported operations, so `tensorflow-onnx` has two options to fix the gap if possible. The option \"**target**\" may insert cast operation to convert unsupported dtype into float in some target platform, please see the detail [here](https://github.com/onnx/tensorflow-onnx/wiki/target). The option \"**custom-ops**\" is useful when the runtime used supports custom ops that are not defined in onnx. For example: --custom-ops Print will insert a op Print in the onnx domain ai.onnx.converters.tensorflow into the graph." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "More detail on `tensorflow-onnx` can be got from its [README](https://github.com/onnx/tensorflow-onnx/blob/master/README.md \"Title\") file, for example the internal procedures in `tensorflow-onnx` to convert a tensorflow model." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" } }, "nbformat": 4, "nbformat_minor": 2 }