# This sample uses a UFF MNIST model to create a TensorRT Inference Engine
from random import randint
from PIL import Image
import numpy as np import pycuda.driver as cuda
# This import causes pycuda to automatically manage CUDA context creation and cleanup.
import pycuda.autoinit import tensorrt as trt
import time import sys, os
sys.path.insert(1, os.path.join(sys.path[0], ".."))
import common # You can set the logger severity higher to suppress messages (or lower to display more messages).
TRT_LOGGER = trt.Logger(trt.Logger.WARNING) batch_size = 128 class ModelData(object):
MODEL_FILE = os.path.join(os.path.dirname(__file__), "model2/frozen_model.uff")
INPUT_NAME ="input_1"
INPUT_SHAPE = (3, 256, 256)
OUTPUT_NAME = 'predictions/Softmax'
DTYPE = trt.float32 def build_engine(model_file):
# For more information on TRT basics, refer to the introductory samples.
with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network, trt.UffParser() as parser:
builder.max_batch_size = batch_size
builder.max_workspace_size = common.GiB(1)
# Parse the Uff Network
parser.register_input(ModelData.INPUT_NAME, ModelData.INPUT_SHAPE)
parser.register_output(ModelData.OUTPUT_NAME)
parser.parse(model_file, network)
# Build and return an engine.
return builder.build_cuda_engine(network) # Loads a test case into the provided pagelocked_buffer.
def load_normalized_test_case(data_path, pagelocked_buffer, case_num=randint(0, 9)):
# test_case_path = os.path.join(data_path, str(case_num) + ".pgm")
# Flatten the image into a 1D array, normalize, and copy to pagelocked memory.
def normalize_image(image):
# Resize, antialias and transpose the image to CHW.
c, h, w = ModelData.INPUT_SHAPE
return np.asarray(image.resize((w, h), Image.ANTIALIAS)).transpose([2, 0, 1]).astype(trt.nptype(ModelData.DTYPE))
test_case_path = "lena.jpg"
img = normalize_image(Image.open(test_case_path))
img_array = []
for i in range(batch_size):
img_array.append(img)
img_array = np.array(img_array, dtype=trt.nptype(ModelData.DTYPE))
img_array = img_array.ravel()
np.copyto(pagelocked_buffer, img_array)
return case_num def main():
# data_path = common.find_sample_data(description="Runs an MNIST network using a UFF model file", subfolder="mnist")
data_path = "/home/bjxiangboren/tools/TensorRT-5.0.2.6/data/mnist/"
model_file = ModelData.MODEL_FILE # with open("inception_batch.engine", "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
# engine = runtime.deserialize_cuda_engine(f.read())
with build_engine(model_file) as engine:
# Build an engine, allocate buffers and create a stream.
# For more information on buffer allocation, refer to the introductory samples.
with open("inception_batch.engine", "wb") as f:
f.write(engine.serialize())
inputs, outputs, bindings, stream = common.allocate_buffers(engine)
with engine.create_execution_context() as context:
case_num = load_normalized_test_case(data_path, pagelocked_buffer=inputs[0].host)
# For more information on performing inference, refer to the introductory samples.
# The common.do_inference function will return a list of outputs - we only have one in this case.
while True:
start_time = time.time()
[output] = common.do_inference(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream, batch_size=batch_size)
end_time = time.time()
print("time dis is %s" % (end_time - start_time))
# output = output.reshape((30,1001))
# print output
# print output.shape
# print np.argmax(output, axis=1)
# pred = np.argmax(output)
# print("Test Case: " + str(case_num))
# print("Prediction: " + str(pred)) if __name__ == '__main__':
main()

1、首先将pb转为uff格式的模型

python  /usr/lib/python3.5/dist-packages/uff/bin/convert_to_uff.py --input_file models/lenet5.pb

2、使用trt engine加速

这个加速还是挺明显的,但转换后的模型无法使用tfservign了,只能用tensorrt自己的engine。

参考:https://devtalk.nvidia.com/default/topic/1044466/tensorrt/uff-inference-time-large-than-pb-time-when-process-vgg-19/

https://blog.csdn.net/zong596568821xp/article/details/86077553

https://blog.csdn.net/g11d111/article/details/92061884

https://mp.weixin.qq.com/s/Ps49ZTfJprcOYrc6xo-gLg?

最新文章

  1. WCF:传输EntityFramework 实体类的POCO 代理
  2. class.c 添加中文注释(2)
  3. ubuntu下python3安装类库
  4. 控制HTML的input控件的输入内容
  5. 鼠标指向GridView某列显示DIV浮动列表
  6. Python built-in函数的源码实现定位
  7. 转:SQL Case when 的使用方法
  8. ora-06502
  9. iOS通知NSNotificationCenter
  10. Oracle 11g oracle客户端(32位)PL/SQL develepment的安装配置
  11. [LeetCode]House Robber II (二次dp)
  12. 谷歌浏览器F12基本用法
  13. Entity Framework 框架
  14. LeetCode算法题-Lowest Common Ancestor of a Binary Search Tree
  15. 微软刚发布的区块链去中心化身份识别系统DID
  16. 使用电脑adb给Essential Phone刷机 —(官方篇)
  17. 1.13flask完结
  18. jvm-垃圾收集
  19. MYSQL之IFNULL
  20. 1.1《想成为黑客,不知道这些命令行可不行》(Learn Enough Command Line to Be Dangerous)——运行终端

热门文章

  1. leetcode 166分数到小数
  2. 从pip+requirements.txt+virtualenv管理依赖到使用pipenv管理依赖-修改布署方式
  3. Linux_RHEL_设置网络
  4. pureftp安装部署
  5. neutron网络服务
  6. Smartform给文本绑定值
  7. Opencv中直方图函数calcHist
  8. centos v7.0解决乱码
  9. 【嵌入式开发】树莓派+官方摄像头模块+VLC串流实时输出网络视频流
  10. 关于多线程efcore dbcontext 的解决方案。