【新闻】:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测、医学图像、时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会。微信:cyx645016617.

参考目录:

本文主要讲述TF2.0的模型文件的存储和载入的多种方法。主要分成两类型:模型结构和参数一起载入,模型的结构载入。

1 模型的构建

import tensorflow.keras as keras

class CBR(keras.layers.Layer):
def __init__(self,output_dim):
super(CBR,self).__init__()
self.conv = keras.layers.Conv2D(filters=output_dim, kernel_size=4, padding='same', strides=1)
self.bn = keras.layers.BatchNormalization(axis=3)
self.ReLU = keras.layers.ReLU() def call(self, inputs):
inputs = self.conv(inputs)
inputs = self.ReLU(self.bn(inputs))
return inputs class MyNet(keras.Model):
def __init__ (self):
super(MyNet,self).__init__()
self.cbr1 = CBR(16)
self.maxpool1 = keras.layers.MaxPool2D(pool_size=(2,2))
self.cbr2 = CBR(32)
self.maxpool2 = keras.layers.MaxPool2D(pool_size=(2,2)) def call(self, inputs):
inputs = self.maxpool1(self.cbr1(inputs))
inputs = self.maxpool2(self.cbr2(inputs))
return inputs model = MyNet()

部分朋友可以发现,上面的代码就是上一次课程所构建的一个自定义的网络。

我们现在需要展示这个模型的框架:

model.build((16,224,224,3))
print(model.summary())

运行结果为:

这里需要对网络执行一个构建.build()函数,之后才能生成model.summary()这样的模型的描述。 这是因为模型的参数量是需要知道输入数据的通道数的,假如我们输入的是单通道的图片,那么就是:

model.build((16,224,224,1))
print(model.summary())

输出结果为:

2 结构参数的存储与载入

model.save('save_model.h5')
new_model = keras.models.load_model('save_model.h5')

这里并不能保存成功,出现这样的错误:

大概的意思就是:因为你的模型不是官方的模型,是自定义的,所以并不能同时保存结构和参数。只有官方的模型可以时候上面的保存的方法,同时保存参数和权重;自定义的模型建议只保存参数

3 参数的存储与载入

model.save_weights('model_weight')
new_model = MyNet()
new_model.load_weights('model_weight')

这样子就可以保存自定义的模型了。在对应的目录下会出现这几个文件:

我们来看一下原来的模型和载入的模型对于同一个样本给出的结果是否相同:

# 看一下原来的模型和载入的模型预测相同的样本的输出
test = tf.ones((1,8,8,3))
prediction = model.predict(test)
new_prediction = new_model.predict(test)
print(prediction,new_prediction)
>>> [[[[0.02559286]]]] [[[[0.02559286]]]]

结果相同,载入的没有问题~

4 结构的存储与载入

结构的存储有两种方法:

  • model.get_config()
  • model.to_json()

需要注意的是,上面的两个方法和save的问题一样,是不能用在自定义的模型中的,如果你在其中使用了自定义的Layer类,那么只能!只能用save_weights的方式进行保存

下面依然给出这两种方法的代码,对于简单的、已经封装好的一些网络层构成的网络,是可以使用这些的。我个人还是常用save_weights啦

# 第一种方法
config = model.get_config()
reinitialized_model = keras.Model.from_config(config)
# 第二种方法
json_config = model.to_json()
# 把json写的文件中
with open('model_config.json', 'w') as json_file:
json_file.write(json_config)
# 读取本地json文件
with open('model_config.json') as json_file:
json_config = json_file.read()
reinitialized_model = keras.models.model_from_json(json_config)

今天的内容就是这么多,虽然提供了四种方法,但是对于自定义程度较高的模型,还是要使用save_weights哦~

最新文章

  1. openscales实现漂亮的冒泡效果
  2. 编写 Unity Editor 插件
  3. bzoj4403: 序列统计
  4. jqgrid如何在一个页面点击按钮后,传递参数到新页面
  5. Naive Bayes理论与实践
  6. 开源项目:X265
  7. create-maximum-number(难)
  8. quartz简单实现
  9. JavaScript操作剪贴板(转)
  10. android入门——UI(6)——ViewPager+Menu+PopupWindow
  11. stl 迭代子的失效
  12. c# gdi+输出成不同mime类型的图片
  13. CF374 Journey
  14. [Lugu3380]【模板】二逼平衡树(树套树)
  15. Python爬虫——Python 岗位分析报告
  16. Elasticsearch 创建以及修改索引结构
  17. mysql外键使用
  18. 本地上传文件至服务器的技巧(linux文件压缩及解压文件)
  19. java基础---->Java的格式化输出
  20. ORB feature(O for orientation)

热门文章

  1. 【NodeJS】-init
  2. EditText设置输入的类型,只能输入纯数字,只能输入手机号码,只能输入邮箱等等。
  3. 如何自制WC3地形纹理贴图
  4. 集成react-native-image-picker时,报错Couldn't get file path for photo
  5. vue实现局部预览打印
  6. google protocol buffer——protobuf的问题及改进一
  7. el-select 封装
  8. PyTorch ResNet 使用与源码解析
  9. 零基础一分钟入门Python
  10. Vue和d3.js(v4)力导向图force结合使用,v3版本升级v4【一】