from abc import ABC

import numpy as np
import pandas as pd
from finta import TA
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler

from ai.config import LABEL_RANGE


def imp():
    print(TA)


class TrainingDataBuilder(ABC):

    def __init__(self, index, eco, fund, indexDict, toForecast, win1W, win1M, win1Q, numForecastDays,
                 theThreshold) -> None:
        super().__init__()
        self._index = index
        self._eco = eco
        self._fund = fund
        self._indexDict = indexDict
        self._toForecast = toForecast
        self._win1W = win1W  # 1 week
        self._win1M = win1M  # 1 Month
        self._win1Q = win1Q  # 1 Quarter
        self._numForecastDays = numForecastDays  # business days, 21 business days means one month
        self._theThreshold = theThreshold
        # List of symbols for technical indicators
        # INDICATORS = ['RSI', 'MACD', 'STOCH','ADL', 'ATR', 'MOM', 'MFI', 'ROC', 'OBV', 'CCI', 'EMV', 'VORTEX']
        # Note that '14 period MFI' and '14 period EMV' is not available for forecast
        self.INDICATORS = ['RSI', 'MACD', 'STOCH', 'ADL', 'ATR', 'MOM', 'ROC', 'OBV', 'CCI', 'VORTEX']
        self.FUND_INDICATORS = []

    def get_indicator_data(self, data, pid):
        """
        Function that uses the finta API to calculate technical indicators used as the features
        """

        def indicator_calcu(data, indicators):
            """
            指数和基金不同,基金只有收盘价,生成指标会变少
            @param data:
            @param indicators:
            @return:
            """
            for indicator in indicators:
                ind_data = eval('TA.' + indicator + '(data)')
                if not isinstance(ind_data, pd.DataFrame):
                    ind_data = ind_data.to_frame()
                data = data.merge(ind_data, left_index=True, right_index=True)
            return data

        if pid in self._index:
            data = indicator_calcu(data, self.INDICATORS)
            # Instead of using the actual volume value (which changes over time), we normalize it with a moving volume average
            data['normVol'] = data['volume'] / data['volume'].ewm(5).mean()
            # get relative values
            data['relativeOpen'] = data['open'] / data['close'].shift(1)
            data['relativeHigh'] = data['high'] / data['close'].shift(1)
            data['relativeLow'] = data['low'] / data['close'].shift(1)
            # Remove columns that won't be used as features
            # data['close'] are still needed and will be deleted later
            data.drop(['open', 'high', 'low', 'volume'], axis=1, inplace=True)
        elif pid in self._fund:
            indicator_calcu(data, self.FUND_INDICATORS)
        # Also calculate moving averages for features
        data['ema50'] = data['close'] / data['close'].ewm(50).mean()
        data['ema21'] = data['close'] / data['close'].ewm(21).mean()
        data['ema15'] = data['close'] / data['close'].ewm(15).mean()
        data['ema5'] = data['close'] / data['close'].ewm(5).mean()
        data['relativeClose'] = data['close'] / data['close'].shift(1)
        return data

    def build_predict_data(self, indexData, pid):
        def map_to_label(ret):
            for label, (lower, upper) in LABEL_RANGE.items():
                if float(lower) <= ret < float(upper):
                    return label

        """
        @param pid: 需要预测的指数或基金id
        @return:
        """
        if pid in self._index:
            ###### get individual data from raw data
            predictData = indexData[indexData['rid_index_id'] == self._indexDict[pid]].copy()
            del (predictData['rid_index_id'])
            ###### Additional preparing SPX Data
            # finta expects properly formated ohlc DataFrame, with column names in lowercase:
            # ["open", "high", "low", close"] and ["volume"] for indicators that expect ohlcv input.
            predictData.rename(
                columns={"rid_high": 'high', 'rid_open': 'open', "rid_low": 'low', "rid_close": 'close',
                         'rid_volume': 'volume',
                         "rid_pe": f"{self._indexDict[pid]}_pe", "rid_pb": f"{self._indexDict[pid]}_pb"},
                inplace=True)
        elif pid in self._fund:
            predictData = indexData[indexData['rfn_fund_id'] == self._indexDict[pid]].copy()
            del (predictData['rfn_fund_id'])
            predictData.rename(columns={"rfn_nav_cal": 'close'}, inplace=True)

        predictData.set_index('date', inplace=True)
        predictData.index = pd.to_datetime(predictData.index)
        predictData.sort_index(inplace=True)
        predictData.reset_index(inplace=True)
        # Calculate the indicator data
        predictData = self.get_indicator_data(predictData, pid)
        # Calculate Historical Return and Volatility
        predictData['R1W'] = np.log(predictData['close'] / predictData['close'].shift(self._win1W))
        predictData['R1M'] = np.log(predictData['close'] / predictData['close'].shift(self._win1M))
        predictData['R1Q'] = np.log(predictData['close'] / predictData['close'].shift(self._win1Q))
        price_list = predictData['close']
        rollist = price_list.rolling(self._win1W)
        predictData['Vol_1W'] = rollist.std(ddof=0)
        rollist = price_list.rolling(self._win1M)
        predictData['Vol_1M'] = rollist.std(ddof=0)
        rollist = price_list.rolling(self._win1Q)
        predictData['Vol_1Q'] = rollist.std(ddof=0)
        # The following uses future info for the y label, to be deleted later
        predictData['futureR'] = np.log(predictData['close'].shift(-self._numForecastDays) / predictData['close'])
        # predictData = predictData[predictData['futureR'].notna()]
        predictData['yLabel'] = predictData['futureR'].apply(lambda r: map_to_label(r))
        del (predictData['close'])
        return predictData

    def build_train_test(self, pid, indexData, vixData, indexOtherData, cpiData, FDTRData, NAPMPMIData, TTM):
        ###### Merge Data to one table
        predictData = self.build_predict_data(indexData, pid)
        forecastDay = None
        if (self._toForecast):
            forecastDay = predictData['date'].iloc[-1]
        DataAll = pd.merge(predictData, vixData, how='outer', on='date')
        DataAll = pd.merge(DataAll, indexOtherData, how='outer', on='date')
        DataAll = pd.merge(DataAll, cpiData, how='outer', on='date')
        DataAll = pd.merge(DataAll, FDTRData, how='outer', on='date')
        DataAll = pd.merge(DataAll, NAPMPMIData, how='outer', on='date')
        DataAll = pd.merge(DataAll, TTM, how='outer', on='date')

        DataAll.set_index('date', inplace=True)
        DataAll.sort_index(inplace=True)
        DataAll.reset_index(inplace=True)
        ###### fill eco data
        for col in ['CPI_YOY', 'CPURNSA', 'CPI_MOM', 'CPI_MOM_Diff', 'FDTR', 'NAPMPMI', 'JIFU_SPX_OPEPS_CURRQ_TTM']:
            DataAll[col].ffill(inplace=True)
        DataAll['EPS_TTM_YOY'] = (DataAll['JIFU_SPX_OPEPS_CURRQ_TTM'] / DataAll['JIFU_SPX_OPEPS_CURRQ_TTM'].shift(
            252) - 1.0)
        DataAll.dropna(subset=[DataAll.columns[1]], inplace=True)
        DataAll.ffill(inplace=True)
        if (self._toForecast):
            # 处理CPI_YOY:美国城镇消费物价指数同比未经季 CPURNSA:美国消费者物价指数未经季调
            DataAllCopy = DataAll.copy()
            for col in ['CPI_YOY', 'CPURNSA']:
                DataAllCopy[col].ffill(inplace=True)
            for col in ['CPI_MOM', 'CPI_MOM_Diff']:
                DataAllCopy[col] = DataAllCopy[col].fillna(0)
            DataAllCopy.drop(['futureR', 'yLabel'], axis=1, inplace=True)
            forecastDayIndex = DataAllCopy.index[DataAllCopy['date'] == forecastDay]
            forecastData = DataAllCopy.iloc[forecastDayIndex.to_list(), 1:]
            forecastData.dropna(inplace=True, axis=1)
            X_forecast = forecastData.to_numpy()
            del DataAllCopy

        ###### clean NaN
        DataAll.dropna(inplace=True)
        DataAll.reset_index(inplace=True, drop=True)

        ###### get X and y
        y = DataAll['yLabel'].to_numpy(copy=True)

        # delete future information
        DataAll.drop(['futureR', 'yLabel'], axis=1, inplace=True)
        X = DataAll.iloc[:, 1:].values

        ###################
        # scale data
        labels = list(LABEL_RANGE.keys())
        scaler = MinMaxScaler(feature_range=(labels[-1], labels[0]))
        DataScaler = scaler.fit(X)
        scaledX = DataScaler.transform(X)
        scaledX_forecast = None
        if self._toForecast:
            scaledX_forecast = DataScaler.transform(X_forecast)

            X_train = scaledX
            y_train = y
            X_test = []
            y_test = []
            date_index = []
        else:
            # Step 2: Split data into train set and test set
            X_train, X_test, y_train, y_test = train_test_split(scaledX, y, test_size=0.04, shuffle=False)
            # To avoid data leak, test set should start from numForecastDays later
            X_test = X_test[self._numForecastDays:]
            y_test = y_test[self._numForecastDays:]
            date_index = DataAll['date'][-len(X_test):].to_numpy()
        return X_train, X_test, y_train, y_test, scaledX_forecast, forecastDay, date_index