from py_jftech import component, autowired
from dateutil.relativedelta import relativedelta

from api import PortfoliosRisk, SignalType, Datum, PortfoliosHolder, DriftSolver
from rebalance.base_signal import BaseRebalanceSignal
from rebalance.dao import robo_rebalance_signal as rrs


@component(bean_name='curve-drift')
class CurveDrift(BaseRebalanceSignal):

    @autowired(names={'solver': 'date-curve'})
    def __init__(self, datum: Datum = None, hold: PortfoliosHolder = None, solver: DriftSolver = None):
        super().__init__()
        self._datum = datum
        self._hold = hold
        self._solver = solver

    @property
    def exclude_last_type(self):
        return [
            SignalType.CRISIS_ONE,
            SignalType.CRISIS_TWO,
            SignalType.MARKET_RIGHT,
            SignalType.INIT
        ]

    def is_trigger(self, day, risk: PortfoliosRisk) -> bool:
        last_re = rrs.get_last_one(max_date=day, risk=risk, effective=True)
        if last_re is None or SignalType(last_re['type']) in self.exclude_last_type:
            return False
        hr_datums = self._datum.get_high_risk_datums(risk)
        datum_ids = [x['id'] for x in hr_datums]
        normal_portfolio = self._builder.get_portfolios(day, risk)
        normal_weight = round(sum([x[1] for x in normal_portfolio.items() if x[0] in datum_ids]), 2)
        hold_portfolio = self._hold.get_portfolios_weight(day, risk)
        hold_weight = round(sum([x[1] for x in hold_portfolio.items() if x[0] in datum_ids]), 2)
        return normal_weight - hold_weight >= self._solver.get_drift(day, risk)

    @property
    def signal_type(self) -> SignalType:
        return SignalType.DRIFT_BUY


@component(bean_name='curve-drift')
class Max120tCurveDrift(CurveDrift):

    def is_trigger(self, day, risk: PortfoliosRisk) -> bool:
        last_re = rrs.get_last_one(max_date=day, risk=risk, effective=True)
        if last_re is None or SignalType(last_re['type']) in self.exclude_last_type:
            return False
        if last_re['date'] + relativedelta(days=120) <= day:
            return True
        return super(Max120tCurveDrift, self).is_trigger(day, risk)


@component(bean_name='curve-drift')
class AbsCurveDrift(CurveDrift):

    def is_trigger(self, day, risk: PortfoliosRisk) -> bool:
        last_re = rrs.get_last_one(max_date=day, risk=risk, effective=True)
        if last_re is None or SignalType(last_re['type']) in self.exclude_last_type:
            return False
        hr_datums = self._datum.get_high_risk_datums(risk)
        datum_ids = [x['id'] for x in hr_datums]
        normal_portfolio = self._builder.get_portfolios(day, risk)
        normal_weight = round(sum([x[1] for x in normal_portfolio.items() if x[0] in datum_ids]), 2)
        hold_portfolio = self._hold.get_portfolios_weight(day, risk)
        hold_weight = round(sum([x[1] for x in hold_portfolio.items() if x[0] in datum_ids]), 2)
        return abs(normal_weight - hold_weight) >= self._solver.get_drift(day, risk)