不加Dropout,训练数据的准确率高,基本上可以接近100%,但是,对于测试集来说,效果并不好;

加上Dropout,训练数据的准确率可能变低,但是,对于测试集来说,效果更好了,所以说Dropout可以防止过拟合。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data # 载入数据集
mnist = input_data.read_data_sets("MNIST_data", one_hot=True) # 每个批次的大小
batch_size = 100
# 计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size # 定义两个placeholder
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
keep_prob = tf.placeholder(tf.float32) # 创建一个简单的神经网络
W1 = tf.Variable(tf.truncated_normal([784, 2000], stddev=0.1))
b1 = tf.Variable(tf.zeros([2000]) + 0.1)
L1 = tf.nn.tanh(tf.matmul(x, W1) + b1)
L1_drop = tf.nn.dropout(L1, keep_prob) W2 = tf.Variable(tf.truncated_normal([2000, 2000], stddev=0.1))
b2 = tf.Variable(tf.zeros([2000]) + 0.1)
L2 = tf.nn.tanh(tf.matmul(L1_drop, W2) + b2)
L2_drop = tf.nn.dropout(L2, keep_prob) W3 = tf.Variable(tf.truncated_normal([2000, 1000], stddev=0.1))
b3 = tf.Variable(tf.zeros([1000]) + 0.1)
L3 = tf.nn.tanh(tf.matmul(L2_drop, W3) + b3)
L3_drop = tf.nn.dropout(L3, keep_prob) W4 = tf.Variable(tf.truncated_normal([1000, 10], stddev=0.1))
b4 = tf.Variable(tf.zeros([10]) + 0.1)
prediction = tf.nn.softmax(tf.matmul(L3_drop, W4) + b4) # 二次代价函数
# loss = tf.reduce_mean(tf.square(y-prediction))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=prediction))
# 使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) # 初始化变量
init = tf.global_variables_initializer() # 结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1)) #argmax返回一维张量中最大的值所在的位置
# 求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) with tf.Session() as sess:
sess.run(init)
for epoch in range(31):
for batch in range(n_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys, keep_prob: 0.7}) test_acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels, keep_prob: 1.0})
train_acc = sess.run(accuracy, feed_dict={x: mnist.train.images, y: mnist.train.labels, keep_prob: 1.0})
print("Iter " + str(epoch) + ",Testing Accuracy " + str(test_acc) + ",Training Accuracy " + str(train_acc))

最新文章

  1. [LeetCode] Odd Even Linked List 奇偶链表
  2. DevExpress GridControl 选择整行被选单元格不变色的设置
  3. CC2540自己的配置文件
  4. 由单例模式学到:Lazy<T>
  5. Javascript中typeof instanceof constructor的区别
  6. exp命令ORACLCE10G导出ORACLE11G的数据1455错误
  7. Elsevier期刊网上投稿指南
  8. WCF分布式开发步步为赢(2)自定义托管宿主WCF解决方案开发配置过程详解
  9. java的动态代理机制
  10. select组件
  11. com.mysql.jdbc.exceptions.MySQLSyntaxErrorException错误
  12. 第二次作业--------STEAM
  13. android4.2添加重启菜单项
  14. Android中View的绘制流程(专题讲解)
  15. LeetCode手记-Add Binary
  16. WinThruster清理电脑注册表
  17. github协同开发
  18. bzoj千题计划179:bzoj1237: [SCOI2008]配对
  19. NOIP上机测试注意事项
  20. mysql源码

热门文章

  1. 设计模式之动态代理(Java的JDK动态代理实现)
  2. CSS多种方式实现底部对齐
  3. 如何将阿里云上的RDS 备份的mysql数据还原到windows环境中
  4. 阶段1 语言基础+高级_1-3-Java语言高级_06-File类与IO流_04 IO字节流_3_字节输出流_OutputStream类&FileOutputStream
  5. Caffe::Snapshot的运行过程
  6. shell 比较符号
  7. 【ABAP系列】SAP ABAP选择屏幕(SELECTION SCREEN)事件解析
  8. Creat-React-Native-App 之StackNavigator之踩坑记录
  9. python列表-定义
  10. CentOS7 修复boot目录