原文地址:

https://blog.csdn.net/weixin_40759186/article/details/87547795

---------------------------------------------------------------------------------------------------------------

用pytorch做dropout和BN时需要注意的地方

pytorch做dropout:

就是train的时候使用dropout,训练的时候不使用dropout,
pytorch里面是通过net.eval()固定整个网络参数,包括不会更新一些前向的参数,没有dropout,BN参数固定,理论上对所有的validation set都要使用net.eval()
net.train()表示会纳入梯度的计算。

net_dropped = torch.nn.Sequential(
torch.nn.Linear(1, N_HIDDEN),
torch.nn.Dropout(0.5), # drop 50% of the neuron
torch.nn.ReLU(),
torch.nn.Linear(N_HIDDEN, N_HIDDEN),
torch.nn.Dropout(0.5), # drop 50% of the neuron
torch.nn.ReLU(),
torch.nn.Linear(N_HIDDEN, 1),
)

for t in range(500):
pred_drop = net_dropped(x)
loss_drop = loss_func(pred_drop, y) optimizer_drop.zero_grad()
loss_drop.backward()
optimizer_drop.step() if t % 10 == 0:
# change to eval mode in order to fix drop out effect
net_dropped.eval() # parameters for dropout differ from train mode test_pred_drop = net_dropped(test_x) # change back to train mode
net_dropped.train()

pytorch做Batch Normalization:

net.eval()固定整个网络参数,固定BN的参数,moving_mean 和moving_var,不懂这个看下图:

            if self.do_bn:
bn = nn.BatchNorm1d(10, momentum=0.5)
setattr(self, 'bn%i' % i, bn) # IMPORTANT set layer to the Module
self.bns.append(bn) for epoch in range(EPOCH):
print('Epoch: ', epoch)
for net, l in zip(nets, losses):
net.eval() # set eval mode to fix moving_mean and moving_var
pred, layer_input, pre_act = net(test_x) net.train() # free moving_mean and moving_var
plot_histogram(*layer_inputs, *pre_acts)

moving_mean   和   moving_var

用tensorflow做dropout和BN时需要注意的地方

dropout和BN都有一个training的参数表明到底是train还是test, 表明test那dropout就是不dropout,BN就是固定住了BN的参数;

tf_is_training = tf.placeholder(tf.bool, None)  # to control dropout when training and testing

# dropout net
d1 = tf.layers.dense(tf_x, N_HIDDEN, tf.nn.relu)
d1 = tf.layers.dropout(d1, rate=0.5, training=tf_is_training) # drop out 50% of inputs

d2 = tf.layers.dense(d1, N_HIDDEN, tf.nn.relu)
d2 = tf.layers.dropout(d2, rate=0.5, training=tf_is_training) # drop out 50% of inputs

d_out = tf.layers.dense(d2, 1) for t in range(500):
sess.run([o_train, d_train], {tf_x: x, tf_y: y, tf_is_training: True}) # train, set is_training=True if t % 10 == 0:
# plotting
plt.cla()
o_loss_, d_loss_, o_out_, d_out_ = sess.run(
[o_loss, d_loss, o_out, d_out], {tf_x: test_x, tf_y: test_y, tf_is_training: False} # test, set is_training=False
)
    def add_layer(self, x, out_size, ac=None):
x = tf.layers.dense(x, out_size, kernel_initializer=self.w_init, bias_initializer=B_INIT)
self.pre_activation.append(x)
# the momentum plays important rule. the default 0.99 is too high in this case!
if self.is_bn: x = tf.layers.batch_normalization(x, momentum=0.4, training=tf_is_train) # when have BN
out = x if ac is None else ac(x)
return out
 

当BN的training的参数为train时,只是表示BN的参数是可变化的,并不是代表BN会自己更新moving_mean 和moving_var,因为这个操作是前向更新的op,在做train之前必须确保moving_mean 和moving_var更新了,更新moving_mean 和moving_var的操作在tf.GraphKeys.UPDATE_OPS

        # !! IMPORTANT !! the moving_mean and moving_variance need to be updated,
# pass the update_ops with control_dependencies to the train_op
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
self.train = tf.train.AdamOptimizer(LR).minimize(self.loss)

最新文章

  1. Java 日志性能优化
  2. VBA在WORD中给表格外的字体设置为标题
  3. HDU 2602 (简单的01背包) Bone Collector
  4. How to use For loop in CruiseControl.net
  5. JAVA首选五款开源Web开发框架
  6. Swift - 使用storyboard创建表格视图(TableViewController)
  7. JSF教程(9)——生命周期之Process Validations Phase
  8. excel导入到Orcle
  9. [置顶] 一个demo学会c#
  10. c#实战开发:以太坊Geth 命令发布智能合约 (五)
  11. 灵雀云受邀加入VMware 创新网络,共同助力企业数字化进程
  12. MySql数据库实现分布式的主从结构
  13. MySQL 在Windows平台上的安装及实例多开
  14. NOIp2018爆零记
  15. python文件操作r+,w+,a+,rb+,
  16. GO语言-基础语法:条件判断
  17. php 对象转字符串
  18. 交叉编译sudo
  19. 489. Robot Room Cleaner扫地机器人
  20. Django在Win7下安装与创建项目hello word示例

热门文章

  1. C# 3.0 / C# 3.5 隐式(推断)类型 var
  2. 基本数据类型list,tuple
  3. Visual Studio编译时报错“函数名:重定义;不同的基类型”
  4. 解决iOS第三方SDK之间重复的symbols问题
  5. 再谈数据库优化(database tuning)的真谛和误区
  6. Spring Boot + Spring Cloud 实现权限管理系统 后端篇(十一):集成 Shiro 框架
  7. TinyXML C++解析XML
  8. powerdesidgner1
  9. vim : Depends: vim-common (= 2:7.4.052-1ubuntu3.1) but 2:7.4.273-2ubuntu4 is to be installed
  10. 平面图转对偶图&19_03_21校内训练 [Everfeel]