Pytorch线性规划模型 学习笔记(一)

Pytorch视频学习资料参考:《PyTorch深度学习实践》完结合集

Pytorch搭建神经网络的四大部分

  • 1. 准备数据 Prepare dataset

    准备数据包括数据的读取加载并转换为torch框架下识别的tensor格式,注意数据的dtype为float32格式

  • 2. 设计模型 Design model using class

    网络的基本框架部分,包括自定义的网络layer结构,注意维度的变换要一致,另外,该类中还应包括forward部分

  • 3. 构建损失和优化器 Construct loss and optimizer

    根据处理的问题和模型设置合适的损失,或自己构建损失函数。优化器为梯度下降的解决方案,可选择合适的优化器进行梯度下降

  • 4. 重复训练 Training cycle

    重复训练部分可以后续设置batchsize的大小,按batch进行随机梯度下降(此代码中暂无设置),注意优化器的清零迭代操作

数据部分

X.csv,y.csv链接: https://pan.baidu.com/s/1dJD8zBewCS86fRgv0nL7kQ 密码: 0us0

下载后与程序放置在同一文件夹下

代码部分

# import

import torch
import numpy as np ## 1. prepare dataset
x_data = 'X.csv'
y_data = 'y.csv'
x = np.loadtxt(x_data, delimiter=' ', dtype=np.float32)
y = np.loadtxt(y_data, delimiter=' ', dtype=np.float32).reshape(-1, 1)
x = torch.from_numpy(x[:, :])
y = torch.from_numpy(y[:, :])
print(y.shape)
print(x.shape) ## 2. design model using class
class LinearModel(torch.nn.Module):
def __init__(self):
super(LinearModel, self).__init__()
self.linear1 = torch.nn.Linear(10, 6)
self.linear2 = torch.nn.Linear(6, 6)
self.linear3 = torch.nn.Linear(6, 1)
self.sigmoid = torch.nn.Sigmoid() def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
x = self.sigmoid(x) return x
model = LinearModel() ## 3. construct loss and optimizer
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01) ## 4. training cycle
for epoch in range(500):
y_hat = model(x)
loss = criterion(y_hat, y)
print('epoch', epoch, loss.item()) optimizer.zero_grad()
loss.backward()
optimizer.step()

最新文章

  1. scheduleInRunLoop作用
  2. VS2013设置护眼背景颜色
  3. 作弊Q-百威
  4. C语言程序设计第二次作业
  5. NOIP2013pj小朋友的数字[DP 最大子段和]
  6. AChartEngine 安卓折线图 柱形图等利器
  7. func 和 actin 委托的区别
  8. UML类图五种关系与代码的对应关系
  9. [学习嵌入式开发板]iTOP-4412实现NFS网络文件系统
  10. SgmlReader使用方法
  11. 转 java 类 单例
  12. ZooKeeper 安装部署及hello world
  13. ActionBar +Tab+ViewPager +Fragment 支持侧滑动完成办税工具的页面展示
  14. Springmvc异步上传文件
  15. C#的WebBrowser控制浏览
  16. java接口多实现和多继承
  17. 把本人基于Dubbo的毕业设计分享粗来~
  18. 单源最短路径问题(dijkstra算法 及其 优化算法(优先队列实现))
  19. 【iCore1S 双核心板_ARM】例程九:DAC实验——输出直流电压
  20. webservice调用dll

热门文章

  1. PHP Proxy 负载均衡技术
  2. 【python】Leetcode每日一题-删除排序链表中的重复元素
  3. 一款好用的数据血缘关系在线工具--SQLFlow
  4. VS2019解决X64无法内联汇编的问题
  5. ffmpeg实践
  6. JAVA并发(1)-AQS(亿点细节)
  7. mysql基本命令(增,查,改,删)
  8. 解密华为云FusionInsight MRS新特性:一架构三湖
  9. centos7 连接打印机
  10. 结合JVM 浅谈Java 类加载器(Day_03)