再来一个tensorflow的测试性能的代码
2024-09-08 10:37:21
感觉这个比前一套,容易理解些~~
关于数据提前下载的问题:
https://www.jianshu.com/p/5116046733fe
如果使用keras的cifar10.load_data()函数,你会发现,代码会自动去下载 cifar-10-python.tar.gz 文件
实际上,通过查看cifar10.py和site-packages/keras/utils/data_utils.py的get_file函数,你会发现,代码将将下载后的文件存放在 ~./keras/datasets目录下,但是!!!!文件名却被改成了 cifar-10-batches-py.tar.gz
惊不惊喜,意不意外?所以如果要避免下载,已经有数据集了,应该:
cp cifar-10-python.tar.gz ~./keras/datasets/cifar-10-batches-py.tar.gz
完美解决问题!
作者:不爱吃饭的小孩怎么办
链接:https://www.jianshu.com/p/5116046733fe
来源:简书
简书著作权归作者所有,任何形式的转载都请联系作者获得授权并注明出处。
import timeit import tensorflow as tf import numpy as np from tensorflow.keras.datasets.cifar10 import load_data def model(): x = tf.placeholder(tf.float32, shape=[None, 32, 32, 3]) y = tf.placeholder(tf.float32, shape=[None, 10]) rate = tf.placeholder(tf.float32) # convolutional layer 1 conv_1 = tf.layers.conv2d(x, 32, [3, 3], padding='SAME', activation=tf.nn.relu) max_pool_1 = tf.layers.max_pooling2d(conv_1, [2, 2], strides=2, padding='SAME') drop_1 = tf.layers.dropout(max_pool_1, rate=rate) # convolutional layer 2 conv_2 = tf.layers.conv2d(drop_1, 64, [3, 3], padding="SAME", activation=tf.nn.relu) max_pool_2 = tf.layers.max_pooling2d(conv_2, [2, 2], strides=2, padding="SAME") drop_2 = tf.layers.dropout(max_pool_2, rate=rate) # convolutional layers 3 conv_3 = tf.layers.conv2d(drop_2, 128, [3, 3], padding="SAME", activation=tf.nn.relu) max_pool_3 = tf.layers.max_pooling2d(conv_3, [2, 2], strides=2, padding="SAME") drop_3 = tf.layers.dropout(max_pool_3, rate=rate) # fully connected layer 1 flat = tf.reshape(drop_3, shape=[-1, 4 * 4 * 128]) fc_1 = tf.layers.dense(flat, 80, activation=tf.nn.relu) drop_4 = tf.layers.dropout(fc_1 , rate=rate) # fully connected layer 2 or the output layers fc_2 = tf.layers.dense(drop_4, 10) output = tf.nn.relu(fc_2) # accuracy correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(output, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) # loss loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=output, labels=y)) # optimizer optimizer = tf.train.AdamOptimizer(1e-4, beta1=0.9, beta2=0.999, epsilon=1e-8).minimize(loss) return x, y, rate, accuracy, loss, optimizer def one_hot_encoder(y): ret = np.zeros(len(y) * 10) ret = ret.reshape([-1, 10]) for i in range(len(y)): ret[i][y[i]] = 1 return (ret) def train(x_train, y_train, sess, x, y, rate, optimizer, accuracy, loss): batch_size = 128 y_train_cls = one_hot_encoder(y_train) start = end = 0 for i in range(int(len(x_train) / batch_size)): if (i + 1) % 100 == 1: start = timeit.default_timer() batch_x = x_train[i * batch_size:(i + 1) * batch_size] batch_y = y_train_cls[i * batch_size:(i + 1) * batch_size] _, batch_loss, batch_accuracy = sess.run([optimizer, loss, accuracy], feed_dict={x:batch_x, y:batch_y, rate:0.4}) if (i + 1) % 100 == 0: end = timeit.default_timer() print("Time:", end-start, "s the loss is ", batch_loss, " and the accuracy is ", batch_accuracy * 100, "%") def test(x_test, y_test, sess, x, y, rate, accuracy, loss): batch_size = 64 y_test_cls = one_hot_encoder(y_test) global_loss = 0 global_accuracy = 0 for t in range(int(len(x_test) / batch_size)): batch_x = x_test[t * batch_size : (t + 1) * batch_size] batch_y = y_test_cls[t * batch_size : (t + 1) * batch_size] batch_loss, batch_accuracy = sess.run([loss, accuracy], feed_dict={x:batch_x, y:batch_y, rate:1}) global_loss += batch_loss global_accuracy += batch_accuracy global_loss = global_loss / (len(x_test) / batch_size) global_accuracy = global_accuracy / (len(x_test) / batch_size) print("In Test Time, loss is ", global_loss, ' and the accuracy is ', global_accuracy) EPOCH = 100 (x_train, y_train), (x_test, y_test) = load_data() print("There is ", len(x_train), " training images and ", len(x_test), " images") x, y, rate, accuracy, loss, optimizer = model() sess = tf.Session() sess.run(tf.global_variables_initializer()) for i in range(EPOCH): print("Train on epoch ", i ," start") train(x_train, y_train, sess, x, y, rate, optimizer, accuracy, loss) test(x_train, y_train, sess, x, y, rate, accuracy, loss)
最新文章
- 【Redis】简介与安装
- Javascript运动基础
- Asp.Net回车键触发Button的OnClick事件解决方案
- C+命令行+方向键=简易版扫雷
- struts 标签库注解
- ASP.NET MVC5 高级编程 第2章 控制器
- 【笨嘴拙舌WINDOWS】tagTEXTMETRIC结构
- win32 sdk树形控件的项拖拽实现
- python之字符串
- 移动应用开发者最应该知道的8款SDK
- Java模拟http上传文件请求(HttpURLConnection,HttpClient4.4,RestTemplate)
- 自制操作系统Antz(10)——实现shell(上)
- 一个小时就能理解Java的NIO必须掌握这三大要素!
- qs.parse()、qs.stringify()使用方法
- RocketMQ最佳实践(一)4.0版本/概念介绍/安装调试/客户端demo
- django —— Celery实现异步和定时任务
- HTML5游戏开发系列教程9(译)
- Docker 快速入门教程
- 【BioCode】读文件夹以发现缺失文件
- 网络协议之ts---ts 流分析详解