Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Sign in
Toggle navigation
R
robo-dividend
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenwen.tang
robo-dividend
Commits
18c7d36d
Commit
18c7d36d
authored
Dec 01, 2023
by
wenwen.tang
😕
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
ai 模块对X和Y列增添数据
parent
f02ff4be
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
498 additions
and
330 deletions
+498
-330
EstimateMarketTrendV20.py
ai/EstimateMarketTrendV20.py
+73
-330
robo_datas.py
ai/dao/robo_datas.py
+20
-0
data_access.py
ai/data_access.py
+108
-0
model_trainer.py
ai/model_trainer.py
+73
-0
noticer.py
ai/noticer.py
+28
-0
training_data_builder.py
ai/training_data_builder.py
+196
-0
No files found.
ai/EstimateMarketTrendV20.py
View file @
18c7d36d
from
datetime
import
datetime
from
typing
import
List
from
typing
import
List
import
matplotlib.pyplot
as
plt
from
py_jftech
import
autowired
import
numpy
as
np
import
pandas
as
pd
# from finta import TA
from
finta
import
TA
from
lightgbm
import
LGBMClassifier
from
py_jftech
import
sendmail
,
format_date
,
autowired
from
sklearn
import
svm
from
sklearn.ensemble
import
RandomForestClassifier
,
VotingClassifier
from
sklearn.metrics
import
classification_report
,
confusion_matrix
,
ConfusionMatrixDisplay
,
accuracy_score
from
sklearn.model_selection
import
train_test_split
from
sklearn.preprocessing
import
MinMaxScaler
# for draw confusion matrix
from
ai.data_access
import
DataAccess
# import sys
from
ai.model_trainer
import
ModelTrainer
# import matplotlib
from
ai.training_data_builder
import
TrainingDataBuilder
# matplotlib.use('Agg')/,nZ'/
# from sklearn import metrics
# from tensorflow.keras.models import Sequential
# from tensorflow.keras.layers import Dense, Dropout, LSTM
from
ai.dao.robo_datas
import
get_eco_list
,
get_index_list
from
api
import
DataSync
from
api
import
DataSync
# List of symbols for technical indicators
# INDICATORS = ['RSI', 'MACD', 'STOCH','ADL', 'ATR', 'MOM', 'MFI', 'ROC', 'OBV', 'CCI', 'EMV', 'VORTEX']
# Note that '14 period MFI' and '14 period EMV' is not available for forecast
from
basic.sync
import
EcoSync
,
IndexSync
INDICATORS
=
[
'RSI'
,
'MACD'
,
'STOCH'
,
'ADL'
,
'ATR'
,
'MOM'
,
'ROC'
,
'OBV'
,
'CCI'
,
'VORTEX'
]
eco
=
[
65
,
66
,
74
]
index
=
[
67
,
68
,
69
,
70
,
71
,
72
,
73
,
75
]
# 预测发送邮箱
email
=
[
'wenwen.tang@thizgroup.com'
]
# 截止日期
# 截止日期
max_date
=
None
# max_date = None
# max_date = '2023-09-01'
max_date
=
'2023-11-24'
# 待预测指数
# PREDICT_LIST = [67]
PREDICT_LIST
=
[
67
,
121
,
122
,
123
]
eco
=
[
65
,
66
,
74
,
134
]
index
=
[
67
,
68
,
69
,
70
,
71
,
72
,
73
,
75
,
116
,
117
,
138
,
139
,
142
,
143
,
140
,
141
,
144
,
145
,
146
]
fund
=
[
121
,
122
,
123
]
@
autowired
@
autowired
def
sync
(
syncs
:
List
[
DataSync
]
=
None
):
def
sync
(
syncs
:
List
[
DataSync
]
=
None
):
for
s
in
syncs
:
for
s
in
syncs
:
if
isinstance
(
s
,
(
IndexSync
,
EcoSync
)):
#
if isinstance(s, (IndexSync, EcoSync)):
s
.
do_sync
()
s
.
do_sync
()
def
send
(
content
):
def
predictionFromMoel
(
the_model
,
scaledX_forecast
,
predict_item
):
receives
=
email
prediction
=
the_model
.
predict
(
scaledX_forecast
)
subject
=
'预测_{today}'
.
format
(
today
=
format_date
(
datetime
.
today
()))
predictionStr
=
'DOWN'
sendmail
(
receives
=
receives
,
copies
=
[],
attach_paths
=
[],
subject
=
subject
,
content
=
content
)
if
(
prediction
>
0.5
):
predictionStr
=
'UP'
content
=
f
"""
\n
On day {forecastDay.strftime("
%
m/
%
d/
%
Y")}, the model predicts {predict_item} to be {predictionStr} in {str(numForecastDays)} business days.
\n
"""
def
_get_indicator_data
(
data
):
print
(
content
)
"""
# upload_predict(predict_item, forecastDay, predictionStr)
Function that uses the finta API to calculate technical indicators used as the features
# send(content)
"""
return
prediction
for
indicator
in
INDICATORS
:
ind_data
=
eval
(
'TA.'
+
indicator
+
'(data)'
)
if
not
isinstance
(
ind_data
,
pd
.
DataFrame
):
ind_data
=
ind_data
.
to_frame
()
data
=
data
.
merge
(
ind_data
,
left_index
=
True
,
right_index
=
True
)
data
.
rename
(
columns
=
{
"14 period EMV."
:
'14 period EMV'
},
inplace
=
True
)
# Also calculate moving averages for features
data
[
'ema50'
]
=
data
[
'close'
]
/
data
[
'close'
]
.
ewm
(
50
)
.
mean
()
data
[
'ema21'
]
=
data
[
'close'
]
/
data
[
'close'
]
.
ewm
(
21
)
.
mean
()
data
[
'ema15'
]
=
data
[
'close'
]
/
data
[
'close'
]
.
ewm
(
15
)
.
mean
()
data
[
'ema5'
]
=
data
[
'close'
]
/
data
[
'close'
]
.
ewm
(
5
)
.
mean
()
# Instead of using the actual volume value (which changes over time), we normalize it with a moving volume average
data
[
'normVol'
]
=
data
[
'volume'
]
/
data
[
'volume'
]
.
ewm
(
5
)
.
mean
()
# get relative values
data
[
'relativeOpen'
]
=
data
[
'open'
]
/
data
[
'close'
]
.
shift
(
1
)
data
[
'relativeHigh'
]
=
data
[
'high'
]
/
data
[
'close'
]
.
shift
(
1
)
data
[
'relativeLow'
]
=
data
[
'low'
]
/
data
[
'close'
]
.
shift
(
1
)
data
[
'relativeClose'
]
=
data
[
'close'
]
/
data
[
'close'
]
.
shift
(
1
)
# Remove columns that won't be used as features
del
(
data
[
'open'
])
del
(
data
[
'high'
])
del
(
data
[
'low'
])
# data['close'] are still needed and will be deleted later
del
(
data
[
'volume'
])
return
data
########################################
########################################
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
# sync()
sync
()
toForecast
=
True
# False means test, True means forecast
toForecast
=
True
# False means test, True means forecast
indexDataPath
=
r"AI_Data"
# define some parameters
# define some parameters
win1W
=
5
# 1 week
win1W
=
5
# 1 week
win1M
=
21
# 1 Month
win1M
=
21
# 1 Month
...
@@ -115,256 +62,52 @@ if __name__ == '__main__':
...
@@ -115,256 +62,52 @@ if __name__ == '__main__':
74
:
"CPURNSA"
,
74
:
"CPURNSA"
,
75
:
"VIX"
,
75
:
"VIX"
,
76
:
"US0001M"
,
76
:
"US0001M"
,
77
:
"US0012M"
}
77
:
"US0012M"
,
# FUND
121
:
"IEF US"
,
122
:
"TLT US"
,
123
:
"UUP US"
,
139
:
"COI TOTL"
,
138
:
"LEI TOTL"
,
116
:
"MID"
,
134
:
"NAPMPMI"
,
142
:
"OE4EKLAC"
,
143
:
"OEA5KLAC"
,
146
:
"OECNKLAC"
,
145
:
"OEJPKLAC"
,
141
:
"OEOTGTAC"
,
144
:
"OEUSKLAC"
,
117
:
"SML"
,
140
:
"USRINDEX"
}
###################
###################
# Step 1: Prepare X and y (features and labels)
# Step 1: Prepare X and y (features and labels)
###### get raw data
# 准备基础数据
data_access
=
DataAccess
(
index
,
eco
,
fund
,
max_date
,
indexDict
)
# indexData = pd.read_excel(indexDataPath + r"\robo_index_datas.xlsx", sheet_name='robo_index_datas')
indexData
=
data_access
.
get_index_datas
()
indexData
=
pd
.
DataFrame
(
get_index_list
(
index_ids
=
index
,
max_date
=
max_date
))
ecoData
=
data_access
.
get_eco_datas
()
indexData
=
indexData
[
fundData
=
data_access
.
get_fund_datas
()
[
"rid_index_id"
,
"rid_date"
,
"rid_high"
,
"rid_open"
,
"rid_low"
,
"rid_close"
,
"rid_pe"
,
"rid_pb"
,
"rid_volume"
]]
# 指数数据准备
indexData
.
rename
(
columns
=
{
"rid_date"
:
'date'
},
inplace
=
True
)
# please use 'date'
vixData
=
data_access
.
get_vix
(
indexData
)
indexData
[
"rid_index_id"
]
=
indexData
[
"rid_index_id"
]
.
map
(
indexDict
)
indexOtherData
=
data_access
.
get_other_index
(
indexData
)
# 经济指标数据准备
# ecoData = pd.read_excel(indexDataPath + r"\robo_index_datas.xlsx", sheet_name='robo_eco_datas')
cpiData
=
data_access
.
get_cpi
(
ecoData
)
ecoData
=
pd
.
DataFrame
(
get_eco_list
(
eco_ids
=
eco
,
max_date
=
max_date
))
FDTRData
=
data_access
.
get_fdtr
(
ecoData
)
ecoData
=
ecoData
[[
"red_eco_id"
,
"red_date"
,
"red_indicator"
]]
# 新增指标 NAPMPMI :美國的ISM製造業指數 (Monthly)
ecoData
.
rename
(
columns
=
{
"red_date"
:
'date'
},
inplace
=
True
)
# please use 'date'
NAPMPMIData
=
data_access
.
get_napmpmi
(
ecoData
)
ecoData
[
"red_eco_id"
]
=
ecoData
[
"red_eco_id"
]
.
map
(
indexDict
)
builder
=
TrainingDataBuilder
(
index
,
eco
,
fund
,
indexDict
,
toForecast
,
win1W
,
win1M
,
win1Q
,
numForecastDays
,
###### get individual data from raw data
theThreshold
)
spxData
=
indexData
[
indexData
[
'rid_index_id'
]
==
"SPX"
]
.
copy
()
for
pid
in
PREDICT_LIST
:
del
(
spxData
[
'rid_index_id'
])
t_data
=
indexData
if
pid
in
index
else
fundData
spxData
.
set_index
(
'date'
,
inplace
=
True
)
X_train
,
X_test
,
y_train
,
y_test
,
scaledX_forecast
,
forecastDay
=
\
spxData
.
index
=
pd
.
to_datetime
(
spxData
.
index
)
builder
.
build_train_test
(
pid
,
t_data
,
vixData
,
indexOtherData
,
cpiData
,
FDTRData
,
NAPMPMIData
)
spxData
.
sort_index
(
inplace
=
True
)
trainer
=
ModelTrainer
(
toForecast
)
spxData
.
reset_index
(
inplace
=
True
)
rf_model
=
trainer
.
train_random_forest
(
X_train
,
y_train
,
X_test
,
y_test
)
gbt_model
=
trainer
.
train_GBT
(
X_train
,
y_train
,
X_test
,
y_test
)
if
(
toForecast
):
svc_model
=
trainer
.
train_SVC
(
X_train
,
y_train
,
X_test
,
y_test
)
forecastDay
=
spxData
[
'date'
]
.
iloc
[
-
1
]
ensemble_model
=
trainer
.
ensemble_model
(
rf_model
,
gbt_model
,
svc_model
,
X_train
,
y_train
,
X_test
,
y_test
)
vixData
=
indexData
[
indexData
[
'rid_index_id'
]
==
"VIX"
]
.
copy
()
vixData
=
vixData
[[
"date"
,
"rid_high"
,
"rid_open"
,
"rid_low"
,
"rid_close"
]]
vixData
.
rename
(
columns
=
{
"rid_high"
:
'vix_high'
,
'rid_open'
:
'vix_open'
,
"rid_low"
:
'vix_low'
,
"rid_close"
:
'vix_close'
},
inplace
=
True
)
vixData
.
set_index
(
'date'
,
inplace
=
True
)
vixData
.
index
=
pd
.
to_datetime
(
vixData
.
index
)
indexOtherData
=
indexData
[(
indexData
[
'rid_index_id'
]
==
"USGG10YR"
)
|
(
indexData
[
'rid_index_id'
]
==
"USGG2YR"
)
|
(
indexData
[
'rid_index_id'
]
==
"CCMP"
)
|
(
indexData
[
'rid_index_id'
]
==
"US0001M"
)
|
(
indexData
[
'rid_index_id'
]
==
"US0012M"
)]
.
copy
()
indexOtherData
=
indexOtherData
[[
'rid_index_id'
,
'date'
,
'rid_close'
]]
indexOtherData
=
indexOtherData
.
pivot
(
index
=
'date'
,
columns
=
'rid_index_id'
,
values
=
'rid_close'
)
indexOtherData
.
index
=
pd
.
to_datetime
(
indexOtherData
.
index
)
cpiData
=
ecoData
[(
ecoData
[
'red_eco_id'
]
==
"CPI_YOY"
)
|
(
ecoData
[
'red_eco_id'
]
==
"CPURNSA"
)]
.
copy
()
cpiData
=
cpiData
.
pivot
(
index
=
'date'
,
columns
=
'red_eco_id'
,
values
=
'red_indicator'
)
cpiData
[
'CPI_MOM'
]
=
(
cpiData
[
'CPURNSA'
]
/
cpiData
[
'CPURNSA'
]
.
shift
(
1
)
-
1.0
)
*
100
*
12
# Annualized Percentage
cpiData
[
'CPI_MOM_Diff'
]
=
cpiData
[
'CPURNSA'
]
-
cpiData
[
'CPURNSA'
]
.
shift
(
1
)
cpiData
.
index
=
pd
.
to_datetime
(
cpiData
.
index
)
FDTRData
=
ecoData
[
ecoData
[
'red_eco_id'
]
==
"FDTR"
]
.
copy
()
del
(
FDTRData
[
'red_eco_id'
])
FDTRData
.
rename
(
columns
=
{
"red_indicator"
:
'FDTR'
},
inplace
=
True
)
FDTRData
.
set_index
(
'date'
,
inplace
=
True
)
FDTRData
.
index
=
pd
.
to_datetime
(
FDTRData
.
index
)
###### Additional preparing SPX Data
# finta expects properly formated ohlc DataFrame, with column names in lowercase:
# ["open", "high", "low", close"] and ["volume"] for indicators that expect ohlcv input.
spxData
.
rename
(
columns
=
{
"rid_high"
:
'high'
,
'rid_open'
:
'open'
,
"rid_low"
:
'low'
,
"rid_close"
:
'close'
,
'rid_volume'
:
'volume'
,
"rid_pe"
:
"SPX_pe"
,
"rid_pb"
:
"SPX_pb"
},
inplace
=
True
)
# Calculate the indicator data
spxData
=
_get_indicator_data
(
spxData
)
# Calculate Historical Return and Volatility
spxData
[
'R1W'
]
=
np
.
log
(
spxData
[
'close'
]
/
spxData
[
'close'
]
.
shift
(
win1W
))
spxData
[
'R1M'
]
=
np
.
log
(
spxData
[
'close'
]
/
spxData
[
'close'
]
.
shift
(
win1M
))
spxData
[
'R1Q'
]
=
np
.
log
(
spxData
[
'close'
]
/
spxData
[
'close'
]
.
shift
(
win1Q
))
price_list
=
spxData
[
'close'
]
rollist
=
price_list
.
rolling
(
win1W
)
spxData
[
'Vol_1W'
]
=
rollist
.
std
(
ddof
=
0
)
rollist
=
price_list
.
rolling
(
win1M
)
spxData
[
'Vol_1M'
]
=
rollist
.
std
(
ddof
=
0
)
rollist
=
price_list
.
rolling
(
win1Q
)
spxData
[
'Vol_1Q'
]
=
rollist
.
std
(
ddof
=
0
)
# The following uses future info for the y label, to be deleted later
spxData
[
'futureR'
]
=
np
.
log
(
spxData
[
'close'
]
.
shift
(
-
numForecastDays
)
/
spxData
[
'close'
])
# spxData = spxData[spxData['futureR'].notna()]
spxData
[
'yLabel'
]
=
(
spxData
[
'futureR'
]
>=
theThreshold
)
.
astype
(
int
)
spxDataCloseSave
=
spxData
[[
'date'
,
'close'
]]
del
(
spxData
[
'close'
])
###### Merge Data to one table
DataAll
=
pd
.
merge
(
spxData
,
vixData
,
how
=
'outer'
,
on
=
'date'
)
DataAll
=
pd
.
merge
(
DataAll
,
indexOtherData
,
how
=
'outer'
,
on
=
'date'
)
DataAll
=
pd
.
merge
(
DataAll
,
cpiData
,
how
=
'outer'
,
on
=
'date'
)
DataAll
=
pd
.
merge
(
DataAll
,
FDTRData
,
how
=
'outer'
,
on
=
'date'
)
DataAll
.
set_index
(
'date'
,
inplace
=
True
)
DataAll
.
sort_index
(
inplace
=
True
)
DataAll
.
reset_index
(
inplace
=
True
)
###### fill eco data
for
col
in
[
'CPI_YOY'
,
'CPURNSA'
,
'CPI_MOM'
,
'CPI_MOM_Diff'
]:
DataAll
[
col
]
.
bfill
(
inplace
=
True
)
for
col
in
[
'FDTR'
]:
DataAll
[
col
]
.
ffill
(
inplace
=
True
)
if
(
toForecast
):
DataAllCopy
=
DataAll
.
copy
()
for
col
in
[
'CPI_YOY'
,
'CPURNSA'
]:
DataAllCopy
[
col
]
.
ffill
(
inplace
=
True
)
for
col
in
[
'CPI_MOM'
,
'CPI_MOM_Diff'
]:
DataAllCopy
[
col
]
=
DataAllCopy
[
col
]
.
fillna
(
0
)
del
(
DataAllCopy
[
'futureR'
])
del
(
DataAllCopy
[
'yLabel'
])
forecastDayIndex
=
DataAllCopy
.
index
[
DataAllCopy
[
'date'
]
==
forecastDay
]
forecastData
=
DataAllCopy
.
iloc
[
forecastDayIndex
.
to_list
(),
1
:]
X_forecast
=
forecastData
.
to_numpy
()
del
DataAllCopy
###### clean NaN
DataAll
.
dropna
(
inplace
=
True
)
DataAll
.
reset_index
(
inplace
=
True
,
drop
=
True
)
###### get X and y
y
=
DataAll
[
'yLabel'
]
.
to_numpy
(
copy
=
True
)
# delete future information
del
(
DataAll
[
'futureR'
])
del
(
DataAll
[
'yLabel'
])
X
=
DataAll
.
iloc
[:,
1
:]
.
values
###################
# scale data
scaler
=
MinMaxScaler
(
feature_range
=
(
0
,
1
))
# scaledX = scaler.fit_transform(X)
DataScaler
=
scaler
.
fit
(
X
)
scaledX
=
DataScaler
.
transform
(
X
)
if
(
toForecast
):
scaledX_forecast
=
DataScaler
.
transform
(
X_forecast
)
X_train
=
scaledX
y_train
=
y
X_test
=
[]
y_test
=
[]
else
:
# Step 2: Split data into train set and test set
X_train
,
X_test
,
y_train
,
y_test
=
train_test_split
(
scaledX
,
y
,
test_size
=
0.02
,
shuffle
=
False
)
# To avoid data leak, test set should start from numForecastDays later
X_test
=
X_test
[
numForecastDays
:]
y_test
=
y_test
[
numForecastDays
:]
def
test_model
(
strMethod
,
classifier
,
X_test
,
y_test
):
print
(
strMethod
+
" ====== test results ======"
)
y_pred
=
classifier
.
predict
(
X_test
)
result0
=
confusion_matrix
(
y_test
,
y_pred
)
print
(
strMethod
+
" Confusion Matrix:"
)
print
(
result0
)
result1
=
classification_report
(
y_test
,
y_pred
,
zero_division
=
1.0
)
print
(
strMethod
+
" Classification Report:"
)
print
(
result1
)
result2
=
accuracy_score
(
y_test
,
y_pred
)
print
(
strMethod
+
" Accuracy:"
,
result2
)
cm_display
=
ConfusionMatrixDisplay
(
confusion_matrix
=
result0
,
display_labels
=
[
'Down'
,
'Up'
])
cm_display
.
plot
()
plt
.
title
(
strMethod
+
' Accuracy: '
+
f
'{result2:.0
%
}'
)
plt
.
show
()
###################
# Step 3: Train the model
def
_train_random_forest
(
X_train
,
y_train
,
X_test
,
y_test
):
classifier
=
RandomForestClassifier
()
classifier
.
fit
(
X_train
,
y_train
)
if
(
not
toForecast
):
test_model
(
'Random Forest'
,
classifier
,
X_test
,
y_test
)
return
classifier
rf_model
=
_train_random_forest
(
X_train
,
y_train
,
X_test
,
y_test
)
def
_train_GBT
(
X_train
,
y_train
,
X_test
,
y_test
):
# Gradient Boosted Tree
classifierGBT
=
LGBMClassifier
()
classifierGBT
.
fit
(
X_train
,
y_train
)
if
(
not
toForecast
):
test_model
(
'Gradient Boosted Tree'
,
classifierGBT
,
X_test
,
y_test
)
return
classifierGBT
gbt_model
=
_train_GBT
(
X_train
,
y_train
,
X_test
,
y_test
)
def
_train_SVC
(
X_train
,
y_train
,
X_test
,
y_test
):
# Support Vector Machines
classifierSVC
=
svm
.
SVC
()
classifierSVC
.
fit
(
X_train
,
y_train
)
if
(
not
toForecast
):
test_model
(
'Support Vector Machines'
,
classifierSVC
,
X_test
,
y_test
)
return
classifierSVC
svc_model
=
_train_SVC
(
X_train
,
y_train
,
X_test
,
y_test
)
def
_ensemble_model
(
rf_model
,
gbt_model
,
svc_model
,
X_train
,
y_train
,
X_test
,
y_test
):
# Create a dictionary of our models
estimators
=
[(
'rf'
,
rf_model
),
(
'gbt'
,
gbt_model
),
(
'svc'
,
svc_model
)]
# Create our voting classifier, inputting our models
ensemble
=
VotingClassifier
(
estimators
,
voting
=
'hard'
)
# fit model to training data
ensemble
.
fit
(
X_train
,
y_train
)
if
(
not
toForecast
):
test_model
(
'Ensemble Model'
,
ensemble
,
X_test
,
y_test
)
return
ensemble
ensemble_model
=
_ensemble_model
(
rf_model
,
gbt_model
,
svc_model
,
X_train
,
y_train
,
X_test
,
y_test
)
def
predictionFromMoel
(
the_model
,
scaledX_forecast
):
prediction
=
the_model
.
predict
(
scaledX_forecast
)
predictionStr
=
'DOWN'
if
(
prediction
>
0.5
):
predictionStr
=
'UP'
content
=
"
\n
On day "
+
forecastDay
.
strftime
(
"
%
m/
%
d/
%
Y"
)
+
", the model predicts SPX to be "
+
predictionStr
+
" in "
+
str
(
numForecastDays
)
+
" business days.
\n
"
print
(
content
)
send
(
content
)
return
prediction
if
(
toForecast
):
if
(
toForecast
):
predictionFromMoel
(
ensemble_model
,
scaledX_forecast
)
predictionFromMoel
(
ensemble_model
,
scaledX_forecast
,
indexDict
[
pid
]
)
ai/dao/robo_datas.py
View file @
18c7d36d
...
@@ -25,3 +25,23 @@ def get_eco_list(eco_ids=None, min_date=None, max_date=None):
...
@@ -25,3 +25,23 @@ def get_eco_list(eco_ids=None, min_date=None, max_date=None):
select * from robo_eco_datas
select * from robo_eco_datas
{where(*sqls, red_eco_id=to_tuple(eco_ids))} order by red_eco_id, red_date
{where(*sqls, red_eco_id=to_tuple(eco_ids))} order by red_eco_id, red_date
'''
'''
@
read
def
get_fund_list
(
fund_ids
=
None
,
min_date
=
None
,
max_date
=
None
):
sqls
=
[]
if
min_date
:
sqls
.
append
(
f
"rfn_date >= '{min_date}'"
)
if
max_date
:
sqls
.
append
(
f
"rfn_date <= '{max_date}'"
)
return
f
'''
select * from robo_fund_navs
{where(*sqls, rfn_fund_id=to_tuple(fund_ids))} order by rfn_fund_id, rfn_date
'''
@
read
def
get_base_info
(
ids
=
None
):
sqls
=
[]
return
f
"""
SELECT rbd_id id,v_rbd_bloomberg_ticker ticker,v_rbd_type type FROM `robo_base_datum`
{where(*sqls,rbd_id=to_tuple(ids))}
"""
\ No newline at end of file
ai/data_access.py
0 → 100644
View file @
18c7d36d
from
abc
import
ABC
import
pandas
as
pd
from
ai.dao.robo_datas
import
get_eco_list
,
get_fund_list
,
get_index_list
class
DataAccess
(
ABC
):
def
__init__
(
self
,
index
,
eco
,
fund
,
max_date
,
indexDict
)
->
None
:
super
()
.
__init__
()
self
.
_index
=
index
self
.
_eco
=
eco
self
.
_fund
=
fund
self
.
_max_date
=
max_date
self
.
_indexDict
=
indexDict
def
get_index_datas
(
self
):
indexData
=
pd
.
DataFrame
(
get_index_list
(
index_ids
=
self
.
_index
,
max_date
=
self
.
_max_date
))
# todo erp 没有数据 "rid_erp",
indexData
=
indexData
[
[
"rid_index_id"
,
"rid_date"
,
"rid_high"
,
"rid_open"
,
"rid_low"
,
"rid_close"
,
"rid_pe"
,
"rid_pb"
,
"rid_volume"
,
"rid_frdpe"
,
"rid_frdpes"
,
"rid_pc"
]]
indexData
.
rename
(
columns
=
{
"rid_date"
:
'date'
},
inplace
=
True
)
# please use 'date'
indexData
[
"rid_index_id"
]
=
indexData
[
"rid_index_id"
]
.
map
(
self
.
_indexDict
)
indexData
[
'rid_frdpe'
]
.
ffill
(
inplace
=
True
)
return
indexData
def
get_eco_datas
(
self
):
ecoData
=
pd
.
DataFrame
(
get_eco_list
(
eco_ids
=
self
.
_eco
,
max_date
=
self
.
_max_date
))
ecoData
=
ecoData
[[
"red_eco_id"
,
"red_date"
,
"red_indicator"
]]
ecoData
.
rename
(
columns
=
{
"red_date"
:
'date'
},
inplace
=
True
)
# please use 'date'
ecoData
[
"red_eco_id"
]
=
ecoData
[
"red_eco_id"
]
.
map
(
self
.
_indexDict
)
return
ecoData
def
get_fund_datas
(
self
):
fundData
=
pd
.
DataFrame
(
get_fund_list
(
fund_ids
=
self
.
_fund
,
max_date
=
self
.
_max_date
))
fundData
=
fundData
[[
"rfn_fund_id"
,
"rfn_date"
,
"rfn_nav_cal"
]]
fundData
.
rename
(
columns
=
{
"rfn_date"
:
'date'
},
inplace
=
True
)
# please use 'date'
fundData
[
"rfn_fund_id"
]
=
fundData
[
"rfn_fund_id"
]
.
map
(
self
.
_indexDict
)
return
fundData
def
get_vix
(
self
,
indexData
):
# VIX:芝加哥期权交易所SPX波动率指
vixData
=
indexData
[
indexData
[
'rid_index_id'
]
==
"VIX"
]
.
copy
()
vixData
=
vixData
[[
"date"
,
"rid_high"
,
"rid_open"
,
"rid_low"
,
"rid_close"
]]
vixData
.
rename
(
columns
=
{
"rid_high"
:
'vix_high'
,
'rid_open'
:
'vix_open'
,
"rid_low"
:
'vix_low'
,
"rid_close"
:
'vix_close'
},
inplace
=
True
)
vixData
.
set_index
(
'date'
,
inplace
=
True
)
vixData
.
index
=
pd
.
to_datetime
(
vixData
.
index
)
return
vixData
def
get_other_index
(
self
,
indexData
):
other_index
=
[
"USGG10YR"
,
"USGG2YR"
,
"CCMP"
,
"US0001M"
,
"US0012M"
,
"COI TOTL"
,
"LEI TOTL"
,
"MID"
,
"OE4EKLAC"
,
"OEA5KLAC"
,
"OECNKLAC"
,
"OEJPKLAC"
,
"OEOTGTAC"
,
"OEUSKLAC"
,
"USRINDEX"
,
"SPX"
]
cols
=
[
'date'
,
'rid_close'
,
'rid_pe'
,
'rid_pb'
,
'rid_volume'
,
'rid_frdpe'
,
'rid_frdpes'
,
'rid_pc'
]
indexOtherData
=
pd
.
DataFrame
()
idxs
=
[
self
.
_indexDict
[
i
]
for
i
in
self
.
_index
]
for
idx
in
other_index
:
if
idx
in
idxs
:
idx_data
=
indexData
[
indexData
[
'rid_index_id'
]
==
idx
]
.
copy
()
idx_data
=
idx_data
[
cols
]
idx_data
.
rename
(
columns
=
{
"rid_close"
:
f
'{idx}_close'
,
'rid_pe'
:
f
'{idx}_pe'
,
'rid_pb'
:
f
'{idx}_pb'
,
'rid_volume'
:
f
'{idx}_volume'
,
'rid_frdpe'
:
f
'{idx}_frdpe'
,
'rid_frdpes'
:
f
'{idx}_frdpes'
,
'rid_pc'
:
f
'{idx}_pc'
},
inplace
=
True
)
idx_data
.
set_index
(
'date'
,
inplace
=
True
)
idx_data
.
index
=
pd
.
to_datetime
(
idx_data
.
index
)
if
indexOtherData
.
size
>
0
:
indexOtherData
=
pd
.
merge
(
indexOtherData
,
idx_data
,
how
=
'outer'
,
on
=
'date'
)
else
:
indexOtherData
=
idx_data
indexOtherData
.
ffill
(
inplace
=
True
)
indexOtherData
.
bfill
(
inplace
=
True
)
indexOtherData
=
indexOtherData
.
dropna
(
axis
=
1
)
return
indexOtherData
def
get_cpi
(
self
,
ecoData
):
# CPI_YOY:美国城镇消费物价指数同比未经季 CPURNSA:美国消费者物价指数未经季调
cpiData
=
ecoData
[(
ecoData
[
'red_eco_id'
]
==
"CPI_YOY"
)
|
(
ecoData
[
'red_eco_id'
]
==
"CPURNSA"
)]
.
copy
()
cpiData
=
cpiData
.
pivot
(
index
=
'date'
,
columns
=
'red_eco_id'
,
values
=
'red_indicator'
)
cpiData
[
'CPI_MOM'
]
=
(
cpiData
[
'CPURNSA'
]
/
cpiData
[
'CPURNSA'
]
.
shift
(
1
)
-
1.0
)
*
100
*
12
# Annualized Percentage
cpiData
[
'CPI_MOM_Diff'
]
=
cpiData
[
'CPURNSA'
]
-
cpiData
[
'CPURNSA'
]
.
shift
(
1
)
cpiData
.
index
=
pd
.
to_datetime
(
cpiData
.
index
)
return
cpiData
def
get_fdtr
(
self
,
ecoData
):
# FDTR 美国联邦基金目标利率
FDTRData
=
ecoData
[
ecoData
[
'red_eco_id'
]
==
"FDTR"
]
.
copy
()
del
(
FDTRData
[
'red_eco_id'
])
FDTRData
.
rename
(
columns
=
{
"red_indicator"
:
'FDTR'
},
inplace
=
True
)
FDTRData
.
set_index
(
'date'
,
inplace
=
True
)
FDTRData
.
index
=
pd
.
to_datetime
(
FDTRData
.
index
)
return
FDTRData
def
get_napmpmi
(
self
,
ecoData
):
# 新增指标 NAPMPMI :美國的ISM製造業指數 (Monthly)
NAPMPMIData
=
ecoData
[
ecoData
[
'red_eco_id'
]
==
"NAPMPMI"
]
.
copy
()
del
(
NAPMPMIData
[
'red_eco_id'
])
NAPMPMIData
.
rename
(
columns
=
{
"red_indicator"
:
'NAPMPMI'
},
inplace
=
True
)
NAPMPMIData
.
set_index
(
'date'
,
inplace
=
True
)
NAPMPMIData
.
index
=
pd
.
to_datetime
(
NAPMPMIData
.
index
)
return
NAPMPMIData
ai/model_trainer.py
0 → 100644
View file @
18c7d36d
from
abc
import
ABC
import
matplotlib.pyplot
as
plt
from
lightgbm
import
LGBMClassifier
from
sklearn
import
svm
from
sklearn.ensemble
import
RandomForestClassifier
,
VotingClassifier
from
sklearn.metrics
import
classification_report
,
confusion_matrix
,
ConfusionMatrixDisplay
,
accuracy_score
class
ModelTrainer
(
ABC
):
"""
模型训练类
"""
def
__init__
(
self
,
toForecast
)
->
None
:
super
()
.
__init__
()
self
.
_toForecast
=
toForecast
###################
# Step 3: Train the model
def
test_model
(
self
,
strMethod
,
classifier
,
X_test
,
y_test
):
print
(
strMethod
+
" ====== test results ======"
)
y_pred
=
classifier
.
predict
(
X_test
)
result0
=
confusion_matrix
(
y_test
,
y_pred
)
print
(
strMethod
+
" Confusion Matrix:"
)
print
(
result0
)
result1
=
classification_report
(
y_test
,
y_pred
,
zero_division
=
1.0
)
print
(
strMethod
+
" Classification Report:"
)
print
(
result1
)
result2
=
accuracy_score
(
y_test
,
y_pred
)
print
(
strMethod
+
" Accuracy:"
,
result2
)
cm_display
=
ConfusionMatrixDisplay
(
confusion_matrix
=
result0
,
display_labels
=
[
'Down'
,
'Up'
])
cm_display
.
plot
()
plt
.
title
(
strMethod
+
' Accuracy: '
+
f
'{result2:.0
%
}'
)
plt
.
show
()
def
train_random_forest
(
self
,
X_train
,
y_train
,
X_test
,
y_test
):
classifier
=
RandomForestClassifier
()
classifier
.
fit
(
X_train
,
y_train
)
if
not
self
.
_toForecast
:
self
.
test_model
(
'Random Forest'
,
classifier
,
X_test
,
y_test
)
return
classifier
def
train_GBT
(
self
,
X_train
,
y_train
,
X_test
,
y_test
):
# Gradient Boosted Tree
classifierGBT
=
LGBMClassifier
()
classifierGBT
.
fit
(
X_train
,
y_train
)
if
not
self
.
_toForecast
:
self
.
test_model
(
'Gradient Boosted Tree'
,
classifierGBT
,
X_test
,
y_test
)
return
classifierGBT
def
train_SVC
(
self
,
X_train
,
y_train
,
X_test
,
y_test
):
# Support Vector Machines
classifierSVC
=
svm
.
SVC
()
classifierSVC
.
fit
(
X_train
,
y_train
)
if
not
self
.
_toForecast
:
self
.
test_model
(
'Support Vector Machines'
,
classifierSVC
,
X_test
,
y_test
)
return
classifierSVC
def
ensemble_model
(
self
,
rf_model
,
gbt_model
,
svc_model
,
X_train
,
y_train
,
X_test
,
y_test
):
# Create a dictionary of our models
estimators
=
[(
'rf'
,
rf_model
),
(
'gbt'
,
gbt_model
),
(
'svc'
,
svc_model
)]
# Create our voting classifier, inputting our models
ensemble
=
VotingClassifier
(
estimators
,
voting
=
'hard'
)
# fit model to training data
ensemble
.
fit
(
X_train
,
y_train
)
if
not
self
.
_toForecast
:
self
.
test_model
(
'Ensemble Model'
,
ensemble
,
X_test
,
y_test
)
return
ensemble
ai/noticer.py
0 → 100644
View file @
18c7d36d
from
datetime
import
datetime
import
requests
from
py_jftech
import
sendmail
,
format_date
# 预测发送邮箱
email
=
[
'wenwen.tang@thizgroup.com'
]
jrp_domain
=
'https://jrp.jfquant.com/api/v1.0'
# jrp_domain = 'http://localhost:7090/jrp'
def
send
(
content
):
receives
=
email
subject
=
'预测_{today}'
.
format
(
today
=
format_date
(
datetime
.
today
()))
sendmail
(
receives
=
receives
,
copies
=
[],
attach_paths
=
[],
subject
=
subject
,
content
=
content
)
def
upload_predict
(
ticker
,
predictDate
,
predict
):
predict_data
=
{
"aiPredict"
:
{
"predictDate"
:
format_date
(
predictDate
),
"predict"
:
1
if
predict
==
'UP'
else
-
1
},
"bloombergTicker"
:
ticker
}
headers
=
{
"X-AUTH-token"
:
"rt7297LwQvyAYTke2iD8Vg"
}
response
=
requests
.
post
(
url
=
f
'{jrp_domain}/ai/predict'
,
json
=
predict_data
,
headers
=
headers
)
if
response
.
status_code
!=
200
:
print
(
"上传ai预测结果失败,请重试"
)
ai/training_data_builder.py
0 → 100644
View file @
18c7d36d
from
abc
import
ABC
import
numpy
as
np
import
pandas
as
pd
from
finta
import
TA
from
sklearn.model_selection
import
train_test_split
from
sklearn.preprocessing
import
MinMaxScaler
def
imp
():
print
(
TA
)
class
TrainingDataBuilder
(
ABC
):
def
__init__
(
self
,
index
,
eco
,
fund
,
indexDict
,
toForecast
,
win1W
,
win1M
,
win1Q
,
numForecastDays
,
theThreshold
)
->
None
:
super
()
.
__init__
()
self
.
_index
=
index
self
.
_eco
=
eco
self
.
_fund
=
fund
self
.
_indexDict
=
indexDict
self
.
_toForecast
=
toForecast
self
.
_win1W
=
win1W
# 1 week
self
.
_win1M
=
win1M
# 1 Month
self
.
_win1Q
=
win1Q
# 1 Quarter
self
.
_numForecastDays
=
numForecastDays
# business days, 21 business days means one month
self
.
_theThreshold
=
theThreshold
# List of symbols for technical indicators
# INDICATORS = ['RSI', 'MACD', 'STOCH','ADL', 'ATR', 'MOM', 'MFI', 'ROC', 'OBV', 'CCI', 'EMV', 'VORTEX']
# Note that '14 period MFI' and '14 period EMV' is not available for forecast
self
.
INDICATORS
=
[
'RSI'
,
'MACD'
,
'STOCH'
,
'ADL'
,
'ATR'
,
'MOM'
,
'ROC'
,
'OBV'
,
'CCI'
,
'VORTEX'
]
self
.
FUND_INDICATORS
=
[]
def
get_indicator_data
(
self
,
data
,
pid
):
"""
Function that uses the finta API to calculate technical indicators used as the features
"""
def
indicator_calcu
(
data
,
indicators
):
"""
指数和基金不同,基金只有收盘价,生成指标会变少
@param data:
@param indicators:
@return:
"""
for
indicator
in
indicators
:
ind_data
=
eval
(
'TA.'
+
indicator
+
'(data)'
)
if
not
isinstance
(
ind_data
,
pd
.
DataFrame
):
ind_data
=
ind_data
.
to_frame
()
data
=
data
.
merge
(
ind_data
,
left_index
=
True
,
right_index
=
True
)
return
data
if
pid
in
self
.
_index
:
data
=
indicator_calcu
(
data
,
self
.
INDICATORS
)
# Instead of using the actual volume value (which changes over time), we normalize it with a moving volume average
data
[
'normVol'
]
=
data
[
'volume'
]
/
data
[
'volume'
]
.
ewm
(
5
)
.
mean
()
# get relative values
data
[
'relativeOpen'
]
=
data
[
'open'
]
/
data
[
'close'
]
.
shift
(
1
)
data
[
'relativeHigh'
]
=
data
[
'high'
]
/
data
[
'close'
]
.
shift
(
1
)
data
[
'relativeLow'
]
=
data
[
'low'
]
/
data
[
'close'
]
.
shift
(
1
)
# Remove columns that won't be used as features
# data['close'] are still needed and will be deleted later
data
.
drop
([
'open'
,
'high'
,
'low'
,
'volume'
],
axis
=
1
,
inplace
=
True
)
elif
pid
in
self
.
_fund
:
indicator_calcu
(
data
,
self
.
FUND_INDICATORS
)
# Also calculate moving averages for features
data
[
'ema50'
]
=
data
[
'close'
]
/
data
[
'close'
]
.
ewm
(
50
)
.
mean
()
data
[
'ema21'
]
=
data
[
'close'
]
/
data
[
'close'
]
.
ewm
(
21
)
.
mean
()
data
[
'ema15'
]
=
data
[
'close'
]
/
data
[
'close'
]
.
ewm
(
15
)
.
mean
()
data
[
'ema5'
]
=
data
[
'close'
]
/
data
[
'close'
]
.
ewm
(
5
)
.
mean
()
data
[
'relativeClose'
]
=
data
[
'close'
]
/
data
[
'close'
]
.
shift
(
1
)
return
data
def
build_predict_data
(
self
,
indexData
,
pid
):
"""
@param pid: 需要预测的指数或基金id
@return:
"""
if
pid
in
self
.
_index
:
###### get individual data from raw data
predictData
=
indexData
[
indexData
[
'rid_index_id'
]
==
self
.
_indexDict
[
pid
]]
.
copy
()
del
(
predictData
[
'rid_index_id'
])
###### Additional preparing SPX Data
# finta expects properly formated ohlc DataFrame, with column names in lowercase:
# ["open", "high", "low", close"] and ["volume"] for indicators that expect ohlcv input.
predictData
.
rename
(
columns
=
{
"rid_high"
:
'high'
,
'rid_open'
:
'open'
,
"rid_low"
:
'low'
,
"rid_close"
:
'close'
,
'rid_volume'
:
'volume'
,
"rid_pe"
:
"SPX_pe"
,
"rid_pb"
:
"SPX_pb"
},
inplace
=
True
)
elif
pid
in
self
.
_fund
:
predictData
=
indexData
[
indexData
[
'rfn_fund_id'
]
==
self
.
_indexDict
[
pid
]]
.
copy
()
del
(
predictData
[
'rfn_fund_id'
])
predictData
.
rename
(
columns
=
{
"rfn_nav_cal"
:
'close'
},
inplace
=
True
)
predictData
.
set_index
(
'date'
,
inplace
=
True
)
predictData
.
index
=
pd
.
to_datetime
(
predictData
.
index
)
predictData
.
sort_index
(
inplace
=
True
)
predictData
.
reset_index
(
inplace
=
True
)
# Calculate the indicator data
predictData
=
self
.
get_indicator_data
(
predictData
,
pid
)
# Calculate Historical Return and Volatility
predictData
[
'R1W'
]
=
np
.
log
(
predictData
[
'close'
]
/
predictData
[
'close'
]
.
shift
(
self
.
_win1W
))
predictData
[
'R1M'
]
=
np
.
log
(
predictData
[
'close'
]
/
predictData
[
'close'
]
.
shift
(
self
.
_win1M
))
predictData
[
'R1Q'
]
=
np
.
log
(
predictData
[
'close'
]
/
predictData
[
'close'
]
.
shift
(
self
.
_win1Q
))
price_list
=
predictData
[
'close'
]
rollist
=
price_list
.
rolling
(
self
.
_win1W
)
predictData
[
'Vol_1W'
]
=
rollist
.
std
(
ddof
=
0
)
rollist
=
price_list
.
rolling
(
self
.
_win1M
)
predictData
[
'Vol_1M'
]
=
rollist
.
std
(
ddof
=
0
)
rollist
=
price_list
.
rolling
(
self
.
_win1Q
)
predictData
[
'Vol_1Q'
]
=
rollist
.
std
(
ddof
=
0
)
# The following uses future info for the y label, to be deleted later
predictData
[
'futureR'
]
=
np
.
log
(
predictData
[
'close'
]
.
shift
(
-
self
.
_numForecastDays
)
/
predictData
[
'close'
])
# predictData = predictData[predictData['futureR'].notna()]
predictData
[
'yLabel'
]
=
(
predictData
[
'futureR'
]
>=
self
.
_theThreshold
)
.
astype
(
int
)
spxDataCloseSave
=
predictData
[[
'date'
,
'close'
]]
del
(
predictData
[
'close'
])
return
predictData
def
build_train_test
(
self
,
pid
,
indexData
,
vixData
,
indexOtherData
,
cpiData
,
FDTRData
,
NAPMPMIData
):
###### Merge Data to one table
predictData
=
self
.
build_predict_data
(
indexData
,
pid
)
if
(
self
.
_toForecast
):
forecastDay
=
predictData
[
'date'
]
.
iloc
[
-
1
]
DataAll
=
pd
.
merge
(
predictData
,
vixData
,
how
=
'outer'
,
on
=
'date'
)
DataAll
=
pd
.
merge
(
DataAll
,
indexOtherData
,
how
=
'outer'
,
on
=
'date'
)
DataAll
=
pd
.
merge
(
DataAll
,
cpiData
,
how
=
'outer'
,
on
=
'date'
)
DataAll
=
pd
.
merge
(
DataAll
,
FDTRData
,
how
=
'outer'
,
on
=
'date'
)
DataAll
=
pd
.
merge
(
DataAll
,
NAPMPMIData
,
how
=
'outer'
,
on
=
'date'
)
DataAll
.
set_index
(
'date'
,
inplace
=
True
)
DataAll
.
sort_index
(
inplace
=
True
)
DataAll
.
reset_index
(
inplace
=
True
)
###### fill eco data
for
col
in
[
'CPI_YOY'
,
'CPURNSA'
,
'CPI_MOM'
,
'CPI_MOM_Diff'
]:
DataAll
[
col
]
.
bfill
(
inplace
=
True
)
for
col
in
[
'FDTR'
]:
DataAll
[
col
]
.
ffill
(
inplace
=
True
)
# 新增指数NAPMPMI :美國的ISM製造業指數 (Monthly)
for
col
in
[
'NAPMPMI'
]:
DataAll
[
col
]
.
bfill
(
inplace
=
True
)
DataAll
[
col
]
.
ffill
(
inplace
=
True
)
if
(
self
.
_toForecast
):
# 处理CPI_YOY:美国城镇消费物价指数同比未经季 CPURNSA:美国消费者物价指数未经季调
DataAllCopy
=
DataAll
.
copy
()
for
col
in
[
'CPI_YOY'
,
'CPURNSA'
]:
DataAllCopy
[
col
]
.
ffill
(
inplace
=
True
)
for
col
in
[
'CPI_MOM'
,
'CPI_MOM_Diff'
]:
DataAllCopy
[
col
]
=
DataAllCopy
[
col
]
.
fillna
(
0
)
DataAllCopy
.
drop
([
'futureR'
,
'yLabel'
],
axis
=
1
,
inplace
=
True
)
forecastDayIndex
=
DataAllCopy
.
index
[
DataAllCopy
[
'date'
]
==
forecastDay
]
forecastData
=
DataAllCopy
.
iloc
[
forecastDayIndex
.
to_list
(),
1
:]
X_forecast
=
forecastData
.
to_numpy
()
del
DataAllCopy
###### clean NaN
DataAll
.
dropna
(
inplace
=
True
)
DataAll
.
reset_index
(
inplace
=
True
,
drop
=
True
)
###### get X and y
y
=
DataAll
[
'yLabel'
]
.
to_numpy
(
copy
=
True
)
# delete future information
DataAll
.
drop
([
'futureR'
,
'yLabel'
],
axis
=
1
,
inplace
=
True
)
X
=
DataAll
.
iloc
[:,
1
:]
.
values
###################
# scale data
scaler
=
MinMaxScaler
(
feature_range
=
(
0
,
1
))
# scaledX = scaler.fit_transform(X)
DataScaler
=
scaler
.
fit
(
X
)
scaledX
=
DataScaler
.
transform
(
X
)
if
(
self
.
_toForecast
):
scaledX_forecast
=
DataScaler
.
transform
(
X_forecast
)
X_train
=
scaledX
y_train
=
y
X_test
=
[]
y_test
=
[]
else
:
# Step 2: Split data into train set and test set
X_train
,
X_test
,
y_train
,
y_test
=
train_test_split
(
scaledX
,
y
,
test_size
=
0.02
,
shuffle
=
False
)
# To avoid data leak, test set should start from numForecastDays later
X_test
=
X_test
[
self
.
_numForecastDays
:]
y_test
=
y_test
[
self
.
_numForecastDays
:]
return
X_train
,
X_test
,
y_train
,
y_test
,
scaledX_forecast
,
forecastDay
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment