用于深度学习的自动混合精度

深度神经网络训练传统上依赖IEEE单精度格式,但在混合精度的情况下,可以训练半精度,同时保持单精度网络的精度。这种同时使用单精度和半精度表示的技术称为混合精度技术。

​混合精度训练的好处

通过使用Tensor Core加速数学密集型运算,如线性和卷积层。

与单精度相比,通过访问一半的字节可以加快内存受限的操作。

减少训练模型的内存需求,支持更大的模型或更小的批。

启用混合精度涉及两个步骤:在适当的情况下,将模型移植到使用半精度数据类型;并使用损失缩放来保持较小的梯度值。

TensorFlow、PyTorch和MXNet中的自动混合精度特性为深度学习研究人员和工程师提供了在NVIDIA Volta和Turing gpu上最多3倍的人工智能训练速度,而只需要添加几行代码。

使用自动混合精度的主要深度学习框架

  • TensorFlow

在NVIDIA NGC容器注册表中提供的TensorFlow容器中提供了自动混合精度特性。要在容器内启用此功能,只需设置一个环境变量:

export TF_ENABLE_AUTO_MIXED_PRECISION=1

另外,环境变量可以在TensorFlow Python脚本中设置:

os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1'

另外还需要对优化器(Optimizer)作如下修改:

optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer) # 需要添加这句话,该例子是tf1.14.0版本,不同版本可能不一样

自动混合精度在TensorFlow内部应用这两个步骤,使用一个环境变量,并在必要时进行更细粒度的控制。

  • PyTorch

自动混合精度特性在GitHub上的Apex repository中可用。要启用,请将这两行代码添加到您现有的训练脚本中:

model, optimizer = amp.initialize(model, optimizer)

with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
  • MXNet

NVIDIA正在为MXNet构建自动混合精度特性。你可以在GitHub上找到正在进行的工作。要启用该功能,请在现有的训练脚本中添加以下代码行:

amp.init()
amp.init_trainer(trainer)
with amp.scale_loss(loss, trainer) as scaled_loss:
autograd.backward(scaled_loss)

微信公众号:AutoML机器学习

MARSGGBO♥原创

如有意合作或学术讨论欢迎私戳联系~
邮箱:marsggbo@foxmail.com




2020-01-23 17:45:35

最新文章

  1. Git 小技巧
  2. 【翻译十七】java-并发之高性能对象
  3. WPF--Blend制作Button控件模板--问题补充
  4. poj 1195 mobile phone
  5. ubuntukylin提取root权限及mongoDB部署
  6. Error:(6, 0) No such property: outputDir for class: org.gradle.api.internal.project.DefaultProject_Decorated
  7. 在运行时切换 WinForm 程序的界面语言 ---------多语言设置基础
  8. mysql技术内幕InnoDB存储引擎-阅读笔记
  9. Linux下软件的卸载
  10. 【BZOJ3992】序列统计(动态规划,NTT)
  11. 有关Linux ipv6模块加载失败的问题
  12. 四、activiti工作流-第一个HelloWorld
  13. JavaScript基础知识梳理,你能回答几道题?
  14. Mysql 8.0修改密码
  15. Delphi.XE2破解方法
  16. 关于JavaScript和Java的区别和联系
  17. 4月24 php基础及函数的应用
  18. Jquery Ajax 返回数据类型变成document
  19. Java -- 异常的捕获及处理 -- 范例 -- throw与throws的应用
  20. Java不定参数

热门文章

  1. Java集合详解2:一文读懂Queue和LinkedList
  2. PB级数据实时查询,滴滴Elasticsearch多集群架构实践
  3. .NET 微服务 1. Docker 容器简介和选择
  4. 读《PMI 分析手册》
  5. apt-get命令使用
  6. 网络基础 ----------- osi 与 一些协议
  7. [转帖]Redis持久化--Redis宕机或者出现意外删库导致数据丢失--解决方案
  8. day59——orm单表操作
  9. Scala 函数基础入门
  10. java中String字符串