莫烦pytorch学习笔记(七)——Optimizer优化器
2024-10-18 09:58:52
import torch import torch.utils.data as Data import torch.nn.functional as F from torch.autograd import Variable import matplotlib.pyplot as plt # 超参数 LR = 0.01 BATCH_SIZE = EPOCH = # 生成假数据 # torch.unsqueeze() 的作用是将一维变二维,torch只能处理二维的数据 x = torch.unsqueeze(torch.linspace(-, , ), dim=) # x data (tensor), shape(, ) # 0.2 * torch.rand(x.size())增加噪点 y = x.pow() + 0.1 * torch.normal(torch.zeros(*x.size())) # 输出数据图 # plt.scatter(x.numpy(), y.numpy()) # plt.show() torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y) loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=) class Net(torch.nn.Module): # 初始化 def __init__(self): super(Net, self).__init__() self.hidden = torch.nn.Linear(, ) self.predict = torch.nn.Linear(, ) # 前向传递 def forward(self, x): x = F.relu(self.hidden(x)) x = self.predict(x) return x net_SGD = Net() net_Momentum = Net() net_RMSProp = Net() net_Adam = Net() nets = [net_SGD, net_Momentum, net_RMSProp, net_Adam] opt_SGD = torch.optim.SGD(net_SGD.parameters(), lr=LR) opt_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8) opt_RMSProp = torch.optim.RMSprop(net_RMSProp.parameters(), lr=LR, alpha=0.9) opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99)) optimizers = [opt_SGD, opt_Momentum, opt_RMSProp, opt_Adam] loss_func = torch.nn.MSELoss() loss_his = [[], [], [], []] # 记录损失 for epoch in range(EPOCH): print(epoch) for step, (batch_x, batch_y) in enumerate(loader): b_x = Variable(batch_x) b_y = Variable(batch_y) for net, opt,l_his in zip(nets, optimizers, loss_his): output = net(b_x) # get output for every net loss = loss_func(output, b_y) # compute loss for every net opt.zero_grad() # clear gradients for next train loss.backward() # backpropagation, compute gradients opt.step() # apply gradients l_his.append(loss.data.numpy()) # loss recoder labels = ['SGD', 'Momentum', 'RMSprop', 'Adam'] for i, l_his in enumerate(loss_his): plt.plot(l_his, label=labels[i]) plt.legend(loc='best') plt.xlabel('Steps') plt.ylabel('Loss') plt.ylim((, 0.2)) plt.show()
最新文章
- python2-gst0.10制作静态包的补丁
- CSS3初学篇章_7(布局/浏览器默认样式重置)
- 利用ffmpeg给小视频结尾增加logo水印
- 关于APP接口设计
- java中的toString方法
- Delta-wave
- svn的初级使用
- EasyUI - 一般处理程序 返回 Json值
- Matlab强迫症产生的图像
- 【网络】 应用&;传输层笔记
- web.py模块使用
- zookeeper选举流程
- C#多线程--信号量(Semaphore)[z]
- 【bzoj2023/1630】[Usaco2005 Nov]Ant Counting 数蚂蚁 dp
- Codeforces Round #440 (Div. 2) A,B,C
- wget 无法建立ssl连接 [ERROR: certificate common name ?..ssl.fastly.net?.doesn?. match requested host name ?.ache.ruby-lang.org?. To connect to cache.ruby-lang.org insecurely, use ?.-no-check-certificate?]
- canvas 实现鼠标画出矩形
- Iterator 迭代器模式 MD
- Python 列表 index() 方法
- jquery操作select(选中,取值)