#coding=utf-8
import tensorflow as tf
import numpy as np
import matplotlib .pyplot as plt
from tensorflow .examples .tutorials .mnist import input_data #define dataset mnist=input_data .read_data_sets ("/home/nvidia/Downloads/",one_hot= True ) #defien agruments batch_zize=20
iter=np.int(mnist .train.images.shape[0]/batch_zize )
print(iter ) #define learning_rate LEARNING_RATE_STEP=100
LEARNING_RATE_BASE=0.001
LEARNING_RATE_DECAY=0.99
global_step=tf.Variable (0,trainable= False )
learning_rate=tf.train.exponential_decay (learning_rate= LEARNING_RATE_BASE ,global_step= global_step ,decay_steps= LEARNING_RATE_STEP
,decay_rate= LEARNING_RATE_DECAY ,staircase= True ) #define tool def Weight_V(shape):
weight=tf.truncated_normal (shape=shape,stddev= 0.1)
return tf.Variable (weight ) def bias_V(shape):
bia_=tf.constant (shape=shape,value= 0.1)
return tf.Variable (bia_ ) def conv2d_(x,w):
return tf.nn.conv2d (x,filter= w,padding= "SAME",strides= [1,1,1,1]) def max_pool(x):
return tf.nn.max_pool (x,ksize= [1,2,2,1],strides=[1,2,2,1],padding="SAME") #define net x_input=tf.placeholder (shape=[None,784],dtype= tf.float32)
y_input=tf.placeholder (shape= [None,10],dtype= tf.float32) x =tf.reshape(x_input ,shape= [-1,28,28,1]) #
w_conv1=Weight_V(shape= [5,5,1,32])
b_conv1=bias_V(shape= [32])
c_conv1=tf.nn.relu (conv2d_(x ,w_conv1 )+b_conv1 )
m_conv1=max_pool(c_conv1 )
#14*14*32 w_conv2=Weight_V(shape= [5,5,32,64])
b_conv2=bias_V(shape= [64])
c_conv2=tf.nn.relu (conv2d_(m_conv1 ,w_conv2 )+b_conv2 )
m_conv2=max_pool(c_conv2 )
#7*7*64 w_fc1=Weight_V([7*7*64,1024])
b_fc1=bias_V(shape= [1024])
c_fc1=tf.reshape(m_conv2 ,[-1,7*7*64])
fc1=tf.nn.relu(tf.matmul(c_fc1 ,w_fc1 )+b_fc1 ) w_fc2=Weight_V(shape= [1024,10])
b_fc2=bias_V(shape= [10])
prediction=tf.nn.softmax (tf.matmul(fc1,w_fc2 )+b_fc2 ) #define # correct_accurcy=tf.equal(tf.argmax(prediction,axis=1),tf.argmax(y_input,axis=1))
# accurcy=tf.reduce_mean(tf.cast(correct_accurcy,dtype=tf.float32)) correct_accurcy=tf.equal (tf.argmax (prediction ,axis= 1),tf.argmax (y_input ,axis= 1)) accurcy=tf.reduce_mean (tf.cast(correct_accurcy ,dtype= tf.float32)) #traing backward
#
crosss_entropy =-tf.reduce_mean (y_input *tf.log(prediction ))
train_step=tf.train.GradientDescentOptimizer (learning_rate).minimize(crosss_entropy,global_step= global_step ) #initial global argumnets init=tf.global_variables_initializer () #SESS with tf.Session() as sess:
sess.run(init)
for i in range(21):
X,Y=mnist .test.next_batch(100)
for j in range(iter ):
xt,yt=mnist .train.next_batch (batch_zize )
sess.run(train_step ,feed_dict= {x_input :xt,y_input :yt}) acc=sess.run(accurcy ,feed_dict= {x_input :X,y_input :Y})
print(acc)

最新文章

  1. jQuery实现checkbox反选(转载)
  2. ubuntu graphic cannot display
  3. 快速入门系列--WCF--07传输安全、授权与审核
  4. codeforces 360 C - NP-Hard Problem
  5. HDU 2089 数位dp/字符串处理 两种方法
  6. WPF之小动画二
  7. 基于Ubuntu 15.04 LTS编译Android5.1.0源代码 (转)
  8. UVA 1599 Ideal Path
  9. git解决冲突
  10. java中this关键字解析
  11. webpack code splitting
  12. springboot shiro 项目前端页面访问问题总结
  13. XSplit Quality, VBV-Buffer, VBV-Maxrate and Preset Settings
  14. 微信公众号自定义菜单中添加emoji表情
  15. Qt5中的lambda表达式和使用lambda来写connect
  16. LAMP平台部署
  17. CodeReview实践与总结
  18. HSTS 与 307 状态码
  19. Linux 系统的目录结构_【all】
  20. 洛谷 P1306 斐波那契公约数

热门文章

  1. ORACLE 删除重复的数据
  2. Vue二次精度随笔(2)
  3. 【STM32H7教程】第53章 STM32H7的LTDC应用之汉字小字库和全字库制作
  4. oracle数据库的完整性约束规则详解
  5. 前端学习笔记系列一:15vscode汉化、快速复制行、网页背景图有效设置、 dl~dt~dd标签使用
  6. Mac的VIM中delete键失效的原因和解决方案
  7. IEEE Spectrum 2014年十大编程语言盘点
  8. POJ 3273:Monthly Expense 二分好题啊啊啊啊啊啊
  9. 选择本地文件上传控件 input标签
  10. 【pwnable.tw】 starbound