我们在训练好模型的时候,通常是要将模型进行保存的,以便于下次能够直接的将训练好的模型进行载入。

1.保存模型

首先需要建立一个saver,然后在session中通过saver的save即可将模型保存起来,具体的代码流程如下

# 前面的是定义好的模型结构

# 前面的代码是模型的定义代码

saver = tf.train.Saver()    # 生成saver

with tf.Session() as sess:
sess.run(init) # 模型的初始化
#
# 模型的训练代码,当模型训练完毕后,下面就可以对模型进行保存了
#
saver.save(sess, "model/linear") # 当路径不存在时,会自动创建路径

2.载入模型

将模型保存后,在保存的路径中,可以看到生成的模型路径,下面我们就能够加载模型了:

saver = tf.train.Saver()

with tf.Session() as sess:
# 可以对模型进行初始化,也可以不进行模型的初始化,因为后面的加载会覆盖之前的
# 初始化操作
sess.run(init) saver.restore(sess, "model/linear")

下面我们以linearmodel为例进行讲解:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os train_x = np.linspace(-5, 3, 50)
train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5 plt.plot(train_x, train_y, 'r.')
plt.grid(True)
plt.show() X = tf.placeholder(dtype=tf.float32)
Y = tf.placeholder(dtype=tf.float32) w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
b = tf.Variable(tf.random.truncated_normal([1]), name='bias') z = tf.multiply(X, w) + b cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) init = tf.global_variables_initializer() training_epochs = 20
display_step = 2 saver = tf.train.Saver() if __name__ == '__main__':
with tf.Session() as sess:
sess.run(init)
if os.path.exists("model/"):
saver.restore(sess, "model/linear") w_, b_ = sess.run([w, b]) print(" Finished ")
print("W: ", w_, " b: ", b_)
plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
plt.grid(True)
plt.show()
else:
loss_list = []
for epoch in range(training_epochs):
for (x, y) in zip(train_x, train_y):
sess.run(optimizer, feed_dict={X: x, Y: y}) if epoch % display_step == 0:
loss = sess.run(cost, feed_dict={X: x, Y: y})
loss_list.append(loss)
print('Iter: ', epoch, ' Loss: ', loss) w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y}) saver.save(sess, "model/linear") print(" Finished ")
print("W: ", w_, " b: ", b_, " loss: ", loss)
plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
plt.grid(True)
plt.show()

3.查看模型的内容

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
modeldir = 'model/'
print_tensors_in_checkpoint_file(modeldir + 'linear.cpkt', None, True)

在上述使用saver的代码中,我们还可以将参数放入Saver中实现指定存储参数的功能,可以指定存储变量名字和变量的对应关系,如下形式:

saver = tf.train.Saver({'weight_':w, 'bias_':b})
# saver = tf.train.Saver([w, b])

最新文章

  1. 常用mysql语句
  2. 升级SSH
  3. 2016年11月29日 星期二 --出埃及记 Exodus 20:20
  4. MySQL下载、安装和修改root密码
  5. Codeforces Round #261 (Div. 2)
  6. 路由器刷机常见第三方固件及管理前端种类(OpenWrt、Tomato、DD-Wrt)
  7. angularJS中如何写服务
  8. Oracle EBS-SQL (PO-3):检查期间手工下达的采购订单记录数.sql
  9. cocos2d-3.x 创建动画
  10. C和C++运行库
  11. 学以致用十九-----shell脚本之引号
  12. java 线程(七)等待与唤醒
  13. ReactiveX 学习笔记(4)过滤数据流
  14. 50x页面放到本地单独目录下,进行显示
  15. Python: ljust()|rjust()|center()字符串对齐
  16. Kafka消息队列
  17. iOS 9应用开发教程之多行读写文本ios9文本视图
  18. UI领域中常常听见的''modal''到底是什么?
  19. 再次谈谈easyui datagrid 的数据加载
  20. Thrift编译错误('::malloc' has not been declared)

热门文章

  1. 洛谷P1310 表达式的值 题解 栈/后缀表达式的应用
  2. laravel中将session由文件保存改为数据库保存
  3. Python--day42--mysql创建用户及授权
  4. java List接口中常用类
  5. codeforce 382 div2 E —— 树状dp
  6. H3C 环路避免机制四:定义最大值
  7. 2018.12.7 浪在ACM 集训队第八次测试赛
  8. Linux 内核kobject 缺省属性
  9. C语言 屏幕截图 (GDI)
  10. [板子]KMP