目录

Fashion MNIST数据库

分类模型的建立

模型预测

总体代码


主要介绍基于tf.keras的Fashion MNIST数据库分类,

官方文档地址为:https://tensorflow.google.cn/tutorials/keras/basic_classification

文本分类类似,官网文档地址为https://tensorflow.google.cn/tutorials/keras/basic_text_classification

首先是函数的调用,对于tensorflow只有在版本1.2以上的版本才有tf.keras库。另外推荐使用python3,而不是python2。

# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras # 其他库
import numpy as np
import matplotlib.pyplot as plt
#查看版本
print(tf.__version__)
#1.9.0

Fashion MNIST数据库

fashion mnist数据库是mnist数据库的一个拓展。目的是取代mnist数据库,类似MINST数据库,fashion mnist数据库为训练集60000张,测试集10000张的28X28大小的服装彩色图片。具体分类如下:

标注编号 描述
0 T-shirt/top(T恤)
1 Trouser(裤子)
2 Pullover(套衫)
3 Dress(裙子)
4 Coat(外套)
5 Sandal(凉鞋)
6 Shirt(汗衫)
7 Sneaker(运动鞋)
8 Bag(包)
9 Ankle boot(踝靴)

样本描述如下:

名称 描述 样本数量 文件大小 链接
train-images-idx3-ubyte.gz 训练集的图像 60,000 26 MBytes 下载
train-labels-idx1-ubyte.gz 训练集的类别标签 60,000 29 KBytes 下载
t10k-images-idx3-ubyte.gz 测试集的图像 10,000 4.3 MBytes 下载
t10k-labels-idx1-ubyte.gz 测试集的类别标签 10,000 5.1 KBytes 下载

单张图像展示代码:

#分类标签
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
#单张图像展示,推荐使用python3
plt.figure()
plt.imshow(train_images[0])
#添加颜色渐变条
plt.colorbar()
#不显示网格线
plt.gca().grid(False)

效果图:

样本的展示代码:

#图像预处理
train_images = train_images / 255.0
test_images = test_images / 255.0 #样本展示
plt.figure(figsize=(10,10))
for i in range(25):
plt.subplot(5,5,i+1)
plt.xticks([])
plt.yticks([])
plt.grid('off')
plt.imshow(train_images[i], cmap=plt.cm.binary)
plt.xlabel(class_names[train_labels[i]])

效果图:

分类模型的建立

检测模型输入数据为28X28,1个隐藏层节点数为128,输出类别10类,代码如下:

#检测模型
model = keras.Sequential([
keras.layers.Flatten(input_shape=(28, 28)),
keras.layers.Dense(128, activation=tf.nn.relu),
keras.layers.Dense(10, activation=tf.nn.softmax)
])

模型训练参数设置:

model.compile(optimizer=tf.train.AdamOptimizer(),
loss='sparse_categorical_crossentropy', #多分类的对数损失函数
metrics=['accuracy']) #准确度

模型的训练:

model.fit(train_images, train_labels, epochs=5)

模型预测

预测函数:

predictions = model.predict(test_images)

分类器是softmax分类器,输出的结果一个predictions是一个长度为10的数组,数组中每一个数字的值表示其所对应分类的概率值。如下所示:

predictions[0]
array([2.1840347e-07, 1.9169457e-09, 4.5915922e-08, 5.3185740e-08,
6.6372898e-08, 2.6090498e-04, 6.5197796e-06, 4.7861701e-03,
2.9425648e-06, 9.9494308e-01], dtype=float32)

对于predictions[0]其中第10个值最大,则该值对应的分类为class[9]ankle boot。

np.argmax(predictions[0]) #9
test_labels[0] #9

前25张图的分类效果展示:

#前25张图分类效果
plt.figure(figsize=(10,10))
for i in range(25):
plt.subplot(5,5,i+1)
plt.xticks([])
plt.yticks([])
plt.grid('off')
plt.imshow(test_images[i], cmap=plt.cm.binary)
predicted_label = np.argmax(predictions[i])
true_label = test_labels[i]
if predicted_label == true_label:
color = 'green'
else:
color = 'red'
plt.xlabel("{} ({})".format(class_names[predicted_label],
class_names[true_label]),
color=color)

