今天进行小批量梯度下降时,代码给我报错,具体代码如下

import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader class DiabetesDataset(Dataset):
def __init__(self, filepath):
xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
self.len = xy.shape[0]
self.x_data = torch.from_numpy(xy[:, :-1])
self.y_data = torch.from_numpy(xy[:, [-1]]) def __getitem__(self, index):
return self.x_data[index], self.y_data[index] def __len__(self):
return self.len dataset = DiabetesDataset('diabetes.csv.gz')
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=2) class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear1 = torch.nn.Linear(8, 6)
self.linear2 = torch.nn.Linear(6, 4)
self.linear3 = torch.nn.Linear(4, 2)
self.linear4 = torch.nn.Linear(2, 1)
self.sigmoid = torch.nn.Sigmoid() def forward(self, x):
x = self.sigmoid(self.linear1(x))
x = self.sigmoid(self.linear2(x))
x = self.sigmoid(self.linear3(x))
x = self.sigmoid(self.linear4(x))
return x model = Model()
criterion = torch.nn.BCELoss(size_average=True)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) for epoch in range(100):
for i, data in enumerate(train_loader, 0):
inputs, labels = data
y_pred = model(inputs)
loss = criterion(y_pred, labels)
print(epoch, i, loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()

报错内容如下



室友告诉我,需要在主运行的代码,也就是for前面加上

if __name__ == '__main__':

通过查阅大致知道了我这句代码的意思,原因就是我上面有一句

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

这句话的意思就是,当模块被直接运行时,以下代码块将被运行,当模块是被导入时,代码块不被运行。

这样就可以很好的决定模块中那些代码运行,那些代码不运行

还有一个警告就是

UserWarning: size_average and reduce args will be deprecated, please use reduction='mean' instead.
warnings.warn(warning.format(ret))

这里是版本更新导致的问题

criterion = torch.nn.BCELoss(size_average=True)

改为:

criterion = torch.nn.BCELoss(reduction='mean')

即可

最新文章

  1. SQL Server 无法连接到服务器。SQL Server 复制需要有实际的服务器名称才能连接到服务器。请指定实际的服务器名称。
  2. CXF集成Spring实现webservice的发布与请求
  3. NYOJ题目111分数加减法
  4. Linux分区练习(1)
  5. HP 7440老机器重启
  6. java从命令行接收多个数字,求和之后输出结果
  7. 9段高效率开发PHP程序的代码
  8. Linux Kernel Schduler History And Centos7.2's Kernel Resource Analysis
  9. Babelfish(二分)
  10. Timestamp解析0000-00-00 00:00:00报格式错误
  11. python3 爬 妹子图
  12. 径向基网络(RBF network)
  13. centos命令自动补全增强
  14. C++学习,两个小的语法错误-network-programming
  15. UIAlertControl的使用对比与UIAlertView和UIActionSheet
  16. 更改MySQL密码
  17. java14周
  18. Perl正则表达式超详细教程
  19. MQTT 嵌入式端通讯协议解析(转)
  20. 34 char类型转换为int类型

热门文章

  1. 基于Python+Sqlite3实现最简单的CRUD
  2. Deep Learning-深度学习(二)
  3. Deep Learning-深度学习(一)
  4. 一文解决Vue中实现 Excel下载到本地以及上传Excel
  5. 【新人福利】使用CSDN 官方插件,赠永久免站内广告特权 >>电脑端访问:https://t.csdnimg.cn/PVqS
  6. 题解【洛谷 P1466 [USACO2.2]集合 Subset Sums】
  7. 在centos7.6上部署前后端分离项目Nginx反向代理vue.js2.6+Tornado5.1.1,使用supervisor统一管理服务
  8. Win10系统下基于Docker构建Appium容器连接Android模拟器Genymotion完成移动端Python自动化测试
  9. 6.1 NOI 模拟
  10. .NET 6学习笔记(4)——解决VS2022中Nullable警告