导入Keras函数模型

假设使用Keras的函数API开始定义一个简单的MLP:

from keras.models import Model
from keras.layers import Dense, Input

inputs = Input(shape=(100,))
x = Dense(64, activation='relu')(inputs)
predictions = Dense(10, activation='softmax')(x)
model = Model(inputs=inputs, outputs=predictions)
model.compile(loss='categorical_crossentropy',optimizer='sgd', metrics=['accuracy'])

在Keras,有几种保存模型的方法。可以将整个模型(模型定义、权重和训练配置)存储为HDF5文件,仅存储模型配置(作为JSON或YAML文件)或仅存储权重(作为HDF5文件):

model.save('full_model.h5')  # save everything in HDF5 format

model_json = model.to_json()  # save just the config. replace with "to_yaml" for YAML serialization
with open("model_config.json", "w") as f:
    f.write(model_json)

model.save_weights('model_weights.h5') # save just the weights.
如果你决定保存完整的模型,那么将能够访问模型的训练配置,否则将不访问。因此,如果想在导入之后在DL4J中进一步训练模型,请记住这一点,并使用model.save(...)来持久化模型。

载加Keras模型

将完整模型加载回DL4J(假设它在类路径上):

String fullModel = new ClassPathResource("full_model.h5").getFile().getPath();
ComputationGraph model = KerasModelImport.importKerasModelAndWeights(fullModel);

万一没有编译Keras模型,它就不会有一个训练配置。在这种情况下,需要显式地告诉模型导入忽略训练配置,方法是将enforceTrainingConfig标志设置为false,如下所示:

ComputationGraph model = KerasModelImport.importKerasModelAndWeights(fullModel, false);

若要仅从JSON加载模型配置,请按如下使用KerasModelImport

String modelJson = new ClassPathResource("model_config.json").getFile().getPath();
ComputationGraphConfiguration modelConfig = KerasModelImport.importKerasModelConfiguration(modelJson)

如果另外还想加载模型权重与配置,那么需要做:

String modelWeights = new ClassPathResource("model_weights.h5").getFile().getPath();
MultiLayerNetwork network = KerasModelImport.importKerasModelAndWeights(modelJson, modelWeights)
在后面两种情况下,将不读取训练配置。

KerasModel

Github:KerasModel.java - 从Keras(函数API)模型或序列模型配置构建计算图

KerasModel(建议)

public KerasModel(KerasModelBuilder modelBuilder)
            throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException 
// 函数API模型的构建器模式构造器
参数 modelBuilder 构建器对象
抛出 IOException IO 异常
抛出 InvalidKerasConfigurationException 无效的 Keras 配置
抛出 UnsupportedKerasConfigurationException 不支持的 Keras 配置

getComputationGraphConfiguration(不推荐)

public ComputationGraphConfiguration getComputationGraphConfiguration()
            throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException 
// 来自模型配置(JSON或YAML)、训练配置(JSON)、权重和“训练模式”布尔指示符的(函数 API)模型的构造器。当内置在训练模式时,某些不支持的配置(例如,未知的正则化器)将抛出异常。当强制TrainingConfig= false时,这些将生成警告,但将被忽略。
参数 modelJson 模型配置JSON 字符串
参数 modelYaml 模型配置 YAML 字符串
参数 enforceTrainingConfig 是否实施训练相关配置
抛出 IOException IO 异常
抛出 InvalidKerasConfigurationException 无效的 Keras 配置
抛出 UnsupportedKerasConfigurationException 不支持的 Keras 配置

getComputationGraph

public ComputationGraph getComputationGraph()
            throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException 
// 从这个Keras模型配置构建计算图并导入权重
返回 ComputationGraph

getComputationGraph

public ComputationGraph getComputationGraph(boolean importWeights)
            throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException 
// 从这个Keras模型配置构建计算图并(可选的)导入权重。
参数 importWeights 是否导入权重
返回 ComputationGraph

最新文章

  1. Java调优知识汇总
  2. 1125MySQL Sending data导致查询很慢的问题详细分析
  3. Android图表类库:WilliamChart
  4. ZooKeeper日志与快照文件简单分析
  5. 【leetcode】Word Ladder II
  6. 配置SharePoint 2013 Search 拓扑结构
  7. 【编程之美】计算1-N中含1的个数
  8. Android Studio API 文档_下载与使用
  9. (ssh整合web导出excel)在ssh框架中使用poi正确导出具有比较高级固定格式的excel 整体过程,查询导出前后台下载
  10. js关闭当前页面不弹出提示的方法
  11. python的xml模块用法
  12. 腾讯云CentOS7.4服务器添加swap分区
  13. windows10环境下安装Tensorflow
  14. Rpgmakermv(31)MOG插件与YEP的结合
  15. oracle 11g AUTO_SAMPLE_SIZE动态采用工作机制
  16. VS中生成时“sgen.exe”已退出,代码为 1解决办法
  17. mysql数据库进阶篇
  18. Ubuntu登录Windows Server 2008r2 密码总是错误与NLA验证
  19. loadrunner11--集合点(Rendezvous )菜单是灰色不能点击
  20. TCP/IP、SOCKET、HTTP之间的联系与区别

热门文章

  1. [java] 将整数在千分位或万分位以逗号分隔表示
  2. APScheduler 3.0.1浅析
  3. 前端知识点回顾——HTML,CSS篇
  4. 信息学竞赛一本通提高版AC题解—例题1.1活动安排
  5. Handle的特点
  6. PCL中有哪些可用的PointT类型(2)
  7. Java对象和集合的拷贝/克隆/复制
  8. 小D课堂-SpringBoot 2.x微信支付在线教育网站项目实战_5-7.授权登录获取微信用户个人信息实战
  9. 小D课堂 - 新版本微服务springcloud+Docker教程_4-03 高级篇幅之Ribbon负载均衡源码分析实战
  10. NFS PersistentVolume(8)