import tensorflow as tf
from sklearn.datasets import load_digits
from sklearn.cross_validation import train_test_split
from sklearn.preprocessing import LabelBinarizer #load data
digits = load_digits()
X = digits.data
y = digits.target
y = LabelBinarizer().fit_transform(y)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.3) #
# add layer
#
def add_layer(inputs, in_size, out_size, n_layer, activation_function = None):
layer_name = 'layer%s' % n_layer Weights = tf.Variable(tf.random_normal([in_size, out_size]), name='W') # hang lie
biases = tf.Variable(tf.zeros([1, out_size]) + 0.1, name = 'b') Wx_plus_b = tf.matmul(inputs, Weights) + biases
Wx_plus_b = tf.nn.dropout(Wx_plus_b, keep_prob) # if activation_function is None:
outputs = Wx_plus_b
else:
outputs = activation_function(Wx_plus_b) tf.summary.histogram(layer_name + '/outputs', outputs)
return outputs #
# define placeholder for inputs to network
#
keep_prob = tf.placeholder(tf.float32) #
xs = tf.placeholder(tf.float32, [None, 64]) # 8x8
ys = tf.placeholder(tf.float32, [None, 10]) #
# add output layer
#
l1 = add_layer(xs, 64, 50, 'l1', activation_function = tf.nn.tanh)
prediction = add_layer(l1, 50, 10, 'l2', activation_function = tf.nn.softmax) #
# the error between prediction and real data
#
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction),
reduction_indices=[1])) #loss
tf.summary.scalar('loss', cross_entropy)
train_step = tf.train.GradientDescentOptimizer(0.6).minimize(cross_entropy) sess = tf.Session()
merged = tf.summary.merge_all() #summary writer goes here
train_writer = tf.summary.FileWriter("logs/train", sess.graph)
test_writer = tf.summary.FileWriter("logs/test", sess.graph) sess.run(tf.global_variables_initializer()) for i in range(500):
#sess.run(train_step, feed_dict={xs:X_train, ys:y_train, keep_prob:1.0}) # overfitted
sess.run(train_step, feed_dict={xs:X_train, ys:y_train, keep_prob:0.5}) # keep 0.5, drop 0.5
if i% 50 == 0:
#record loss
train_result = sess.run(merged, feed_dict={xs:X_train, ys:y_train, keep_prob:1})
test_result = sess.run(merged, feed_dict={xs:X_test, ys:y_test, keep_prob:1})
train_writer.add_summary(train_result, i)
test_writer.add_summary(test_result, i)

  

最新文章

  1. asp.net identity 2.2.0 中角色启用和基本使用(五)
  2. Spring No mapping found for HTTP request with URI错误
  3. 连接Oracle错误:800a0e7a未找到提供程序的解决
  4. 深入浅出 nginx lua 为什么高性能
  5. Mysql 5.6.17-win64.zip配置
  6. 自定义的dialog
  7. 【Properties文件】Java使用Properties来读取配置文件
  8. Leetcode048. Rotate Image
  9. FreeBSD 路由详解
  10. poj 3792 Area of Polycubes
  11. poj1651 最优矩阵乘法动态规划解题
  12. 【VBA研究】查找目录以下全部文件的名称
  13. ASP.NET Core MVC压缩样式、脚本及总是复制文件到输出目录
  14. 使用Windows Server 2012+ 搭建VPN 简单 高效 稳定
  15. java中a=a+1和a+=1的区别
  16. Scratch 2.0-Find The Mouse 发布!
  17. Linux上安装Oracle的辛酸史
  18. [hadoop] hadoop 运行 wordcount
  19. 求[1,n]中与m互素的个数
  20. Swift5 语言指南(二十四) 泛型

热门文章

  1. ICCV2019《KPConv: Flexible and Deformable Convolution for Point Clouds》
  2. 提高python运行效率-numba
  3. [PHP] 破Laravel白屏问题
  4. Vue 事件的基本使用与语法差异
  5. H5视频、音频不能自动播放,Uncaught (in promise) DOMException: play() failed because the user didn't
  6. Paper | No-reference Quality Assessment of Deblocked Images
  7. Unity C# CSV文件解析与加载(已更新移动端处理方式)
  8. Gin框架 - 使用 Logrus 进行日志记录
  9. Java & PHP RSA 互通密钥、签名、验签、加密、解密
  10. Spring Cloud Eureka 服务注册中心(二)