近来看batch normalization的代码时,遇到tf.train.ExponentialMovingAverage()函数,特此记录。

tf.train.ExponentialMovingAverage()函数实现滑动平均模型和计算变量的移动平均值。

TensorFlow官网上对于这个方法的介绍:

Some training algorithms, such as GradientDescent and Momentum often benefit from maintaining a moving average of variables during optimization. Using the moving averages for evaluations often improve results significantly.

一些训练算法,如梯度下降(GradientDescent)和动量(Momentum),经常受益于在优化过程中保持变量的移动平均。使用移动平均线进行评估通常会显著改善结果。

# 类,用于计算滑动平均
tf.train.ExponentialMovingAverage __init__(
decay,
num_updates=None,
zero_debias=False,
name='ExponentialMovingAverage')

decay是衰减率。在创建ExponentialMovingAverage对象时,需要指定衰减率(decay),用于控制模型的更新速度。影子变量的初始值与训练变量的初始值相同。当运行变量更新时,每个影子变量都会更新为:

  shadowvariable=decay∗shadowvariable+(1−decay)∗variable

num_updates是ExponentialMovingAverage提供用来动态设置decay的参数,当初始化时,函数提供了num_updates参数,即不为none时,每次的衰减率是:

apply()方法添加了训练变量的影子副本,并保持了其影子副本中训练变量的移动平均值操作。在每次训练之后调用此操作,更新移动平均值。average()和average_name()方法可以获取影子变量及其名称。

decay设置为接近1的值比较合理,通常为:0.999,0.9999等,decay越大模型越稳定,因为decay越大,参数更新的速度就越慢,趋于稳定。

官网中的示例:

# 创建variables.
var0 = tf.Variable(...)
var1 = tf.Variable(...)
# ... 使用variables去创建一个训练模型...
...
# 创建一个使用the optimizer对的op.
# 这是我们通常会使用作为一个training op.
opt_op = opt.minimize(my_loss, [var0, var1]) # 创建一个ExponentialMovingAverage object
ema = tf.train.ExponentialMovingAverage(decay=0.9999) # 创建the shadow variables,然后把ops加到maintain moving averages of var0 and var1.
maintain_averages_op = ema.apply([var0, var1]) # 创建一个op,在每次训练之后用来更新the moving averages.
# 用来代替the usual training op.
with tf.control_dependencies([opt_op]):
training_op = tf.group(maintain_averages_op)
# run这个op获取当前时刻 ema_value
get_var0_average_op = ema.average(var0)

例子:

import tensorflow as tf
import numpy as np v1 = tf.Variable(0, dtype=tf.float32)
step = tf.Variable(tf.constant(0)) ema = tf.train.ExponentialMovingAverage(0.99, step)
maintain_average = ema.apply([v1]) with tf.Session() as sess:
init = tf.initialize_all_variables()
sess.run(init)
print(sess.run([v1, ema.average(v1)])) #初始的值都为0 sess.run(tf.assign(v1, 5)) #把v1变为5
sess.run(maintain_average)
print(sess.run([v1, ema.average(v1)]))
# decay=min(0.99, 1/10)=0.1, v1=0.1*0+0.9*5=4.5 sess.run(tf.assign(step, 10000)) # steps=10000
sess.run(tf.assign(v1, 10)) # v1=10
sess.run(maintain_average)
print(sess.run([v1, ema.average(v1)]))
# decay=min(0.99,(1+10000)/(10+10000))=0.99, v1=0.99*4.5+0.01*10=4.555 sess.run(maintain_average)
print(sess.run([v1, ema.average(v1)]))
# decay=min(0.99,(1+10000)/(10+10000))=0.99, v1=0.99*4.555+0.01*10=4.60945
> [0.0, 0.0]
> [5.0, 4.5]
> [10.0, 4.555]
> [10.0, 4.60945]

每次更新完之后,影子变量(shadow_variable)的值就会更新,varible的值就是我们设定的值。如果在下一次运行这个函数的时候我们不再指定新的值,那varible的值就不变,影子变量更新。如果指定varible的值,那variable就改变为对应的指定值,相应的影子变量也改变。
原文链接:https://blog.csdn.net/tefuirnever/article/details/88902132

最新文章

  1. javascript 全局对象--w3school
  2. FastReport.Net 常用功能总汇
  3. Qt安装后配置环境变量(Mac)
  4. C#.NET数据库访问类DBHelper
  5. MIUI是小米的核心竞争力
  6. jquery radio取值,checkbox取值,select取值,radio选中,checkbox选中,select选中
  7. [Codeforces Round #237 (Div. 2)] A. Valera and X
  8. 【转】使用Boost Graph library(二)
  9. 回文质数 Prime Palindromes
  10. MO_GLOBAL - EBS R12 中 Multi Org 设计的深入研究 (2)
  11. MySQL DATE_FORMAT函数使用
  12. json对象和字符串的相互转换
  13. 录音--获取语音流(pyAudio)
  14. No.3 数组中重复的数字 (P39)
  15. ansible的模块使用说明
  16. ubuntu 18.04 使用 nvm 安装 nodejs
  17. des加密解密JAVA与.NET互通实例
  18. ACM:油田(Oil Deposits,UVa 572)
  19. Extending Markov to Hidden Markov
  20. call、apply、bind的异同

热门文章

  1. ArrayList去除集合中字符串的重复值
  2. Python之Numpy:线性代数/矩阵运算
  3. docker笔记、常遇问题、常用命令
  4. GitLab 架构
  5. Spark分区实例(teacher)
  6. SparkCore的性能优化
  7. bash: ./vmware-install.pl: /user/bin/perl: 坏的解释器:没有那个文件或目录
  8. Zebra架构与大数据架构优劣对比
  9. 【转帖】超能课堂(186) CPU中的那些指令集都有什么用?
  10. Python3学习笔记-更新中