pytorch循环神经网络实现回归预测

学习视频:莫烦python

# RNN for classification
import torch
import numpy as np
import torch.nn as nn
import torch.utils.data as Data
import matplotlib.pyplot as plt
import torchvision #hyper parameters
TIME_STEP=10 #run time step
INPUT_SIZE=1
LR=0.02 #learning rate # t=np.linspace(0,np.pi*2,100,dtype=float) #from zero to pi*2, and one hundred point there
# x=np.sin(t)
# y=np.cos(t)
# plt.plot(t,x,'r-',label='input (sin)')
# plt.plot(t,y,'b-',label='target (cos)')
# plt.legend(loc='best')
# plt.show() class RNN_Net(nn.Module):
def __init__(self):
super(RNN_Net,self).__init__()
self.rnn=nn.RNN(
input_size=INPUT_SIZE,
hidden_size=32,
num_layers=1,
batch_first=True,
)
self.out=nn.Linear(32,1) def forward(self,x,h_state):
r_out,h_state=self.rnn(x,h_state)
outs=[]
for time_step in range(r_out.size(1)):
outs.append(self.out(r_out[:,time_step,:]))
return torch.stack(outs,dim=1),h_state # the type of return data is torch, and the return data also include h_state rnn=RNN_Net()
# print(rnn) optimizer=torch.optim.Adam(rnn.parameters(),lr=LR)
loss_func=nn.MSELoss() plt.ion()
h_state=None
for step in range (60):
start,end=step*np.pi,(step+1)*np.pi
#using sin predicts cos
steps=np.linspace(start,end,TIME_STEP,dtype=np.float32)
x_np=np.sin(steps)
y_np=np.cos(steps) x=torch.from_numpy(x_np[np.newaxis,:,np.newaxis]) # np.newaxis means increase a dim
y=torch.from_numpy(y_np[np.newaxis,:,np.newaxis])
predition,h_state=rnn(x,h_state) #the first h_state is None
h_state=h_state.data #?????
loss=loss_func(predition,y)
optimizer.zero_grad()
loss.backward()
optimizer.step() plt.plot(steps,y_np,'r-')
plt.plot(steps,predition.detach().numpy().flatten(),'b-') #flatten() 展平维度
plt.draw()
plt.pause(0.05)
plt.ioff()
plt.show()

最新文章

  1. CodeForces 676D代码 哪里有问题呢?
  2. iOS开发之邓白氏编码申请流程
  3. 单点登陆CAS安装过程中可能遇到的问题
  4. table的border-collapse属性与border-spacing属性
  5. .net 开发必备小抄(电子书)
  6. Gradle 载入中 Android 下一个.so档
  7. 【Dijkstra堆优化】洛谷P2243电路维修
  8. python复杂网络库networkx:基础
  9. Java进阶篇之十五 ----- JDK1.8的Lambda、Stream和日期的使用详解(很详细)
  10. 【反编译系列】二、反编译代码(jeb)
  11. Android:图解四种启动模式 及 实际应用场景解说
  12. oracle数据库报错ora-01653表空间扩展失败解决方案
  13. 关于Quad PLL /CPLL参考时钟的选择
  14. 解决Android adjustresize全屏无效问题
  15. windows下使用nginx配置tomcat集群
  16. HDU 4055 Number String dp
  17. 【转】使用python编写网络通信程序
  18. springboot5
  19. CCSpriteBatchNode CCSpriteFrameCache
  20. [技术分享]借用UAC完成的提权思路分享

热门文章

  1. 浏览器tab标签切换触发监听事件visibilitychange
  2. C++ MFC学习 (六)
  3. FMC DA子卡设计原理图:FMCJ465-2路 16bit 12.6GSPS FMC DA子卡
  4. python win32 microsoft excel 类range的copyPictrue方法无效
  5. Oracle备份脚本(数据泵)-Windows平台
  6. node版本和用的包不兼容问题,头疼
  7. kubeSphere+kubernetes 集群更新证书
  8. Oracle查询表中的各列的列名,数据类型,以及类型长度
  9. Sql Sugar 拾遗
  10. 遍历dom节点