import numpy as np
from keras.datasets import mnist
from keras.models import Sequential, Model
from keras.layers.core import Dense, Activation, Dropout
from keras.utils import np_utils import matplotlib.pyplot as plt
import matplotlib.image as processimage # Load mnist RAW dataset
# 训练集28*28的图片X_train = (60000, 28, 28) 训练集标签Y_train = (60000,1)
# 测试集图片X_test = (10000, 28, 28) 测试集标签Y_test = (10000,1)
(X_train, Y_train), (X_test, Y_test) = mnist.load_data()
print(X_train.shape, Y_train.shape)
print(X_test.shape, Y_test.shape) '''
第一步,准备数据
'''
# Prepare 准备数据
# Reshape 60k个图片,每个28*28的图片,降维成一个784的一维数组
X_train = X_train.reshape(60000, 784) # 28*28 = 784
X_test = X_test.reshape(10000, 784)
# set type into float32 设置成浮点型,因为使用的是GPU,GPU可以加速运算浮点型
# CPU使用int型计算会更快
X_train = X_train.astype('float32') # astype SET AS TYPE INTO
X_test = X_test.astype('float32')
# 归一化颜色
X_train = X_train/255 # 除以255个颜色,X_train(0, 255)-->(0, 1) 更有利于浮点运算
X_test = X_test/255 '''
第二步,给神经网络设置基本参数
'''
# Prepare basic setups
batch_sizes = 4096 # 一次给神经网络注入多少数据,别超过6万,和GPU内存有关
nb_class = 10 # 设置多少个分类
nb_epochs = 10 # 60k数据训练20次,一般小数据10次就够了 '''
第三步,设置标签
'''
# Class vectors label(7) into [0,0,0,0,0,0,0,1,0,1] 把7设置成向量
Y_test = np_utils.to_categorical(Y_test, nb_class) # Label
Y_train = np_utils.to_categorical(Y_train, nb_class) '''
第四步,设置网络结构
'''
model = Sequential() # 顺序搭建层
# 1st layer
model.add(Dense(512, input_shape=(784,))) # Dense是输出给下一层, input_dim = 784 [X*784]
model.add(Activation('relu')) # tanh
model.add(Dropout(0.2)) # overfitting # 2nd layer
model.add(Dense(256)) # 256是因为上一层已经输出512了,所以不用标注输入
model.add(Activation('relu'))
model.add(Dropout(0.2)) # 3rd layer
model.add(Dense(10))
model.add(Activation('softmax')) # 根据10层输出,softmax做分类 '''
第五步,编译compile
'''
model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
metrics=['accuracy']
) # 启动网络训练 Fire up
Trainning = model.fit(
X_train, Y_train,
batch_size=batch_sizes,
epochs=nb_epochs,
validation_data=(X_test, Y_test)
)
# 以上就可运行 '''
最后,检查工作
'''
# Trainning.history # 检查训练历史
# Trainning.params # 检查训练参数 # 拉取test里的图
testrun = X_test[9999].reshape(1, 784) testlabel = Y_test[9999]
print('label:-->', testlabel)
print(testrun.shape)
plt.imshow(testrun.reshape([28, 28])) # 判断输出结果
pred = model.predict(testrun)
print(testrun)
print('label of test same Y_test[9999]-->>', testlabel)
print('预测结果-->>', pred)
print([final.argmax() for final in pred]) # 找到pred数组中的最大值 # 用自己的画的图28*28预测一下 (不太准,可以用卷积)
# 可以用PS创建28*28像素的图,且是灰度,没有色彩
target_img = processimage.imread('/.../picture.jpg')
print(' before reshape:->>', target_img.shape)
plt.imshow(target_img)
target_img = target_img.reshape(1, 784) # reshape
print(' after reshape:->>', target_img.shape) target_img = np.array(target_img) # img --> numpy array
target_img = target_img.astype('float32') # int --> float32
target_img /= 255 # (0,255) --> (0,1) print(target_img) mypred = model.predict(target_img)
print(mypred)
print(myfinal.argmax() for myfinal in mypred)

参考:https://www.bilibili.com/video/av29806227

最新文章

  1. [Java] SoapUI使用Java获取各时间日期方法
  2. link和@import导入css文件的区别
  3. codevs 1281 Xn数列
  4. (转)TeamCity配置笔记
  5. Helpers\Sessions
  6. js 比较日期大小
  7. Swift 的类、结构体、枚举等的构造过程Initialization(下)
  8. 002Java概述
  9. windows FileZilla Server 开启FTP over TLS
  10. 【XSY2665】没有上司的舞会 LCT DP
  11. 【tmos】使用joda-time来个格式化时间
  12. 南方IT学校期末PCB结课项目考试(实操)说明书
  13. 操作系统学习笔记(二) 页式映射及windbg验证方式
  14. POJ 2135.Farm Tour 消负圈法最小费用最大流
  15. mongodb的搭建
  16. 【转】每天一个linux命令(57):ss命令
  17. POJ.1330 Nearest Common Ancestors (LCA 倍增)
  18. #import 指令
  19. osx下查看jar文件
  20. Android中Tablayout设置下划线宽度 和 dp和px之间进行相互转换

热门文章

  1. java常用API的总结(1)
  2. ASP.NET Core中使用GraphQL - 第七章 Mutation
  3. 『练手』003 Laura.SqlForever如何扩展 兼容更多数据库引擎
  4. c# 获取当前时间的微秒
  5. AspNetCore 中使用 InentityServer4(2)
  6. c#编写一个简单的http服务器
  7. HTML中的Hack条件注释语句
  8. React 16.x 新特性思维导图
  9. vue中使用provide和inject刷新当前路由(页面)
  10. layui 轮播图动态数据不显示问题