『PyTorch』矩阵乘法总结
1. 二维矩阵乘法 torch.mm()
torch.mm(mat1, mat2, out=None)
,其中mat1
(\(n\times m\)),mat2
(\(m\times d\)),输出out
的维度是(\(n\times d\))。
该函数一般只用来计算两个二维矩阵的矩阵乘法,并且不支持broadcast操作。
2. 三维带batch的矩阵乘法 torch.bmm()
由于神经网络训练一般采用mini-batch,经常输入的时三维带batch的矩阵,所以提供torch.bmm(bmat1, bmat2, out=None)
,其中bmat1
(\(b\times n \times m\)),bmat2
(\(b\times m \times d\)),输出out
的维度是(\(b \times n \times d\))。
该函数的两个输入必须是三维矩阵且第一维相同(表示Batch维度),不支持broadcast操作。
3. 多维矩阵乘法 torch.matmul()
torch.matmul(input, other, out=None)
支持broadcast操作,使用起来比较复杂。
针对多维数据 matmul()
乘法,我们可以认为该matmul()
乘法使用使用两个参数的后两个维度来计算,其他的维度都可以认为是batch维度。假设两个输入的维度分别是input
(\(1000 \times 500 \times 99 \times 11\)), other
(\(500 \times 11 \times 99\))那么我们可以认为torch.matmul(input, other, out=None)
乘法首先是进行后两位矩阵乘法得到\((99 \times 11) \times (11 \times 99)\Rightarrow(99 \times 99)\) ,然后分析两个参数的batch size分别是 \(( 1000 \times 500)\) 和 \(500\) , 可以广播成为 \((1000 \times 500)\), 因此最终输出的维度是(\(1000 \times 500 \times 99 \times 99\))。
4. 矩阵逐元素(Element-wise)乘法 torch.mul()
torch.mul(mat1, other, out=None)
,其中other
乘数可以是标量,也可以是任意维度的矩阵,只要满足最终相乘是可以broadcast的即可
5. 两个运算符 @ 和 *
@
:矩阵乘法,自动执行适合的矩阵乘法函数*
:element-wise乘法
最新文章
- adt_sdk_tools介绍
- 慎用Assembly.LoadFile()和Assembly.LoadFrom()
- new对象时,类名后加括号与不加括号的区别
- Android SDK之API Level
- python在线文档
- 在MS CRM 4.0中引用JS文件
- mybatis模板
- 24.task的运用
- JavaScript中创建命名空间
- Mac 安装工具包brew
- 高端内存映射之kmap持久内核映射--Linux内存管理(二十)
- (二分查找 拓展) leetcode 69. Sqrt(x)
- IDEA eclipse转maven
- 22.纯 CSS 创作出美丽的彩虹条纹文字
- 深入理解net core中的依赖注入、Singleton、Scoped、Transient(三)
- ZLYZD团队第四周项目总结
- ZetCode PyQt4 tutorial Drag and Drop
- Java并发(十六):并发工具类——Exchanger
- Java基础学习-包装类
- XV Open Cup named after E.V. Pankratiev Stage 6, Grand Prix of Japan Problem J. Hyperrectangle