import json
from typing import List

import numpy as np
from py_jftech import autowired, parse_date, prev_workday, format_date

from ai.config import LABEL_RANGE, LABEL_TAG
from ai.dao import robo_predict
from ai.dao.robo_datas import get_base_info, get_index_list, get_fund_list
from ai.data_access import DataAccess
from ai.model_trainer import ModelTrainer
from ai.noticer import send
from ai.training_data_builder import TrainingDataBuilder
from api import DataSync

# 截止日期
max_date = None

toForecast = False  # False means test, True means forecast
syncData = False  # 开启会同步数据库指数及基金数据
uploadData = False  # 开启会上传预测结果
doReport = True  # 开启会生成Excel报告

# 待预测指数
PREDICT_LIST = [156]
# PREDICT_LIST = [67, 121, 122, 123, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163,
#                 164, 165, 166, 167, 168, 169, 170, 171, 174, 175, 177, 178]
eco = [65, 66, 74, 134, 191]
index = [67, 68, 69, 70, 71, 72, 73, 75, 76, 77, 105, 106, 116, 117, 138, 139, 142, 143, 140, 141, 144, 145, 146]
# fund = [121, 122, 123, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165,
#         166, 167, 168, 169, 170, 171, 174, 175, 177, 178]
fund = [156]


@autowired
def sync(syncs: List[DataSync] = None):
    for s in syncs:
        # if isinstance(s, (IndexSync, EcoSync)):
        s.do_sync()


def report_prediction(label, predict_item, indexDict: dict):
    prediction = label
    predictionStr = LABEL_TAG.get(prediction)
    content = f"""\n On day {forecastDay.strftime("%m/%d/%Y")}, the model predicts {predict_item} to be {predictionStr} in {str(numForecastDays)} business days. \n"""
    print(content)
    # 上传预测结果
    key = [k for k, v in indexDict.items() if v == predict_item][0]
    index_info = get_base_info(key)[0]
    if uploadData:
        if len(LABEL_RANGE) > 2:
            data = {"rbd_id": key, "date": forecastDay, "predict": prediction}
            robo_predict.insert(data)
        else:
            from ai.noticer import upload_predict
            upload_predict(index_info['ticker'], forecastDay, predictionStr)
    send(content)
    return prediction


def judgement(id, type, predict):
    from datetime import datetime
    predict_term = 21
    start = parse_date(max_date) if max_date else prev_workday(datetime.today())
    navs = []
    if type == 'INDEX':
        navs = get_index_list(index_ids=id, min_date=start, limit=predict_term)
        navs = [nav['rid_close'] for nav in navs]
    elif type == 'FUND':
        navs = get_fund_list(fund_ids=id, min_date=start, limit=predict_term)
        navs = [nav['rfn_nav_cal'] for nav in navs]
    if len(navs) == predict_term:
        upper = True if navs[-1] >= navs[0] else False
        result = {}
        for k, v, in predict.items():
            pred = True if v[0] > 0 else False
            if upper == pred:
                result[k] = True
            else:
                result[k] = False
        j = {
            'id': id,
            'date': format_date(start),
            'result': result
        }
        with open('predict.txt', 'a+') as file:
            file.write(json.dumps(j))
            file.write('\n')


########################################
if __name__ == '__main__':
    if syncData:
        sync()
    # define some parameters
    win1W = 5  # 1 week
    win1M = 21  # 1 Month
    win1Q = 63  # 1 Quarter
    numForecastDays = 21  # business days, 21 business days means one month
    theThreshold = 0.0
    ids = set(PREDICT_LIST) | set(eco) | set(index) | set(fund)
    infos = get_base_info(ids)
    infos_type = {info['id']: info['type'] for info in infos}
    indexDict = {info['id']: info['ticker'].replace(' Index', '').replace(' Equity', '').replace(' ', '_') for info in
                 infos}
    ###################
    # Step 1: Prepare X and y (features and labels)
    # 准备基础数据
    data_access = DataAccess(index, eco, fund, max_date, indexDict)
    indexData = data_access.get_index_datas()
    ecoData = data_access.get_eco_datas()
    fundData = data_access.get_fund_datas()
    # 指数数据准备
    vixData = data_access.get_vix(indexData)
    indexOtherData = data_access.get_other_index(indexData)
    # 经济指标数据准备
    cpiData = data_access.get_cpi(ecoData)
    FDTRData = data_access.get_fdtr(ecoData)
    # 新增指标 NAPMPMI :美國的ISM製造業指數 (Monthly)
    NAPMPMIData = data_access.get_napmpmi(ecoData)
    TTM = data_access.get_jifu_spx_opeps_currq_ttm(ecoData)

    builder = TrainingDataBuilder(index, eco, fund, indexDict, toForecast, win1W, win1M, win1Q, numForecastDays,
                                  theThreshold)
    for pid in PREDICT_LIST:
        print(f'{indexDict[pid]} start '.center(50, '='))
        t_data = indexData if pid in index else fundData
        X_train, X_test, y_train, y_test, scaledX_forecast, forecastDay, date_index = \
            builder.build_train_test(pid, t_data, vixData, indexOtherData, cpiData, FDTRData, NAPMPMIData, TTM)
        trainer = ModelTrainer(toForecast, pid)
        rf_model = trainer.train_random_forest(X_train, y_train, X_test, y_test, date_index)
        gbt_model = trainer.train_GBT(X_train, y_train, X_test, y_test, date_index)
        svc_model = trainer.train_SVC(X_train, y_train, X_test, y_test, date_index)
        knn_model = trainer.train_nearest_neighbors(X_train, y_train, X_test, y_test, date_index)
        ada_model = trainer.train_AdaBoost(X_train, y_train, X_test, y_test, date_index)
        ensemble_model = trainer.ensemble_model(rf_model, gbt_model, svc_model,
                                                knn_model, ada_model, X_train, y_train, X_test, y_test, date_index)
        if toForecast:
            model_predict = {'forest': rf_model.predict(scaledX_forecast),
                             'gbt': gbt_model.predict(scaledX_forecast),
                             'svc': svc_model.predict(scaledX_forecast),
                             'knn': knn_model.predict(scaledX_forecast),
                             'adaboost': ada_model.predict(scaledX_forecast),
                             'ensemble': ensemble_model.predict(scaledX_forecast)}
            print(f'预测结果:{model_predict}'.center(60, '+'))
            judgement(pid, infos_type[pid], model_predict)
            if len(LABEL_RANGE) > 2:
                average = round(np.mean(list(model_predict.values())))
                report_prediction(average, indexDict[pid], indexDict)
            else:
                report_prediction(ensemble_model.predict(scaledX_forecast), indexDict[pid], indexDict)
    if doReport:
        if len(LABEL_RANGE) > 2:
            from ai.reporter import do_reporter2

            do_reporter2()
        else:
            from ai.reporter import do_reporter

            do_reporter()