Pytorch手写线性回归
2024-08-28 08:12:46
pytorch手写线性回归
import torch
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation LEARN_RATE = 0.1
#1.准备数据
x = torch.randn([500,1])
y_true = x*0.8+3 #2.计算预测值 t_tred = x*w + b w = torch.rand([],requires_grad=True)
b = torch.tensor(0.,requires_grad=True) plt.figure()
plt.grid(True) #开启交互模式
plt.ion()
for i in range(50): plt.cla() for j in [w,b]:
if j.grad is not None:
j.grad.zero_()
y_predict = x*w+b #3.计算损失,把参数的梯度置为0,进行反向传播 loss = (y_predict-y_true).pow(2).mean() loss.backward() #4.更新参数,grad表示导数 w.data = w.data - LEARN_RATE*w.grad
b.data = b.data - LEARN_RATE*b.grad plt.scatter(x.numpy(),y_true.numpy())
plt.plot(x.numpy(),y_predict.detach().numpy(),color="g") plt.pause(0.1) if i %50 ==0:
print( "第{}次,损失{},权重w={},偏执b={}".format(i,loss.data,w.data,b.data)) #关闭交互模式
plt.ioff()
plt.show()
最新文章
- Word基础
- 关于 python
- WinForm开发框架--动态读取DLL模式
- iOS push与present Controller的区别
- DP:Cheapest Palindrome(POJ 3280)
- 【重读】The C++ Programming Language/C++编程语言(一)
- Chp17: Moderate
- NPOI技术,
- POJ 1651	Multiplication PuzzleDP方法:
- 推荐大家一本学习php模式的书
- Java的函数与函数重载
- 【ThinkingInC++】52、函数内部的静态变量
- iOS 7用户界面过渡指南
- 官网.jar包下载技巧
- Java Random介绍
- lbp特征提取(等价模式)
- PE 001~010
- C++Sizeof与Strlen的区别与联系
- deepin linux学习笔记
- 【blog】谷歌浏览器如何设置编码