~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

转载请注明出处:

http://www.cnblogs.com/darkknightzh/p/8297793.html

参考网址:

http://pytorch.org/docs/0.3.0/nn.html?highlight=kaiming#torch.nn.init.kaiming_normal

https://github.com/prlz77/ResNeXt.pytorch/blob/master/models/model.py

https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua

https://github.com/bamos/densenet.pytorch/blob/master/densenet.py

https://github.com/szagoruyko/wide-residual-networks/blob/master/models/utils.lua

说明:暂时就这么多吧,错误之处请见谅。前两个初始化的方法见pytorch官方文档

1. xavier初始化

torch.nn.init.xavier_uniform(tensor, gain=1)

对于输入的tensor或者变量,通过论文Understanding the difficulty of training deep feedforward neural networks” - Glorot, X. & Bengio, Y. (2010)的方法初始化数据。初始化服从均匀分布U(−a,a)" role="presentation" style="position: relative;">U(−a,a)U(−a,a),其中a=gain×2/(fan_in+fan_out)×3" role="presentation" style="position: relative;">a=gain×2/(fan_in+fan_out)−−−−−−−−−−−−−−−−−−√×3–√a=gain×2/(fan_in+fan_out)×3,该初始化方法也称Glorot initialisation。

参数:

tensor:n维的 torch.Tensor 或者 autograd.Variable类型的数据

a:可选择的缩放参数

例如:

w = torch.Tensor(3, 5)
nn.init.xavier_uniform(w, gain=nn.init.calculate_gain('relu'))

torch.nn.init.xavier_normal(tensor, gain=1)

对于输入的tensor或者变量,通过论文Understanding the difficulty of training deep feedforward neural networks” - Glorot, X. & Bengio, Y. (2010)的方法初始化数据。初始化服从高斯分布N(0,std)" role="presentation" style="position: relative;">N(0,std)N(0,std),其中std=gain×2/(fan_in+fan_out)" role="presentation" style="position: relative;">std=gain×2/(fan_in+fan_out)−−−−−−−−−−−−−−−−−−√std=gain×2/(fan_in+fan_out),该初始化方法也称Glorot initialisation。

参数:

tensor:n维的 torch.Tensor 或者 autograd.Variable类型的数据

a:可选择的缩放参数

例如:

w = torch.Tensor(3, 5)
nn.init.xavier_normal(w)

2. kaiming初始化

torch.nn.init.kaiming_uniform(tensor, a=0, mode='fan_in')

对于输入的tensor或者变量,通过论文“Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification” - He, K. et al. (2015)的方法初始化数据。初始化服从均匀分布U(−bound,bound)" role="presentation" style="position: relative;">U(−bound,bound)U(−bound,bound),其中bound=2/((1+a2)×fan_in)×3" role="presentation" style="position: relative;">bound=2/((1+a2)×fan_in)−−−−−−−−−−−−−−−−−−√×3–√bound=2/((1+a2)×fan_in)×3,该初始化方法也称He initialisation。

参数:

tensor:n维的 torch.Tensor 或者 autograd.Variable类型的数据

a:该层后面一层的激活函数中负的斜率(默认为ReLU,此时a=0)

mode:‘fan_in’ (default) 或者 ‘fan_out’. 使用fan_in保持weights的方差在前向传播中不变;使用fan_out保持weights的方差在反向传播中不变。

例如:

w = torch.Tensor(3, 5)
nn.init.kaiming_uniform(w, mode='fan_in')

torch.nn.init.kaiming_normal(tensor, a=0, mode='fan_in')

对于输入的tensor或者变量,通过论文“Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification” - He, K. et al. (2015)的方法初始化数据。初始化服从高斯分布N(0,std)" role="presentation" style="position: relative;">N(0,std)N(0,std),其中std=2/((1+a2)×fan_in)" role="presentation" style="position: relative;">std=2/((1+a2)×fan_in)−−−−−−−−−−−−−−−−−−√std=2/((1+a2)×fan_in),该初始化方法也称He initialisation。

参数:

tensor:n维的 torch.Tensor 或者 autograd.Variable类型的数据

a:该层后面一层的激活函数中负的斜率(默认为ReLU,此时a=0)

mode:‘fan_in’ (default) 或者 ‘fan_out’. 使用fan_in保持weights的方差在前向传播中不变;使用fan_out保持weights的方差在反向传播中不变。

例如:

w = torch.Tensor(3, 5)
nn.init.kaiming_normal(w, mode='fan_out')

使用的例子(具体参见原始网址):

https://github.com/prlz77/ResNeXt.pytorch/blob/master/models/model.py

from torch.nn import init
self.classifier = nn.Linear(self.stages[3], nlabels)
init.kaiming_normal(self.classifier.weight)
for key in self.state_dict():
if key.split('.')[-1] == 'weight':
if 'conv' in key:
init.kaiming_normal(self.state_dict()[key], mode='fan_out')
if 'bn' in key:
self.state_dict()[key][...] = 1
elif key.split('.')[-1] == 'bias':
self.state_dict()[key][...] = 0

3. 实际使用中看到的初始化

3.1 ResNeXt,densenet中初始化

https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua

https://github.com/bamos/densenet.pytorch/blob/master/densenet.py

conv

n = kW* kH*nOutputPlane
weight:normal(,math.sqrt(/n))
bias:zero()

batchnorm

weight:fill()
bias:zero()

linear

bias:zero()

3.2 wide-residual-networks中初始化(MSRinit

https://github.com/szagoruyko/wide-residual-networks/blob/master/models/utils.lua

conv

n = kW* kH*nInputPlane
weight:normal(,math.sqrt(/n))
bias:zero()

linear

bias:zero()

最新文章

  1. 我这样理解js里的this
  2. 跨域攻击xss
  3. UI Automator Viewer获取手机镜像时报错
  4. NetStatusEvent info对象的状态或错误情况的属性
  5. editplus如何设置不自动备份
  6. 网易DBA私享会分享会笔记1
  7. 【转】CentOS 6.3(x86_64)下安装Oracle 10g R2
  8. 下载网易云VIP音乐
  9. Git-删除文件后找回-比较文件差异
  10. windows平台在tomcat中启动cas报错解决
  11. c++Builder debug DataSet Visualizer
  12. div模态显示内容
  13. HTTP常见的Post请求
  14. java学习 猜数字
  15. Java中Integer类的方法和request的setAttribute方法的使用与理解
  16. Delphi 7学习开发控件(续)
  17. 【OK210试用体验】进阶篇(1)视频图像采集之MJPG-streamer编译(Ubuntu系统下)
  18. Drools学习笔记2—Conditions / LHS 匹配模式&条件元素
  19. macro expand error
  20. BZOJ3878: [Ahoi2014&Jsoi2014]奇怪的计算器

热门文章

  1. qt and redis desktop manager
  2. lsof,fuser,xargs,print0,cut,paste,cat,tac,rev,exec,{},双引号,单引号,‘(字符串中执行命令)
  3. Spark源码分析 – BlockManager
  4. J.U.C Atomic(二)基本类型原子操作
  5. odoo学习:创建新数据库及修改数据库内容
  6. 创建Java不可变类
  7. PAT 1136 A Delayed Palindrome[简单]
  8. 转载SQL_trace 和10046使用
  9. 随机深林和GBDT
  10. Python 函数的使用小结