Pytorch: parameters(),children(),modules(),named_*区别
2024-09-02 14:41:50
nn.Module vs nn.functional
前者会保存权重等信息,后者只是做运算
parameters()
返回可训练参数
nn.ModuleList vs. nn.ParameterList vs. nn.Sequential
layer_list = [nn.Conv2d(5,5,3), nn.BatchNorm2d(5), nn.Linear(5,2)]
class myNet(nn.Module):
def __init__(self):
super().__init__()
self.layers = layer_list
def forward(x):
for layer in self.layers:
x = layer(x)
net = myNet()
print(list(net.parameters())) # Parameters of modules in the layer_list don't show up.
nn.ModuleList
的作用就是wrap pthon list,这样其中的参数会被注册,因此可以返回可训练参数(ParameterList)。
nn.Sequential
的作用如下:
class myNet(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Relu(inplace=True),
nn.Linear(10, 10)
)
def forward(x):
x = layer(x)
x = torch.rand(10)
net = myNet()
print(net(x).shape)
可以看到Sequential
的作用就是按照指定的顺序构建网络结构,得到一个完整的模块,而ModuleList
则只是像list那样把元素集合起来而已。
nn.modules vs. nn.children
class myNet(nn.Module):
def __init__(self):
super().__init__()
self.convBN = nn.Sequential(nn.Conv2d(10,10,3), nn.BatchNorm2d(10))
self.linear = nn.Linear(10,2)
def forward(self, x):
pass
Net = myNet()
print("Printing children\n------------------------------")
print(list(Net.children()))
print("\n\nPrinting Modules\n------------------------------")
print(list(Net.modules()))
输出信息如下:
Printing children
------------------------------
[Sequential(
(0): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
(1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), Linear(in_features=10, out_features=2, bias=True)]
Printing Modules
------------------------------
[myNet(
(convBN1): Sequential(
(0): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
(1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(linear): Linear(in_features=10, out_features=2, bias=True)
), Sequential(
(0): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
(1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1)), BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), Linear(in_features=10, out_features=2, bias=True)]
可以看到children
只会返回子元素,子元素可能是单个操作,如Linear,也可能是Sequential。 而modules()
返回的信息更加详细,不仅会返回children
一样的信息,同时还会递归地返回,例如modules()
会迭代地返回Sequential
中包含的若干个子元素。
named_*
- named_parameters: 返回一个
iterator
,每次它会提供包含参数名的元组。
In [27]: x = torch.nn.Linear(2,3)
In [28]: x_name_params = x.named_parameters()
In [29]: next(x_name_params)
Out[29]:
('weight', Parameter containing:
tensor([[-0.5262, 0.3480],
[-0.6416, -0.1956],
[ 0.5042, 0.6732]], requires_grad=True))
In [30]: next(x_name_params)
Out[30]:
('bias', Parameter containing:
tensor([ 0.0595, -0.0386, 0.0975], requires_grad=True))
- named_modules
这个其实就是把上面提到的nn.modules
以iterator
的形式返回,每次读取和上面一样也是用next()
,示例如下:
In [46]: class myNet(nn.Module):
...: def __init__(self):
...: super().__init__()
...: self.convBN1 = nn.Sequential(nn.Conv2d(10,10,3), nn.BatchNorm2d(10))
...: self.linear = nn.Linear(10,2)
...:
...: def forward(self, x):
...: pass
...:
In [47]: net = myNet()
In [48]: net_named_modules = net.named_modules()
In [49]: next(net_named_modules)
Out[49]:
('', myNet(
(convBN1): Sequential(
(0): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
(1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(linear): Linear(in_features=10, out_features=2, bias=True)
))
In [50]: next(net_named_modules)
Out[50]:
('convBN1', Sequential(
(0): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
(1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
))
In [51]: next(net_named_modules)
Out[51]: ('convBN1.0', Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1)))
In [52]: next(net_named_modules)
Out[52]:
('convBN1.1',
BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))
In [53]: next(net_named_modules)
Out[53]: ('linear', Linear(in_features=10, out_features=2, bias=True))
In [54]: next(net_named_modules)
---------------------------------------------------------------------------
StopIteration Traceback (most recent call last)
<ipython-input-54-05e848b071b8> in <module>
----> 1 next(net_named_modules)
StopIteration:
- named_children
同named_modules
参考
https://blog.paperspace.com/pytorch-101-advanced/
最新文章
- Echarts使用
- 一款css3很美的iphone注册表单样式
- jquery笔记(遍历)
- 读JS高级(兼容&;&;BOM&;&;私有变量&;&;面向对象)
- read,for,case,while,if简单例子
- cocos2d-x 2.1.2 bug发现
- Redis 列表(List)
- zoj1013 Great Equipment
- hdu 4790 Just Random 神奇的容斥原理
- 菜鸟级springmvc+spring+mybatis整合开发用户登录功能(下)
- 开始学习机器学习,从Ng的视频开始
- php去除数组中重复值,并返回结果!
- RabbitMQ框架构建系列(一)——AMPQ协议
- IntelliJ IDEA配置Tomcat和Lombok
- Linux 网络编程(一)--Linux操作系统概述
- 异常: Call From * 9000 failed on connection exception: java.net.ConnectException: Connection refused: no further information; For more details see: http://wiki.apache.org/hadoop/ConnectionRefused
- Lodop打印设计(PRINT_DESIGN)里的快捷键
- 哪些intel 网卡支持SR-IOV
- 魔方Newlife.Cube权限系统的使用及模版覆盖详解
- 解决BeautifulSoup库运行时报错问题
热门文章
- [比赛题解]CWOI2019-1
- mac解决安装提示“xxx软件已损坏,打不开,您应该将它移到废纸篓”的提示
- SQLserver 存储过程游标使用
- Let&#39;s Encrypt之acme.sh
- 机器学习之TensorFlow介绍
- 在 ubuntu 下安装 mono 和 xsp4 ,并测试
- Prometheus 与 Alertmanager 通信
- Fuck SELinux :rsyslog无法生成log文件,原来是selinux机制搞的鬼!
- 【spring boot】spring boot 基于redis pipeline 管道,批量操作redis命令
- thinkphp3.2 无法加载模块