pytorch循环神经网络实现回归预测 代码
2024-10-21 14:30:27
pytorch循环神经网络实现回归预测
学习视频:莫烦python
# RNN for classification
import torch
import numpy as np
import torch.nn as nn
import torch.utils.data as Data
import matplotlib.pyplot as plt
import torchvision #hyper parameters
TIME_STEP=10 #run time step
INPUT_SIZE=1
LR=0.02 #learning rate # t=np.linspace(0,np.pi*2,100,dtype=float) #from zero to pi*2, and one hundred point there
# x=np.sin(t)
# y=np.cos(t)
# plt.plot(t,x,'r-',label='input (sin)')
# plt.plot(t,y,'b-',label='target (cos)')
# plt.legend(loc='best')
# plt.show() class RNN_Net(nn.Module):
def __init__(self):
super(RNN_Net,self).__init__()
self.rnn=nn.RNN(
input_size=INPUT_SIZE,
hidden_size=32,
num_layers=1,
batch_first=True,
)
self.out=nn.Linear(32,1) def forward(self,x,h_state):
r_out,h_state=self.rnn(x,h_state)
outs=[]
for time_step in range(r_out.size(1)):
outs.append(self.out(r_out[:,time_step,:]))
return torch.stack(outs,dim=1),h_state # the type of return data is torch, and the return data also include h_state rnn=RNN_Net()
# print(rnn) optimizer=torch.optim.Adam(rnn.parameters(),lr=LR)
loss_func=nn.MSELoss() plt.ion()
h_state=None
for step in range (60):
start,end=step*np.pi,(step+1)*np.pi
#using sin predicts cos
steps=np.linspace(start,end,TIME_STEP,dtype=np.float32)
x_np=np.sin(steps)
y_np=np.cos(steps) x=torch.from_numpy(x_np[np.newaxis,:,np.newaxis]) # np.newaxis means increase a dim
y=torch.from_numpy(y_np[np.newaxis,:,np.newaxis])
predition,h_state=rnn(x,h_state) #the first h_state is None
h_state=h_state.data #?????
loss=loss_func(predition,y)
optimizer.zero_grad()
loss.backward()
optimizer.step() plt.plot(steps,y_np,'r-')
plt.plot(steps,predition.detach().numpy().flatten(),'b-') #flatten() 展平维度
plt.draw()
plt.pause(0.05)
plt.ioff()
plt.show()
最新文章
- CodeForces 676D代码 哪里有问题呢?
- iOS开发之邓白氏编码申请流程
- 单点登陆CAS安装过程中可能遇到的问题
- table的border-collapse属性与border-spacing属性
- .net 开发必备小抄(电子书)
- Gradle 载入中 Android 下一个.so档
- 【Dijkstra堆优化】洛谷P2243电路维修
- python复杂网络库networkx:基础
- Java进阶篇之十五 ----- JDK1.8的Lambda、Stream和日期的使用详解(很详细)
- 【反编译系列】二、反编译代码(jeb)
- Android:图解四种启动模式 及 实际应用场景解说
- oracle数据库报错ora-01653表空间扩展失败解决方案
- 关于Quad PLL /CPLL参考时钟的选择
- 解决Android adjustresize全屏无效问题
- windows下使用nginx配置tomcat集群
- HDU 4055 Number String dp
- 【转】使用python编写网络通信程序
- springboot5
- CCSpriteBatchNode CCSpriteFrameCache
- [技术分享]借用UAC完成的提权思路分享
热门文章
- 浏览器tab标签切换触发监听事件visibilitychange
- C++ MFC学习 (六)
- FMC DA子卡设计原理图:FMCJ465-2路 16bit 12.6GSPS FMC DA子卡
- python win32 microsoft excel 类range的copyPictrue方法无效
- Oracle备份脚本(数据泵)-Windows平台
- node版本和用的包不兼容问题,头疼
- kubeSphere+kubernetes 集群更新证书
- Oracle查询表中的各列的列名,数据类型,以及类型长度
- Sql Sugar 拾遗
- 遍历dom节点