Python深度学习读书笔记-2.初识神经网络
2024-08-28 04:02:06
MNIST 数据集
包含60 000 张训练图像和10 000 张测试图像,由美国国家标准与技术研究院(National Institute of Standards and Technology,即MNIST 中
的NIST)在20 世纪80 年代收集得到。
类和标签
在机器学习中,分类问题中的某个类别叫作类(class)。数据点叫作样本(sample)。某
个样本对应的类叫作标签(label)。
MNIST 数据集预先加载在Keras 库中,其中包括4 个Numpy 数组。
from keras.datasets import mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_images 和train_labels 组成了训练集(training set),模型将从这些数据中进行
学习。然后在测试集(test set,即test_images 和test_labels)上对模型进行测试。
图像被编码为Numpy 数组,而标签是数字数组,取值范围为0~9。图像和标签一一对应。
我们来看一下训练数据:
>>> train_images.shape
(60000, 28, 28)
>>> len(train_labels)
60000
>>> train_labels
array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)
测试数据:
>>> test_images.shape
(10000, 28, 28)
>>> len(test_labels)
10000
>>> test_labels
array([7, 2, 1, ..., 4, 5, 6], dtype=uint8)
神经网络架构
from keras import models
from keras import layers
network = models.Sequential()
network.add(layers.Dense(512, activation='relu', input_shape=(28 * 28,)))
network.add(layers.Dense(10, activation='softmax'))
本例中的网络包含2 个Dense 层,它们是密集连接(也叫全连接)的神经层。第二层(也
是最后一层)是一个10 路softmax 层,它将返回一个由10 个概率值(总和为1)组成的数组。
每个概率值表示当前数字图像属于10 个数字类别中某一个的概率。
要想训练网络,我们还需要选择编译(compile)步骤的三个参数。
损失函数(loss function):网络如何衡量在训练数据上的性能,即网络如何朝着正确的
方向前进。
优化器(optimizer):基于训练数据和损失函数来更新网络的机制。
在训练和测试过程中需要监控的指标(metric):本例只关心精度,即正确分类的图像所
占的比例。
编译步骤
network.compile(optimizer='rmsprop',loss='categorical_crossentropy', metrics=['accuracy'])
在开始训练之前,我们将对数据进行预处理,将其变换为网络要求的形状,并缩放到所
有值都在[0, 1] 区间。比如,之前训练图像保存在一个uint8 类型的数组中,其形状为
(60000, 28, 28),取值区间为[0, 255]。我们需要将其变换为一个float32 数组,其形
状为(60000, 28 * 28),取值范围为0~1。
准备图像数据
train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype('float32') / 255
test_images = test_images.reshape((10000, 28 * 28))
test_images = test_images.astype('float32') / 255
准备标签
from keras.utils import to_categorical
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)
开始训练网络
>>> network.fit(train_images, train_labels, epochs=5, batch_size=128)
Epoch 1/5
60000/60000 [=============================] - 9s - loss: 0.2524 - acc: 0.9273
Epoch 2/5
51328/60000 [=======================>.....] - ETA: 1s - loss: 0.1035 - acc: 0.9692
检查模型在测试集上的性能
>>> test_loss, test_acc = network.evaluate(test_images, test_labels)
>>> print('test_acc:', test_acc)
test_acc: 0.9785
最新文章
- Troubleshooting:重新安装Vertica建库后无法启动
- 剑指Offer面试题:31.两个链表的第一个公共节点
- SQL Server里的闩锁耦合(Latch Coupling)
- js 日期对象Date以及传参
- java 27 - 8 反射之 通过反射来设置某个对象的某个属性为指定值
- Google Protocol Buffer 简单介绍
- Codeforces Round #258 E Devu and Flowers --容斥原理
- firefox与chrome中对select下拉框中的option支持问题
- 递推DP 赛码 1005 Game
- 【Todo】Mybatis学习-偏理论
- 20145120黄玄曦 《java程序设计》 寒假学习总结
- 使用CXF与Spring集成实现RESTFul WebService
- Mac OS命令行运行Sublime Text
- c++ 资源索引
- support.SerializationFailedException: Failed to deserialize payload.
- ssl通关的概念(一个)
- Oracle表空间及分区表
- springboot07-security
- Struts2与spingmvc区别
- 课堂测试代码(未完全实现,部分代码有bug,仅供参考)
热门文章
- Oracle及SQLPLUS使用笔记
- 关于jQuery获取不到动态添加的元素节点的问题
- 初探 -2 JavaScript
- Xcode中常用的快捷键(原文链接http://www.cocoachina.com/ios/20141224/10752.html)
- CodeReview的一些原则
- AIX中逻辑卷管理
- Metafile::EmfToWmfBits的使用
- H5手机端开发问题汇总及解决方案
- 一、redis安装、配置、命令
- Codeforces Round #568 (Div. 2) D. Extra Element