Commit e34f0a89 authored by wenwen.tang's avatar wenwen.tang 😕

新增预测基金

parent 69854f30
......@@ -10,15 +10,15 @@ from ai.training_data_builder import TrainingDataBuilder
from api import DataSync
# 截止日期
max_date = None
# max_date = '2023-12-01'
# max_date = None
max_date = '2024-01-05'
# 待预测指数
# PREDICT_LIST = [67]
PREDICT_LIST = [67, 121, 122, 123]
# PREDICT_LIST = [67, 121, 122, 123, 155, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 168, 169, 170, 171, 174, 175]
eco = [65, 66, 74, 134]
index = [67, 68, 69, 70, 71, 72, 73, 75, 116, 117, 138, 139, 142, 143, 140, 141, 144, 145, 146]
fund = [121, 122, 123]
fund = [121, 122, 123, 155, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 168, 169, 170, 171, 174, 175]
@autowired
......@@ -36,9 +36,9 @@ def predictionFromMoel(the_model, scaledX_forecast, predict_item, indexDict: dic
content = f"""\n On day {forecastDay.strftime("%m/%d/%Y")}, the model predicts {predict_item} to be {predictionStr} in {str(numForecastDays)} business days. \n"""
print(content)
# 上传预测结果
# key = [k for k, v in indexDict.items() if v == predict_item]
# index_info = get_base_info(key)[0]
# upload_predict(index_info['ticker'], forecastDay, predictionStr)
key = [k for k, v in indexDict.items() if v == predict_item]
index_info = get_base_info(key)[0]
upload_predict(index_info['ticker'], forecastDay, predictionStr)
# send(content)
return prediction
......@@ -46,46 +46,17 @@ def predictionFromMoel(the_model, scaledX_forecast, predict_item, indexDict: dic
########################################
if __name__ == '__main__':
sync()
toForecast = False # False means test, True means forecast
toForecast = True # False means test, True means forecast
# define some parameters
win1W = 5 # 1 week
win1M = 21 # 1 Month
win1Q = 63 # 1 Quarter
numForecastDays = 21 # business days, 21 business days means one month
theThreshold = 0.0
indexDict = {
65: "CPI_YOY",
66: "FDTR",
67: "SPX",
68: "USGG10YR",
69: "USGG2YR",
70: "MXWO", # not use now
71: "MXWD", # not use now
72: "CCMP",
73: "TWSE", # not use now
74: "CPURNSA",
75: "VIX",
76: "US0001M",
77: "US0012M",
# FUND
121: "IEF_US",
122: "TLT_US",
123: "UUP_US",
139: "COI_TOTL",
138: "LEI_TOTL",
116: "MID",
134: "NAPMPMI",
142: "OE4EKLAC",
143: "OEA5KLAC",
146: "OECNKLAC",
145: "OEJPKLAC",
141: "OEOTGTAC",
144: "OEUSKLAC",
117: "SML",
140: "USRINDEX"
}
ids = set(PREDICT_LIST) | set(eco) | set(index) | set(fund)
infos = get_base_info(ids)
indexDict = {info['id']: info['ticker'].replace(' Index', '').replace(' Equity', '').replace(' ', '_') for info in
infos}
###################
# Step 1: Prepare X and y (features and labels)
# 准备基础数据
......@@ -114,6 +85,9 @@ if __name__ == '__main__':
gbt_model = trainer.train_GBT(X_train, y_train, X_test, y_test)
svc_model = trainer.train_SVC(X_train, y_train, X_test, y_test)
ensemble_model = trainer.ensemble_model(rf_model, gbt_model, svc_model, X_train, y_train, X_test, y_test)
if (toForecast):
print(f'forest predict{rf_model.predict(scaledX_forecast)}'.center(60, '+'))
print(f'gbt predict{gbt_model.predict(scaledX_forecast)}'.center(60, '+'))
print(f'svc predict{svc_model.predict(scaledX_forecast)}'.center(60, '+'))
print(f'ensemble predict{ensemble_model.predict(scaledX_forecast)}'.center(60, '+'))
if toForecast:
predictionFromMoel(ensemble_model, scaledX_forecast, indexDict[pid], indexDict)
......@@ -23,7 +23,7 @@ class DataAccess(ABC):
"rid_volume", "rid_frdpe", "rid_frdpes", "rid_pc"]]
indexData.rename(columns={"rid_date": 'date'}, inplace=True) # please use 'date'
indexData["rid_index_id"] = indexData["rid_index_id"].map(self._indexDict)
indexData['rid_frdpe'].ffill(inplace=True)
indexData.fillna(method='ffill', inplace=True)
return indexData
def get_eco_datas(self):
......@@ -45,9 +45,13 @@ class DataAccess(ABC):
def get_vix(self, indexData):
# VIX:芝加哥期权交易所SPX波动率指
vixData = indexData[indexData['rid_index_id'] == "VIX"].copy()
vixData = vixData[["date", "rid_high", "rid_open", "rid_low", "rid_close"]]
vixData = vixData[
["date", "rid_high", "rid_open", "rid_low", "rid_close", "rid_pc", "rid_pb", "rid_pe", "rid_frdpe",
"rid_frdpes"]]
vixData.rename(
columns={"rid_high": 'vix_high', 'rid_open': 'vix_open', "rid_low": 'vix_low', "rid_close": 'vix_close'},
columns={"rid_high": 'vix_high', 'rid_open': 'vix_open', "rid_low": 'vix_low', "rid_volume": 'vix_volume',
"rid_close": 'vix_close', "rid_pc": 'vix_pc', "rid_pb": 'vix_pb', "rid_pe": 'vix_pe',
"rid_frdpe": 'vix_frdpe', "rid_frdpes": 'vix_frdpes'},
inplace=True)
vixData.set_index('date', inplace=True)
vixData.index = pd.to_datetime(vixData.index)
......
......@@ -87,7 +87,7 @@ class TrainingDataBuilder(ABC):
predictData.rename(
columns={"rid_high": 'high', 'rid_open': 'open', "rid_low": 'low', "rid_close": 'close',
'rid_volume': 'volume',
"rid_pe": "SPX_pe", "rid_pb": "SPX_pb"},
"rid_pe": f"{self._indexDict[pid]}_pe", "rid_pb": f"{self._indexDict[pid]}_pb"},
inplace=True)
elif pid in self._fund:
predictData = indexData[indexData['rfn_fund_id'] == self._indexDict[pid]].copy()
......
......@@ -42,7 +42,7 @@ class DefaultDatum(Datum):
if DatumType(datum['type']) is DatumType.FUND:
return {
**datum,
'inceptDate': parse_date(datum['inceptDate'])
'inceptDate': parse_date(datum.get('inceptDate')) if datum.get('inceptDate') else None
}
return datum
......
......@@ -177,7 +177,9 @@ class IndexSync(JDCDataSync):
return next_workday(last['date']) if last else self.start_date
def build_urls(self, datum, start_date, page=0) -> str:
return f'http://jdcprod.thiztech.com/api/datas/index-value?page={page}&size=200&sourceCode={quote(datum["bloombergTicker"])}&sourceType=BLOOMBERG&startDate={format_date(start_date)}'
sourceCode = quote(datum["bloombergTicker"]) if quote(datum["bloombergTicker"]) else quote(datum["thsTicker"])
sourceType = 'BLOOMBERG' if quote(datum["bloombergTicker"]) else 'THS'
return f'http://jdcprod.thiztech.com/api/datas/index-value?page={page}&size=200&sourceCode={sourceCode}&sourceType={sourceType}&startDate={format_date(start_date)}'
def store_date(self, datumid, datas: List[dict]):
# add frdpe,frdpes,erp,pc
......@@ -195,7 +197,8 @@ class IndexSync(JDCDataSync):
'frdpes': x['forwardEps'] if 'forwardEps' in x else None,
'erp': x['erp'] if 'erp' in x else None,
'pc': x['pcRatio'] if 'pcRatio' in x else None,
} for x in datas if is_workday(dt.fromtimestamp(x['date'] / 1000, tz=pytz.timezone('Asia/Shanghai'))) and 'close' in x]
} for x in datas if
is_workday(dt.fromtimestamp(x['date'] / 1000, tz=pytz.timezone('Asia/Shanghai'))) and 'close' in x]
if save_datas:
rid.batch_insert(save_datas)
......@@ -270,7 +273,6 @@ class FundNavSync(JDCDataSync):
} for x in response['body']['content'] if x['fundId'] in ft_tickers
}}
def store_date(self, datumid, datas: List[dict]):
save_navs = [{
'fund_id': datumid,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment