import pandas as pd
from py_jftech import component, autowired, get_config
from scipy.stats import norm

from api import SignalType, PortfoliosRisk, Navs
from rebalance.base_signal import BaseRebalanceSignal
from rebalance.dao import robo_rebalance_signal as rrs


@component(bean_name='market-right')
class MarketRight(BaseRebalanceSignal):

    @autowired
    def __init__(self, navs: Navs = None):
        super().__init__()
        self._navs = navs
        self._config = get_config(__name__)

    @property
    def rtn_days(self):
        return self._config['rtn-days']

    @property
    def min_threshold(self):
        return self._config['min-threshold']

    @property
    def coef(self):
        return self._config['coef']

    @property
    def signal_type(self) -> SignalType:
        return SignalType.MARKET_RIGHT

    @property
    def cvar_min_volume(self):
        return self._config['cvar-min-volume']

    def is_trigger(self, day, risk: PortfoliosRisk) -> bool:
        last_re = rrs.get_last_one(risk=risk, max_date=day, effective=True)
        if last_re is not None and SignalType(last_re['type']) in [SignalType.CRISIS_ONE, SignalType.CRISIS_TWO,
                                                                   SignalType.MARKET_RIGHT, SignalType.INIT]:
            return False
        spx = self.load_spx_close_rtns(day)

        # if self.is_fall(day, risk, spx=spx):
        #     return True

        if spx[-1]['rtn'] > self.min_threshold:
            return False
        cvar = self.get_cvar(day, risk, spx=spx)
        return cvar is not None and spx[-1]['rtn'] < cvar

    def is_fall(self, day, risk: PortfoliosRisk, spx=None):
        if spx is None:
            spx = self.load_spx_close_rtns(day)
        start_date = self.find_cvar_start_date(day, risk, spx=spx)
        if start_date:
            spx = pd.DataFrame(spx)
            spx = spx[(spx.date >= start_date) & (spx.date <= day)]
            return spx.iloc[-1].close < spx.iloc[0].close
        return False

    def get_cvar(self, day, risk: PortfoliosRisk, spx=None):
        if spx is None:
            spx = self.load_spx_close_rtns(day)
        start_date = self.find_cvar_start_date(day, risk, spx=spx)
        if start_date:
            spx = pd.DataFrame(spx)
            spx = spx[(spx.date >= start_date) & (spx.date <= day)]
            if len(spx) >= self.cvar_min_volume:
                alpha = round(1 - self.coef, 2)
                mean = spx.rtn.mean()
                std = spx.rtn.std()
                return mean - std * norm.pdf(norm.ppf(alpha)) / alpha
        return None

    def find_cvar_start_date(self, day, risk: PortfoliosRisk, spx=None):
        if spx is None:
            spx = self.load_spx_close_rtns(day)
        spx = pd.DataFrame(spx)
        last_right = rrs.get_last_one(type=(SignalType.MARKET_RIGHT, SignalType.INIT), max_date=day, risk=risk,
                                      effective=True)
        last_buy = rrs.get_first_after(type=(SignalType.LOW_BUY, SignalType.HIGH_BUY), risk=risk, effective=True,
                                       min_date=last_right['date'])
        if not last_buy or not last_right or last_buy['date'] <= last_right['date']:
            return None
        spx = spx[(spx['date'] >= last_right['date']) & (spx['date'] <= last_buy['date'])]
        if not spx.empty and len(spx) > 2:
            return spx.loc[spx.close.idxmin()].date
        return None

    def load_spx_close_rtns(self, day):
        spx = pd.DataFrame(self._navs.get_index_close(ticker='SPX Index', max_date=day))
        spx.sort_values('date', inplace=True)
        spx['rtn'] = spx['close'] / spx['close'].shift(self.rtn_days) - 1
        spx.dropna(inplace=True)
        spx = spx[['date', 'close', 'rtn']]
        return spx.to_dict('records')