tensorflow学习笔记(3)前置数学知识

首先是神经元的模型

接下来是激励函数

神经网络的复杂度计算

层数:隐藏层+输出层

总参数=总的w+b

下图为2层

如下图

w为3*4+4个   b为4*2+2

接下来是损失函数

主流的有均分误差,交叉熵,以及自定义

这里贴上课程里面的代码

# -*- coding: utf-8 -*-
"""
Created on Sat May 26 18:42:08 2018 @author: Administrator
""" import tensorflow as tf
import numpy as np
BATCH_SIZE=8
seed=23455 #基于seed产生随机数
rdm=np.random.RandomState(seed)
#初始化特征值为32个样本*2个特征值
#初始化标签
X=rdm.rand(32,2)
Y_=[[x1+x2+(rdm.rand()/10.0-0.05)] for (x1,x2) in X] #定义输入,参数和输出和传播过程
x=tf.placeholder(tf.float32,shape=(None,2))
y_=tf.placeholder(tf.float32,shape=(None,1))
w1=tf.Variable(tf.random_normal([2,1],stddev=1,seed=1))
y=tf.matmul(x,w1) #定义损失函数以及反向传播方法
loss_mse=tf.reduce_mean(tf.square(y_-y))
train_step=tf.train.GradientDescentOptimizer(0.01).minimize(loss_mse) #会话训练
with tf.Session() as sess:
init_op=tf.global_variables_initializer()
sess.run(init_op)
STEPS=20000
for i in range(STEPS):
start=(i*BATCH_SIZE)%32
end=(i*BATCH_SIZE)%32+BATCH_SIZE
#每次训练抽取start到end的数据
sess.run(train_step,feed_dict={x:X[start:end],y_:Y_[start:end]})
#每500次打印一次参数
if i%500==0:
print("在%d次迭代后,参数为"%(i))
print(sess.run(w1))
#输出训练后的参数
print("\n")
print("FINAL w1 is:",sess.run(w1))

自定义损失函数

loss=tf.reduce_sum(tf.where(tf.greater(y,y_),COST(y-y_),PROFIT(y_-y)))

中间的where是判断y是否大于y_

如图

最新文章

  1. 带有runat="server" 的服务器控件通过 ClientID 获取Id
  2. LINQ系列:Linq to Object联接操作符
  3. JavaWeb开发环境准备之Linux篇
  4. LAMP-五分钟搭建个人论坛
  5. poj3308Paratroopers(dinic)
  6. JavaScript 数组方法总结
  7. SQLSERVER2000使用TSQL将数据导入ACCESS并压缩生成rar
  8. Maven可继承的POM 元素
  9. C语言字符转换ASCII码
  10. boost::thread之while(true)型线程终结方法
  11. SpringData JPA的学习笔记之环境搭建
  12. lua本学习笔记功能
  13. HTML学习总结(四)【canvas绘图、WebGL、SVG】
  14. Core Animation中的组动画
  15. OS模块的常用内置方法
  16. Git帮助之初始化项目设置向导
  17. HTML- 标签语法
  18. 安装mono和monoDevelop开发环境
  19. eclipse 如何安装freemaker ftl 插件
  20. day15 Python函数递归,轻易不要用递归,容易搞出来内存溢出

热门文章

  1. React Native从零到一搭建开发环境
  2. 【读书笔记】The Swift Programming Language (Swift 4.0.3)
  3. 『ACM C++』HDU杭电OJ | 1416 - Gizilch (DFS - 深度优先搜索入门)
  4. git 对文件大小写修改无反应 不敏感解决办法
  5. SQL优化例子
  6. python代理爬取存入csv文件
  7. 03以太网帧结构(链路层 IEEE802.3)
  8. 为什么我要放弃javaScript数据结构与算法(第八章)—— 树
  9. linux signal函数遇到的问题
  10. Java设计模式(18)——行为模式之迭代子模式(Iterator)