效果图,绿色标签表示分类正确,红色标签表示分类错误:

对于单个图像的预测,需要将图像28X28的输入转换为1X28X28的输入,转换函数为np.expand_dims。函数使用如下:https://www.zhihu.com/question/265545749

#格式转换
img = (np.expand_dims(img,0))
print(img.shape) #1X28X28 predictions = model.predict(img)
prediction = predictions[0]
np.argmax(prediction) #9

总体代码

# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras # 其他库
import numpy as np
import matplotlib.pyplot as plt
#查看版本
print(tf.__version__)
#1.9.0 fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data() #分类标签
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
#单张图像展示,推荐使用python3
plt.figure()
plt.imshow(train_images[0])
#添加颜色渐变条
plt.colorbar()
#不显示网格线
plt.gca().grid(False) #图像预处理
train_images = train_images / 255.0
test_images = test_images / 255.0 #样本展示
plt.figure(figsize=(10,10))
for i in range(25):
plt.subplot(5,5,i+1)
plt.xticks([])
plt.yticks([])
plt.grid('off')
plt.imshow(train_images[i], cmap=plt.cm.binary)
plt.xlabel(class_names[train_labels[i]]) #检测模型
model = keras.Sequential([
keras.layers.Flatten(input_shape=(28, 28)),
keras.layers.Dense(128, activation=tf.nn.relu),
keras.layers.Dense(10, activation=tf.nn.softmax)
]) model.compile(optimizer=tf.train.AdamOptimizer(),
loss='sparse_categorical_crossentropy', #多分类的对数损失函数
metrics=['accuracy']) #准确度 model.fit(train_images, train_labels, epochs=5) predictions = model.predict(test_images) #前25张图分类效果
plt.figure(figsize=(10,10))
for i in range(25):
plt.subplot(5,5,i+1)
plt.xticks([])
plt.yticks([])
plt.grid('off')
plt.imshow(test_images[i], cmap=plt.cm.binary)
predicted_label = np.argmax(predictions[i])
true_label = test_labels[i]
if predicted_label == true_label:
color = 'green'
else:
color = 'red'
plt.xlabel("{} ({})".format(class_names[predicted_label],
class_names[true_label]),
color=color) #单个图像检测
img = test_images[0]
print(img.shape) #28X28 #格式转换
img = (np.expand_dims(img,0))
print(img.shape) #1X28X28 predictions = model.predict(img)
prediction = predictions[0]
np.argmax(prediction) #9

最新文章

  1. webrtc进阶-信令篇-之三:信令、stun、turn、ice
  2. 日志框架只打印出Mybatis SQL的配置
  3. JS基础回顾,小练习(去除字符串空格)
  4. 列出本机JCE提供者,支持消息摘要算法,支持公钥私钥算法
  5. ABBYY可以给我们解决那些问题
  6. 常用JS
  7. Awesomplete 屌爆了
  8. ThinkPHP 发送post请求
  9. ASP.NET MVC 用户登录Login
  10. 开启MongoDB客户端访问控制
  11. Linux基础教程
  12. OC学习12——字符串、日期、日历
  13. linux端口详解大全
  14. 在 Ubuntu14.04 上搭建 Spark 2.3.1(latest version)
  15. Eclipse+Maven整合开发Java项目(一)➣Maven基础环境配置
  16. vue mand-mobile ui加class不起作用的问题 css权重问题
  17. hessian 在spring中的使用 (bean 如 Dao无法注入的问题)
  18. Eclipse配置Maven的一些问题
  19. [POJ2337]Catenyms
  20. kubectl get componentstatus ERROR:HTTP probe failed with statuscode: 503

热门文章

  1. mysql工具的使用、增删改查
  2. 详解商业智能“前世今生”,“嵌入式BI”到底是如何产生的?
  3. 齐博x1齐博首创钩子的使用方法
  4. 【MySQL】03_数据类型
  5. 折腾黑苹果-小新Pro13
  6. 题解 AT2361 [AGC012A] AtCoder Group Contest
  7. JS常见问题总结
  8. 【Bluetooth蓝牙开发】一、开篇词 | 打造全网最详细的Bluetooth开发教程
  9. Day16异常1
  10. nginx安装及相关操作