莫烦Python 4

新建模板小书匠

RNN Classifier 循环神经网络

问题描述

使用RNN对MNIST里面的图片进行分类

关键

SimpleRNN()参数

  • batch_input_shape

    使用状态RNN的注意事项

可以将RNN设置为‘stateful’,意味着由每个batch计算出的状态都会被重用于初始化下一个batch的初始状态。状态RNN假设连续的两个batch之中,相同下标的元素有一一映射关系。

要启用状态RNN,请在实例化层对象时指定参数stateful=True,并在Sequential模型使用固定大小的batch:通过在模型的第一层传入batch_size=(…)和input_shape来实现。在函数式模型中,对所有的输入都要指定相同的batch_size。

如果要将循环层的状态重置,请调用.reset_states(),对模型调用将重置模型中所有状态RNN的状态。对单个层调用则只重置该层的状态。

(samples,timesteps,input_dim)

代码

'''
RNN Classifier 循环神经网络
'''
import numpy as np
np.random.seed(1337) from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import SimpleRNN, Activation, Dense
from keras.optimizers import Adam time_step = 28
input_size = 28
batch_size = 50
output_size = 10
cell_size = 50
LR = 0.001 (X_train, y_train), (X_test, y_test) = mnist.load_data() X_train = X_train.reshape(-1, 28, 28) / 255. # normalize
X_test = X_test.reshape(-1, 28, 28) / 255. # normalize
y_train = np_utils.to_categorical(y_train, num_classes=10)
y_test = np_utils.to_categorical(y_test, num_classes=10) model = Sequential()
model.add(
SimpleRNN(
batch_input_shape=(None, time_step, input_size),
units=cell_size
)
) model.add(
Dense(output_size)
) model.add(Activation('softmax')) adam = Adam(LR)
model.compile(
optimizer=adam,
loss='categorical_crossentropy',
metrics=['accuracy']
)
model.summary() model.fit(X_train, y_train, batch_size=batch_size, epochs=2, verbose=2, validation_data=(X_test, y_test))

结果

Model: "sequential_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
simple_rnn_2 (SimpleRNN) (None, 50) 3950
_________________________________________________________________
dense_2 (Dense) (None, 10) 510
_________________________________________________________________
activation_2 (Activation) (None, 10) 0
=================================================================
Total params: 4,460
Trainable params: 4,460
Non-trainable params: 0
_________________________________________________________________
Train on 60000 samples, validate on 10000 samples
Epoch 1/2
- 12s - loss: 0.6643 - accuracy: 0.7966 - val_loss: 0.4501 - val_accuracy: 0.8550
Epoch 2/2
- 9s - loss: 0.3220 - accuracy: 0.9087 - val_loss: 0.2445 - val_accuracy: 0.9359

最新文章

  1. [Django]用户权限学习系列之Permission权限基本操作指令
  2. wpf listview
  3. MVC - Action和ActionResult
  4. 【BZOJ 2434】【NOI 2011】阿狸的打字机 fail树
  5. MyBatis知多少(3)
  6. 开发报表时将已有User做成下拉列表,第一项为label为ALL,value为null
  7. aptana 插件离线下载方式
  8. 设计新Xlator扩展GlusterFS[转]
  9. mysql数据库小常识
  10. Python3 SMTP发送邮件
  11. 在CMD命令下安装nexus报错和启动的问题
  12. asp.net core webapi 生成导出excel
  13. 单片机成长之路(51基础篇) - 008 C51 的标示符和关键字
  14. jdbc链接数据库的url两种写法
  15. 【转】理解js中的原型链,prototype与__proto__的关系
  16. TCP拥塞控制-慢启动、拥塞避免、快重传、快启动
  17. Ie11 的改变
  18. LVS原理详解(3种工作方式8种调度算法)
  19. innodb索引统计信息
  20. Centos 安装ImageMagick 与 imagick for php步骤详解

热门文章

  1. configurable 神图
  2. windows10环境下的RabbitMQ安装步骤
  3. 深度优先搜索算法-dfs讲解
  4. Vulhub 漏洞学习之:ECShop
  5. C#计时器 Stopwatch 使用demo
  6. 解决.Net Core3.0 修改cshtml代码之后必须重新生成才可以看到效果
  7. key对象转换数组title
  8. 《话糙理不糙》之如何在学习openfoam时避免坑蒙拐骗
  9. 大道至简的架构设计思想之:封装(C系架构设计法,sishuok)
  10. 多资产VAR风险--基于python处理