Table Of Contents
- Description
- How does this sample work?
- Prerequisites
- Running the sample
- Additional resources
- License
- Changelog
- Known issues
This sample, network_api_pytorch_mnist, trains a convolutional model on the MNIST dataset and runs inference with a TensorRT engine.
This sample is an end-to-end sample that trains a model in PyTorch, recreates the network in TensorRT, imports weights from the trained model, and finally runs inference with a TensorRT engine. For more information, see Creating A Network Definition In Python.
The sample.py script imports the functions from the mnist.py script for training the PyTorch model, as well as retrieving test cases from the PyTorch Data Loader.
In this sample, the following layers are used. For more information about these layers, see the TensorRT Developer Guide: Layers documentation.
Activation layer
The Activation layer implements element-wise activation functions. Specifically, this sample uses the Activation layer with the type RELU.
Convolution layer The Convolution layer computes a 2D (channel, height, and width) convolution, with or without bias.
FullyConnected layer The FullyConnected layer implements a matrix-vector product, with or without bias.
Pooling layer
The Pooling layer implements pooling within a channel. Supported pooling types are maximum, average and maximum-average blend.
-
Install the dependencies for Python.
python3 -m pip install -r requirements.txt- NOTE: On PowerPC systems, you will need to manually install PyTorch using IBM's PowerAI.
-
Run the sample to create a TensorRT inference engine and run inference:
python sample.py
-
Verify that the sample ran successfully. If the sample runs successfully you should see a match between the test case and the prediction.
Test Case: 0 Prediction: 0
To see the full list of available options and their descriptions, use the -h or --help command line option.
The following resources provide a deeper understanding about getting started with TensorRT using Python:
Model
Dataset
Documentation
- Introduction To NVIDIA’s TensorRT Samples
- Working With TensorRT Using The Python API
- NVIDIA’s TensorRT Documentation Library
For terms and conditions for use, reproduction, and distribution, see the TensorRT Software License Agreement documentation.
February 2019
This README.md file was recreated, updated and reviewed.
There are no known issues in this sample.