两行代码统计模型参数量与FLOPs,这个PyTorch小工具值得一试
你的模型到底有多少参数,每秒的浮点运算到底有多少,这些你都知道吗?近日,GitHub 开源了一个小工具,它可以统计 PyTorch 模型的参数量与每秒浮点运算数(FLOPs)。有了这两种信息,模型大小控制也就更合理了。
其实模型的参数量好算,但浮点运算数并不好确定,我们一般也就根据参数量直接估计计算量了。但是像卷积之类的运算,它的参数量比较小,但是运算量非常大,它是一种计算密集型的操作。反观全连接结构,它的参数量非常多,但运算量并没有显得那么大。
此外,机器学习还有很多结构没有参数但存在计算,例如和 等。因此,PyTorch-OpCounter 这种能直接统计 FLOPs 的工具还是非常有吸引力的。
PyTorch-OpCounter GitHub 地址:https://github.com/Lyken17/pytorch-OpCounter
OpCouter
PyTorch-OpCounter 的安装和使用都非常简单,并且还能定制化统计规则,因此那些特殊的运算也能自定义地统计进去。
我们可以使用 pip 简单地完成安装:pip install thop。不过 GitHub 上的代码总是最新的,因此也可以从 GitHub 上的脚本安装。
对于 torchvision 中自带的模型,Flops 统计通过以下几行代码就能完成:
from torchvision.models import resnet50
from thop import profile
model = resnet50()
input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input, ))
我们测试了一下 DenseNet-121,用 OpCouter 统计了参数量与运算量。API 的输出如下所示,它会告诉我们具体统计了哪些结构,它们的配置又是什么样的。
最后输出的浮点运算数和参数量分别为如下所示,换算一下就能知道 DenseNet-121 的参数量约有 798 万,计算量约有 2.91 GFLOPs。
flops: 2914598912.0
parameters: 7978856.0
OpCouter 是怎么算的
我们可能会疑惑,OpCouter 到底是怎么统计的浮点运算数。其实它的统计代码在项目中也非常可读,从代码上看,目前该工具主要统计了视觉方面的运算,包括各种卷积、激活函数、池化、批归一化等。例如最常见的二维卷积运算,它的统计代码如下所示:
def count_conv2d(m, x, y):
x = x[0]
cin = m.in_channels
cout = m.out_channels
kh, kw = m.kernel_size
batch_size = x.size()[0]
out_h = y.size(2)
out_w = y.size(3)
# ops per output element
# kernel_mul = kh * kw * cin
# kernel_add = kh * kw * cin - 1
kernel_ops = multiply_adds * kh * kw
bias_ops = 1 if m.bias is not None else 0
ops_per_element = kernel_ops + bias_ops
# total ops
# num_out_elements = y.numel()
output_elements = batch_size * out_w * out_h * cout
total_ops = output_elements * ops_per_element * cin // m.groups
m.total_ops = torch.Tensor([int(total_ops)])
总体而言,模型会计算每一个卷积核发生的乘加运算数,再推广到整个卷积层级的总乘加运算数。
定制你的运算统计
有一些运算统计还没有加进去,如果我们知道该怎样算,那么就可以写个自定义函数。
class YourModule(nn.Module):
# your definition
def count_your_model(model, x, y):
# your rule here
input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input, ),
custom_ops={YourModule: count_your_model})
最后,作者利用这个工具统计了各种流行视觉模型的参数量与 FLOPs 量:
欢迎关注磐创博客资源汇总站:http://docs.panchuang.net/
欢迎关注PyTorch官方中文教程站:http://pytorch.panchuang.net/
最新文章
- mongodb安装&;简单使用
- Jumony快速抓取网页 --- Jumony使用笔记--icode
- Spring基础—— Bean 的作用域
- Docker 介绍以及其相关术语、底层原理和技术
- centos下redis和nginx软件的安装
- Oracle SQL的硬解析、软解析、软软解析
- C++之枚举
- I/O体系结构和设备驱动程序
- [置顶] 九度笔记之 1434:今年暑假不AC
- Java学习1——JDK(学前准备)
- MyBatis学习——分步查询与延迟加载
- distributed computing_the World Wide Web
- 设计table表格,用js设计偶数行和奇数行显示不同的颜色
- 【exe4j】如何利用exe4j把java桌面程序生成exe文件
- Linux信号机制
- 【Java字符序列】Pattern
- Linux 分区注意事项
- JavaScript多个音频audio标签,点击其中一个播放时,其他的停止播放
- win10中显示资源管理器扩展
- python的socket网络编程(二)
热门文章
- FPGA小白学习之路(4)PLL中的locked信号解析(转)
- SpringBoot一些基础配置
- 机器学习基础——简单易懂的K邻近算法,根据邻居“找自己”
- js原型继承题目
- redis如何在spring里面的bean配置
- Sequence to Sequence Learning with Neural Networks论文阅读
- Keil MDK版兼容51系列单片机开发环境安装
- express第三方中间件研究之bodyParser中间件
- MySQL InnoDB表的碎片量化和整理(data free能否用来衡量碎片?)
- DVWA Command Injection 解析