from api import PortfoliosRisk, SignalType, Datum, PortfoliosHolder, DriftSolver
from framework import component, autowired, get_config
from rebalance.base_signal import BaseRebalanceSignal
from rebalance.dao import robo_rebalance_signal as rrs


@component(bean_name='curve-drift')
class CurveDrift(BaseRebalanceSignal):

    @autowired(names={'solver': 'date-curve'})
    def __init__(self, datum: Datum = None, hold: PortfoliosHolder = None, solver: DriftSolver = None):
        super().__init__()
        self._datum = datum
        self._hold = hold
        self._solver = solver
        self._config = get_config(__name__)

    @property
    def exclude_last_type(self):
        return [
            SignalType.CRISIS_ONE,
            SignalType.CRISIS_TWO,
            SignalType.MARKET_RIGHT,
            SignalType.INIT,
            SignalType.LOW_BUY
        ]

    def is_trigger(self, day, risk: PortfoliosRisk) -> bool:
        last_re = rrs.get_last_one(max_date=day, risk=risk, effective=True)
        if last_re is None or SignalType(last_re['type']) in self.exclude_last_type:
            return False
        hr_datums = self._datum.get_high_risk_datums(risk)
        datum_ids = [x['id'] for x in hr_datums]
        normal_portfolio = self._builder.get_portfolios(day, risk)
        normal_weight = round(sum([x[1] for x in normal_portfolio.items() if x[0] in datum_ids]), 2)
        hold_portfolio = self._hold.get_portfolios_weight(day, risk)
        hold_weight = round(sum([x[1] for x in hold_portfolio.items() if x[0] in datum_ids]), 2)
        return normal_weight - hold_weight >= self._solver.get_drift(day, risk) # TODO 左边应该加绝对值

    @property
    def diff_threshold(self):
        return self._config['diff-threshold']

    @property
    def init_factor(self):
        return self._config['init-factor']

    @property
    def signal_type(self) -> SignalType:
        return SignalType.DRIFT_BUY