import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt #Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets("/niu/mnist_data/",one_hot=False) # Parameter
learning_rate = 0.01
training_epochs = 10
batch_size = 256
display_step = 1
examples_to_show = 10 # Network Parameters
n_input = 784 #tf Graph input(only pictures)
X=tf.placeholder("float", [None,n_input]) # hidden layer settings
n_hidden_1 = 256
n_hidden_2 = 128
weights = {
'encoder_h1':tf.Variable(tf.random_normal([n_input,n_hidden_1])),
'encoder_h2': tf.Variable(tf.random_normal([n_hidden_1,n_hidden_2])),
'decoder_h1': tf.Variable(tf.random_normal([n_hidden_2,n_hidden_1])),
'decoder_h2': tf.Variable(tf.random_normal([n_hidden_1, n_input])),
}
biases = {
'encoder_b1': tf.Variable(tf.random_normal([n_hidden_1])),
'encoder_b2': tf.Variable(tf.random_normal([n_hidden_2])),
'decoder_b1': tf.Variable(tf.random_normal([n_hidden_1])),
'decoder_b2': tf.Variable(tf.random_normal([n_input])),
} #定义encoder
def encoder(x):
# Encoder Hidden layer with sigmoid activation #1
layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(x, weights['encoder_h1']),
biases['encoder_b1']))
# Decoder Hidden layer with sigmoid activation #2
layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1, weights['encoder_h2']),
biases['encoder_b2']))
return layer_2 #定义decoder
def decoder(x):
# Encoder Hidden layer with sigmoid activation #1
layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(x, weights['decoder_h1']),
biases['decoder_b1']))
# Decoder Hidden layer with sigmoid activation #2
layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1, weights['decoder_h2']),
biases['decoder_b2']))
return layer_2 # Construct model
encoder_op = encoder(X) # 128 Features
decoder_op = decoder(encoder_op) # 784 Features # Prediction
y_pred = decoder_op
# Targets (Labels) are the input data.
y_true = X # Define loss and optimizer, minimize the squared error cost = tf.reduce_mean(tf.pow(y_true - y_pred, 2))
optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost) # Launch the graph
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
total_batch = int(mnist.train.num_examples/batch_size)
# Training cycle
for epoch in range(training_epochs):
# Loop over all batches
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size) # max(x) = 1, min(x) = 0
# Run optimization op (backprop) and cost op (to get loss value)
_, c = sess.run([optimizer, cost], feed_dict={X: batch_xs})
# Display logs per epoch step
if epoch % display_step == 0:
print("Epoch:", '%04d' % (epoch+1),
"cost=", "{:.9f}".format(c)) print("Optimization Finished!")
# # Applying encode and decode over test set
encode_decode = sess.run(
y_pred, feed_dict={X: mnist.test.images[:examples_to_show]})
# Compare original images with their reconstructions
f, a = plt.subplots(2, 10, figsize=(10, 2))
plt.title('Matplotlib,AE--Jason Niu')
for i in range(examples_to_show):
a[0][i].imshow(np.reshape(mnist.test.images[i], (28, 28)))
a[1][i].imshow(np.reshape(encode_decode[i], (28, 28)))
plt.show()

最新文章

  1. 在windows系统下,在终端快速打开某个路径
  2. pip install 报错原因
  3. jquery制作弹出层带遮罩效果,点击阴影部分层消失
  4. LeetCode:Jump Game I II
  5. 【转】Unity3D研究院之通过C#使用Advanced CSharp Messenger(五十)
  6. 【BZOJ】【1874】取石子游戏
  7. string和stringBuilder的区别
  8. CSS学习------之简单图片切换
  9. NET基础课--对象的筛选和排序(NET之美)
  10. Android中SharedPreferences函数具体解释
  11. [转]关于SQL分页存储过程的分析
  12. select 通过表单提交获取select中的值
  13. jQuery控制input不可编辑
  14. 大牛教你用3行HTML代码卡死一台机器
  15. 数据挖掘进阶之序列模式分析算法GSP的实现
  16. AD证书导入文档(单向认证)
  17. pd16.5增加字段备注
  18. python中class的序列化和反序列化
  19. Java 清理和垃圾回收
  20. stm32f103串口实现映射功能

热门文章

  1. PHP项目笔记
  2. IntellJ IDEA下写JUnit
  3. Java代码自动部署
  4. Confluence 6 安装 Oracle
  5. iOS ibeacon 使用详解
  6. VMware安装windows10系统
  7. 水果(map的嵌套)
  8. jmeter从CSV中获取非正常string
  9. .Net(C#)用正则表达式清除HTML标签(包括script和style),保留纯本文(UEdit中编写的内容上传到数据库)
  10. python导入import