"""
torch.float64对应torch.DoubleTensor
torch.float32对应torch.FloatTensor
将真实函数的数据点能够拟合成一个多项式
eg:y = 0.9 +0.5×x + 3×x*x + 2.4 ×x*x*x
"""
import torch from torch import nn def make_features(x):
x = x.unsqueeze(1)#在原来的基础上扩充了一维
return torch.cat([x ** i for i in range(1,4)], 1) def get_batch(batch_size=32): random = torch.randn(batch_size)
# print('random')
# print(random) #32个数 x = make_features(random)#进行维度扩充,扩充后32*1,又进行1,2,3次幂运算,拼接后32*3 '''Compute the actual results'''
y = f(x) # 32*3 *3*1
if torch.cuda.is_available():
return torch.autograd.Variable(x).cuda(), torch.autograd.Variable(y).cuda()
else:
return torch.autograd.Variable(x), torch.autograd.Variable(y) w_target = torch.FloatTensor([0.5,3,2.4]).unsqueeze(1)#三行一列
b_target = torch.FloatTensor([0.9]) def f(x):
return x.mm(w_target)+b_target[0] class poly_model(nn.Module):
def __init__(self):
super(poly_model, self).__init__()
self.poly = nn.Linear(3, 1)# 输入是3维,输出是1维 def forward(self, x):
out = self.poly(x)
return out if torch.cuda.is_available():
model = poly_model().cuda()
else:
model = poly_model() criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) epoch = 0
for epoch in range(20):
batch_x,batch_y = get_batch()#batch_x 和get_batch里面的x是一样的
output = model(batch_x)
loss = criterion(output,batch_y)
print_loss = loss
print(loss.item()) # 0.4版本之后使用loss.item()从标量中获得Python number
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('finished')

最新文章

  1. Flume NG Getting Started(Flume NG 新手入门指南)
  2. IOS中Json解析的四种方法
  3. java编辑器eclipse如何更改jdk版本
  4. code manager tools svn服务安装配置
  5. flume1.5.2安装与简介
  6. UVA 11462 Age Sort(计数排序法 优化输入输出)
  7. 老oj2146 && Pku2135 Farm Tour
  8. 'datetime.datetime' has no attribute 'datetime'问题
  9. IOS9任务管理器特效的实现
  10. 初探Lambda表达式/Java多核编程【2】并行与组合行为
  11. 【原创】07. ajax请求,解决sendRedirect 无效
  12. 查看.ssh文件在哪
  13. 记一次高并发场景下.net监控程序数据上报的性能调优
  14. 读Ghost博客源码与自定义Ghost博客主题
  15. 2016-wing的年度总结
  16. 最新版的Chrome不能设置网页编码怎么解?
  17. Learning-Python【34】:进程之生产者消费者模型
  18. Hbase-2.0.0_02_常用操作
  19. java同一个实体的复制
  20. 30分钟学会JS AST,打造自己的编译器

热门文章

  1. WPF ListView ,XML
  2. python语法01
  3. 如何down掉IB交换机口
  4. maven 学习---使用Maven模板创建项目
  5. maven 学习---Maven Web应用
  6. Kali无法使用Chrome原因及解决方法
  7. LeetCode——Rank Scores
  8. 对于不返回任何键列信息的 SelectCommand,不支持 DeleteCommand 的动态 SQL 生成
  9. anaconda配置清华大学开源软件镜像
  10. svn导一份历史版本出来