标准TensorFlow格式

TensorFlow的训练过程其实就是大量的数据在网络中不断流动的过程,而数据的来源在官方文档[^1](API r1.2)中介绍了三种方式,分别是:

  • Feeding。通过Python直接注入数据。
  • Reading from files。从文件读取数据,本文中的TFRecord属于此类方式。
  • Preloaded data。将数据以constant或者variable的方式直接存储在运算图中。

当数据量较大时,官方推荐采用标准TensorFlow格式[^2](Standard TensorFlow format)来存储训练与验证数据,该格式的后缀名为tfrecord。官方介绍如下:

A TFRecords file represents a sequence of (binary) strings. The format is not random access, so it is suitable for streaming large amounts of data but not suitable if fast sharding or other non-sequential access is desired.

从介绍不难看出,TFRecord文件适用于大量数据的顺序读取。而这正好是神经网络在训练过程中发生的事情。


如何使用TFRecord文件

对于TFRecord文件的使用,官方给出了两份示例代码,分别展示了如何生成与读取该格式的文件。

生成TFRecord文件

第一份代码convert_to_records.py [^3]将MNIST里的图像数据转换为了TFRecord格式 。仔细研读代码,可以发现TFRecord文件中的图像数据存储在Feature下的image_raw里。image_raw来自于data_set.images,而后者又来自mnist.read_data_sets()。因此images的真身藏在mnist.py这个文件里。

mnist.py并不难找,在Pycharm里按下ctrl后单击鼠标左键即可打开源代码。

继续追踪,可以在mnist里发现图像来自extract_images()函数。该函数的说明里清晰的写明:

Extract the images into a 4D uint8 numpy array [index, y, x, depth].
Args:
f: A file object that can be passed into a gzip reader.
Returns:
data: A 4D uint8 numpy array [index, y, x, depth].
Raises:
ValueError: If the bytestream does not start with 2051.

很明显,返回值变量名为data,是一个4D Numpy矩阵,存储值为uint8类型,即图像像素的灰度值(MNIST全部为灰度图像)。四个维度分别代表了:图像的个数,每个图像行数,每个图像列数,每个图像通道数。

在获得这个存储着像素灰度值的Numpy矩阵后,使用numpy的tostring()函数将其转换为Python bytes格式[^4],再使用tf.train.BytesList()函数封装为tf.train.BytesList类,名字为image_raw。最后使用tf.train.Example()image_raw和其它属性一遍打包,并调用tf.python_io.TFRecordWriter将其写入到文件中。

至此,TFRecord文件生成完毕。

可见,将自定义图像转换为TFRecord的过程本质上是将大量图像的像素灰度值转换为Python bytes,并与其它Feature组合在一起,最终拼接成一个文件的过程。

需要注意的是其它Feature的类型不一定必须是BytesList,还可以是Int64List或者FloatList。

读取TFRecord文件

第二份代码fully_connected_reader.py [1]展示了如何从TFRecord文件中读取数据。

读取数据的函数名为input()。函数内部首先通过tf.train.string_input_producer()函数读取TFRecord文件,并返回一个queue;然后使用read_and_decode()读取一份数据,函数内部用tf.decode_raw()解析出图像的灰度值,用tf.cast()解析出label的值。之后通过tf.train.shuffle_batch()的方法生成一批用来训练的数据。并最终返回可供训练的imageslabels,并送入inference部分进行计算。

在这个过程中,有以下几点需要留意:

  1. tf.decode_raw()解析出的数据是没有shape的,因此需要调用set_shape()函数来给出tensor的维度。
  2. read_and_decode()函数返回的是单个的数据,但是后边的tf.train.shuffle_batch()却能够生成批量数据。
  3. 如果需要对图像进行处理的话,需要放在第二项提到的两个函数中间。

其中第2点的原理我暂时没有弄懂。从代码上看read_and_decode()返回的是单个数据,shuffle_batch接收到的也是单个数据,不知道是如何生成批量数据的,猜测与queue有关系。

所以,读取TFRecord文件的本质,就是通过队列的方式依次将数据解码,并按需要进行数据随机化、图像随机化的过程。


参考


  1. Github: fully_connected_reader.py ↩︎

最新文章

  1. spark读取hdfs上的文件和写入数据到hdfs上面
  2. weapp微信小程序初探demo
  3. caffe学习系列(7):Blob,layer,Net介绍
  4. nyoj305_表达式求值
  5. IE10 透明背景的div无法遮罩
  6. java总结第三次//类和对象2、3
  7. JSP或HTML命名规范
  8. TCL语言笔记:TCL中的数学函数
  9. 咦,为DJANGO的ORM的QUERYSET增加数据列的样码,很好用哟
  10. 【转】成为Java顶尖程序员 ,看这11本书就够了
  11. 定制属于自己的Chrome起始页
  12. 使用StackTrace堆栈跟踪记录详细日志(可获取行号)
  13. 经典的SQL语句面试题
  14. Spring的Service层与Dao层解析
  15. 现在都是python 单独开发框架 执行脚本,处理结果,发报告之类的
  16. 用keychain这个特点来保存设备唯一标识。
  17. Reverse Integer - 反转一个int,溢出时返回0
  18. ajax请求service报405错误 - 【服务器不允许的方法】
  19. Nodejs http-proxy代理实战应用
  20. CSS学习笔记_day2

热门文章

  1. 微博预计要火一阵的SleepSort之Shell及C实现
  2. HDU 5358(2015多校联合训练赛第六场1006) First One (区间合并+常数优化)
  3. 微信企业号回调模式配置解说 Java Servlet+Struts2版本号 echostr校验失败解决
  4. 微软继MVC5后,出现ASP.NET VNEXT
  5. JS遮罩层
  6. 神经网络中的激活函数——加入一些非线性的激活函数,整个网络中就引入了非线性部分,sigmoid 和 tanh作为激活函数的话,一定要注意一定要对 input 进行归一话,但是 ReLU 并不需要输入归一化
  7. Node.js:Web 模块
  8. WPF中StringToImage和BoolToImage简单用法
  9. Gym-100935I Farm 计算几何 圆和矩形面积交
  10. 函数的arguments