在进行多卡训练的时候,经常会出现GPU利用率上不来的情况,无法发挥硬件的最大实力。 造成这种现象最有可能的原因是,CPU生成数据的能力,已经跟不上GPU处理数据的能力。


方法一


常见的方法为修改Dataloader里面的线程数量,利用多线程技术提高数据生产能力,但是这种方法提速并不是特别明显。

train_loader = DataLoader(dataset, batch_size,shuffle=True, num_worker=4)

而且windows机器上,num_worker大于0时,有时会出现卡死的情况,这应该是pytorch的bug,因此不是特别建议这种方法。

不过这种方法最简单,还是可以尝试一下更改线程数能否缓解你遇到的问题。nun_worker一般设置为处理器的物理线程数,不宜过大,因为会导致额外的线程开销。

方法二


本文主要介绍第二种方法,也就是Data Prefetcher,最早见于NVIDIA APEX

这里我把代码抠出来了,删除掉了一些不必要的注释,可以将其复用到自己的项目里来。

import torch

class data_prefetcher():
def __init__(self, loader):
self.loader = iter(loader)
self.stream = torch.cuda.Stream()
self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
self.preload() def preload(self):
try:
self.next_input, self.next_target = next(self.loader)
except StopIteration:
self.next_input = None
self.next_target = None
return with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(non_blocking=True)
self.next_target = self.next_target.cuda(non_blocking=True)
self.next_input = self.next_input.float()
self.next_input = self.next_input.sub_(self.mean).div_(self.std) def next(self):
torch.cuda.current_stream().wait_stream(self.stream)
input = self.next_input
target = self.next_target
if input is not None:
input.record_stream(torch.cuda.current_stream())
if target is not None:
target.record_stream(torch.cuda.current_stream())
self.preload()
return input, target

首先我们来看初始化函数,在初始化函数中,会直接调用preload,所以当这个对象初始化时,就会生成第一份的输入数据。

核心逻辑也就在预加载函数preload中,其中第13行是从原来的dataloader中取数,这一步和常规数据加载没有差别。有差别的是第19行,这里出现了Stream的概念。

一般来说,CUDA程序默认都运行在同一个Stream上,因此CPU->GPU,GPU->GPU以及GPU->CPU的一系列计算都是在同一个Stream里面串行运行的。 深度学习一般流程是先从dataloader中取数,这里是内存->CPU的运算,然后执行to_device操作,让数据从CPU->GPU,再是GPU->GPU的神经网络计算。

代码19行,使得data_prefetecher这个类是单独运行在一个Stream上的,因此它让数据加载和神经网络计算可以并行执行,也就加速了整体的运行速度。这样做带来的负面结果就是GPU同时在做两项任务,所以显存占用会增加。

这里不知道解释清楚没有,建议去看一下原作者的回答link

另外,重要的是,使用这个方法的时候一定要将Dataloader里面的pin_memory设置为True。

使用方法如下,非常简单,改造前是从dataloader里取数,改造后是将dataloader包在prefetecher里面,从prefetecher里面取数。

train_loader = DataLoader(dataset, batch_size,shuffle=True, num_worker=4,pin_memory=True)
prefetcher = data_prefetcher(train_loader)
input, target = prefetcher.next() while input is not None:
##
前后向计算...
###
input, target = prefetcher.next()

最新文章

  1. post请求报文
  2. TinyFrame再续篇:整合Spring AOP实现日志拦截
  3. 【JavaScript】直接拿来用!最火的前端开源项目(一)
  4. Oracle 递归查询
  5. TextView设置样式的3种方式
  6. encodeURI和encodeURIComponent的比较
  7. defgen工具
  8. Eclipse perl的IDE环境插件-EPIC
  9. 解析Excel文件并把数据存入数据库
  10. python基础教程(二)
  11. [亲测]ASP.NET Core 2.0怎么发布/部署到Ubuntu Linux服务器并配置Nginx反向代理实现域名访问
  12. Spring MVC CORS 跨域
  13. strtok函数读写冲突问题
  14. 7. mybatis:mapper-locations: 路径放在java路径下报错:org.apache.ibatis.binding.BindingException: Invalid bound statement (not found)
  15. android--------实现Activity和Fragment通信的面向对象的万能接口
  16. Maven运行的方式
  17. shell 命令 netstat 查看端口占用
  18. [Deep-Learning-with-Python]机器学习基础
  19. junit中test注解测试使用案列解析一
  20. widows终端远程连接Linux服务器

热门文章

  1. Vite2+Vue3+ts的eslint设置踩坑
  2. python 处理网络帧时,CRC算法中整数按位取反运算(~)得到负数的规避方法
  3. Java基础语法Day_06(面相对象和封装)
  4. javaScript中Math内置对象基本方法入门
  5. 斯坦福NLP课程 | 第2讲 - 词向量进阶
  6. CSS躬行记(11)——管理后台响应式改造
  7. 小数据池,is和==的区别,id()
  8. linux下nginx软件的学习
  9. 690. Employee Importance - LeetCode
  10. 482. License Key Formatting - LeetCode