from py_jftech import component, autowired, get_config

from api import SignalType, PortfoliosRisk, DriftSolver
from rebalance.base_signal import BaseRebalanceSignal
from rebalance.dao import robo_rebalance_signal as rrs


@component(bean_name='high-buy')
class HighBuySignal(BaseRebalanceSignal):

    @autowired(names={'solver': 'high-weight'})
    def __init__(self, solver: DriftSolver = None):
        super().__init__()
        self._config = get_config(__name__)
        self._solver = solver

    @property
    def include_last_type(self):
        return [
            SignalType.CRISIS_ONE,
            SignalType.CRISIS_TWO,
            SignalType.MARKET_RIGHT,
            SignalType.LOW_BUY,
            SignalType.INIT
        ]

    @property
    def signal_type(self) -> SignalType:
        return SignalType.HIGH_BUY

    def get_threshold(self, risk: PortfoliosRisk):
        threshold = self._config['threshold']
        if isinstance(threshold, dict):
            threshold = threshold[f'ft{risk.value}']
        return threshold

    def is_trigger(self, day, risk: PortfoliosRisk) -> bool:
        last_re = rrs.get_last_one(max_date=day, risk=risk, effective=True)
        if last_re is None or SignalType(last_re['type']) not in self.include_last_type:
            return False
        drift = self._solver.get_drift(day, risk)
        threshold = self.get_threshold(risk)
        return drift >= threshold[1]


@component(bean_name='low-buy')
class LowBuySignal(HighBuySignal):

    @property
    def include_last_type(self):
        return [
            SignalType.CRISIS_ONE,
            SignalType.CRISIS_TWO,
            SignalType.MARKET_RIGHT,
            SignalType.INIT
        ]

    @property
    def signal_type(self) -> SignalType:
        return SignalType.LOW_BUY

    def is_trigger(self, day, risk: PortfoliosRisk) -> bool:
        last_re = rrs.get_last_one(max_date=day, risk=risk, effective=True)
        if last_re is None or SignalType(last_re['type']) not in self.include_last_type:
            return False
        drift = self._solver.get_drift(day, risk)
        threshold = self.get_threshold(risk)
        return threshold[0] <= drift < threshold[1]