Tensorflow的slim框架可以写出像keras一样简单的代码来实现网络结构(虽然现在keras也已经集成在tf.contrib中了),而且models/slim提供了类似之前说过的object detection接口类似的image classification接口,可以很方便的进行fine-tuning利用自己的数据集训练自己所需的模型。

官方文档提供了比较详细的从数据准备,预训练模型的model zoo,fine-tuning,freeze model等一系列流程的步骤,但是缺少了inference的文档,不过tf所有模型的加载方式是通用的,所以调用方法和调用其他pb模型是一样的。

根据TF开发人员是说法Tensorflow对于模型读写的保存和调用的步骤一般如下:Build your graph --> write your graph --> import from written graph --> run compute etc

以下我们使用slim提供的网络inception-resnet-v2作为例子:

1. export inference graph

import tensorflow as tf
import nets.inception_resnet_v2 as net slim = tf.contrib.slim # checkpoint path
checkpoint_path = "/your/path/to/inception_resnet_v2.ckpt" # ckpt file obtained during model training or fine-tuning # set up and load session
sess = tf.Session()
arg_scope = net.inception_resnet_v2_arg_scope()
# initialize tensor suitable for model input
input_tensor = tf.placeholder(tf.float32, [None, 299, 299, 3])
with slim.arg_scope(arg_scope):
logits, end_points = net.inception_resnet_v2(inputs=input_tensor) # set up model saver
saver = tf.train.Saver()
saver.restore(sess, checkpoint_path)
with tf.gfile.GFile('/your/path/to/model_graph.pb', 'w') as f: # save model to given pb file
f.write(sess.graph_def.SerializeToString())
f.close()

2. freeze model

这里用tf提供的tensorflow/python/tools下的freeze_graph工具:

$ bazel build tensorflow/python/tools:freeze_graph
$ bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=/your/path/to/model_graph.pb \ # obtained above
--input_checkpoint=/your/path/to/inception_resnet_v2.ckpt \
--input_binary=true
--output_graph=/your/path/to/frozen_graph.pb \
--output_node_names=InceptionResnetV2/Logits/Predictions # output node name defined in inception resnet v2 net

(Optional) visualize frozen graph

LOG_DIR = ‘/tmp/graphdeflogdir’
model_filename = '/your/path/to/frozen_graph.pb' with tf.Session() as sess:
with tf.gfile.FastGFile(model_filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
writer = tf.summary.FileWriter(LOG_DIR, graph_def)
writer.close()

然后用tensorborad --logdir=LOG_DIR选择graph就可以查看到frozen后的网络结构。

3. inference

import cv2
import numpy as np def preprocess_inception(image_np, central_fraction=0.875):
image_height, image_width, image_channel = image_np.shape
if central_fraction:
bbox_start_h = int(image_height * (1 - central_fraction) / 2)
bbox_end_h = int(image_height - bbox_start_h)
bbox_start_w = int(image_width * (1 - central_fraction) / 2)
bbox_end_w = int(image_width - bbox_start_w)
image_np = image_np[bbox_start_h:bbox_end_h, bbox_start_w:bbox_end_w]
# normalize
image_np = 2 * (image_np / 255.) - 1
return image_np image_np = cv2.imread("test.jpg")
# preprocess image as inception resnet v2 does
image_np = preprcess_inception(image_np)
# resize to model input image size
image_np = cv2.resize(image_np, (299, 299))
# expand dims to shape [None, 299, 299, 3]
image_np = np.expand_dims(image_np, 0)
# load model
with tf.gfile.GFile('/your/path/to/frozen_graph.pb')
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
with tf.Session(graph=graph) as sess:
input tensor = sess.graph.get_tensor_by_name("input:0") # get input tensor
output_tensor = sess.graph.get_tensor_by_name("InceptionResnetV2/Logits/Predictions:0") # get output tensor
logits = sess.run(output_tensor, feed_dict={input_tensor: image_np})
print "Prediciton label index:", np.argmax(logits[0], 1)
print "Top 3 Prediciton label index:", np.argsort(logits[0], 3)

参考:

  1. https://stackoverflow.com/questions/42961243/using-pre-trained-inception-v4-model
  2. https://gist.github.com/cchadowitz-pf/f1c3e781c125813f9976f6e69c06fec2
  3. https://blog.metaflow.fr/tensorflow-how-to-freeze-a-model-and-serve-it-with-a-python-api-d4f3596b3adc
  4. https://github.com/tensorflow/models/blob/master/slim/README.md
  5. https://gist.github.com/tokestermw/795cc1fd6d0c9069b20204cbd133e36b

最新文章

  1. Apache Server 添加虚拟主机(Virtual Host )
  2. Java-继承,多态练习0922-06
  3. Http进行网络通信
  4. Codeforces Round #325 (Div. 2) F. Lizard Era: Beginning meet in the mid
  5. Contest 20140923 登月计划 BabyStepGaintStep
  6. 史上最强NDK入门项目实战
  7. 转:SSE:服务器发送事件
  8. vim中使用gdb。
  9. jquery ajax方法和其他api回顾
  10. [CSS3] 学习笔记-HTML与CSS简单页面效果实例
  11. Natas Wargame Level 17 Writeup(Time-based Blind SQL Injection)
  12. 【2017集美大学1412软工实践_助教博客】团队作业10——项目复审与事后分析(Beta版本)
  13. 十招让Ubuntu 16.04用起来更得心应手(转)
  14. ajaxFileUpload上传带参数,返回值改成json格式
  15. Windows内核驱动中操作文件
  16. show出相应单据列表
  17. convert 批量文件的格式转换
  18. 点击threadItem查看MessageList时传递数据
  19. VirtualBox导入已存在的VHD遇到的uuid冲突问题
  20. [Flutter] 支持描边效果的Text

热门文章

  1. json数据格式 net.sf.json.JSONException: A JSONObject text must begin with '{' at character 1 of Error:(findColumns1)Read timed out
  2. iOS主流机型更新
  3. PS文字生成头像
  4. index-document-shard
  5. 广州移动宽带DNS
  6. 访问win10的远程桌面(Remote Desktop)总是凭据或者用户密码错误
  7. Python中Queue模块及多线程使用
  8. Codeforces Round #277.5 (Div. 2)C——Given Length and Sum of Digits...
  9. Qt下多线程日之类
  10. 五分钟搞清楚MySQL事务隔离级别