我们经常在训练完一个模型之后希望保存训练的结果,这些结果指的是模型的参数,以便下次迭代的训练或者用作测试。Tensorflow针对这一需求提供了Saver类。

  • Saver类提供了向checkpoints文件保存和从checkpoints文件中恢复变量的相关方法。Checkpoints文件是一个二进制文件,它把变量名映射到对应的tensor值 。
  • 只要提供一个计数器,当计数器触发时,Saver类可以自动的生成checkpoint文件。这让我们可以在训练过程中保存多个中间结果。例如,我们可以保存每一步训练的结果。
  • 为了避免填满整个磁盘,Saver可以自动的管理Checkpoints文件。例如,我们可以指定保存最近的N个Checkpoints文件。

示例代码:

import tensorflow as tf
import numpy as np
from six.moves import xrange x = tf.placeholder(tf.float32, shape=[None, 1])
y = 4 * x + 2 w = tf.Variable(tf.random_normal([1], -1, 1))
b = tf.Variable(tf.zeros([1]))
y_predict = w * x + b loss = tf.reduce_mean(tf.square(y - y_predict))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss) #isTrain = True
isTrain = False
train_steps = 100
checkpoint_steps = 50
checkpoint_dir = 'test/' saver = tf.train.Saver() # defaults to saving all variables - in this case w and b
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1)) with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
if isTrain:
for i in xrange(train_steps):
sess.run(train, feed_dict={x: x_data})
if (i + 1) % checkpoint_steps == 0:
saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i + 1)
else:
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
else:
pass
print(sess.run(w))
print(sess.run(b)) y_result = sess.run(y_predict, feed_dict={x: np.reshape(4, (1, 1))})
print(y_result)

2.1 训练阶段

使用Saver.save()方法保存模型:
  1. sess:表示当前会话,当前会话记录了当前的变量值
  2. checkpoint_dir + 'model.ckpt':表示存储的文件名
  3. global_step:表示当前是第几步

训练完成后,当前目录底下会多出5个文件。

    打开名为“checkpoint”的文件,可以看到保存记录,和最新的模型存储位置。

2.2测试阶段

    测试阶段使用saver.restore()方法恢复变量:
  1. sess:表示当前会话,之前保存的结果将被加载入这个会话
  2. ckpt.model_checkpoint_path:表示模型存储的位置,不需要提供模型的名字,它会去查看checkpoint文件,看看最新的是谁,叫做什么。

运行结果如下图所示,加载了之前训练的参数w和b的结果

最新文章

  1. Android开发之《常用工具及文档汇总》
  2. C# 写的一个生成随机汉语名字的小程序
  3. Intent和Activity知识点总结
  4. 一个简单的代码计算行数demo编写
  5. 在Eclipse设置打开项目或文件目录
  6. Delphi XE5 如何与其他版本共存
  7. HTML5与CSS3权威指南.pdf3
  8. 基于Node.js的强大爬虫 能直接发布抓取的文章哦
  9. Hadoop 的常用组件一览
  10. linux_shell_类似sql的orderby 取最大值
  11. Linux Debian 7部署LEMP(Linux+Nginx+MySQL+PHP)网站环境
  12. sublime 使用快捷键
  13. css3整理--Animation
  14. Ubuntu关机时间过长,总是停在logo界面
  15. Code Signal_练习题_firstDigit
  16. OpenCV学习:播放avi视频文件
  17. SEH分析笔记(X64篇)
  18. NSArray最简单的倒序
  19. C++复习14 构造函数初始化调用顺序
  20. HDU2089:不要62——题解

热门文章

  1. MySQL将DESC等关键字作为列名表名的处理方式
  2. WordPress无插件实现SMTP给评论用户发送邮件提醒
  3. Java中的List集合和迭代器
  4. matlab数组和矩阵
  5. mvc core2.1 Identity.EntityFramework Core 导航状态栏(六)
  6. Codeforces Div3 #498 A-F
  7. 怎样去掉wordpress中默认的未分类目录
  8. 使用VUE搭建tab标签组件
  9. 计算机网络-数据结构-MAC帧头-IP头-TCP头-UDP头
  10. 使用C语言简单模拟Linux的cat程序