import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #this is data
mnist = input_data.read_data_sets("MNIST_data",one_hot=True) lr = 0.001
train_iters = 10000
batch_size = 128
display_step = 10 n_inputs = 28
n_steps = 28
n_hidden_unis = 128
n_classes = 10 x = tf.placeholder(tf.float32,[None,n_steps,n_inputs])
y = tf.placeholder(tf.float32,[None,n_classes]) #define weight
weights = {
#(28,128)
"in":tf.Variable(tf.random_normal([n_inputs,n_hidden_unis])),
#(128,10)
"out":tf.Variable(tf.random_normal([n_hidden_unis,n_classes]))
}
biases = {
#(128,)
"in":tf.Variable(tf.constant(0.1,shape=[n_hidden_unis,])),
#(10,)
"out":tf.Variable(tf.constant(0.1,shape=[n_classes,]))
} def RNN(X,weights,biases):
#形状变换成lstm可以训练的维度
X = tf.reshape(X,[-1,n_inputs]) #(128*28,28)
X_in = tf.matmul(X,weights["in"])+biases["in"] #(128*28,128)
X_in = tf.reshape(X_in,[-1,n_steps,n_hidden_unis]) #(128,28,128) #cell
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden_unis,forget_bias=1.0,state_is_tuple=True)
#lstm cell is divided into two parts(c_state,m_state)
_init_state = lstm_cell.zero_state(batch_size,dtype=tf.float32) outputs,states = tf.nn.dynamic_rnn(lstm_cell,X_in,initial_state=_init_state,time_major = False) #outputs
# results = tf.matmul(states[1],weights["out"])+biases["out"]
#or
outputs = tf.transpose(outputs,[1,0,2])
results = tf.matmul(outputs[-1],weights["out"])+biases["out"] return results pred = RNN(x,weights,biases)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
train_op = tf.train.AdamOptimizer(lr).minimize(loss) correct_pred = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred,tf.float32)) init = tf.initialize_all_variables() with tf.Session() as sess:
sess.run(init)
step = 0
while step*batch_size < train_iters:
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
batch_xs = batch_xs.reshape([batch_size,n_steps,n_inputs])
sess.run(train_op,feed_dict={x:batch_xs,y:batch_ys})
if step%20 ==0:
print(sess.run(accuracy,feed_dict={x:batch_xs,y:batch_ys}))

  

最新文章

  1. NSSortDescriptor 的使用
  2. struts2+spring的两种整合方式
  3. 关于是用dotnet获取本机IP地址+计算机名的方法
  4. [C#]循环输出 000 - 999999
  5. 使用Spring的命名空间p装配属性-摘自《Spring实战(第3版)》
  6. php之常用函数库
  7. Android 内核基本知识
  8. CriticalFinalizerObject的作用
  9. List用法
  10. SlopOne推荐算法
  11. web Deploy发布问题
  12. 闵可夫斯基和(Mincowsky sum)
  13. ThinkPHP 5隐藏public/index.php方法
  14. js对象,字符串 互相 转换
  15. HDU 3594 Cactus (强连通+仙人掌图)
  16. HDU3584 Cube
  17. FastAdmin 的 url 有一个 ref=addtabs 是怎么添加的?
  18. Effective STL 笔记: Item 6--Be alert for C++&#39;s most vexing parse
  19. 如何删除EF4.0以上的版本
  20. 使用 Azure CLI 将 IaaS 资源从经典部署模型迁移到 Azure Resource Manager 部署模型

热门文章

  1. 知识图谱里的知识表示:RDF
  2. 分享一款一直在维护的【网络开发运维|通用调试工具】: http请求, websocket,cmd, RSA,DES, 参数签名工具,脚本批量生成工具,google动态口令,端口检测,组件注册,js混淆...
  3. WPF使用 Gmap.NET 绘制极坐标运动轨迹
  4. webpack4.x 从零开始配置vue 项目(三)
  5. SaaS架构(一) 弱后端强前端的尝试和问题
  6. awk扩展应用
  7. Vulnhub DC-7靶机渗透
  8. docker 相关操作
  9. 计算机网络协议,PPP协议分析
  10. C语言输出 1到20 的阶乘之和