from abc import ABC

import pandas as pd

from ai.dao.robo_datas import get_eco_list, get_fund_list, get_index_list


class DataAccess(ABC):
    def __init__(self, index, eco, fund, max_date, indexDict) -> None:
        super().__init__()
        self._index = index
        self._eco = eco
        self._fund = fund
        self._max_date = max_date
        self._indexDict = indexDict

    def get_index_datas(self):
        indexData = pd.DataFrame(
            get_index_list(index_ids=self._index, max_date=self._max_date))
        # todo erp 没有数据 "rid_erp",
        indexData = indexData[
            ["rid_index_id", "rid_date", "rid_high", "rid_open", "rid_low", "rid_close", "rid_pe", "rid_pb",
             "rid_volume", "rid_frdpe", "rid_frdpes", "rid_pc"]]
        # 数据替换和截取以后用105 和106,(都从2022.1.3开始)然后2022年以前的部分:105前边接76的数据 106前边接77的数据
        condition1 = ((indexData['rid_index_id'] == 76) & (indexData['rid_date'] >= '2022-01-03'))
        condition2 = ((indexData['rid_index_id'] == 77) & (indexData['rid_date'] >= '2022-01-03'))
        condition3 = ((indexData['rid_index_id'] == 105) & (indexData['rid_date'] < '2022-01-03'))
        condition4 = ((indexData['rid_index_id'] == 106) & (indexData['rid_date'] < '2022-01-03'))
        indexData.drop(indexData[condition1 | condition2 | condition3 | condition4].index, inplace=True)
        indexData.loc[indexData['rid_index_id'] == 76, 'rid_index_id'] = 105
        indexData.loc[indexData['rid_index_id'] == 77, 'rid_index_id'] = 106
        indexData.rename(columns={"rid_date": 'date'}, inplace=True)  # please use 'date'
        indexData["rid_index_id"] = indexData["rid_index_id"].map(self._indexDict)
        return indexData

    def get_eco_datas(self):
        ecoData = pd.DataFrame(
            get_eco_list(eco_ids=self._eco, max_date=self._max_date))
        ecoData = ecoData[["red_eco_id", "red_release_date", "red_indicator"]]
        ecoData.rename(columns={"red_release_date": 'date'}, inplace=True)  # please use 'date'
        ecoData["red_eco_id"] = ecoData["red_eco_id"].map(self._indexDict)
        return ecoData

    def get_fund_datas(self):
        fundData = pd.DataFrame(
            get_fund_list(fund_ids=self._fund, max_date=self._max_date))
        fundData = fundData[["rfn_fund_id", "rfn_date", "rfn_nav_cal"]]
        fundData.rename(columns={"rfn_date": 'date'}, inplace=True)  # please use 'date'
        fundData["rfn_fund_id"] = fundData["rfn_fund_id"].map(self._indexDict)
        return fundData

    def get_vix(self, indexData):
        # VIX:芝加哥期权交易所SPX波动率指
        vixData = indexData[indexData['rid_index_id'] == "VIX"].copy()
        vixData = vixData[
            ["date", "rid_high", "rid_open", "rid_low", "rid_close", "rid_pc", "rid_pb", "rid_pe", "rid_frdpe",
             "rid_frdpes"]]
        vixData.rename(
            columns={"rid_high": 'vix_high', 'rid_open': 'vix_open', "rid_low": 'vix_low', "rid_volume": 'vix_volume',
                     "rid_close": 'vix_close', "rid_pc": 'vix_pc', "rid_pb": 'vix_pb', "rid_pe": 'vix_pe',
                     "rid_frdpe": 'vix_frdpe', "rid_frdpes": 'vix_frdpes'},
            inplace=True)
        vixData.set_index('date', inplace=True)
        vixData.index = pd.to_datetime(vixData.index)
        vixData.dropna(axis=1, inplace=True)
        return vixData

    def get_other_index(self, indexData):
        other_index = ["USGG10YR", "USGG2YR", "CCMP", "TSFR1M", "TSFR12M", "COI_TOTL", "LEI_TOTL", "MID",
                       "OE4EKLAC", "OEA5KLAC", "OECNKLAC", "OEJPKLAC", "OEOTGTAC", "OEUSKLAC", "USRINDEX", "SPX"]
        cols = ['date', 'rid_close', 'rid_pe', 'rid_pb', 'rid_volume', 'rid_frdpe', 'rid_frdpes', 'rid_pc']
        indexOtherData = pd.DataFrame()
        idxs = [self._indexDict[i] for i in self._index]
        for idx in other_index:
            if idx in idxs:
                idx_data = indexData[indexData['rid_index_id'] == idx].copy()
                idx_data = idx_data[cols]
                idx_data.rename(
                    columns={"rid_close": f'{idx}_close', 'rid_pe': f'{idx}_pe', 'rid_pb': f'{idx}_pb',
                             'rid_volume': f'{idx}_volume', 'rid_frdpe': f'{idx}_frdpe', 'rid_frdpes': f'{idx}_frdpes',
                             'rid_pc': f'{idx}_pc'},
                    inplace=True)
                idx_data.set_index('date', inplace=True)
                idx_data.index = pd.to_datetime(idx_data.index)
                if indexOtherData.size > 0:
                    indexOtherData = pd.merge(indexOtherData, idx_data, how='outer', on='date')
                else:
                    indexOtherData = idx_data
        indexOtherData.ffill(inplace=True)
        indexOtherData.bfill(inplace=True)
        indexOtherData = indexOtherData.dropna(axis=1)
        return indexOtherData

    def get_cpi(self, ecoData):
        # CPI_YOY:美国城镇消费物价指数同比未经季 CPURNSA:美国消费者物价指数未经季调
        cpiData = ecoData[(ecoData['red_eco_id'] == "CPI_YOY") | (ecoData['red_eco_id'] == "CPURNSA")].copy()
        cpiData = cpiData.pivot(index='date', columns='red_eco_id', values='red_indicator')
        cpiData['CPI_MOM'] = (cpiData['CPURNSA'] / cpiData['CPURNSA'].shift(
            1) - 1.0) * 100 * 12  # Annualized Percentage
        cpiData['CPI_MOM_Diff'] = cpiData['CPURNSA'] - cpiData['CPURNSA'].shift(1)
        cpiData.index = pd.to_datetime(cpiData.index)
        return cpiData

    def get_fdtr(self, ecoData):
        # FDTR 美国联邦基金目标利率
        FDTRData = ecoData[ecoData['red_eco_id'] == "FDTR"].copy()
        del (FDTRData['red_eco_id'])
        FDTRData.rename(columns={"red_indicator": 'FDTR'}, inplace=True)
        FDTRData.set_index('date', inplace=True)
        FDTRData.index = pd.to_datetime(FDTRData.index)
        return FDTRData

    def get_napmpmi(self, ecoData):
        # 新增指标 NAPMPMI :美國的ISM製造業指數 (Monthly)
        NAPMPMIData = ecoData[ecoData['red_eco_id'] == "NAPMPMI"].copy()
        del (NAPMPMIData['red_eco_id'])
        NAPMPMIData.rename(columns={"red_indicator": 'NAPMPMI'}, inplace=True)
        NAPMPMIData.set_index('date', inplace=True)
        NAPMPMIData.index = pd.to_datetime(NAPMPMIData.index)
        return NAPMPMIData

    def get_jifu_spx_opeps_currq_ttm(self, ecoData):
        # 新增指标 SP500 Operating EPS Current Quarter TTM
        ttm = ecoData[ecoData['red_eco_id'] == "JIFU_SPX_OPEPS_CURRQ_TTM"].copy()
        del (ttm['red_eco_id'])
        ttm.rename(columns={"red_indicator": 'JIFU_SPX_OPEPS_CURRQ_TTM'}, inplace=True)
        ttm.set_index('date', inplace=True)
        ttm.index = pd.to_datetime(ttm.index)
        return ttm