import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data INPUT_NODE = 784 # 输入节点
OUTPUT_NODE = 10 # 输出节点
LAYER1_NODE = 500 # 隐藏层数 BATCH_SIZE = 100 # 每次batch打包的样本个数 # 模型相关的参数
LEARNING_RATE_BASE = 0.8
LEARNING_RATE_DECAY = 0.99 TRAINING_STEPS = 5000
MOVING_AVERAGE_DECAY = 0.99 def inference(input_tensor, avg_class, weights1, biases1, weights2, biases2):
# 不使用滑动平均类
if avg_class == None:
layer1 = tf.nn.relu(tf.matmul(input_tensor, weights1) + biases1)
return tf.matmul(layer1, weights2) + biases2
else:
# 使用滑动平均类
layer1 = tf.nn.relu(tf.matmul(input_tensor, avg_class.average(weights1)) + avg_class.average(biases1))
return tf.matmul(layer1, avg_class.average(weights2)) + avg_class.average(biases2) def train(mnist):
x = tf.placeholder(tf.float32, [None, INPUT_NODE], name='x-input')
y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name='y-input')
# 生成隐藏层的参数。
weights1 = tf.Variable(tf.truncated_normal([INPUT_NODE, LAYER1_NODE], stddev=0.1))
biases1 = tf.Variable(tf.constant(0.1, shape=[LAYER1_NODE]))
# 生成输出层的参数。
weights2 = tf.Variable(tf.truncated_normal([LAYER1_NODE, OUTPUT_NODE], stddev=0.1))
biases2 = tf.Variable(tf.constant(0.1, shape=[OUTPUT_NODE])) # 计算不含滑动平均类的前向传播结果
y = inference(x, None, weights1, biases1, weights2, biases2) # 定义训练轮数及相关的滑动平均类
global_step = tf.Variable(0, trainable=False)
variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
variables_averages_op = variable_averages.apply(tf.trainable_variables())
average_y = inference(x, variable_averages, weights1, biases1, weights2, biases2) # 计算交叉熵及其平均值
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
cross_entropy_mean = tf.reduce_mean(cross_entropy) # 损失函数的计算
loss = cross_entropy_mean # 设置指数衰减的学习率。
learning_rate = tf.train.exponential_decay(
LEARNING_RATE_BASE,
global_step,
mnist.train.num_examples / BATCH_SIZE,
LEARNING_RATE_DECAY,
staircase=True) # 优化损失函数
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step) # 反向传播更新参数和更新每一个参数的滑动平均值
with tf.control_dependencies([train_step, variables_averages_op]):
train_op = tf.no_op(name='train') # 计算正确率
correct_prediction = tf.equal(tf.argmax(average_y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) # 初始化会话,并开始训练过程。
with tf.Session() as sess:
tf.global_variables_initializer().run()
validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels}
test_feed = {x: mnist.test.images, y_: mnist.test.labels} # 循环的训练神经网络。
for i in range(TRAINING_STEPS):
if i % 1000 == 0:
validate_acc = sess.run(accuracy, feed_dict=validate_feed)
print("After %d training step(s), validation accuracy using average model is %g " % (i, validate_acc))
xs,ys=mnist.train.next_batch(BATCH_SIZE)
sess.run(train_op,feed_dict={x:xs,y_:ys})
test_acc=sess.run(accuracy,feed_dict=test_feed)
print(("After %d training step(s), test accuracy using average model is %g" %(TRAINING_STEPS, test_acc))) def main(argv=None):
mnist = input_data.read_data_sets("E:\\MNIST_data\\", one_hot=True)
train(mnist) if __name__=='__main__':
main()

最新文章

  1. jquery的animate({})动画整理
  2. 委托 lambda表达式浅显理解
  3. 操作系统学习笔记(五)--CPU调度
  4. Unity3d插件汇总
  5. HRBUST 1326 循环找父节点神术
  6. 使用 /proc 文件系统来访问 linux操作系统 内核的内容 && 虚拟文件系统vfs及proc详解
  7. Codevs 2307[SDOI2009]HH的项链
  8. oracle静态与动态监听
  9. WinForm聊天室
  10. 狗血phonegap备忘录[3.3]
  11. 为什么总是要求使用position的时候父类是relative
  12. IOS编程学习笔记
  13. 利用JS做网页特效——大图轮播
  14. 002_Add Two Numbers
  15. keepalived weight正负值问题(实现主服务器nginx故障后迅速切换到备服务器)
  16. css 鼠标移上去会变大
  17. 《RESTful Web APIs中文版》
  18. 【Rewrite重定向】Nginx使用rewrite重新定向
  19. WPF制作表示透明区域的马赛克画刷
  20. e.getMessage() e.printStackTrace() 和e.printStackTrace() 小结

热门文章

  1. HTML5学习(1)简介
  2. HTML学习(16)颜色
  3. python项目虚拟环境搭建
  4. intellij手动创建Mybatis遇到java.io.IOException: Could not find resource mybatis.xml
  5. jquery--获取input checkbox是否被选中以及渲染checkbox选中状态
  6. 交换机的MAC地址?
  7. RPA_播放语音
  8. pgspider http fdw http 相关的几个配置参数
  9. 解决wps for linux缺失windows字体
  10. Java - JVM - jinfo