#基于mnist数据集的手写数字识别

#构造了三层全连接层组成的多层感知机,最后一层为输出层

#基于Keras 2.1.1 Tensorflow 1.4.0

代码:

 import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense,Dropout
from keras.optimizers import RMSprop (x_train,y_train),(x_test,y_test) = mnist.load_data()
#载入数据,第一次运行时会从外部网络下载数据集到对应目录下
print(x_train.shape,y_train.shape)
print(x_test.shape,y_test.shape) # import matplotlib.pyplot as plt
# im = plt.imshow(x_train[0],cmap='gray')
# plt.show()
# im2 = plt.imshow(x_train[1],cmap='gray')
# plt.show()
x_train = x_train.reshape(60000,784)
x_test = x_test.reshape(10000,784)
x_train = x_train.astype('float32')
x_train = x_train.astype('float32')
print(x_train.shape)
x_train = x_train/255
x_test = x_test/255
y_train = keras.utils.to_categorical(y_train,10)
y_test = keras.utils.to_categorical(y_test,10) model = Sequential()
model.add(Dense(512,activation='relu',input_shape=(784,)))
model.add(Dropout(0.2))
model.add(Dense(512,activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(10,activation='softmax'))
model.summary()
model.compile(loss='categorical_crossentropy',optimizer=RMSprop(),metrics=['accuracy'])
model.fit(x_train,y_train,batch_size=64,epochs=2,verbose=1,validation_data=(x_test,y_test))
score = model.evaluate(x_test,y_test,verbose=1)
print('Test loss:',score[0])
print('Test accuracy',score[1])

结果:

Test loss: 0.123420921481
Test accuracy 0.9682

最新文章

  1. scheduletask任务调度(2间隔时间)
  2. C语言-《通讯录》
  3. java 缓冲
  4. C# 4.0中dynamic的作用
  5. Java Web的开始学习
  6. What is Proguard?
  7. 查看80端口被占用的方法(IIS、apmserv、system)
  8. L003-oldboy-mysql-dba-lesson03
  9. android:layout_gravity="bottom"不起作用问题
  10. poj1163The Triangle(简单DP)
  11. Chapter 1 First Sight——32
  12. c#DES加密解密代码
  13. 打开控制台F12弹出弹窗
  14. jquery-- json字符串没有自动包装为 json对象
  15. EBS查询默认应用用户,比如是否需要锁定、修改这些用户
  16. BFS-广度优先遍历
  17. 【Python】【Flask】前端调用后端方法
  18. MIPS rop gadgets记录贴&&持续更新
  19. svn 和 git的区别
  20. python 判断是否是元音字母

热门文章

  1. spoj Distinct Substrings 后缀数组
  2. Ubuntu 如何编译安装第三方库
  3. BZOJ 1008 越狱题解
  4. 转载 LibGDX: 使用 Gradle 命令运行和打包项目
  5. seleium 滑动到底部
  6. oracle获取中文出现乱码问题解决
  7. MongoDB -- JAVA基本API操作
  8. python 数据集变量的数据类型总结
  9. Flask学习之十二 使用boostrap
  10. laravel 队列重启