#基于mnist数据集的手写数字识别

#构造了cnn网络拟合识别函数,前两层为卷积层,第三层为池化层,第四层为Flatten层,最后两层为全连接层

#基于Keras 2.1.1 Tensorflow 1.4.0

代码:

 from __future__ import print_function
import numpy as np
np.random.seed(1337)
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Convolution2D, MaxPooling2D
from keras.utils import np_utils
from keras import backend as K batch_size = 128
nb_classes = 10
nb_epoch = 12 img_rows, img_cols = 28, 28
nb_filters = 32
pool_size = (2,2)
kernel_size = (3,3)
(X_train, y_train), (X_test, y_test) = mnist.load_data() X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255 Y_train = np_utils.to_categorical(y_train, nb_classes)
Y_test = np_utils.to_categorical(y_test, nb_classes)
# 建立序贯模型
model = Sequential() model.add(Convolution2D(nb_filters, kernel_size[0] ,kernel_size[1],border_mode='valid',input_shape=input_shape))
model.add(Activation('relu')) model.add(Convolution2D(nb_filters, kernel_size[0], kernel_size[1]))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=pool_size))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(nb_classes))
model.add(Activation('softmax')) model.summary()
model.compile(loss='categorical_crossentropy',optimizer='adadelta',metrics=['accuracy'])
model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch,verbose=1, validation_data=(X_test, Y_test)) score = model.evaluate(X_test, Y_test, verbose=0)
print('Test score:', score[0])
print('Test accuracy:', score[1])

最新文章

  1. Win10系统出问题?简单一招即可修复win10!
  2. Warning:mailcious javascript detected on this domain来由
  3. 【UVA 11401】Triangle Counting
  4. 超链接点击后不显示hover
  5. maven搭建java ee项目
  6. shell 循环使用
  7. RGB色彩模式
  8. ios开发-确定/自适应textView的高度
  9. [转载]ubuntu Atheros Communications Device 1083 驱动
  10. HDU 5514 Frogs (容斥原理)
  11. Oracle impdp通过network_link不落地方式导入数据
  12. 【Mysql知识补充】
  13. javscript eval()的优缺点与web安全防范
  14. vue+vuecli+webapck2实现多页面应用
  15. Cent OS安装使用ffmpeg(完整版)
  16. webView 获取内容高度不准确的原因是因为你设置了某个属性
  17. android开发环境配置以及测试所遇到的的问题
  18. 批处理DOS基础命令
  19. 【Hadoop学习之九】MapReduce案例分析一-天气
  20. 吴恩达讲了干货满满的一节全新AI课,全程手写板书充满诚意非常干货

热门文章

  1. Python小技巧整理
  2. 判断php的运行模式
  3. mybatis分页插件pageHelper简单实用
  4. @总结 - 12@ burnside引理与pólya定理
  5. HZOJ Function
  6. 17-1 djanjo进阶-路由,视图,模板
  7. react 问题记录
  8. Android 高仿QQ滑动弹出菜单标记已读、未读消息
  9. laravel 授权使用gate门类
  10. js 数组的拼接