pytorch 的 Variable 对象中有两个方法,detach和 detach_ 本文主要介绍这两个方法的效果和 能用这两个方法干什么。

detach

官方文档中,对这个方法是这么介绍的。

返回一个新的 从当前图中分离的 Variable。
返回的 Variable 永远不会需要梯度
如果 被 detach 的Variable volatile=True, 那么 detach 出来的 volatile 也为 True
还有一个注意事项,即:返回的 Variable 和 被 detach 的Variable 指向同一个 tensor
import torch
from torch.nn import init
from torch.autograd import Variable
t1 = torch.FloatTensor([1., 2.])
v1 = Variable(t1)
t2 = torch.FloatTensor([2., 3.])
v2 = Variable(t2)
v3 = v1 + v2
v3_detached = v3.detach()
v3_detached.data.add_(t1) # 修改了 v3_detached Variable中 tensor 的值
print(v3, v3_detached) # v3 中tensor 的值也会改变
1
2
3
4
5
6
7
8
9
10
11
# detach 的源码
def detach(self):
result = NoGrad()(self) # this is needed, because it merges version counters
result._grad_fn = None
return result
1
2
3
4
5
detach_

官网给的解释是:将 Variable 从创建它的 graph 中分离,把它作为叶子节点。

从源码中也可以看出这一点

将 Variable 的grad_fn 设置为 None,这样,BP 的时候,到这个 Variable 就找不到 它的 grad_fn,所以就不会再往后BP了。
将 requires_grad 设置为 False。这个感觉大可不必,但是既然源码中这么写了,如果有需要梯度的话可以再手动 将 requires_grad 设置为 true
# detach_ 的源码
def detach_(self):
"""Detaches the Variable from the graph that created it, making it a
leaf.
"""
self._grad_fn = None
self.requires_grad = False
1
2
3
4
5
6
7
能用来干啥

如果我们有两个网络 A,BA,B, 两个关系是这样的 y=A(x),z=B(y)y=A(x),z=B(y) 现在我们想用 z.backward()z.backward() 来为 BB 网络的参数来求梯度,但是又不想求 AA 网络参数的梯度。我们可以这样:

# y=A(x), z=B(y) 求B中参数的梯度,不求A中参数的梯度
# 第一种方法
y = A(x)
z = B(y.detach())
z.backward()

# 第二种方法
y = A(x)
y.detach_()
z = B(y)
z.backward()
1
2
3
4
5
6
7
8
9
10
11
在这种情况下,detach 和 detach_ 都可以用。但是如果 你也想用 yy 来对 AA 进行 BP 呢?那就只能用第一种方法了。因为 第二种方法 已经将 AA 模型的输出 给 detach(分离)了。
---------------------
作者:ke1th
来源:CSDN
原文:https://blog.csdn.net/u012436149/article/details/76714349
版权声明:本文为博主原创文章,转载请附上博文链接!

最新文章

  1. Swift函数的定义
  2. 傻瓜式操作Nagios
  3. JSONP跨域的原理解析( 一种脚本注入行为)
  4. wireshark 和 Httpwatch tcpdump
  5. [转]crontab命令指南
  6. SQL总结(六)触发器
  7. ADO.NET笔记——利用Command对象的ExecuteScalar()方法返回一个数据值
  8. sizeof的用法的一些归纳
  9. C#获取上个月的第一天零点和最后一天23点59分59秒
  10. Joel在耶鲁大学的演讲
  11. 全选js实现
  12. javaweb作業中的幾個要點
  13. Description has only two Sentences(欧拉定理 +快速幂+分解质因数)
  14. 使用Dreamweaver正则表达式替换href中的内容
  15. ipfs上传下载
  16. Ubuntu下常用指令
  17. Django Context对象 + 过滤器 + 标签
  18. mysql 报错You can't specify target table 'wms_cabinet_form' for update in FROM clause
  19. trap命令的实战用法
  20. Android 实现简单 倒计时60秒,一次1秒

热门文章

  1. 【转】浏览器中输入url后发生了什么
  2. 初探js闭包
  3. Android Debuggerd 简要介绍和源码分析(转载)
  4. Ubuntu 12.04 root默认密码? 如何使用root登录? (转载)
  5. PowerDesigner在PDM转换为sql脚本时报错Generation aborted due to errors detected during the verification of the mod
  6. docker学习教程
  7. Qt事件系统之五:事件过滤器和事件的发送
  8. Hdu 5358 First One (尺取法+枚举)
  9. STL之map基础知识
  10. 题解报告:NYOJ 题目143 第几是谁?(逆康托展开)