pytorch中 all_gather 操作是不进行梯度回传的。在计算图构建中如果需要经过all_gather操作后,仍需要将梯度回传给各个进程中的allgather前的对应变量,则需要重新继承torch.autograd.Function

https://pytorch.org/docs/stable/autograd.html 中对torch.autograd.Function进行了介绍

https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd 中举例介绍如何重新实现其子类

 下面代码是为了说明all_gather相关特性及如何实现梯度回传.

\(x,y,z\)都是2x2矩阵,其之间关系为\(y=x+2, z=y*y\)

接下来就需要MPI进行进程间数据传递,将z进行汇总到每个进程即all_gather操作。然后将汇总的矩阵进行相乘,然后求均值。

r对y的导数如下:

\(r=0.25({}_{g_0}y_{11}^2*{}_{g_1}y_{11}^2+{}_{g_0}y_{12}^2*{}_{g_1}y_{12}^2+
{}_{g_0}y_{21}^2*{}_{g_1}y_{21}^2+
{}_{g_0}y_{22}^2*{}_{g_1}y_{22}^2)\)

\(\frac{dr}{d{}_{g_0}y}=
\begin{Bmatrix}
0.5{}_{g_0}y_{11}*{}_{g_1}y_{11}^2 & 0.5{}_{g_0}y_{12}*{}_{g_1}y_{12}^2 \\
0.5{}_{g_0}y_{21}*{}_{g_1}y_{21}^2 & 0.5{}_{g_0}y_{22}*{}_{g_1}y_{22}^2)
\end{Bmatrix}\)

gpu0上x值为\(\begin{Bmatrix} 1 & 1 \\1 & 1 \end{Bmatrix}\),gpu1上x值为\(\begin{Bmatrix} 0 & 0 \\0 & 0 \end{Bmatrix}\).通过公式可以计算出,r关于gpu0上的y的导数为\(\begin{Bmatrix}6 & 6 \\6 & 6\end{Bmatrix}\),r关于gpu1上的y的导数为\(\begin{Bmatrix}9 & 9 \\9 & 9\end{Bmatrix}\)

import os
import torch
from torch import nn
import sys
sys.path.append('./')
import torch.distributed as dist
from torch.autograd import Variable
from utils import GatherLayer def test():
#torch.manual_seed(0)
torch.backends.cudnn.deterministic=True
torch.backends.cudnn.benchmark=True
dist.init_process_group(backend="nccl", init_method="env://")
rank = dist.get_rank()
local_rank = int(os.environ.get('LOCAL_RANK', 0))
world_size = dist.get_world_size()
torch.cuda.set_device(local_rank)
print('world_size: {}, rank: {}, local_rank: {}'.format(world_size, rank, local_rank)) if local_rank == 0:
x = Variable(torch.ones(2, 2), requires_grad=True).cuda()
else:
x = Variable(torch.zeros(2, 2), requires_grad=True).cuda()
y = x + 2
y.retain_grad()
z = y * y z_gather = [torch.zeros_like(z) for _ in range(world_size)]
dist.all_gather(z_gather, z)
#z_gather = GatherLayer.apply(z)
r = z_gather[0] * z_gather[1] out = r.mean()
out.backward()
if local_rank == 0:
print('rank:0', y.grad)
else:
print('rank:1', y.grad)

(1)上述述代码中,先使用pytorch中提供的all_gather操作,运行代码会提示错误。错误信息如下:

Traceback (most recent call last):
File "test/test_all_gather.py", line 46, in <module>
Traceback (most recent call last):
File "test/test_all_gather.py", line 46, in <module>
test()
File "test/test_all_gather.py", line 36, in test
out.backward()
File "/usr/local/lib/python3.6/dist-packages/torch/tensor.py", line 185, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/usr/local/lib/python3.6/dist-packages/torch/autograd/__init__.py", line 127, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
test()

(2)参考https://github.com/Spijkervet/SimCLR/blob/master/simclr/modules/gather.py, 该函数就是继承torch.autograd.Function,实现了all_gather后,梯度也能回传。

上述代码,启用z_gather = GatherLayer.apply(z),就实现了梯度回传功能,打印对变量y的梯度

world_size: 2, rank: 0, local_rank: 0
world_size: 2, rank: 1, local_rank: 1
rank:0 tensor([[6., 6.],
[6., 6.]], device='cuda:0')
rank:1 tensor([[9., 9.],
[9., 9.]], device='cuda:1')

GatherLayer类实现如下:

class GatherLayer(torch.autograd.Function):
"""Gather tensors from all process, supporting backward propagation.""" @staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
output = [torch.zeros_like(input) for _ in range(dist.get_world_size())]
dist.all_gather(output, input)
return tuple(output) @staticmethod
def backward(ctx, *grads):
(input,) = ctx.saved_tensors
grad_out = torch.zeros_like(input)
grad_out[:] = grads[dist.get_rank()]
return grad_out

下面网址有关all gather梯度传播的讨论

https://discuss.pytorch.org/t/will-dist-all-gather-break-the-auto-gradient-graph/47350

最新文章

  1. shell及脚本2——shell 环境及命令
  2. ASP.NET安全
  3. Learning to rank 特征抽取
  4. mysql 连接超时解决
  5. 夺命雷公狗—angularjs—20—$watch监听的用法
  6. java 类型转化
  7. 展讯NAND Flash高级教程【转】
  8. Lua游戏脚本语言入门(一)
  9. MySQLdb的安装
  10. DotNetOpenAuth搭建OAuth2.0
  11. Azure AI 服务之文本翻译
  12. 20175305张天钰《java程序设计》第九周学习总结
  13. nfs 共享目录
  14. 贝叶斯分类器,随机森林,梯度下载森林,神经网络相关参数的意义和data leakage
  15. 【数论&amp;想法题】小C的问题 @&quot;科林明伦杯&quot;哈尔滨理工大学第八届程序设计竞赛
  16. 统计php-fpm内存占用
  17. .NET MVC同页面显示从不同数据库(mssql、mysql)的数据
  18. LINQ to Objects系列(2)两种查询语法介绍
  19. C++多线程同步之Mutex(互斥量)
  20. 1864. [ZJOI2006]三色二叉树【树形DP】

热门文章

  1. redis底层数据结构之跳表(skiplist)
  2. Workbench download Document
  3. Pytest之生成allure报告
  4. 03java基础(二)java面向对象
  5. Blog-3
  6. flutter SliverPersistentHeader子组件透明度渐变【滑动悬停appbar添加自定义组件的透明度】
  7. SpringMVC的学习day01
  8. H3C交换机基本操作
  9. Linux的top命令原理简单了解
  10. springcloud(七) - Sleuth链路追踪