使用Tensorflow中的神经网络来拟合函数(y = x ^ 3 + 0.7)

# -*- coding:utf-8 -*-
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt #训练数据
x_data = np.linspace(-6.0,6.0,30)[:,np.newaxis]
y_data = np.power(x_data,3) + 0.7
#验证数据
t_data = np.linspace(-20.0,20.0,40)[:,np.newaxis]
ty_data = np.power(t_data,3) + 0.7
#占位符
x = tf.placeholder(tf.float32,[None,1])
y = tf.placeholder(tf.float32,[None,1]) #network
#--layer one--
l_w_1 = tf.Variable(tf.random_normal([1,10]))
l_b_1 = tf.Variable(tf.zeros([1,10]))
l_fcn_1 = tf.matmul(x, l_w_1) + l_b_1
relu_1 = tf.nn.relu(l_fcn_1)
#---layer two----
l_w_2 = tf.Variable(tf.random_normal([10,20]))
l_b_2 = tf.Variable(tf.zeros([1,20]))
l_fcn_2 = tf.matmul(relu_1, l_w_2) + l_b_2
relu_2 = tf.nn.relu(l_fcn_2) #---output---
l_w_3 = tf.Variable(tf.random_normal([20,1]))
l_b_3 = tf.Variable(tf.zeros([1,1]))
l_fcn_3 = tf.matmul(relu_2, l_w_3) + l_b_3
#relu_3 = tf.tanh(l_fcn_3)
# init
init = tf.global_variables_initializer()
#定义 loss func
loss = tf.reduce_mean(tf.square(y-l_fcn_3))
learn_rate =0.001
train_step = tf.train.GradientDescentOptimizer(learn_rate).minimize(loss) with tf.Session() as sess:
sess.run(init);
for epoch in range(20):
for step in range(5000):
sess.run(train_step,feed_dict={x:x_data,y:y_data})
y_pred = sess.run(l_fcn_3,feed_dict={x:t_data})
print sess.run(l_fcn_3,feed_dict={x:[[10.]]})
plt.figure()
plt.scatter(t_data,ty_data)
plt.plot(t_data,y_pred,'r-')
plt.show()

实验结果

最新文章

  1. 【Alpha】Daily Scrum Meeting第三次
  2. 总结一些关于操作数据库是sql语句还是存储过程问题
  3. SQL范式小结
  4. LoRaWAN协议(一)--架构解析
  5. ImageButton如何让图片按比例缩放不被拉伸
  6. 基于DOM的XSS注入漏洞简单解析
  7. 【LA3523】 Knights of the Round Table (点双连通分量+染色问题?)
  8. Linux系统编程(30)—— socket编程之TCP/IP协议
  9. TEA加密
  10. Linq实现t-Sql的各种连接
  11. 第14天dbutils与案例
  12. 如何在windows系统下安装swoole
  13. QQ音乐的动效歌词是如何实践的?
  14. (转)[Python 网络编程] makefile (三)
  15. Redis自学笔记:3.1入门-热身
  16. MT【36】反函数有关的一道题
  17. MXNet官方文档中文版教程(3):神经网络图(Symbol)
  18. [HAOI 2010]订货
  19. IE8及以下的数组处理与其它浏览器的不同
  20. Xcode The operation couldn’t be completed. (NSURLErrorDomain error -1012.)

热门文章

  1. python学习之platform模块
  2. 微信中调起qq
  3. 安全DNS
  4. linux学习笔记30--网络命令ifconfig
  5. linux学习笔记7---命令cp
  6. jQuery 实战读书笔记之第四章:使用特性、属性和数据
  7. Unity3D避免代码被反编译
  8. tomcat 内存溢出原因分析及解决
  9. Windows的静态库使用步骤
  10. oracle 里 插入空字符串会被转成null插入