diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000..fb9b3b0fd --- /dev/null +++ b/.gitmodules @@ -0,0 +1,9 @@ +[submodule "tftrt/examples/third_party/models"] + path = tftrt/examples/third_party/models + url = https://github.com/tensorflow/models.git +[submodule "tftrt/examples/third_party/cocoapi"] + path = tftrt/examples/third_party/cocoapi + url = https://github.com/cocodataset/cocoapi.git +[submodule "tftrt/examples/third_party/DeepLearningExamples"] + path = tftrt/examples/third_party/DeepLearningExamples + url = https://github.com/NVIDIA/DeepLearningExamples.git diff --git a/README.md b/README.md index 57e3269b2..95272a1ef 100644 --- a/README.md +++ b/README.md @@ -1 +1,85 @@ -Coming soon: Examples using [NVIDIA TensorRT](https://developer.nvidia.com/tensorrt) in TensorFlow. +# Documentation for TensorRT in TensorFlow (TF-TRT) + +The documentaion on how to accelerate inference in TensorFlow with TensorRT (TF-TRT) is here: https://docs.nvidia.com/deeplearning/dgx/tf-trt-user-guide/index.html + +# Examples for TensorRT in TensorFlow (TF-TRT) + +This repository contains a number of different examples +that show how to use +[TF-TRT](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/tensorrt). +TF-TRT is a part of TensorFlow +that optimizes TensorFlow graphs using +[TensorRT](https://developer.nvidia.com/tensorrt). +We have used these examples to verify the accuracy and +performance of TF-TRT. For more information see +[Verified Models](https://docs.nvidia.com/deeplearning/dgx/tf-trt-user-guide/index.html#verified-models). + +## Examples + +* [Image Classification](tftrt/examples/image-classification) +* [Object Detection](tftrt/examples/object_detection) + + +# Using TensorRT in TensorFlow (TF-TRT) + +This module provides necessary bindings and introduces +`TRTEngineOp` operator that wraps a subgraph in TensorRT. +This module is under active development. + + +## Installing TF-TRT + +Currently Tensorflow nightly builds include TF-TRT by default, +which means you don't need to install TF-TRT separately. +You can pull the latest TF containers from docker hub or +install the latest TF pip package to get access to the latest TF-TRT. + +If you want to use TF-TRT on NVIDIA Jetson platform, you can find +the download links for the relevant Tensorflow pip packages here: +https://docs.nvidia.com/deeplearning/dgx/index.html#installing-frameworks-for-jetson + + +## Installing TensorRT + +In order to make use of TF-TRT, you will need a local installation +of TensorRT from the +[NVIDIA Developer website](https://developer.nvidia.com/tensorrt). +Installation instructions for compatibility with TensorFlow are provided on the +[TensorFlow GPU support](https://www.tensorflow.org/install/gpu) guide. + + +## Documentation + +[TF-TRT documentaion](https://docs.nvidia.com/deeplearning/dgx/tf-trt-user-guide/index.html) +gives an overview of the supported functionalities, provides tutorials +and verified models, explains best practices with troubleshooting guides. + + +## Tests + +TF-TRT includes both Python tests and C++ unit tests. +Most of Python tests are located in the test directory +and they can be executed uring `bazel test` or directly +with the Python command. Most of the C++ unit tests are +used to test the conversion functions that convert each TF op to +a number of TensorRT layers. + + +## Compilation + +In order to compile the module, you need to have a local TensorRT installation +(libnvinfer.so and respective include files). During the configuration step, +TensorRT should be enabled and installation path should be set. If installed +through package managers (deb,rpm), configure script should find the necessary +components from the system automatically. If installed from tar packages, user +has to set path to location where the library is installed during configuration. + +```shell +bazel build --config=cuda --config=opt //tensorflow/tools/pip_package:build_pip_package +bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/ +``` + + +## License + +[Apache License 2.0](LICENSE) diff --git a/setup.py b/setup.py new file mode 100644 index 000000000..a727bc46e --- /dev/null +++ b/setup.py @@ -0,0 +1,27 @@ +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from setuptools import find_packages, setup + +setup( + name='tftrt', + version='0.0', + description='NVIDIA TensorRT integration in TensorFlow', + author='NVIDIA', + packages=find_packages(), + install_requires=['tqdm'] +) diff --git a/tftrt/__init__.py b/tftrt/__init__.py new file mode 100644 index 000000000..04285a017 --- /dev/null +++ b/tftrt/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= diff --git a/tftrt/examples/__init__.py b/tftrt/examples/__init__.py new file mode 100644 index 000000000..04285a017 --- /dev/null +++ b/tftrt/examples/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= diff --git a/tftrt/examples/image-classification/README.md b/tftrt/examples/image-classification/README.md index d4b8d662f..9c69df1ee 100644 --- a/tftrt/examples/image-classification/README.md +++ b/tftrt/examples/image-classification/README.md @@ -1,24 +1,34 @@ -# TensorFlow-TensorRT Examples +# Image classification example -This script will run inference using a few popular image classification models -on the ImageNet validation set. - -You can turn on TensorFlow-TensorRT integration with the flag `--use_trt`. This -will apply TensorRT inference optimization to speed up execution for portions of -the model's graph where supported, and will fall back to native TensorFlow for -layers and operations which are not supported. See -https://devblogs.nvidia.com/tensorrt-integration-speeds-tensorflow-inference/ +The example script `image_classification.py` runs inference using a number of +popular image classification models. This script is included in the NVIDIA +TensorFlow Docker containers under `/workspace/nvidia-examples`. See [Preparing +To Use NVIDIA +Containers](https://docs.nvidia.com/deeplearning/dgx/preparing-containers/index.html) for more information. -When using TF-TRT, you can also control the precision with `--precision`. -float32 is the default (`--precision fp32`) with float16 (`--precision fp16`) or -int8 (`--precision int8`) allowing further performance improvements, at the cost -of some accuracy. int8 mode requires a calibration step which is done -automatically. +You can enable TF-TRT integration by passing the `--use_trt` flag to the script. +This causes the script to apply TensorRT inference optimization to speed up +execution for portions of the model's graph where supported, and to fall back on +native TensorFlow for layers and operations which are not supported. See +[Accelerating Inference In TensorFlow With TensorRT User +Guide](https://docs.nvidia.com/deeplearning/dgx/tf-trt-user-guide/index.html) for +more information. + +When using the TF-TRT integration flag, you can use the precision option +(`--precision`) to control precision. float32 is the default (`--precision +fp32`) with float16 (`--precision fp16`) or int8 (`--precision int8`) allowing +further performance improvements. + +int8 mode requires a calibration step (which is done automatically), but you +also must specificy the directory in which the calibration dataset is stored +with `--calib_data_dir /imagenet_validation_data`. You can use the same data +for both calibration and validation. ## Models -This test supports the following models for image classification: +We have verified the following models. + * MobileNet v1 * MobileNet v2 * NASNet - Large @@ -30,41 +40,217 @@ This test supports the following models for image classification: * Inception v3 * Inception v4 +For the accuracy numbers of these models on the +ImageNet validation dataset, see +[Verified Models](https://docs.nvidia.com/deeplearning/dgx/tf-trt-user-guide/index.html#verified-models). + ## Setup + +### Setup for running within an NVIDIA TensorFlow Docker container + +If you are running these examples within the [NVIDIA TensorFlow docker +container](https://ngc.nvidia.com/catalog/containers/nvidia:tensorflow) under +`/workspace/nvidia-examples/tensorrt/tftrt/examples/image-classification`, run +the `install_dependencies.sh` setup script. Then skip below to the +[Data](#Data) section. + +``` +cd /workspace/nvidia-examples/tensorrt/tftrt/examples/image-classification +./install_dependencies.sh +cd ../third_party/models +export PYTHONPATH="$PYTHONPATH:$PWD" ``` -# Clone [tensorflow/models](https://github.com/tensorflow/models) + +### Setup for running standalone + +If you are running these examples within your own TensorFlow environment, +perform the following steps: + +``` +# Clone this repository (tensorflow/tensorrt) if you haven't already. +git clone https://github.com/tensorflow/tensorrt.git + +# Clone tensorflow/models. git clone https://github.com/tensorflow/models.git # Add the models directory to PYTHONPATH to install tensorflow/models. cd models export PYTHONPATH="$PYTHONPATH:$PWD" -# Run the TF Slim setup. +# Run the TensorFlow Slim setup. cd research/slim python setup.py install -# You may also need to install the requests package +# Install the requests package. pip install requests ``` -Note: the PYTHONPATH environment variable will be not be saved between different -shells. You can either repeat that step each time you work in a new shell, or -add `export PYTHONPATH="$PYTHONPATH:/path/to/tensorflow_models"` to your .bashrc -file (replacing /path/to/tensorflow_models with the path to your -tensorflow/models repository). + +### PYTHONPATH environment variable + +The `PYTHONPATH` environment variable is not saved between different shell +sessions. To avoid having to set `PYTHONPATH` in each new shell session, you +can add the following line to your `.bashrc` file: + +```export PYTHONPATH="$PYTHONPATH:/path/to/tensorflow_models"``` + +replacing `/path/to/tensorflow_models` with the path to your `tensorflow/models` +repository). + +Also see [Setting Up The Environment +](https://docs.nvidia.com/deeplearning/dgx/tf-trt-user-guide/index.html#image-class-envirn) +for more information. ### Data -The script supports only TFRecord format for data. The script -assumes that validation TFRecords are named according to the pattern: -`validation-*-of-00128`. +The example script supports either using a dataset (for validation +mode - TFRecord format, for benchmark mode - jpeg format) or using +autogenerated synthetic data (with the `--use_synthetic` flag). If you use +TFRecord files, the script assumes that the TFRecords are named according to the +pattern: `validation-*-of-00128`. + +Note: The reported accuracy numbers are the results of running the scripts on +the ImageNet validation dataset. + +To download and process the ImageNet data, you can: + +- Use the scripts provided in the `nvidia-examples/build_imagenet_data` + directory in the NVIDIA TensorFlow Docker container `workspace` directory. + Follow the `README` file in that directory for instructions on how to use + these scripts. + +or + +- Use the scripts provided by TF Slim in the `tensorflow/models` repository at + `research/slim`. Consult the `README` file under `research/slim for + instructions on how to use these scripts. Also please note that these scripts + download both the training and validation sets, and this example only requires + the validation set. + +Also see [Obtaining The ImageNet Data +](https://docs.nvidia.com/deeplearning/dgx/tf-trt-user-guide/index.html#image-class-data) +for more information. + +## Running the examples as a Jupyter notebook + +You can run the examples as a Jupyter notebook (`image-classification.ipynb`) +from this directory: + +``` +jupyter notebook --ip=0.0.0.0 +``` -You can download and process Imagenet using [this script provided by TF -Slim](https://github.com/tensorflow/models/blob/master/research/slim/datasets/download_imagenet.sh). -Please note that this script downloads both the training and validation sets, -and this example only requires the validation set. +If you want to run these examples as a Jupyter notebook within an NVIDIA +TensorFlow Docker container, first you need to run the container with the +`--publish 0.0.0.0:8888:8888` option to publish Jupyter's port `8888` to the +host machine at port `8888` over all network interfaces (`0.0.0.0`). Then you +can use the following command in the +`/workspace/nvidia-examples/tensorrt/tftrt/examples/image-classification` +directory: + +``` +jupyter notebook --ip=0.0.0.0 --allow-root +``` ## Usage -`python inference.py --data_dir /imagenet_validation_data --model vgg_16 [--use_trt]` +The main Python script is `image_classification.py`. Assuming that the ImageNet +validation data are located under `/data/imagenet/train-val-tfrecord`, you can +evaluate inference with TF-TRT integration using the pre-trained ResNet V1 50 +model as follows: + +``` +python image_classification.py --model resnet_v1_50 \ + --data_dir /data/imagenet/train-val-tfrecord \ + --use_trt \ + --precision fp16 \ + --mode validation +``` + +Where: + +`--model`: Which model to use to run inference, in this case ResNet V1 50. + +`--data_dir`: Path to the ImageNet TFRecord validation files. + +`--use_trt`: Convert the graph to a TensorRT graph. + +`--precision`: Precision mode to use, in this case FP16. + +`--mode`: Which mode to use (validation or benchmark). In validation we run inference with accuracy and performance measurments, in benchmark only performance. Run with `--help` to see all available options. + +Also see [General Script Usage +](https://docs.nvidia.com/deeplearning/dgx/tf-trt-user-guide/index.html#image-class-usage) +for more information. + +## Output + +The script first loads the pre-trained model. If given the flag `--use_trt`, +the model is converted to a TensorRT graph, and the script displays (in addition +to its initial configuration options): + +- the number of nodes before conversion (`num_nodes(native_tf)`) + +- the number of nodes after conversion (`num_nodes(trt_total)`) + +- the number of separate TensorRT nodes (`num_nodes(trt_only)`) + +- the size of the graph before conversion (`graph_size(MB)(native_tf)`) + +- the size of the graph after conversion (`graph_size(MB)(trt)`) + +- how long the conversion took (`time(s)(trt_conversion)`) + +For example: + +``` +num_nodes(native_tf): 741 +num_nodes(trt_total): 10 +num_nodes(trt_only): 1 +graph_size(MB)(native_tf): *** +graph_size(MB)(tft): *** +time(s)(trt_conversion): *** +``` + +Note: For a list of supported operations that can be converted to a TensorRT +graph, see the [Supported +Ops](https://docs.nvidia.com/deeplearning/dgx/tf-trt-user-guide/index.html#support-ops) +section of the [Accelerating Inference In TensorFlow With TensorRT User +Guide](https://docs.nvidia.com/deeplearning/dgx/tf-trt-user-guide/index.html). + +The script then begins running inference on the ImageNet validation set, +displaying run times of each iteration after the interval defined by the +`--display_every` option (default: `100`): + +``` +running inference... + step 100/6202, iter_time(ms)=**.****, images/sec=*** + step 200/6202, iter_time(ms)=**.****, images/sec=*** + step 300/6202, iter_time(ms)=**.****, images/sec=*** + ... +``` + +On completion, the script prints overall accuracy and timing information over +the inference session: + +``` +results of resnet_v1_50: + accuracy: 75.95 + images/sec: *** + 99th_percentile(ms): *** + total_time(s): *** + latency_mean(ms): *** +``` + +The accuracy metric measures the percentage of predictions from inference that +match the labels on the ImageNet Validation set. The remaining metrics capture +various performance measurements: + +- number of images processed per second (`images/sec`) + +- total time of the inference session (`total_time(s)`) + +- the mean duration for each iteration (`latency_mean(ms)`) + +- the slowest duration for an iteration (`99th_percentile(ms)`) diff --git a/tftrt/examples/image-classification/image_classification.ipynb b/tftrt/examples/image-classification/image_classification.ipynb new file mode 100644 index 000000000..40202dacd --- /dev/null +++ b/tftrt/examples/image-classification/image_classification.ipynb @@ -0,0 +1,417 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Image classification example\n", + "\n", + "This example script runs inference using a number of popular image classification models. This script is included in the NVIDIA TensorFlow Docker containers under `/workspace/nvidia-examples`. See [Preparing To Use NVIDIA Containers](https://docs.nvidia.com/deeplearning/dgx/preparing-containers/index.html) for more information.\n", + "\n", + "You can enable TF-TRT integration by passing the `--use_trt` flag to the script. This causes the script to apply TensorRT inference optimization to speed up execution for portions of the model's graph where supported, and to fall back on native TensorFlow for layers and operations which are not supported. See [Accelerating Inference In TensorFlow With TensorRT User Guide](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html) for more information. \n", + "\n", + "When using TF-TRT, you can use the precision option (`--precision`) to control precision. float32 is the default (`--precision fp32`) with float16 (`--precision fp16`) or int8 (`--precision int8`) allowing further performance improvements. \n", + "\n", + "int8 mode requires a calibration step (which is done automatically), but you also must specificy the directory in which the calibration dataset is stored with `--calib_data_dir /imagenet_validation_data`. You can use the same data for both calibration and validation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Models\n", + "\n", + "We have verified the following models.\n", + "\n", + "* MobileNet v1\n", + "* MobileNet v2\n", + "* NASNet - Large\n", + "* NASNet - Mobile\n", + "* ResNet50 v1\n", + "* ResNet50 v2\n", + "* VGG16\n", + "* VGG19\n", + "* Inception v3\n", + "* Inception v4\n", + "\n", + "For the accuracy numbers of these models on the ImageNet validation dataset, see [Verified Models](https://docs.nvidia.com/deeplearning/dgx/tf-trt-user-guide/index.html#verified-models)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Usage\n", + "\n", + "The example Python script is `image_classification.py`. You can evaluate inference with TF-TRT integration using the pre-trained ResNet V1 50 model by calling the script with the following arguments:\n", + "\n", + "```\n", + "python image_classification.py --model resnet_v1_50 \\\n", + " --data_dir /path/to/imagenet/tfrecord/files \\\n", + " --use_trt \\\n", + " --precision fp16\n", + "```\n", + "\n", + "Where:\n", + "\n", + "`--model`: Which model to use to run inference, in this case ResNet V1 50.\n", + "\n", + "`--data_dir`: Path to the ImageNet TFRecord validation files.\n", + "\n", + "`--use_trt`: Convert the graph to a TensorRT graph.\n", + "\n", + "`--precision`: Precision mode to use, in this case FP16.\n", + "\n", + "Run with `--help` to see all available options.\n", + "\n", + "Note: In this notebook, we run the script inside IPython using the `%run` built-in command, so that realtime output and tracebacks are displayed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%run image_classification --help" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Also see [General Script Usage](https://docs.nvidia.com/deeplearning/dgx/tf-trt-user-guide/index.html#image-class-usage) for more information." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Output\n", + "\n", + "The script first loads the pre-trained model. If given the flag `--use_trt`, the model is converted to a TensorRT graph, and the script displays (in addition to its inital configuration options):\n", + "\n", + "- the number of nodes before conversion (`num_nodes(native_tf)`)\n", + "\n", + "- the number of nodes after conversion (`num_nodes(trt_total)`)\n", + "\n", + "- the number of separate TensorRT nodes (`num_nodes(trt_only)`)\n", + "\n", + "- the size of the graph before conversion (`graph_size(MB)(native_tf)`)\n", + "\n", + "- the size of the graph after conversion (`graph_size(MB)(trt)`)\n", + "\n", + "- how long the conversion took (`time(s)(trt_conversion)`)\n", + "\n", + "For example:\n", + "\n", + "```\n", + "num_nodes(native_tf): 741\n", + "num_nodes(trt_total): 10\n", + "num_nodes(trt_only): 1\n", + "graph_size(MB)(native_tf): ***\n", + "graph_size(MB)(tft): ***\n", + "time(s)(trt_conversion): ***\n", + "```\n", + "\n", + "Note: For a list of supported operations that can be converted to a TensorRT graph, see the [Supported\n", + "Ops](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html#support-ops) section of the [Accelerating Inference In TensorFlow With TensorRT User Guide](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html).\n", + "\n", + "The script then begins running inference on the ImageNet validation set, displaying run times of each iteration after the interval defined by the `--display_every` option (default: `100`):\n", + "\n", + "```\n", + "running inference...\n", + " step 100/6202, iter_time(ms)=**.****, images/sec=***\n", + " step 200/6202, iter_time(ms)=**.****, images/sec=***\n", + " step 300/6202, iter_time(ms)=**.****, images/sec=***\n", + " ...\n", + "```\n", + "\n", + "On completion, the script prints overall accuracy and timing information over the inference session:\n", + "\n", + "```\n", + "results of resnet_v1_50:\n", + " accuracy: 75.95\n", + " images/sec: ***\n", + " 99th_percentile(ms): ***\n", + " total_time(s): ***\n", + " latency_mean(ms): ***\n", + "```\n", + "\n", + "The accuracy metric measures the percentage of predictions from inference that match the labels on the ImageNet Validation set. The remaining metrics capture various performance measurements:\n", + "\n", + "- number of images processed per second (`images/sec`)\n", + "\n", + "- total time of the inference session (`total_time(s)`)\n", + "\n", + "- the mean duration for each iteration (`latency_mean(ms)`)\n", + "\n", + "- the slowest duration for an iteration (`99th_percentile(ms)`)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Using TF-TRT With ResNet V1 50\n", + "\n", + "Here we walk through how to use the example Python scripts in the with the ResNet V1 50 model.\n", + "\n", + "Using TF-TRT with precision modes lower than FP32, that is, FP16 and INT8, improves the performance of inference. The FP16 precision mode uses Tensor Cores or half-precision hardware instructions, if possible, while the INT8 precision mode uses Tensor Cores or integer hardware instructions. INT8 mode also requires running a calibration step, which the script does automatically.\n", + "\n", + "Below we use the example script to compare the accuracy and timing performance of all the precision modes when running inference using the ResNet V1 50 model." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Native TensorFlow Using FP32\n", + "\n", + "This is our baseline session running inference using native TensorFlow without TensorRT integration/conversion.\n", + "\n", + "First, set `DATA_DIR` to where you stored the ImageNet TFRecord validation files:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "DATA_DIR = \"/path/to/imagenet/tfrecord/files\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can run the baseline session with native TensorFlow.\n", + "\n", + "Note: We use the `--cache` flag to allow the script to cache checkpoint and frozen graph files to use with future sessions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%run image_classification --model resnet_v1_50 \\\n", + " --data_dir $DATA_DIR \\\n", + " --cache" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Look for the accuracy and timing information under:\n", + "\n", + "```\n", + "results of resnet_v1_50:\n", + " ...\n", + "```\n", + "\n", + "You can compare the accuracy metrics for the ResNet 50 models with the metrics listed at: [Pre-trained model](https://github.com/tensorflow/models/tree/master/official/resnet#pre-trained-model)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### TF-TRT Using FP32\n", + "\n", + "In this session, we use the same precision mode as in our native TensorFlow session (FP32), but this time we use the `--use_trt` flag to convert the graph to a TensorRT optimized graph." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%run image_classification --model resnet_v1_50 \\\n", + " --data_dir $DATA_DIR \\\n", + " --use_trt \\\n", + " --cache" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Before the script starts running inference, it converts the TensorFlow graph to a TensorRT optimized graph with fewer nodes. Look for the following metrics in the log:\n", + "\n", + "```\n", + "num_nodes(native_tf): ***\n", + "num_nodes(tftrt_total): ***\n", + "num_nodes(trt_only): ***\n", + "graph_size(MB)(native_tf): ***\n", + "graph_size(MB)(tft): ***\n", + "...\n", + "time(s)(trt_conversion): ***\n", + "```\n", + "\n", + "Note: For a list of supported operations that can be converted to a TensorRT graph, see [Supported Ops](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html#support-ops).\n", + "\n", + "Again, note the accuracy and timing information under:\n", + "\n", + "```\n", + "results of resnet_v1_50:\n", + " ...\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### TF-TRT Using FP16\n", + "\n", + "In this session, we continue to use TF-TRT conversion, but we reduce the precision mode to FP16, allowing the use of Tensor Cores for performance improvements during inference, while preserving accuracy within the acceptable tolerance level (0.1%)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%run image_classification --model resnet_v1_50 \\\n", + " --data_dir $DATA_DIR \\\n", + " --use_trt \\\n", + " --precision fp16 \\\n", + " --cache" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Again, we see that the native TensorFlow graph gets converted to a TensorRT graph. Look again for the following in the log to confirm:\n", + "\n", + "```\n", + "num_nodes(native_tf): ***\n", + "num_nodes(tftrt_total): ***\n", + "num_nodes(trt_only): ***\n", + "graph_size(MB)(native_tf): ***\n", + "graph_size(MB)(tft): ***\n", + "...\n", + "time(s)(trt_conversion): ***\n", + "```\n", + "\n", + "Compare the results with the previous sessions:\n", + "\n", + "```\n", + "results of resnet_v1_50:\n", + " ...\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### TF-TRT Using INT8\n", + "\n", + "For this session we continue to use TF-TRT conversion, and we reduce the precision further to INT8 for faster computation. Because INT8 has significantly lower precision and dynamic range than FP32, the INT8 precision mode requires an additional calibration step before performing the type conversion. In this calibration step, inference is first run with FP32 precision on a calibration dataset to generate many INT8 quantizations of the weights and activations in the trained TensorFlow graph, from which are chosen the INT8 quantizations that minimize information loss. For more details on the calibration process, see the [8-bit Inference with TensorRT presentation](http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf).\n", + "\n", + "The calibration dataset should closely reflect the distribution of the problem dataset. In this walkthrough, we use the same ImageNet validation set training data for the calibration data, with `--calib_data_dir $DATA_DIR`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%run image_classification --model resnet_v1_50 \\\n", + " --data_dir $DATA_DIR \\\n", + " --use_trt \\\n", + " --precision int8 \\\n", + " --calib_data_dir $DATA_DIR \\\n", + " --cache" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This time, we see the script performing the calibration step:\n", + "\n", + "```\n", + "Calibrating INT8...\n", + "...\n", + "INFO:tensorflow:Evaluation [6/62]\n", + "INFO:tensorflow:Evaluation [12/62]\n", + "INFO:tensorflow:Evaluation [18/62]\n", + "...\n", + "```\n", + "\n", + "The process completes with the message:\n", + "\n", + "```\n", + "INT8 graph created.\n", + "```\n", + "\n", + "When the calibration step completes -- it may take some time -- we again see that the native TensorFlow graph gets converted to a TensorRT graph. Look again for the following in the log to confirm:\n", + "\n", + "```\n", + "num_nodes(native_tf): ***\n", + "num_nodes(tftrt_total): ***\n", + "num_nodes(trt_only): ***\n", + "graph_size(MB)(native_tf): ***\n", + "graph_size(MB)(tft): ***\n", + "...\n", + "time(s)(trt_conversion): ***\n", + "```\n", + "\n", + "Also notice the following INT8-specific timing information:\n", + "\n", + "```\n", + "time(s)(trt_calibration): ***\n", + "...\n", + "time(s)(trt_int8_conversion): ***\n", + "```\n", + "\n", + "Compare the results with the previous sessions:\n", + "\n", + "```\n", + "results of resnet_v1_50:\n", + " ...\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "Congratulations! You have run inference with an image classification model using various modes of precision and taking advantge of TensorRT inference optimization where possible." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.5.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tftrt/examples/image-classification/image_classification.py b/tftrt/examples/image-classification/image_classification.py index 0090b802a..e70fcad23 100644 --- a/tftrt/examples/image-classification/image_classification.py +++ b/tftrt/examples/image-classification/image_classification.py @@ -38,13 +38,12 @@ def __init__(self, batch_size, num_records, display_every): self.num_steps = (num_records + batch_size - 1) / batch_size self.batch_size = batch_size - def begin(self): + def before_run(self, run_context): self.start_time = time.time() def after_run(self, run_context, run_values): current_time = time.time() duration = current_time - self.start_time - self.start_time = current_time self.iter_times.append(duration) current_step = len(self.iter_times) if current_step % self.display_every == 0: @@ -52,19 +51,53 @@ def after_run(self, run_context, run_values): current_step, self.num_steps, duration * 1000, self.batch_size / self.iter_times[-1])) -def run(frozen_graph, model, data_dir, batch_size, - num_iterations, num_warmup_iterations, use_synthetic, display_every=100): +class BenchmarkHook(tf.train.SessionRunHook): + """Limits run duration and number of iterations""" + def __init__(self, target_duration=None, iteration_limit=None): + self.target_duration = target_duration + self.start_time = None + self.current_iteration = 0 + self.iteration_limit = iteration_limit + + def before_run(self, run_context): + if not self.start_time: + self.start_time = time.time() + if self.target_duration: + print(" running for target duration {} seconds from {}".format( + self.target_duration, time.asctime(time.localtime(self.start_time)))) + + def after_run(self, run_context, run_values): + if self.target_duration: + current_time = time.time() + if (current_time - self.start_time) > self.target_duration: + print(" target duration {} reached at {}, requesting stop".format( + self.target_duration, time.asctime(time.localtime(current_time)))) + run_context.request_stop() + + if self.iteration_limit: + self.current_iteration += 1 + if self.current_iteration >= self.iteration_limit: + run_context.request_stop() + +def run(frozen_graph, model, data_files, batch_size, + num_iterations, num_warmup_iterations, use_synthetic, display_every=100, + mode='validation', target_duration=None): """Evaluates a frozen graph - + This function evaluates a graph on the ImageNet validation set. tf.estimator.Estimator is used to evaluate the accuracy of the model and a few other metrics. The results are returned as a dict. frozen_graph: GraphDef, a graph containing input node 'input' and outputs 'logits' and 'classes' model: string, the model name (see NETS table in graph.py) - data_dir: str, directory containing ImageNet validation TFRecord files + data_files: List of TFRecord files used for inference batch_size: int, batch size for TensorRT optimizations num_iterations: int, number of iterations(batches) to run for + num_warmup_iterations: int, number of iteration(batches) to exclude from benchmark measurments + use_synthetic: bool, if true run using real data, otherwise synthetic + display_every: int, print log every @display_every iteration + mode: validation - using estimator.evaluate with accuracy measurments, + benchmark - using estimator.predict """ # Define model function for tf.estimator.Estimator def model_fn(features, labels, mode): @@ -72,17 +105,19 @@ def model_fn(features, labels, mode): input_map={'input': features}, return_elements=['logits:0', 'classes:0'], name='') - loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits_out) - accuracy = tf.metrics.accuracy(labels=labels, predictions=classes_out, name='acc_op') + if mode == tf.estimator.ModeKeys.PREDICT: + return tf.estimator.EstimatorSpec(mode=mode, + predictions={'classes': classes_out}) if mode == tf.estimator.ModeKeys.EVAL: + loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits_out) + accuracy = tf.metrics.accuracy(labels=labels, predictions=classes_out, name='acc_op') return tf.estimator.EstimatorSpec( mode, loss=loss, eval_metric_ops={'accuracy': accuracy}) - # Create the dataset - preprocess_fn = get_preprocess_fn(model) - validation_files = tf.gfile.Glob(os.path.join(data_dir, 'validation*')) + # preprocess function for input data + preprocess_fn = get_preprocess_fn(model, mode) def get_tfrecords_count(files): num_records = 0 @@ -92,53 +127,85 @@ def get_tfrecords_count(files): return num_records # Define the dataset input function for tf.estimator.Estimator - def eval_input_fn(): + def input_fn(): if use_synthetic: input_width, input_height = get_netdef(model).get_input_dims() features = np.random.normal( loc=112, scale=70, size=(batch_size, input_height, input_width, 3)).astype(np.float32) features = np.clip(features, 0.0, 255.0) - features = tf.identity(tf.constant(features)) labels = np.random.randint( low=0, high=get_netdef(model).get_num_classes(), size=(batch_size), dtype=np.int32) - labels = tf.identity(tf.constant(labels)) + with tf.device('/device:GPU:0'): + features = tf.identity(tf.constant(features)) + labels = tf.identity(tf.constant(labels)) else: - dataset = tf.data.TFRecordDataset(validation_files) - dataset = dataset.apply(tf.contrib.data.map_and_batch(map_func=preprocess_fn, batch_size=batch_size, num_parallel_calls=8)) - dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE) - dataset = dataset.repeat(count=1) - iterator = dataset.make_one_shot_iterator() - features, labels = iterator.get_next() + if mode == 'validation': + dataset = tf.data.TFRecordDataset(data_files) + dataset = dataset.apply(tf.contrib.data.map_and_batch(map_func=preprocess_fn, batch_size=batch_size, num_parallel_calls=8)) + dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE) + dataset = dataset.repeat(count=1) + iterator = dataset.make_one_shot_iterator() + features, labels = iterator.get_next() + elif mode == 'benchmark': + dataset = tf.data.Dataset.from_tensor_slices(data_files) + dataset = dataset.apply(tf.contrib.data.map_and_batch(map_func=preprocess_fn, batch_size=batch_size, num_parallel_calls=8)) + dataset = dataset.repeat(count=1) + iterator = dataset.make_one_shot_iterator() + features = iterator.get_next() + labels = np.random.randint( + low=0, + high=get_netdef(model).get_num_classes(), + size=(batch_size), + dtype=np.int32) + labels = tf.identity(tf.constant(labels)) + else: + raise ValueError("Mode must be either 'validation' or 'benchmark'") return features, labels # Evaluate model + if use_synthetic: + num_records = num_iterations * batch_size + elif mode == 'validation': + num_records = get_tfrecords_count(data_files) + elif mode == 'benchmark': + num_records = len(data_files) + else: + raise ValueError("Mode must be either 'validation' or 'benchmark'") logger = LoggerHook( display_every=display_every, batch_size=batch_size, - num_records=get_tfrecords_count(validation_files)) + num_records=num_records) tf_config = tf.ConfigProto() tf_config.gpu_options.allow_growth = True estimator = tf.estimator.Estimator( model_fn=model_fn, config=tf.estimator.RunConfig(session_config=tf_config), model_dir='model_dir') - results = estimator.evaluate(eval_input_fn, steps=num_iterations, hooks=[logger]) - + results = {} + if mode == 'validation': + results = estimator.evaluate(input_fn, steps=num_iterations, hooks=[logger]) + elif mode == 'benchmark': + benchmark_hook = BenchmarkHook(target_duration=target_duration, iteration_limit=num_iterations) + prediction_results = [p for p in estimator.predict(input_fn, predict_keys=["classes"], hooks=[logger, benchmark_hook])] + else: + raise ValueError("Mode must be either 'validation' or 'benchmark'") # Gather additional results iter_times = np.array(logger.iter_times[num_warmup_iterations:]) results['total_time'] = np.sum(iter_times) results['images_per_sec'] = np.mean(batch_size / iter_times) results['99th_percentile'] = np.percentile(iter_times, q=99, interpolation='lower') * 1000 results['latency_mean'] = np.mean(iter_times) * 1000 + results['latency_median'] = np.median(iter_times) * 1000 + results['latency_min'] = np.min(iter_times) * 1000 return results class NetDef(object): """Contains definition of a model - + name: Name of model url: (optional) Where to download archive containing checkpoint model_dir_in_archive: (optional) Subdirectory in archive containing @@ -176,11 +243,13 @@ def get_input_dims(self): def get_num_classes(self): return self.num_classes + def get_url(self): + return self.url + + def get_netdef(model): - """ - Creates the dictionary NETS with model names as keys and NetDef as values. + """Creates the dictionary NETS with model names as keys and NetDef as values. Returns the NetDef corresponding to the model specified in the parameter. - model: string, the model name (see NETS table) """ NETS = { @@ -203,16 +272,16 @@ def get_netdef(model): 'resnet_v1_50': NetDef( name='resnet_v1_50', - url='http://download.tensorflow.org/models/official/20180601_resnet_v1_imagenet_checkpoint.tar.gz', - model_dir_in_archive='20180601_resnet_v1_imagenet_checkpoint', + url='http://download.tensorflow.org/models/official/20181001_resnet/checkpoints/resnet_imagenet_v1_fp32_20181001.tar.gz', + model_dir_in_archive='resnet_imagenet_v1_fp32_20181001', slim=False, preprocess='vgg', model_fn=official.resnet.imagenet_main.ImagenetModel(resnet_size=50, resnet_version=1)), 'resnet_v2_50': NetDef( name='resnet_v2_50', - url='http://download.tensorflow.org/models/official/20180601_resnet_v2_imagenet_checkpoint.tar.gz', - model_dir_in_archive='20180601_resnet_v2_imagenet_checkpoint', + url='http://download.tensorflow.org/models/official/20181001_resnet/checkpoints/resnet_imagenet_v2_fp32_20181001.tar.gz', + model_dir_in_archive='resnet_imagenet_v2_fp32_20181001', slim=False, preprocess='vgg', model_fn=official.resnet.imagenet_main.ImagenetModel(resnet_size=50, resnet_version=2)), @@ -267,14 +336,14 @@ def deserialize_image_record(record): text = obj['image/class/text'] return imgdata, label, bbox, text -def get_preprocess_fn(model, mode='classification'): +def get_preprocess_fn(model, mode='validation'): """Creates a function to parse and process a TFRecord using the model's parameters model: string, the model name (see NETS table) - mode: string, whether the model is for classification or detection + mode: string, which mode to use (validation or benchmark) returns: function, the preprocessing function for a record """ - def process(record): + def validation_process(record): # Parse TFRecord imgdata, label, bbox, text = deserialize_image_record(record) label -= 1 # Change to 0-based (don't use background class) @@ -285,7 +354,22 @@ def process(record): image = netdef.preprocess(image, netdef.input_height, netdef.input_width, is_training=False) return image, label - return process + def benchmark_process(path): + image = tf.read_file(path) + image = tf.image.decode_jpeg(image, channels=3) + net_def = get_netdef(model) + input_width, input_height = net_def.get_input_dims() + image = net_def.preprocess(image, input_width, input_height, is_training=False) + return image + + if mode == 'validation': + return validation_process + elif mode == 'benchmark': + return benchmark_process + else: + raise ValueError("Mode must be either 'validation' or 'benchmark'") + + def build_classification_graph(model, model_dir=None, default_models_dir='./data'): """Builds an image classification model by name @@ -372,7 +456,7 @@ def get_checkpoint(model, model_dir=None, default_models_dir='.'): if get_netdef(model).url: download_checkpoint(model, model_dir) return find_checkpoint_in_dir(model_dir) - + print('No model_dir was provided and the model does not define a download' \ ' URL.') exit(1) @@ -399,7 +483,17 @@ def find_checkpoint_in_dir(model_dir): checkpoint_path = '.'.join(parts[:ckpt_index+1]) return checkpoint_path + def download_checkpoint(model, destination_path): + #copy files from source to destination (without any directories) + def copy_files(source, destination): + try: + shutil.copy2(source, destination) + except OSError as e: + pass + except shutil.Error as e: + pass + # Make directories if they don't exist. if not os.path.exists(destination_path): os.makedirs(destination_path) @@ -407,7 +501,7 @@ def download_checkpoint(model, destination_path): archive_path = os.path.join(destination_path, os.path.basename(get_netdef(model).url)) if not os.path.isfile(archive_path): - subprocess.call(['wget', '--no-check-certificate', + subprocess.call(['wget', '--no-check-certificate', '-q', get_netdef(model).url, '-O', archive_path]) # Extract. subprocess.call(['tar', '-xzf', archive_path, '-C', destination_path]) @@ -417,21 +511,23 @@ def download_checkpoint(model, destination_path): get_netdef(model).model_dir_in_archive, '*') for f in glob.glob(source_files): - shutil.copy2(f, destination_path) + copy_files(f, destination_path) def get_frozen_graph( model, model_dir=None, use_trt=False, + engine_dir=None, use_dynamic_op=False, precision='fp32', batch_size=8, minimum_segment_size=2, - calib_data_dir=None, + calib_files=None, num_calib_inputs=None, use_synthetic=False, cache=False, - default_models_dir='./data'): + default_models_dir='./data', + max_workspace_size=(1<<32)): """Retreives a frozen GraphDef from model definitions in classification.py and applies TF-TRT model: str, the model name (see NETS table in classification.py) @@ -442,6 +538,7 @@ def get_frozen_graph( """ num_nodes = {} times = {} + graph_sizes = {} # Load from pb file if frozen graph was already created and cached if cache: @@ -456,11 +553,13 @@ def get_frozen_graph( times['loading_frozen_graph'] = time.time() - start_time num_nodes['loaded_frozen_graph'] = len(frozen_graph.node) num_nodes['trt_only'] = len([1 for n in frozen_graph.node if str(n.op)=='TRTEngineOp']) - return frozen_graph, num_nodes, times + graph_sizes['loaded_frozen_graph'] = len(frozen_graph.SerializeToString()) + return frozen_graph, num_nodes, times, graph_sizes # Build graph and load weights frozen_graph = build_classification_graph(model, model_dir, default_models_dir) num_nodes['native_tf'] = len(frozen_graph.node) + graph_sizes['native_tf'] = len(frozen_graph.SerializeToString()) # Convert to TensorRT graph if use_trt: @@ -469,27 +568,41 @@ def get_frozen_graph( input_graph_def=frozen_graph, outputs=['logits', 'classes'], max_batch_size=batch_size, - max_workspace_size_bytes=(4096<<20)-1000, - precision_mode=precision, + max_workspace_size_bytes=max_workspace_size, + precision_mode=precision.upper(), minimum_segment_size=minimum_segment_size, is_dynamic_op=use_dynamic_op ) times['trt_conversion'] = time.time() - start_time num_nodes['tftrt_total'] = len(frozen_graph.node) num_nodes['trt_only'] = len([1 for n in frozen_graph.node if str(n.op)=='TRTEngineOp']) + graph_sizes['trt'] = len(frozen_graph.SerializeToString()) + + if engine_dir: + segment_number = 0 + for node in frozen_graph.node: + if node.op == "TRTEngineOp": + engine = node.attr["serialized_segment"].s + engine_path = engine_dir+'/{}_{}_{}_segment{}.trtengine'.format(model, precision, batch_size, segment_number) + segment_number += 1 + with open(engine_path, "wb") as f: + f.write(engine) if precision == 'int8': calib_graph = frozen_graph + graph_sizes['calib'] = len(calib_graph.SerializeToString()) # INT8 calibration step print('Calibrating INT8...') start_time = time.time() - run(calib_graph, model, calib_data_dir, batch_size, - num_calib_inputs // batch_size, 0, False) + run(calib_graph, model, calib_files, batch_size, + num_calib_inputs // batch_size, 0, use_synthetic=use_synthetic) times['trt_calibration'] = time.time() - start_time start_time = time.time() frozen_graph = trt.calib_graph_to_infer_graph(calib_graph) times['trt_int8_conversion'] = time.time() - start_time + # This is already set but overwriting it here to ensure the right size + graph_sizes['trt'] = len(frozen_graph.SerializeToString()) del calib_graph print('INT8 graph created.') @@ -506,7 +619,7 @@ def get_frozen_graph( f.write(frozen_graph.SerializeToString()) times['saving_frozen_graph'] = time.time() - start_time - return frozen_graph, num_nodes, times + return frozen_graph, num_nodes, times, graph_sizes if __name__ == '__main__': parser = argparse.ArgumentParser(description='Evaluate model') @@ -515,22 +628,25 @@ def get_frozen_graph( 'resnet_v1_50', 'resnet_v2_50', 'resnet_v2_152', 'vgg_16', 'vgg_19', 'inception_v3', 'inception_v4'], help='Which model to use.') - parser.add_argument('--data_dir', type=str, required=True, + parser.add_argument('--data_dir', type=str, default=None, help='Directory containing validation set TFRecord files.') parser.add_argument('--calib_data_dir', type=str, help='Directory containing TFRecord files for calibrating int8.') parser.add_argument('--model_dir', type=str, default=None, help='Directory containing model checkpoint. If not provided, a ' \ 'checkpoint may be downloaded automatically and stored in ' \ - '"{--default_models_dir}/{--model}" for future use.') + '"{--default_models_dir}/{--model}" for future use.') parser.add_argument('--default_models_dir', type=str, default='./data', help='Directory where downloaded model checkpoints will be stored and ' \ 'loaded from if --model_dir is not provided.') parser.add_argument('--use_trt', action='store_true', help='If set, the graph will be converted to a TensorRT graph.') + parser.add_argument('--engine_dir', type=str, default=None, + help='Directory where to write trt engines. Engines are written only if the directory ' \ + 'is provided. This option is ignored when not using tf_trt.') parser.add_argument('--use_trt_dynamic_op', action='store_true', help='If set, TRT conversion will be done using dynamic op instead of statically.') - parser.add_argument('--precision', type=str, choices=['fp32', 'fp16', 'int8'], default='fp32', + parser.add_argument('--precision', type=str, choices=['FP32', 'FP16', 'INT8'], default='FP32', help='Precision mode to use. FP16 and INT8 only work in conjunction with --use_trt') parser.add_argument('--batch_size', type=int, default=8, help='Number of images per batch.') @@ -547,13 +663,19 @@ def get_frozen_graph( parser.add_argument('--num_calib_inputs', type=int, default=500, help='Number of inputs (e.g. images) used for calibration ' '(last batch is skipped in case it is not full)') + parser.add_argument('--max_workspace_size', type=int, default=(1<<32), + help='workspace size in bytes') parser.add_argument('--cache', action='store_true', help='If set, graphs will be saved to disk after conversion. If a converted graph is present on disk, it will be loaded instead of building the graph again.') + parser.add_argument('--mode', choices=['validation', 'benchmark'], default='validation', + help='Which mode to use (validation or benchmark)') + parser.add_argument('--target_duration', type=int, default=None, + help='If set, script will run for specified number of seconds.') args = parser.parse_args() if args.precision != 'fp32' and not args.use_trt: raise ValueError('TensorRT must be enabled for fp16 or int8 modes (--use_trt).') - if args.precision == 'int8' and not args.calib_data_dir: + if args.precision == 'int8' and not args.calib_data_dir and not args.use_synthetic: raise ValueError('--calib_data_dir is required for int8 mode') if args.num_iterations is not None and args.num_iterations <= args.num_warmup_iterations: raise ValueError('--num_iterations must be larger than --num_warmup_iterations ' @@ -561,28 +683,59 @@ def get_frozen_graph( if args.num_calib_inputs < args.batch_size: raise ValueError('--num_calib_inputs must not be smaller than --batch_size' '({} <= {})'.format(args.num_calib_inputs, args.batch_size)) + if args.mode == 'validation' and args.use_synthetic: + raise ValueError('Cannot use both validation mode and synthetic dataset') + if args.data_dir is None and not args.use_synthetic: + raise ValueError("--data_dir required if you are not using synthetic data") + if args.use_synthetic and args.num_iterations is None: + raise ValueError("--num_iterations is required for --use_synthetic") + + def get_files(data_dir, filename_pattern): + if data_dir == None: + return [] + files = tf.gfile.Glob(os.path.join(data_dir, filename_pattern)) + if files == []: + raise ValueError('Can not find any files in {} with ' + 'pattern "{}"'.format(data_dir, filename_pattern)) + return files + + calib_files = [] + data_files = [] + if not args.use_synthetic: + if args.mode == "validation": + data_files = get_files(args.data_dir, 'validation*') + elif args.mode == "benchmark": + data_files = [os.path.join(path, name) for path, _, files in os.walk(args.data_dir) for name in files] + else: + raise ValueError("Mode must be either 'validation' or 'benchamark'") + calib_files = get_files(args.calib_data_dir, 'train*') - # Retreive graph using NETS table in graph.py - frozen_graph, num_nodes, times = get_frozen_graph( + frozen_graph, num_nodes, times, graph_sizes = get_frozen_graph( model=args.model, model_dir=args.model_dir, use_trt=args.use_trt, + engine_dir=args.engine_dir, use_dynamic_op=args.use_trt_dynamic_op, precision=args.precision, batch_size=args.batch_size, minimum_segment_size=args.minimum_segment_size, - calib_data_dir=args.calib_data_dir, + calib_files=calib_files, num_calib_inputs=args.num_calib_inputs, use_synthetic=args.use_synthetic, cache=args.cache, - default_models_dir=args.default_models_dir) + default_models_dir=args.default_models_dir, + max_workspace_size=args.max_workspace_size) - def print_dict(input_dict, str=''): + def print_dict(input_dict, str='', scale=None): for k, v in sorted(input_dict.items()): headline = '{}({}): '.format(str, k) if str else '{}: '.format(k) + v = v * scale if scale else v print('{}{}'.format(headline, '%.1f'%v if type(v)==float else v)) + print_dict(vars(args)) + print("url: " + get_netdef(args.model).get_url()) print_dict(num_nodes, str='num_nodes') + print_dict(graph_sizes, str='graph_size(MB)', scale=1./(1<<20)) print_dict(times, str='time(s)') # Evaluate model @@ -590,17 +743,22 @@ def print_dict(input_dict, str=''): results = run( frozen_graph, model=args.model, - data_dir=args.data_dir, + data_files=data_files, batch_size=args.batch_size, num_iterations=args.num_iterations, num_warmup_iterations=args.num_warmup_iterations, use_synthetic=args.use_synthetic, - display_every=args.display_every) + display_every=args.display_every, + mode=args.mode, + target_duration=args.target_duration) # Display results print('results of {}:'.format(args.model)) - print(' accuracy: %.2f' % (results['accuracy'] * 100)) + if args.mode == 'validation': + print(' accuracy: %.2f' % (results['accuracy'] * 100)) print(' images/sec: %d' % results['images_per_sec']) - print(' 99th_percentile(ms): %.1f' % results['99th_percentile']) + print(' 99th_percentile(ms): %.2f' % results['99th_percentile']) print(' total_time(s): %.1f' % results['total_time']) - print(' latency_mean(ms): %.1f' % results['latency_mean']) + print(' latency_mean(ms): %.2f' % results['latency_mean']) + print(' latency_median(ms): %.2f' % results['latency_median']) + print(' latency_min(ms): %.2f' % results['latency_min']) diff --git a/tftrt/examples/image-classification/install_dependencies.sh b/tftrt/examples/image-classification/install_dependencies.sh new file mode 100755 index 000000000..27fd5767f --- /dev/null +++ b/tftrt/examples/image-classification/install_dependencies.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +set +e + +TF_MODELS_DIR=$PWD/../third_party/models + +echo Install slim +pushd $TF_MODELS_DIR/research/slim +pip install . +popd + +echo Install requests +pip install requests + diff --git a/tftrt/examples/object_detection/README.md b/tftrt/examples/object_detection/README.md new file mode 100644 index 000000000..89dead0d2 --- /dev/null +++ b/tftrt/examples/object_detection/README.md @@ -0,0 +1,138 @@ +# TensorRT / TensorFlow Object Detection + +This package demonstrated object detection using TensorRT integration in TensorFlow. +It includes utilities for accuracy and performance benchmarking, along with +utilities for model construction and optimization. + +* [Setup](#setup) +* [Download](#od_download) +* [Optimize](#od_optimize) +* [Benchmark](#od_benchmark) +* [Test](#od_test) + + +## Setup + +1. Install object detection dependencies (from tftrt/examples/object_detection) + +```bash +git submodule update --init +./install_dependencies.sh +``` + +2. Ensure you've installed the tftrt package (from root folder of repository) + +```bash +python setup.py install --user +``` + + +## Object Detection + + +### Download +```python +from tftrt.examples.object_detection import download_model + +config_path, checkpoint_path = download_model('ssd_mobilenet_v1_coco', output_dir='models') +# help(download_model) for more +``` + + +### Optimize + +```python +from tftrt.examples.object_detection import optimize_model + +frozen_graph = optimize_model( + config_path=config_path, + checkpoint_path=checkpoint_path, + use_trt=True, + precision_mode='FP16' +) +# help(optimize_model) for other parameters +``` + + +### Benchmark + +First, we download the validation dataset + +```python +from tftrt.examples.object_detection import download_dataset + +images_dir, annotation_path = download_dataset('val2014', output_dir='dataset') +# help(download_dataset) for more +``` + +Next, we run inference over the dataset to benchmark the optimized model + +```python +from tftrt.examples.object_detection import benchmark_model + +statistics = benchmark_model( + frozen_graph=frozen_graph, + images_dir=images_dir, + annotation_path=annotation_path +) +# help(benchmark_model) for more parameters +``` + + +### Test +To simplify evaluation of different models with different optimization parameters +we include a ``test`` function that ingests a JSON file containing test arguments +and combines the model download, optimization, and benchmark steps. Below is an +example JSON file, call it ``my_test.json`` + +```json +{ + "source_model": { + "model_name": "ssd_inception_v2_coco", + "output_dir": "models" + }, + "optimization_config": { + "use_trt": true, + "precision_mode": "FP16", + "force_nms_cpu": true, + "replace_relu6": true, + "remove_assert": true, + "override_nms_score_threshold": 0.3, + "max_batch_size": 1 + }, + "benchmark_config": { + "images_dir": "coco/val2017", + "annotation_path": "coco/annotations/instances_val2017.json", + "batch_size": 1, + "image_shape": [600, 600], + "num_images": 4096, + "output_path": "stats/ssd_inception_v2_coco_trt_fp16.json" + }, + "assertions": [ + "statistics['map'] > (0.268 - 0.005)" + ] +} +``` + +We execute the test using the ``test`` python function + +```python +from tftrt.examples.object_detection import test + +test('my_test.json') +# help(test) for more details +``` + +Alternatively, we can directly call the object_detection.test module, which +is configured to execute this function by default. + +```shell +python -m tftrt.examples.object_detection.test my_test.json +``` + +For the example configuration shown above, the following steps will be performed + +1. Downloads ssd_inception_v2_coco +2. Optimizes with TensorRT and FP16 precision +3. Benchmarks against the MSCOCO 2017 validation dataset +4. Asserts that the MAP is greater than some reference value diff --git a/tftrt/examples/object_detection/__init__.py b/tftrt/examples/object_detection/__init__.py new file mode 100644 index 000000000..d7675e24e --- /dev/null +++ b/tftrt/examples/object_detection/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from .object_detection import download_model, download_dataset, optimize_model, benchmark_model +from .test import test diff --git a/tftrt/examples/object_detection/graph_utils.py b/tftrt/examples/object_detection/graph_utils.py new file mode 100644 index 000000000..775127abb --- /dev/null +++ b/tftrt/examples/object_detection/graph_utils.py @@ -0,0 +1,108 @@ +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +import tensorflow as tf + + +def make_const6(const6_name='const6'): + graph = tf.Graph() + with graph.as_default(): + tf_6 = tf.constant(dtype=tf.float32, value=6.0, name=const6_name) + return graph.as_graph_def() + + +def make_relu6(output_name, input_name, const6_name='const6'): + graph = tf.Graph() + with graph.as_default(): + tf_x = tf.placeholder(tf.float32, [10, 10], name=input_name) + tf_6 = tf.constant(dtype=tf.float32, value=6.0, name=const6_name) + with tf.name_scope(output_name): + tf_y1 = tf.nn.relu(tf_x, name='relu1') + tf_y2 = tf.nn.relu(tf.subtract(tf_x, tf_6, name='sub1'), name='relu2') + + #tf_y = tf.nn.relu(tf.subtract(tf_6, tf.nn.relu(tf_x, name='relu1'), name='sub'), name='relu2') + #tf_y = tf.subtract(tf_6, tf_y, name=output_name) + tf_y = tf.subtract(tf_y1, tf_y2, name=output_name) + + graph_def = graph.as_graph_def() + graph_def.node[-1].name = output_name + + # remove unused nodes + for node in graph_def.node: + if node.name == input_name: + graph_def.node.remove(node) + for node in graph_def.node: + if node.name == const6_name: + graph_def.node.remove(node) + for node in graph_def.node: + if node.op == '_Neg': + node.op = 'Neg' + + return graph_def + + +def convert_relu6(graph_def, const6_name='const6'): + # add constant 6 + has_const6 = False + for node in graph_def.node: + if node.name == const6_name: + has_const6 = True + if not has_const6: + const6_graph_def = make_const6(const6_name=const6_name) + graph_def.node.extend(const6_graph_def.node) + + for node in graph_def.node: + if node.op == 'Relu6': + input_name = node.input[0] + output_name = node.name + relu6_graph_def = make_relu6(output_name, input_name, const6_name=const6_name) + graph_def.node.remove(node) + graph_def.node.extend(relu6_graph_def.node) + + return graph_def + + +def remove_node(graph_def, node): + for n in graph_def.node: + if node.name in n.input: + n.input.remove(node.name) + ctrl_name = '^' + node.name + if ctrl_name in n.input: + n.input.remove(ctrl_name) + graph_def.node.remove(node) + + +def remove_op(graph_def, op_name): + matches = [node for node in graph_def.node if node.op == op_name] + for match in matches: + remove_node(graph_def, match) + + +def force_nms_cpu(frozen_graph): + for node in frozen_graph.node: + if 'NonMaxSuppression' in node.name: + node.device = '/device:CPU:0' + return frozen_graph + + +def replace_relu6(frozen_graph): + return convert_relu6(frozen_graph) + + +def remove_assert(frozen_graph): + remove_op(frozen_graph, 'Assert') + return frozen_graph diff --git a/tftrt/examples/object_detection/install_dependencies.sh b/tftrt/examples/object_detection/install_dependencies.sh new file mode 100755 index 000000000..8633f1939 --- /dev/null +++ b/tftrt/examples/object_detection/install_dependencies.sh @@ -0,0 +1,73 @@ +#!/bin/bash +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +echo Setup local variables... +TF_MODELS_DIR=../third_party/models +RESEARCH_DIR=$TF_MODELS_DIR/research +SLIM_DIR=$RESEARCH_DIR/slim +COCO_API_DIR=../third_party/cocoapi +PYCOCO_DIR=$COCO_API_DIR/PythonAPI +PROTO_BASE_URL="https://github.com/google/protobuf/releases/download/v3.5.1/" +PROTOC_DIR=$PWD/protoc + +#echo Install python-tk ... +#python -V 2>&1 | grep "Python 3" || \ +# ( export DEBIAN_FRONTEND=noninteractive && \ +# apt-get update && \ +# apt-get install -y --no-install-recommends python-tk ) + +set -v + +echo Download protobuf... +mkdir -p $PROTOC_DIR +pushd $PROTOC_DIR +ARCH=$(uname -m) +if [ "$ARCH" == "aarch64" ] ; then + filename="protoc-3.5.1-linux-aarch_64.zip" +elif [ "$ARCH" == "x86_64" ] ; then + filename="protoc-3.5.1-linux-x86_64.zip" +else + echo ERROR: $ARCH not supported. + exit 1; +fi +wget --no-check-certificate ${PROTO_BASE_URL}${filename} +unzip -o ${filename} +popd + +echo Compile object detection protobuf files... +pushd $RESEARCH_DIR +$PROTOC_DIR/bin/protoc object_detection/protos/*.proto --python_out=. +popd + +echo Install tensorflow/models/research... +pushd $RESEARCH_DIR +pip install . +popd + +echo Install tensorflow/models/research/slim... +pushd $SLIM_DIR +pip install . +popd + +echo Install cocodataset/cocoapi/PythonAPI... +pushd $PYCOCO_DIR +python setup.py build_ext --inplace +make +# pip install . +python setup.py install +popd diff --git a/tftrt/examples/object_detection/object_detection.py b/tftrt/examples/object_detection/object_detection.py new file mode 100644 index 000000000..c00a17784 --- /dev/null +++ b/tftrt/examples/object_detection/object_detection.py @@ -0,0 +1,687 @@ +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + + +from __future__ import absolute_import + +import tensorflow as tf +import tensorflow.contrib.tensorrt as trt +import pdb + +from collections import namedtuple +from PIL import Image +import numpy as np +import time +import json +import subprocess +import os +import glob + +from .graph_utils import force_nms_cpu as f_force_nms_cpu +from .graph_utils import replace_relu6 as f_replace_relu6 +from .graph_utils import remove_assert as f_remove_assert + +from google.protobuf import text_format +from object_detection.protos import pipeline_pb2, image_resizer_pb2 +from object_detection import exporter + +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval + +Model = namedtuple('Model', ['name', 'url', 'extract_dir']) + +INPUT_NAME = 'image_tensor' +BOXES_NAME = 'detection_boxes' +CLASSES_NAME = 'detection_classes' +SCORES_NAME = 'detection_scores' +MASKS_NAME = 'detection_masks' +NUM_DETECTIONS_NAME = 'num_detections' +FROZEN_GRAPH_NAME = 'frozen_inference_graph.pb' +PIPELINE_CONFIG_NAME = 'pipeline.config' +CHECKPOINT_PREFIX = 'model.ckpt' + +MODELS = { + 'ssd_mobilenet_v1_coco': + Model( + 'ssd_mobilenet_v1_coco', + 'http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28.tar.gz', + 'ssd_mobilenet_v1_coco_2018_01_28', + ), + 'ssd_mobilenet_v1_0p75_depth_quantized_coco': + Model( + 'ssd_mobilenet_v1_0p75_depth_quantized_coco', + 'http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_0.75_depth_quantized_300x300_coco14_sync_2018_07_18.tar.gz', + 'ssd_mobilenet_v1_0.75_depth_quantized_300x300_coco14_sync_2018_07_18' + ), + 'ssd_mobilenet_v1_ppn_coco': + Model( + 'ssd_mobilenet_v1_ppn_coco', + 'http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_ppn_shared_box_predictor_300x300_coco14_sync_2018_07_03.tar.gz', + 'ssd_mobilenet_v1_ppn_shared_box_predictor_300x300_coco14_sync_2018_07_03' + ), + 'ssd_mobilenet_v1_fpn_coco': + Model( + 'ssd_mobilenet_v1_fpn_coco', + 'http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_fpn_shared_box_predictor_640x640_coco14_sync_2018_07_03.tar.gz', + 'ssd_mobilenet_v1_fpn_shared_box_predictor_640x640_coco14_sync_2018_07_03' + ), + 'ssd_mobilenet_v2_coco': + Model( + 'ssd_mobilenet_v2_coco', + 'http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v2_coco_2018_03_29.tar.gz', + 'ssd_mobilenet_v2_coco_2018_03_29', + ), + 'ssdlite_mobilenet_v2_coco': + Model( + 'ssdlite_mobilenet_v2_coco', + 'http://download.tensorflow.org/models/object_detection/ssdlite_mobilenet_v2_coco_2018_05_09.tar.gz', + 'ssdlite_mobilenet_v2_coco_2018_05_09'), + 'ssd_inception_v2_coco': + Model( + 'ssd_inception_v2_coco', + 'http://download.tensorflow.org/models/object_detection/ssd_inception_v2_coco_2018_01_28.tar.gz', + 'ssd_inception_v2_coco_2018_01_28', + ), + 'ssd_resnet_50_fpn_coco': + Model( + 'ssd_resnet_50_fpn_coco', + 'http://download.tensorflow.org/models/object_detection/ssd_resnet50_v1_fpn_shared_box_predictor_640x640_coco14_sync_2018_07_03.tar.gz', + 'ssd_resnet50_v1_fpn_shared_box_predictor_640x640_coco14_sync_2018_07_03', + ), + 'faster_rcnn_resnet50_coco': + Model( + 'faster_rcnn_resnet50_coco', + 'http://download.tensorflow.org/models/object_detection/faster_rcnn_resnet50_coco_2018_01_28.tar.gz', + 'faster_rcnn_resnet50_coco_2018_01_28', + ), + 'faster_rcnn_nas': + Model( + 'faster_rcnn_nas', + 'http://download.tensorflow.org/models/object_detection/faster_rcnn_nas_coco_2018_01_28.tar.gz', + 'faster_rcnn_nas_coco_2018_01_28', + ), + 'mask_rcnn_resnet50_atrous_coco': + Model( + 'mask_rcnn_resnet50_atrous_coco', + 'http://download.tensorflow.org/models/object_detection/mask_rcnn_resnet50_atrous_coco_2018_01_28.tar.gz', + 'mask_rcnn_resnet50_atrous_coco_2018_01_28', + ), + 'facessd_mobilenet_v2_quantized_open_image_v4': + Model( + 'facessd_mobilenet_v2_quantized_open_image_v4', + 'http://download.tensorflow.org/models/object_detection/facessd_mobilenet_v2_quantized_320x320_open_image_v4.tar.gz', + 'facessd_mobilenet_v2_quantized_320x320_open_image_v4') +} + +Dataset = namedtuple( + 'Dataset', + ['images_url', 'images_dir', 'annotation_url', 'annotation_path']) + +DATASETS = { + 'val2014': + Dataset( + 'http://images.cocodataset.org/zips/val2014.zip', 'val2014', + 'http://images.cocodataset.org/annotations/annotations_trainval2014.zip', + 'annotations/instances_val2014.json'), + 'train2014': + Dataset( + 'http://images.cocodataset.org/zips/train2014.zip', 'train2014', + 'http://images.cocodataset.org/annotations/annotations_trainval2014.zip', + 'annotations/instances_train2014.json'), + 'val2017': + Dataset( + 'http://images.cocodataset.org/zips/val2017.zip', 'val2017', + 'http://images.cocodataset.org/annotations/annotations_trainval2017.zip', + 'annotations/instances_val2017.json'), + 'train2017': + Dataset( + 'http://images.cocodataset.org/zips/train2017.zip', 'train2017', + 'http://images.cocodataset.org/annotations/annotations_trainval2017.zip', + 'annotations/instances_train2017.json') +} + + +def download_model(model_name, output_dir='.'): + """Downloads a model from the TensorFlow Object Detection API + + Downloads a model from the TensorFlow Object Detection API to a specific + output directory. The download will be skipped if an existing directory + for the selected model already found under output_dir. + + Args + ---- + model_name: A string representing the model to download. This must be + one of the keys in the module variable + ``trt_samples.object_detection.MODELS``. + output_dir: A string representing the directory to download the model + under. A directory for the specified model will be created at + ``output_dir/``. If output_dir/ + already exists, then the download will be skipped. + + Returns + ------- + config_path: A string representing the path to the object detection + pipeline configuration file of the downloaded model. + checkpoint_path: A string representing the path to the object detection + model checkpoint. + """ + global MODELS + + model_name + + model = MODELS[model_name] + + # make output directory if it doesn't exist + subprocess.call(['mkdir', '-p', output_dir]) + + tar_file = os.path.join(output_dir, os.path.basename(model.url)) + + config_path = os.path.join(output_dir, model.extract_dir, + PIPELINE_CONFIG_NAME) + checkpoint_path = os.path.join(output_dir, model.extract_dir, + CHECKPOINT_PREFIX) + + extract_dir = os.path.join(output_dir, model.extract_dir) + if os.path.exists(extract_dir): + print('Using cached model found at: %s' % extract_dir) + else: + subprocess.call(['wget', '-q', model.url, '-O', tar_file]) + subprocess.call(['tar', '-xzf', tar_file, '-C', output_dir]) + + # hack fix to handle mobilenet_v2 config bug + subprocess.call(['sed', '-i', '/batch_norm_trainable/d', config_path]) + + return config_path, checkpoint_path + + +def optimize_model(config_path, + checkpoint_path, + use_trt=True, + force_nms_cpu=True, + replace_relu6=True, + remove_assert=True, + override_nms_score_threshold=None, + override_resizer_shape=None, + max_batch_size=1, + precision_mode='FP32', + minimum_segment_size=2, + max_workspace_size_bytes=1 << 32, + maximum_cached_engines=100, + calib_images_dir=None, + num_calib_images=None, + calib_image_shape=None, + tmp_dir='.optimize_model_tmp_dir', + remove_tmp_dir=True, + output_path=None, + display_every=100): + """Optimizes an object detection model using TensorRT + + Optimizes an object detection model using TensorRT. This method also + performs pre-tensorrt optimizations specific to the TensorFlow object + detection API models. Please see the list of arguments for other + optimization parameters. + + Args + ---- + config_path: A string representing the path of the object detection + pipeline config file. + checkpoint_path: A string representing the path of the object + detection model checkpoint. + use_trt: A boolean representing whether to optimize with TensorRT. If + False, regular TensorFlow will be used but other optimizations + (like NMS device placement) will still be applied. + force_nms_cpu: A boolean indicating whether to place NMS operations on + the CPU. + replace_relu6: A boolean indicating whether to replace relu6(x) + operations with relu(x) - relu(x-6). + remove_assert: A boolean indicating whether to remove Assert + operations from the graph. + override_nms_score_threshold: An optional float representing + a NMS score threshold to override that specified in the object + detection configuration file. + override_resizer_shape: An optional list/tuple of integers + representing a fixed shape to override the default image resizer + specified in the object detection configuration file. + max_batch_size: An integer representing the max batch size to use for + TensorRT optimization. + precision_mode: A string representing the precision mode to use for + TensorRT optimization. Must be one of 'FP32', 'FP16', or 'INT8'. + minimum_segment_size: An integer representing the minimum segment size + to use for TensorRT graph segmentation. + max_workspace_size_bytes: An integer representing the max workspace + size for TensorRT optimization. + maximum_cached_engines: An integer represenging the number of TRT engines + that can be stored in the cache. + calib_images_dir: A string representing a directory containing images to + use for int8 calibration. + num_calib_images: An integer representing the number of calibration + images to use. If None, will use all images in directory. + calib_image_shape: A tuple of integers representing the height, + width that images will be resized to for calibration. + tmp_dir: A string representing a directory for temporary files. This + directory will be created and removed by this function and should + not already exist. If the directory exists, an error will be + thrown. + remove_tmp_dir: A boolean indicating whether we should remove the + tmp_dir or throw error. + output_path: An optional string representing the path to save the + optimized GraphDef to. + display_every: print log for calibration every display_every iteration + + Returns + ------- + A GraphDef representing the optimized model. + """ + if max_batch_size > 1 and calib_image_shape is None: + raise RuntimeError( + 'Fixed calibration image shape must be provided for max_batch_size > 1') + if os.path.exists(tmp_dir): + if not remove_tmp_dir: + raise RuntimeError( + 'Cannot create temporary directory, path exists: %s' % tmp_dir) + subprocess.call(['rm', '-rf', tmp_dir]) + + # load config from file + config = pipeline_pb2.TrainEvalPipelineConfig() + with open(config_path, 'r') as f: + text_format.Merge(f.read(), config, allow_unknown_extension=True) + + # override some config parameters + if config.model.HasField('ssd'): + config.model.ssd.feature_extractor.override_base_feature_extractor_hyperparams = True + if override_nms_score_threshold is not None: + config.model.ssd.post_processing.batch_non_max_suppression.score_threshold = override_nms_score_threshold + if override_resizer_shape is not None: + config.model.ssd.image_resizer.fixed_shape_resizer.height = override_resizer_shape[ + 0] + config.model.ssd.image_resizer.fixed_shape_resizer.width = override_resizer_shape[ + 1] + elif config.model.HasField('faster_rcnn'): + if override_nms_score_threshold is not None: + config.model.faster_rcnn.second_stage_post_processing.batch_non_max_suppression.score_threshold = override_nms_score_threshold + if override_resizer_shape is not None: + config.model.faster_rcnn.image_resizer.fixed_shape_resizer.height = override_resizer_shape[ + 0] + config.model.faster_rcnn.image_resizer.fixed_shape_resizer.width = override_resizer_shape[ + 1] + + tf_config = tf.ConfigProto() + tf_config.gpu_options.allow_growth = True + + # export inference graph to file (initial), this will create tmp_dir + with tf.Session(config=tf_config): + with tf.Graph().as_default(): + exporter.export_inference_graph( + INPUT_NAME, + config, + checkpoint_path, + tmp_dir, + input_shape=[max_batch_size, None, None, 3]) + + # read frozen graph from file + frozen_graph_path = os.path.join(tmp_dir, FROZEN_GRAPH_NAME) + frozen_graph = tf.GraphDef() + with open(frozen_graph_path, 'rb') as f: + frozen_graph.ParseFromString(f.read()) + + # apply graph modifications + if force_nms_cpu: + frozen_graph = f_force_nms_cpu(frozen_graph) + if replace_relu6: + frozen_graph = f_replace_relu6(frozen_graph) + if remove_assert: + frozen_graph = f_remove_assert(frozen_graph) + + # get input names + output_names = [BOXES_NAME, CLASSES_NAME, SCORES_NAME, NUM_DETECTIONS_NAME] + + # optionally perform TensorRT optimization + if use_trt: + runtimes = [] + with tf.Graph().as_default() as tf_graph: + with tf.Session(config=tf_config) as tf_sess: + graph_size = len(frozen_graph.SerializeToString()) + num_nodes = len(frozen_graph.node) + start_time = time.time() + frozen_graph = trt.create_inference_graph( + input_graph_def=frozen_graph, + outputs=output_names, + max_batch_size=max_batch_size, + max_workspace_size_bytes=max_workspace_size_bytes, + precision_mode=precision_mode, + minimum_segment_size=minimum_segment_size, + is_dynamic_op=True, + maximum_cached_engines=maximum_cached_engines) + end_time = time.time() + print("graph_size(MB)(native_tf): %.1f" % (float(graph_size)/(1<<20))) + print("graph_size(MB)(trt): %.1f" % + (float(len(frozen_graph.SerializeToString()))/(1<<20))) + print("num_nodes(native_tf): %d" % num_nodes) + print("num_nodes(tftrt_total): %d" % len(frozen_graph.node)) + print("num_nodes(trt_only): %d" % len([1 for n in frozen_graph.node if str(n.op)=='TRTEngineOp'])) + print("time(s) (trt_conversion): %.4f" % (end_time - start_time)) + + # perform calibration for int8 precision + if precision_mode == 'INT8': + + if calib_images_dir is None: + raise ValueError('calib_images_dir must be provided for int8 optimization.') + + tf.import_graph_def(frozen_graph, name='') + tf_input = tf_graph.get_tensor_by_name(INPUT_NAME + ':0') + tf_boxes = tf_graph.get_tensor_by_name(BOXES_NAME + ':0') + tf_classes = tf_graph.get_tensor_by_name(CLASSES_NAME + ':0') + tf_scores = tf_graph.get_tensor_by_name(SCORES_NAME + ':0') + tf_num_detections = tf_graph.get_tensor_by_name( + NUM_DETECTIONS_NAME + ':0') + + image_paths = glob.glob(os.path.join(calib_images_dir, '*.jpg')) + image_paths = image_paths[0:num_calib_images] + + for image_idx in range(0, len(image_paths), max_batch_size): + + # read batch of images + batch_images = [] + for image_path in image_paths[image_idx:image_idx+max_batch_size]: + image = _read_image(image_path, calib_image_shape) + batch_images.append(image) + + t0 = time.time() + # execute batch of images + boxes, classes, scores, num_detections = tf_sess.run( + [tf_boxes, tf_classes, tf_scores, tf_num_detections], + feed_dict={tf_input: batch_images}) + t1 = time.time() + runtimes.append(float(t1 - t0)) + if len(runtimes) % display_every == 0: + print(" step %d/%d, iter_time(ms)=%.4f" % ( + len(runtimes), + (len(image_path) + max_batch_size - 1) / max_batch_size, + np.mean(runtimes) * 1000)) + + pdb.set_trace() + frozen_graph = trt.calib_graph_to_infer_graph(frozen_graph) + + # re-enable variable batch size, this was forced to max + # batch size during export to enable TensorRT optimization + for node in frozen_graph.node: + if INPUT_NAME == node.name: + node.attr['shape'].shape.dim[0].size = -1 + + # write optimized model to disk + if output_path is not None: + with open(output_path, 'wb') as f: + f.write(frozen_graph.SerializeToString()) + + # remove temporary directory + subprocess.call(['rm', '-rf', tmp_dir]) + + return frozen_graph + + +def download_dataset(dataset_name, output_dir='.'): + """Downloads a COCO dataset + + Downloads a COCO dataset to the specified output directory. A new + directory corresponding to the specified dataset will be created under + output_dir. This directory will contain the images of the dataset. + + Args + ---- + dataset_name: A string representing the name of the dataset, it must + be one of the keys in trt_samples.object_detection.DATASETS. + + Returns + ------- + images_dir: A string representing the path of the directory containing + images of the dataset. + annotation_path: A string representing the path of the COCO annotation + file for the dataset. + """ + global DATASETS + + dataset = DATASETS[dataset_name] + + subprocess.call(['mkdir', '-p', output_dir]) + + images_dir = os.path.join(output_dir, dataset.images_dir) + images_zip_file = os.path.join(output_dir, + os.path.basename(dataset.images_url)) + annotation_path = os.path.join(output_dir, dataset.annotation_path) + annotation_zip_file = os.path.join( + output_dir, os.path.basename(dataset.annotation_url)) + + # download or use cached annotation + if os.path.exists(annotation_path): + print('Using cached annotation_path; %s' % (annotation_path)) + else: + subprocess.call( + ['wget', '-q', dataset.annotation_url, '-O', annotation_zip_file]) + subprocess.call(['unzip', annotation_zip_file, '-d', output_dir]) + + # download or use cached images + if os.path.exists(images_dir): + print('Using cached images_dir; %s' % (images_dir)) + else: + subprocess.call(['wget', '-q', dataset.images_url, '-O', images_zip_file]) + subprocess.call(['unzip', images_zip_file, '-d', output_dir]) + + return images_dir, annotation_path + + +def benchmark_model(frozen_graph, + images_dir, + annotation_path, + batch_size=1, + image_shape=None, + num_images=4096, + tmp_dir='.benchmark_model_tmp_dir', + remove_tmp_dir=True, + output_path=None, + display_every=100, + use_synthetic=False, + num_warmup_iterations=50): + """Computes accuracy and performance statistics + + Computes accuracy and performance statistics by executing over many images + from the MSCOCO dataset defined by images_dir and annotation_path. + + Args + ---- + frozen_graph: A GraphDef representing the object detection model to + test. Alternatively, a string representing the path to the saved + frozen graph. + images_dir: A string representing the path of the COCO images + directory. + annotation_path: A string representing the path of the COCO annotation + file. + batch_size: An integer representing the batch size to use when feeding + images to the model. + image_shape: An optional tuple of integers representing a fixed shape + to resize all images before testing. For synthetic data the default + image_shape is [600, 600, 3] + num_images: An integer representing the number of images in the + dataset to evaluate with. + tmp_dir: A string representing the path where the function may create + a temporary directory to store intermediate files. + output_path: An optional string representing a path to store the + statistics in JSON format. + display_every: int, print log every display_every iteration + num_warmup_iteration: An integer represtening number of initial iteration, + that are not cover in performance statistics + Returns + ------- + statistics: A named dictionary of accuracy and performance statistics + computed for the model. + """ + if os.path.exists(tmp_dir): + if not remove_tmp_dir: + raise RuntimeError('Temporary directory exists; %s' % tmp_dir) + subprocess.call(['rm', '-rf', tmp_dir]) + if batch_size > 1 and image_shape is None: + raise RuntimeError( + 'Fixed image shape must be provided for batch size > 1') + + if not use_synthetic: + coco = COCO(annotation_file=annotation_path) + + + # get list of image ids to use for evaluation + image_ids = coco.getImgIds() + if num_images > len(image_ids): + print( + 'Num images provided %d exceeds number in dataset %d, using %d images instead' + % (num_images, len(image_ids), len(image_ids))) + num_images = len(image_ids) + image_ids = image_ids[0:num_images] + + # load frozen graph from file if string, otherwise must be GraphDef + if isinstance(frozen_graph, str): + frozen_graph_path = frozen_graph + frozen_graph = tf.GraphDef() + with open(frozen_graph_path, 'rb') as f: + frozen_graph.ParseFromString(f.read()) + elif not isinstance(frozen_graph, tf.GraphDef): + raise TypeError('Expected frozen_graph to be GraphDef or str') + + tf_config = tf.ConfigProto() + tf_config.gpu_options.allow_growth = True + + coco_detections = [] # list of all bounding box detections in coco format + runtimes = [] # list of runtimes for each batch + image_counts = [] # list of number of images in each batch + + with tf.Graph().as_default() as tf_graph: + with tf.Session(config=tf_config) as tf_sess: + tf.import_graph_def(frozen_graph, name='') + tf_input = tf_graph.get_tensor_by_name(INPUT_NAME + ':0') + tf_boxes = tf_graph.get_tensor_by_name(BOXES_NAME + ':0') + tf_classes = tf_graph.get_tensor_by_name(CLASSES_NAME + ':0') + tf_scores = tf_graph.get_tensor_by_name(SCORES_NAME + ':0') + tf_num_detections = tf_graph.get_tensor_by_name( + NUM_DETECTIONS_NAME + ':0') + + # load batches from coco dataset + for image_idx in range(0, num_images, batch_size): + if use_synthetic: + if image_shape is None: + batch_images = np.random.randint(256, size=(batch_size, 600, 600, 3)) + else: + batch_images = np.random(256, size=(batch_size, image_shape[0], image_shape[1], 3)) + else: + batch_image_ids = image_ids[image_idx:image_idx + batch_size] + batch_images = [] + batch_coco_images = [] + # read images from file + for image_id in batch_image_ids: + coco_img = coco.imgs[image_id] + batch_coco_images.append(coco_img) + image_path = os.path.join(images_dir, + coco_img['file_name']) + image = _read_image(image_path, image_shape) + batch_images.append(image) + + # run num_warmup_iterations outside of timing + if image_idx < num_warmup_iterations: + boxes, classes, scores, num_detections = tf_sess.run( + [tf_boxes, tf_classes, tf_scores, tf_num_detections], + feed_dict={tf_input: batch_images}) + else: + # execute model and compute time difference + t0 = time.time() + boxes, classes, scores, num_detections = tf_sess.run( + [tf_boxes, tf_classes, tf_scores, tf_num_detections], + feed_dict={tf_input: batch_images}) + t1 = time.time() + + # log runtime and image count + runtimes.append(float(t1 - t0)) + if len(runtimes) % display_every == 0: + print(" step %d/%d, iter_time(ms)=%.4f" % ( + len(runtimes), + (len(image_ids) + batch_size - 1) / batch_size, + np.mean(runtimes) * 1000)) + image_counts.append(len(batch_images)) + + if not use_synthetic: + # add coco detections for this batch to running list + batch_coco_detections = [] + for i, image_id in enumerate(batch_image_ids): + image_width = batch_coco_images[i]['width'] + image_height = batch_coco_images[i]['height'] + + for j in range(int(num_detections[i])): + bbox = boxes[i][j] + bbox_coco_fmt = [ + bbox[1] * image_width, # x0 + bbox[0] * image_height, # x1 + (bbox[3] - bbox[1]) * image_width, # width + (bbox[2] - bbox[0]) * image_height, # height + ] + + coco_detection = { + 'image_id': image_id, + 'category_id': int(classes[i][j]), + 'bbox': bbox_coco_fmt, + 'score': float(scores[i][j]) + } + + coco_detections.append(coco_detection) + + if not use_synthetic: + # write coco detections to file + subprocess.call(['mkdir', '-p', tmp_dir]) + coco_detections_path = os.path.join(tmp_dir, 'coco_detections.json') + with open(coco_detections_path, 'w') as f: + json.dump(coco_detections, f) + + # compute coco metrics + cocoDt = coco.loadRes(coco_detections_path) + eval = COCOeval(coco, cocoDt, 'bbox') + eval.params.imgIds = image_ids + + eval.evaluate() + eval.accumulate() + eval.summarize() + + statistics = { + 'map': eval.stats[0], + 'avg_latency_ms': 1000.0 * np.mean(runtimes), + 'avg_throughput_fps': np.sum(image_counts) / np.sum(runtimes), + 'runtimes_ms': [1000.0 * r for r in runtimes] + } + else: + statistics = { + 'avg_latency_ms': 1000.0 * np.mean(runtimes), + 'avg_throughput_fps': np.sum(image_counts) / np.sum(runtimes), + 'runtimes_ms': [1000.0 * r for r in runtimes] + } + + if output_path is not None: + subprocess.call(['mkdir', '-p', os.path.dirname(output_path)]) + with open(output_path, 'w') as f: + json.dump(statistics, f) + subprocess.call(['rm', '-rf', tmp_dir]) + + return statistics + + +def _read_image(image_path, image_shape): + image = Image.open(image_path).convert('RGB') + if image_shape is not None: + image = image.resize(image_shape[::-1]) + return np.array(image) diff --git a/tftrt/examples/object_detection/test.py b/tftrt/examples/object_detection/test.py new file mode 100644 index 000000000..89b1175a8 --- /dev/null +++ b/tftrt/examples/object_detection/test.py @@ -0,0 +1,106 @@ +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +import argparse +import json +from .object_detection import download_model, download_dataset, optimize_model, benchmark_model + + +def test(test_config_path): + """Runs an object detection test configuration + + This runs an object detection test configuration. This involves + + 1. Download a model architecture (or use cached). + 2. Optimize the downloaded model architecrue + 3. Benchmark the optimized model against a dataset + 4. (optional) Run assertions to check the benchmark output + + The input to this function is a JSON file which specifies the test + configuration. + + example_test_config.json: + + { + "source_model": { ... }, + "optimization_config": { ... }, + "benchmark_config": { ... }, + "assertions": [ ... ] + } + + source_model: A dictionary of arguments passed to download_model, which + specify the pre-optimized model architure. The model downloaded (or + the cached model if found) will be passed to optimize_model. + optimization_config: A dictionary of arguments passed to optimize_model. + Please see help(optimize_model) for more details. + benchmark_config: A dictionary of arguments passed to benchmark_model. + Please see help(benchmark_model) for more details. + assertions: A list of strings containing python code that will be + evaluated. If the code returns false, an error will be thrown. These + assertions can reference any variables local to this 'test' function. + Some useful values are + + statistics['map'] + statistics['avg_latency'] + statistics['avg_throughput'] + + Args + ---- + test_config_path: A string corresponding to the test configuration + JSON file. + """ + with open(args.test_config_path, 'r') as f: + test_config = json.load(f) + print(json.dumps(test_config, sort_keys=True, indent=4)) + + # download model or use cached + config_path, checkpoint_path = download_model(**test_config['source_model']) + + # optimize model using source model + frozen_graph = optimize_model( + config_path=config_path, + checkpoint_path=checkpoint_path, + **test_config['optimization_config']) + + # benchmark optimized model + statistics = benchmark_model( + frozen_graph=frozen_graph, + **test_config['benchmark_config']) + + # print some statistics to command line + print_statistics = statistics + if 'runtimes_ms' in print_statistics: + print_statistics.pop('runtimes_ms') + print(json.dumps(print_statistics, sort_keys=True, indent=4)) + + # run assertions + if 'assertions' in test_config: + for a in test_config['assertions']: + if not eval(a): + raise AssertionError('ASSERTION FAILED: %s' % a) + else: + print('ASSERTION PASSED: %s' % a) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + 'test_config_path', + help='Path of JSON file containing test configuration. Please' + 'see help(tftrt.examples.object_detection.test) for more information') + args=parser.parse_args() + test(args.test_config_path) diff --git a/tftrt/examples/recommendation/README.md b/tftrt/examples/recommendation/README.md new file mode 100644 index 000000000..7a5d4b60e --- /dev/null +++ b/tftrt/examples/recommendation/README.md @@ -0,0 +1,69 @@ +## NCF examples + +The example script `inference.py` runs inference with NVIDIA NCF model implementation. +This script is included in the NVIDIA Tensorflow Docker +containers under `/workspace/nvidia-examples'. + + +## Model + +Model that we use is available here: +`https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow/Recommendation/NCF` + +### Setup for running within an NVIDIA Tensorflow Docker container + + +If you are running these examples within the [NVIDIA TensorFlow docker +container](https://ngc.nvidia.com/catalog/containers/nvidia:tensorflow): + +``` +cd ../third_party/models +export PYTHONPATH="$PYTHONPATH:$PWD" +``` + +### Prepare dataset + +We are using standard movielense dataset, which is available here: +`https://grouplens.org/datasets/movielens/` + +To use it for our script you need to prepare it first (we require csv file). +You can do that using script, which is here: +`tensorrt/tftrt/examples/third_party/DeepLearningExamples/TensorFlow/Recommendation/NCF/prepare_dataset.sh` +You need to provide path where you download ml-20m dataset. + +### Setup for running standalone + +If you are running these examples within your own TensorFlow environment, +perform the following steps: + +``` +# Clone this repository (tensorflow/tensorrt) if you haven't already. +git clone https://github.com/tensorflow/tensorrt.git --recurse-submodules + +# Add official models to python path +cd tensorrt/tftrt/examples/third_party/models/ +export PYTHONPATH="$PYTHONPATH:$PWD" +``` +## Usage + +The main Python script is `inference.py`. Here is some example of usage: + +``` +python inference.py + --data_dir /data/cache/ml-20m/ + --use_trt + --precision FP16 +``` + +Where: + +`--data_dir`: Path to the ml-20m test dataset + +`--use_trt`: Convert the graph to a TensorRT graph. + +`--precision`: Precision mode to use, in this case FP16. + + +Run with `--help` to see all available options. + + diff --git a/tftrt/examples/recommendation/inference.py b/tftrt/examples/recommendation/inference.py new file mode 100644 index 000000000..7ac0163e1 --- /dev/null +++ b/tftrt/examples/recommendation/inference.py @@ -0,0 +1,351 @@ +import tensorflow as tf +import tensorflow.contrib.tensorrt as trt +import time +import random +import numpy as np +import pandas as pd +from official.datasets import movielens + +from neumf import compute_eval_metrics +from neumf import neural_mf +import os +import argparse +import csv + +class LoggerHook(tf.train.SessionRunHook): + """Logs runtime of each iteration""" + def __init__(self, batch_size, num_records, display_every): + self.iter_times = [] + self.display_every = display_every + self.num_steps = (num_records + batch_size - 1) / batch_size + self.batch_size = batch_size + + def before_run(self, run_context): + self.start_time = time.time() + + def after_run(self, run_context, run_values): + current_time = time.time() + duration = current_time - self.start_time + self.iter_times.append(duration) + current_step = len(self.iter_times) + if current_step % self.display_every == 0: + print(" step %d/%d, iter_time(ms)=%.4f, images/sec=%d" % ( + current_step, self.num_steps, duration * 1000, + self.batch_size / self.iter_times[-1])) + +class BenchmarkHook(tf.train.SessionRunHook): + """Limits run duration and number of iterations""" + def __init__(self, target_duration=None, iteration_limit=None): + self.target_duration = target_duration + self.start_time = None + self.current_iteration = 0 + self.iteration_limit = iteration_limit + def before_run(self, run_context): + if not self.start_time: + self.start_time = time.time() + if self.target_duration: + print(" running for target duration {} seconds from {}".format( + self.target_duration, time.asctime(time.localtime(self.start_time)))) + + def after_run(self, run_context, run_values): + if self.target_duration: + current_time = time.time() + if (current_time - self.start_time) > self.target_duration: + print(" target duration {} reached at {}, requesting stop".format( + self.target_duration), time.asctime(time.localtime(current_time))) + run_context.request_stop() + if self.iteration_limit: + self.current_iteration += 1 + if self.current_iteration >= self.iteration_limit: + run_context.request_stop() + + +def get_frozen_graph(model_checkpoint=None, + mode="benchmark", + use_trt=True, + batch_size=1024, + use_dynamic_op=True, + precision="FP32", + model_dtype=tf.float32, + mf_dim=64, + mf_reg=64, + mlp_layer_sizes=[256, 256, 128, 64], + mlp_layer_regs=[.0, .0, .0, .0], + nb_items=26744, + nb_users=138493, + dup_mask=0.1, + K=10, + minimum_segment_size=2, + calib_data_dir=None, + num_calib_inputs=None, + use_synthetic=False, + max_workspace_size=(1<<32)): + + num_nodes = {} + times = {} + graph_sizes = {} + + tf_config = tf.ConfigProto() + with tf.Graph().as_default() as tf_graph: + with tf.Session(config=tf_config) as tf_sess: + users = tf.placeholder(shape=(None,), dtype=tf.int32, name="user_input") + items = tf.placeholder(shape=(None,), dtype=tf.int32, name="item_input") + with tf.variable_scope("neumf"): + logits = neural_mf(users, items, model_dtype, nb_users, nb_items, mf_dim, mf_reg, mlp_layer_sizes, mlp_layer_regs, 0.1) + + saver = tf.train.Saver() + saver.restore(tf_sess, model_checkpoint) + graph0 = tf.graph_util.convert_variables_to_constants(tf_sess, + tf_sess.graph_def, output_node_names=['neumf/dense_3/BiasAdd']) + frozen_graph = tf.graph_util.remove_training_nodes(graph0) + + for node in frozen_graph.node: + if node.op == "Assign": + node.op = "Identity" + if 'use_locking' in node.attr: del node.attr['use_locking'] + if 'validate_shape' in node.attr: del node.attr['validate_shape'] + if len(node.input) == 2: + node.input[0] = node.input[1] + del node.input[1] + + if use_trt: + start_time = time.time() + frozen_graph = trt.create_inference_graph( + input_graph_def=frozen_graph, + outputs=['neumf/dense_3/BiasAdd:0'], + max_batch_size=batch_size, + max_workspace_size_bytes=max_workspace_size, + precision_mode=precision_mode, + minimum_segment_size=minimum_segment_size, + is_dynamic_op=use_dynamic_op) + times['trt_conversion'] = time.time() - start_time + num_nodes['tftrt_total']=len(frozen_graph.node) + num_nodes['trt_only'] = len([1 for n in frozen_graph.node if str(n.op)=='TRTEngineOp']) + graph_sizes['trt'] = len(frozen_graph.SerializeToString()) + + if precision == 'INT8': + calib_graph = frozen_graph + graph_size['calib'] = len(calib_graph.SerializeToString()) + # INT8 calibration step + print('Calibrating INT8...') + start_time = time.time() + run(calib_graph, + data_dir=calib_data_dir, + batch_size=batch_size, + num_iterations=num_calib_inputs // batch_size, + num_warmup_iterations=0, + use_synthetic=use_synthetic) + times['trt_calibration'] = time.time() - start_time + start_time = time.time() + frozen_graph = trt.calib_graph_to_infer_graph(calib_graph) + times['trt_int8_conversion'] = time.time() - start_time + graph_sizes['trt'] = len(frozen_graph.SerializeToString()) + + del calib_graph + print('INT8 graph created') + + return frozen_graph, num_nodes, times, graph_sizes + + +def run(frozen_graph, + data_dir=None, + batch_size=1024, + num_iterations=None, + num_warmup_iterations=None, + use_synthetic=False, + display_every=100, + mode='benchmark', + target_duration=None, + nb_items=26744, + nb_users=138493, + dup_mask=0.1, + K=10): + + def model_fn(features, labels, mode): + logits_out = tf.import_graph_def(frozen_graph, + input_map={'user_input:0': features["user_input"], 'item_input:0': features["item_input"]}, + return_elements=['neumf/dense_3/BiasAdd:0'], + name='') + if mode == tf.estimator.ModeKeys.PREDICT: + return tf.estimator.EstimatorSpec(mode=mode, + predictions={'logits': logits_out[0]}) + if mode == tf.estimator.ModeKeys.EVAL: + found_positive, dcg = compute_eval_metrics(logits_out[0], dup_mask, batch_size, K) + hit_rate = tf.metrics.mean(found_positive, name='hit_rate') + ndcg = tf.metrics.mean(dcg, name='ndcg') + return tf.estimator.EstimatorSpec( + mode=mode, + loss=dcg, + eval_metric_ops={'hit_rate': hit_rate, 'ndcg': ndcg}) + + def input_fn(): + if use_synthetic: + items = [random.randint(1, nb_items) for _ in range(batch_size)] + users = [random.randint(1, nb_users) for _ in range(batch_size)] + with tf.device('/device:GPU:0'): + items = tf.identity(items) + users = tf.identity(users) + else: + data_path = os.path.join(data_dir, 'test_ratings.pickle') + dataset = pd.read_pickle(data_path) + users = dataset["user_id"] + items = dataset["item_id"] + + users = users.astype('int32') + items = items.astype('int32') + user_dataset = tf.data.Dataset.from_tensor_slices(users) + user_dataset = user_dataset.batch(batch_size) + user_dataset = user_dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE) + user_dataset = user_dataset.repeat(count=1) + user_iterator = user_dataset.make_one_shot_iterator() + users = user_iterator.get_next() + + item_dataset = tf.data.Dataset.from_tensor_slices(items) + item_dataset = item_dataset.batch(batch_size) + item_dataset = item_dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE) + item_dataset = item_dataset.repeat(count=1) + item_iterator = item_dataset.make_one_shot_iterator() + items = item_iterator.get_next() + return {"user_input": users, "item_input": items}, [] + + if use_synthetic and num_iterations is None: + num_iterations=1000 + + if use_synthetic: + num_records=num_iterations*batch_size + else: + data_path = os.path.join(data_dir, 'test_ratings.pickle') + dataset = pd.read_pickle(data_path) + users = dataset["user_id"] + num_records = len(users) + if num_iterations is None: + num_iterations = num_records // batch_size + + logger = LoggerHook( + display_every=display_every, + batch_size=batch_size, + num_records=num_records) + tf_config = tf.ConfigProto() + estimator = tf.estimator.Estimator( + model_fn=model_fn, + config=tf.estimator.RunConfig(session_config=tf_config), + model_dir='model_dir') + results = {} + + if mode == 'validation': + results = estimator.evaluate(input_fn, steps=num_iterations, hooks=[logger]) + elif mode == 'benchmark': + benchmark_hook = BenchmarkHook(target_duration=target_duration, iteration_limit=num_iterations) + prediction_results = [p for p in estimator.predict(input_fn, predict_keys=["logits"], hooks=[logger, benchmark_hook])] + print(prediction_results) + else: + raise ValueError("Mode must be either 'validation' or 'benchmark'") + + iter_times = np.array(logger.iter_times[num_warmup_iterations:]) + results['total_time'] = np.sum(iter_times) + results['images_per_sec'] = np.mean(batch_size / iter_times) + results['99th_percentile'] = np.percentile(iter_times, q=99, interpolation='lower') * 1000 + results['latency_mean'] = np.mean(iter_times) * 1000 + results['latency_median'] = np.median(iter_times) * 1000 + results['latency_min'] = np.min(iter_times) * 1000 + + return results + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Evaluate model') + parser.add_argument('--use_synthetic', action='store_true', + default=False, + help='If set, one batch of random data is generated and used at every iteration.') + parser.add_argument('--mode', choices=['validation', 'benchmark'], + default='validation', help='Which mode to use (validation or benchmark)') + parser.add_argument('--data_dir', type=str, default=None, + help='Directory containing validation set csv files.') + parser.add_argument('--calib_data_dir', type=str, + help='Directory containing TFRecord files for calibrating int8.') + parser.add_argument('--model_dir', type=str, default=None, + help='Directory containing model checkpoint.') + parser.add_argument('--use_trt', action='store_true', + help='If set, the graph will be converted to a TensorRT graph.') + parser.add_argument('--use_dynamic_op', action='store_true', + help='If set, TRT conversion will be done using dynamic op instead of statically.') + parser.add_argument('--precision', type=str, + choices=['FP32', 'FP16', 'INT8'], default='FP32', + help='Precision mode to use. FP16 and INT8 only work in conjunction with --use_trt') + parser.add_argument('--nb_items', type=int, default=26744, + help='Number of items') + parser.add_argument('--nb_users', type=int, default=138493, + help='Number of users') + parser.add_argument('--batch_size', type=int, default=1024, + help='Batch size') + parser.add_argument('--minimum_segment_size', type=int, default=2, + help='Minimum number of TF ops in a TRT engine') + parser.add_argument('--num_iterations', type=int, default=None, + help='How many iterations(batches) to evaluate. If not supplied, the whole set will be evaluated.') + parser.add_argument('--num_warmup_iterations', type=int, default=50, + help='Number of initial iterations skipped from timing') + parser.add_argument('--num_calib_inputs', type=int, default=500, + help='Number of inputs (e.g. images) used for calibration ' + '(last batch is skipped in case it is not full)') + parser.add_argument('--max_workspace_size', type=int, default=(1<<32), + help='workspace size in bytes') + parser.add_argument('--display_every', type=int, default=100, + help='Number of iterations executed between two consecutive display of metrics') + parser.add_argument('--dup_mask', type=float, default=0.1) + parser.add_argument('--K', type=int, default=10) + parser.add_argument('--mf_dim', type=int, default=64) + parser.add_argument('--mf_reg', type=int, default=64) + parser.add_argument('--mlp_layer_sizes', default=[256, 256, 128, 64]) + parser.add_argument('--mlp_layer_regs', default=[.0, .0, .0, .0]) + + args = parser.parse_args() + if not args.use_synthetic and args.data_dir is None: + raise ValueError("Data_dir is not provided") + + frozen_graph, num_nodes, times, graph_sizes = get_frozen_graph( + model_checkpoint=args.model_dir, + use_trt=args.use_trt, + use_dynamic_op=args.use_dynamic_op, + precision=args.precision, + batch_size=args.batch_size, + mf_dim=args.mf_dim, + mf_reg=args.mf_reg, + mlp_layer_sizes=args.mlp_layer_sizes, + mlp_layer_regs=args.mlp_layer_regs, + nb_items=args.nb_items, + nb_users=args.nb_users, + minimum_segment_size=args.minimum_segment_size, + calib_data_dir=args.calib_data_dir, + num_calib_inputs=args.num_calib_inputs, + use_synthetic=args.use_synthetic, + max_workspace_size=args.max_workspace_size + ) + + + def print_dict(input_dict, str='', scale=None): + for k, v in sorted(input_dict.items()): + headline = '{}({}): '.format(str, k) if str else '{}: '.format(k) + v = v * scale if scale else v + print('{}{}'.format(headline, '%.1f'%v if type(v)==float else v)) + + print_dict(num_nodes) + print_dict(graph_sizes) + print_dict(times) + + results = run(frozen_graph, + data_dir=args.data_dir, + batch_size=args.batch_size, + num_iterations=args.num_iterations, + num_warmup_iterations=args.num_warmup_iterations, + use_synthetic=args.use_synthetic, + display_every=args.display_every, + mode=args.mode, + target_duration=None, + nb_items=args.nb_items, + nb_users=args.nb_users, + dup_mask=args.dup_mask, + K=args.K) + + + print_dict(results) + diff --git a/tftrt/examples/third_party/DeepLearningExamples b/tftrt/examples/third_party/DeepLearningExamples new file mode 160000 index 000000000..531a570c5 --- /dev/null +++ b/tftrt/examples/third_party/DeepLearningExamples @@ -0,0 +1 @@ +Subproject commit 531a570c5cc1705041ca69f3841bdb437022309b diff --git a/tftrt/examples/third_party/cocoapi b/tftrt/examples/third_party/cocoapi new file mode 160000 index 000000000..ed842bffd --- /dev/null +++ b/tftrt/examples/third_party/cocoapi @@ -0,0 +1 @@ +Subproject commit ed842bffd41f6ff38707c4f0968d2cfd91088688 diff --git a/tftrt/examples/third_party/models b/tftrt/examples/third_party/models new file mode 160000 index 000000000..402b561b0 --- /dev/null +++ b/tftrt/examples/third_party/models @@ -0,0 +1 @@ +Subproject commit 402b561b03857151f684ee00b3d997e5e6be9778