关键代码:
tflearn.DNN(net, checkpoint_path='model_resnet_cifar10',
max_checkpoints=10, tensorboard_verbose=0,
clip_gradients=0.)
snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch.
我的demo:
def get_model(width, height, classes=40):
# TODO, modify model
network = input_data(shape=[None, width, height, 3]) # if RGB, 224,224,3
# Residual blocks
# 32 layers: n=5, 56 layers: n=9, 110 layers: n=18
n = 2
net = tflearn.conv_2d(network, 16, 3, regularizer='L2', weight_decay=0.0001)
net = tflearn.residual_block(net, n, 16)
net = tflearn.residual_block(net, 1, 32, downsample=True)
net = tflearn.residual_block(net, n-1, 32)
net = tflearn.residual_block(net, 1, 64, downsample=True)
net = tflearn.residual_block(net, n-1, 64)
net = tflearn.batch_normalization(net)
net = tflearn.activation(net, 'relu')
net = tflearn.global_avg_pool(net)
# Regression
net = tflearn.fully_connected(net, classes, activation='softmax')
#mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=32000, staircase=True)
mom = tflearn.Momentum(0.01, lr_decay=0.1, decay_step=2000, staircase=True)
net = tflearn.regression(net, optimizer=mom,
loss='categorical_crossentropy')
# Training
model = tflearn.DNN(net, checkpoint_path='model_resnet_cifar10',
max_checkpoints=10, tensorboard_verbose=0,
clip_gradients=0.)
return model def main():
trainX, trainY = image_preloader("data/train", image_shape=(width, height, 3), mode='folder', categorical_labels=True, normalize=True)
testX, testY = image_preloader("data/test", image_shape=(width, height, 3), mode='folder', categorical_labels=True, normalize=True)
#trainX = trainX.reshape([-1, width, height, 1])
#testX = testX.reshape([-1, width, height, 1])
print("sample data:")
print(trainX[0])
print(trainY[0])
print(testX[-1])
print(testY[-1]) model = get_model(width, height, classes=3755) filename = 'tflearn_resnet/model.tflearn'
# try to load model and resume training
try:
#model.load(filename)
model.load("model_resnet_cifar10-195804")
print("Model loaded OK. Resume training!")
except:
pass early_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.94)
try:
model.fit(trainX, trainY, validation_set=(testX, testY), n_epoch=500, shuffle=True,
snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch.
show_metric=True, batch_size=1024, callbacks=early_stopping_cb, run_id='cnn_handwrite')
except StopIteration as e:
print("OK, stop iterate!Good!") model.save(filename) del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:]
filename = 'tflearn_resnet/model-infer.tflearn'
model.save(filename)

最新文章

  1. Lind.DDD.ConfigConstants统一管理系统配置
  2. JSPatch 实现原理详解
  3. 武汉新芯:定位存储器制造,两年后或推3D NAND
  4. 基于DDD的.NET项目搭建
  5. python--对于装饰器的理解
  6. 笔记-windbg及时调试
  7. vue基础入门
  8. Electron应用使用electron-builder配合electron-updater实现自动更新(windows + mac)
  9. Construct Binary Tree from Preorder and Inorder Traversal(根据前序中序构建二叉树)
  10. [Reinforcement Learning] 马尔可夫决策过程
  11. Centos解除端口占用
  12. 【css】适配iphoneX
  13. hdu-1176免费馅饼
  14. 数据绑定和第一个AngularJS Web应用
  15. linux用户的增加与删除
  16. Thunder团队第七周 - Scrum会议3
  17. redhat9安装gcc(转)
  18. openresty安装
  19. Jmeter非GUI分布式测试
  20. ubuntu下使用code::blocks编译运行一个简单的gtk+2.0项目

热门文章

  1. storm笔记:Storm+Kafka简单应用
  2. HttpClient 模拟登录搜狐微博
  3. C# 字节数组拼接的速度实验(Array.copy(),Buffer.BlockCopy(),Contact())
  4. ETL拉链算法汇总大全
  5. 笔记本WIFI卡简介
  6. zabbix 3.2.4 安装
  7. golang 格式化时间成datetime
  8. 【Android】第三方库使用的问题集
  9. 再看GS线程
  10. Devexpress Spreadsheet 中文教程