前回に引き続き,予測やっていきます!!
今回は,前回作成した予測モデルをもとに実際に色々データを入れてみます
Quandlを用いたデータ取得
こちらAPIを叩くだけで,簡単にデータが取得できるものになります(ただし,データの取得幅はあまり大きくない)
$ pip install quandl
または
$ easy_install quandl
でインストールできます
しかし,こちらのAPIには使用制限があり一日あたり50回までの使用制限が設けられています
50回以上アクセスしたい方はgoogleアカウント等でサインインし,右上の”ME”からマイページに入ってください
その後, ”API KEY"を開き, "Request new api key"を押すことでKeyを入手できます
そのキーを以下のソースコード内指定の場所に埋め込むことで何度も利用できるようになります
作成した予測モデルで株価終値予測
import tensorflow as tf import tflearn import pandas as pd import numpy as np import matplotlib.pyplot as plt import datetime import quandl #CSVデータのインポート def import_data(train_csv): dataframe = pd.read_csv(train_csv,engine='python').iloc[::-1] return dataframe #データ生成 def create_testdata(dataset): X = [] for i in range(1, len(dataset), 1): X.append(dataset[i - 1: i]) X = np.reshape(np.array(X), [-1, 1, 1]) return X #学習モデルの読み込み def load_train_model(model_name): net = tflearn.input_data(shape=[None, 1, 1]) net = tflearn.gru(net, n_units=150, activation='relu',return_seq=True) net = tflearn.gru(net, n_units=150, activation='relu') net = tflearn.fully_connected(net, 1, activation='linear') net = tflearn.regression(net, optimizer='adam', learning_rate=0.001, loss='mape') model = tflearn.DNN(net) model.load(model_name) return model #プロットの表示 def drow_predict_plot(testdata,predictdata,date,name='result_of_predict.png'): test_predict_plot = np.vstack((np.zeros([1,1]),predictdata)) fig=plt.figure(figsize=(10, 10)) fig.autofmt_xdate() fig.add_subplot(111) plt.plot(date,testdata, label = "実際のデータ", color = "black") plt.plot(date,test_predict_plot,label = "実際のデータ", color = "green") #plt.xlim(0, len(testdata)) plt.ylim(0, np.max(testdata)) plt.savefig(name) if __name__=='__main__': #QuandlのAPI Keyを入れる #これを入れないと一日に50回までしか利用できない quandl.ApiConfig.api_key = 'your api key' #学習モデルの呼び出し trainModel = load_train_model('model.tfl') #予測したいデータの取得 name_data = 'GOOG/TYO_9684' oridata = quandl.get(name_data) testdata = oridata['Close'].astype('float32').values date = pd.to_datetime(oridata.index, format='%Y-%m-%d') testX = create_testdata(testdata) #予測とプロット test_predict = trainModel.predict(testX) drow_predict_plot(testdata, test_predict,date, name_data.replace('/','-') + '.png')
好きな銘柄を入れて予測してみる
name_data = 'XXX'
'XXX'に好きなデータを入れてみましょう
先ほどのQuandlのサイト左上タブから”DATA”を選択し, "Core Financial Data" を開きます
適当な会社名を入れて,出てきたやつを適当にクリックします(適当
右上の"Quandl Code"をコピペしてやればおしまい
色々やった結果がこれだよ
とりあえず,思いついた大きい会社 日立
ベンチャーもやってみる Cyber Agent
目下炎上中 DeNA
前回も言ったけど…ずれる!!!
売買ルールについて学習させたほうが利口かもしれない
最近,堅実に勉強はじめました。。。。