这一节使用TF搭建一个简单的神经网络用于回归预测,首先随机生成一组数据

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
tf.set_random_seed(42)
np.random.seed(42)
x = np.linspace(-1,1,100)[:,np.newaxis] #<==>x=x.reshape(100,1)
noise = np.random.normal(0,0.1,size = x.shape)
y=np.power(x,2) + x +noise #y=x^2 + x+噪音
plt.scatter(x,y)
plt.show()

随机生成了一组数据,模型为\(y=x^2+x\),看一下数据的分布

接下来搭建一个含有一个隐藏层的神经网络,损失选择使用均方差误差

#模型部分
tf_X = tf.placeholder(tf.float32,x.shape) #=>X
tf_y = tf.placeholder(tf.float32,y.shape) #=>y output = tf.layers.dense(tf_X,10,tf.nn.relu,name="hidden")#隐藏层10个节点
output = tf.layers.dense(output,1,name='output') #1个输出层
#loss = tf.losses.mean_squared_error(tf_y,output)
loss = tf.reduce_mean(tf.sqrt(tf.pow(tf_y-output,2)))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.2)
train_op = optimizer.minimize(loss)

其中tf.losses中提供了常用的损失函数实现,也可以自己去实现,开始训练模型

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
plt.ion()
for step in range(100):
_,err,pred = sess.run([train_op,loss,output],feed_dict={tf_X:x,tf_y:y})
#cla() # Clear axis
#clf() # Clear figure
#close() # Close a figure window
plt.cla()#
plt.scatter(x,y)
plt.plot(x,pred,'r-',lw=5)
plt.text(0.5, 0, 'Loss=%.4f' % err, fontdict={'size': 20, 'color': 'red'})
#plt.show()
plt.ioff()
plt.show()

看一看效果:

note:上面使用了plt.cla方法,这是由于方便看到变化过程,将plot过程写入到了for循环中,为了避免发生意外错误将对象从内存中清空。

最新文章

  1. (AS3)关于arguments
  2. (转)SVN 服务端、客户端安装及配置、导入导出项目
  3. IDT hook KiTrap03
  4. phpstorm8 设置及license key
  5. RecordWriter接口解析
  6. 3.12php
  7. DC-DC芯片 同步和異步方式有什么區別
  8. POI读写Excel-操作包含合并单元格操作
  9. Redis集群教程(Redis cluster tutorial)
  10. Android版数据结构与算法(八):二叉排序树
  11. python之路-数据运算
  12. 线程变量---ThreadLocal类
  13. k8s中的api server的ca证书,可以和front proxy ca证书一样么?
  14. Linux vfpd锁定用户目录
  15. js如何将选中图片文件转换成Base64字符串?
  16. html 简单的table样式
  17. hadoop ha环境下的datanode启动报错java.lang.NumberFormatException: For input string: &quot;10m&quot;
  18. javascript 计算文件MD5 浏览器 javascript读取文件内容
  19. javascript 判断数据类型
  20. centos6.5 宽带连接

热门文章

  1. Storm日志分析调研及其实时架构
  2. TCP协议的三次握手和四次分手
  3. Python基础学习参考(七):字典和集合
  4. R+NLP︱text2vec包——BOW词袋模型做监督式情感标注案例(二,情感标注)
  5. Alibaba阿里巴巴开源软件列表
  6. APACHE服务器出现No input file specified.的完美解决方案
  7. 安装Android的SDK
  8. [mysql] 2进制安装和简单优化
  9. ASP.NET 初识Cookie
  10. Python基础__字典、集合、运算符