本文发布于 2020-12-27,很可能已经过时

fashion_mnist 计算准确率、召回率、F1值

1、定义

首先需要明确几个概念:

假设某次预测结果统计为下图:

那么各个指标的计算方法为:

  • A类的准确率:TP1/(TP1+FP5+FP9+FP13+FP17) 即预测为A的结果中,真正为A的比例
  • A类的召回率:TP1/(TP1+FP1+FP2+FP3+FP4) 即实际上所有为A的样例中,能预测出来多少个A(的比例)
  • A类的F1值:(准确率*召回率*2)/(准确率+召回率)

实际上我们在训练出某个模型后,会将测试集中每个测试样例进行一次结果预测,因此只需统计这些结果,经过计算即可得到各类数据的准确率、召回率、F1值

2、使用fashion_mnist

需要提前pip安装tensorflow、prettytable、numpy

from tensorflow import keras
import numpy as np
import prettytable # 下载数据集
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data() # 制作标签名称
class_names = ['T-shirt', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Boot']
# 图片数据归一化
train_images = train_images / 255.0
test_images = test_images / 255.0 # 构建3层DNN模型,使用激活函数softmax
model = keras.Sequential([
keras.layers.Flatten(input_shape=(28, 28)),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(10, activation='softmax')
])
# 定义模型的损失函数,优化器与评估指标
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=0.001),
loss=keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy']
)
# 训练模型
model.fit(train_images, train_labels, epochs=5)
# 评估模型
predictions = model.predict(test_images)
train_result = np.zeros((10, 10), dtype=int)
for i in range(10000):
train_result[test_labels[i]][np.argmax(predictions[i])] += 1 result_table = prettytable.PrettyTable()
result_table.field_names = ['Type', 'Accu', 'Recall', 'F1']
for i in range(10):
ac = train_result[i][i] / sum(train_result.T[i])
rc = train_result[i][i] / sum(train_result[i])
result_table.add_row([class_names[i], round(ac, 3), round(rc, 3), round(ac * rc * 2 / (ac + rc), 3)]) print(result_table)

实际效果:

最新文章

  1. 利用节点更改table内容
  2. HTML5 移动浏览器支持
  3. Struts2中过滤器和拦截器的区别
  4. JQuery:JQuery删除元素
  5. HDU2073(暴力) VS HDU5214(贪心)
  6. HW6.10
  7. Wireshark 使用教程
  8. GLFW库文件配置
  9. 嵌入Python | 调用Python模块中有参数的函数
  10. unity打包exe中的资源管理
  11. python深浅拷贝与赋值
  12. nowcoder练习赛28
  13. [LeetCode] 312. Burst Balloons_hard tag: 区间Dynamic Programming
  14. java修改AD域用户密码使用SSL连接方式
  15. 通过UNIX域套接字传递描述符和 sendmsg/recvmsg 函数
  16. 重新来认识你的老朋友Spring框架
  17. iOS- 指压即达,如何集成iOS9里的3D Touch
  18. MongoDB GridFS规范
  19. Python3之itertools模块
  20. sed----Linux下文本处理五大神器之一

热门文章

  1. Springcloud-微服务
  2. 替小白整理的 linux基操命令 切勿扣6 不用感谢
  3. 2022寒假集训day2
  4. Spring5源码解析系列一——IoC容器核心类图
  5. 从MVC到DDD的架构演进
  6. Solution -「多校联训」Sample
  7. Solution -「洛谷 P3911」最小公倍数之和
  8. CentOS7防火墙firewall
  9. curl常用参数详解及示例
  10. 忘掉cmd.exe吧!选用优雅的控制台终端(ConsoleZ)