开始导入 MinMaxScaler 时会报错 “from . import _arpack ImportError: DLL load failed: 找不到指定的程序。” (把sklearn更新下)和“AttributeError: module 'numpy' has no attribute 'testing'”,然后把numpy卸载重装(pip uninstall numpy; pip install numpy),问题解决。

#import datetime
import pandas as pd
import numpy as np
#from numpy import row_stack,column_stack
import tushare as ts
#import matplotlib
import matplotlib.pyplot as plt
#from matplotlib.pylab import date2num
#from matplotlib.dates import DateFormatter, WeekdayLocator, DayLocator, MONDAY,YEARLY
#from matplotlib.finance import quotes_historical_yahoo_ohlc, candlestick_ohlc
from sklearn.preprocessing import MinMaxScaler
#https://scikit-learn.org/stable/auto_examples/preprocessing/plot_all_scaling.html#sphx-glr-auto-examples-preprocessing-plot-all-scaling-py
from keras.models import Sequential
from keras.layers import LSTM, Dense, Activation df=ts.get_hist_data('601857',start='2016-06-15',end='2018-01-12')
dd=df[['open','high','low','close']] #print(dd.values.shape[0]) dd1=dd .sort_index() dd2=dd1.values.flatten() dd3=pd.DataFrame(dd1['close']) def load_data(df, sequence_length=10, split=0.8): #df = pd.read_csv(file_name, sep=',', usecols=[1])
#data_all = np.array(df).astype(float) data_all = np.array(df).astype(float)
scaler = MinMaxScaler()
data_all = scaler.fit_transform(data_all)
data = []
for i in range(len(data_all) - sequence_length - 1):
data.append(data_all[i: i + sequence_length + 1])
reshaped_data = np.array(data).astype('float64')
#np.random.shuffle(reshaped_data)
# 对x进行统一归一化,而y则不归一化
x = reshaped_data[:, :-1]
y = reshaped_data[:, -1]
split_boundary = int(reshaped_data.shape[0] * split)
train_x = x[: split_boundary]
test_x = x[split_boundary:] train_y = y[: split_boundary]
test_y = y[split_boundary:] return train_x, train_y, test_x, test_y, scaler def build_model():
# input_dim是输入的train_x的最后一个维度,train_x的维度为(n_samples, time_steps, input_dim)
model = Sequential()
model.add(LSTM(input_dim=1, output_dim=6, return_sequences=True))
#model.add(LSTM(6, input_dim=1, return_sequences=True))
#model.add(LSTM(6, input_shape=(None, 1),return_sequences=True)) """
#model.add(LSTM(input_dim=1, output_dim=6,input_length=10, return_sequences=True))
#model.add(LSTM(6, input_dim=1, input_length=10, return_sequences=True))
model.add(LSTM(6, input_shape=(10, 1),return_sequences=True))
"""
print(model.layers)
#model.add(LSTM(100, return_sequences=True))
#model.add(LSTM(100, return_sequences=True))
model.add(LSTM(100, return_sequences=False))
model.add(Dense(output_dim=1))
model.add(Activation('linear')) model.compile(loss='mse', optimizer='rmsprop')
return model def train_model(train_x, train_y, test_x, test_y):
model = build_model() try:
model.fit(train_x, train_y, batch_size=512, nb_epoch=300, validation_split=0.1)
predict = model.predict(test_x)
predict = np.reshape(predict, (predict.size, ))
except KeyboardInterrupt:
print(predict)
print(test_y)
print(predict)
print(test_y)
try:
fig = plt.figure(1)
plt.plot(predict, 'r:')
plt.plot(test_y, 'g-')
plt.legend(['predict', 'true'])
except Exception as e:
print(e)
return predict, test_y if __name__ == '__main__':
#train_x, train_y, test_x, test_y, scaler = load_data('international-airline-passengers.csv')
train_x, train_y, test_x, test_y, scaler =load_data(dd3, sequence_length=10, split=0.8)
train_x = np.reshape(train_x, (train_x.shape[0], train_x.shape[1], 1))
test_x = np.reshape(test_x, (test_x.shape[0], test_x.shape[1], 1))
predict_y, test_y = train_model(train_x, train_y, test_x, test_y)
predict_y = scaler.inverse_transform([[i] for i in predict_y])
test_y = scaler.inverse_transform(test_y)
fig2 = plt.figure(2)
plt.plot(predict_y, 'g:')
plt.plot(test_y, 'r-')
plt.show()

  

参考资料:

基于keras 的lstm 股票收盘价预测

RNN,LSTM,GRU基本原理的个人理解

  

最新文章

  1. Leetcode: plus one
  2. JS 拼接字符串数组
  3. php高级研发或架构师必了解---很多问题面试中常问到!
  4. 通过FTP命令上传下载
  5. Android 生成颜色器
  6. excel导入导出
  7. Java之重载与覆盖
  8. 关于自定义tabBar时修改系统自带tabBarItem属性造成的按钮顺序错乱的问题相关探究
  9. asp.net运行原理(一)总体概要
  10. kickstartInstalls
  11. word和.txt文件转html 及pdf文件, 使用poi jsoup itext心得
  12. Linux CentOS7 安装 Qt 5.9.2
  13. CSS深入理解学习笔记之vertical-align
  14. (转)SQL中的循环、for循环、游标
  15. 1—ARM中的寄存器
  16. Ubuntu 14.04循环登录问题(密码正确,无法登录)
  17. phpstudy如何安装ssl证书
  18. [cookie篇]cookie-parser之parser.js
  19. Codeforces 558C Amr and Chemistry 暴力 - -
  20. 取出当前会话的sid、process_id.sql

热门文章

  1. Mac下编译libpomelo静态库,并在cocos2dx项目中引用
  2. CodeForces - 1251D (贪心+二分)
  3. sql客户端工具Navicat_Premiun12中文破解版
  4. XOR加密作业
  5. 201871010116-祁英红《面向对象程序设计(java)》第十五周学习总结
  6. LINUX上安装JDK+tomcat+mysql操作笔记
  7. C语言快速入门一:win10系统环境搭建
  8. 使用OC实现单链表:创建、删除、插入、查询、遍历、反转、合并、判断相交、求成环入口
  9. 如何将Azure SQL 数据库还原到本地数据库实例中
  10. Android Monkey的用法(一)