from framework import component, autowired, get_config, workday_range, next_workday
from api import RebalanceRuler, PortfoliosRisk, RebalanceSignal, SignalType, PortfoliosType, PortfoliosHolder
from typing import List
from rebalance.dao import robo_rebalance_signal as rrs


@component
class LevelRebalanceRuler(RebalanceRuler):
    '''
    定义:
    1.定义所有调仓类型为非NORMAL类型的信号为清仓信号
    2.定义所有调仓类型为NORMAL类型的信号为加仓信号
    3.定义持久信号为上次选用调仓的信号时间到当前时间内,该信号都有效
    4.定义临时信号为仅当天有效
    规则:
    1.所有清仓信号为持久信号,所有加仓信号为临时信号
    2.对于持久信号规则如下:
        2.1 上一次选用信号到当前时间内,是否有持久信号
        2.2 如果有,则看级别是否高于上一次选用信号
        2.3 如果高于,则输出该信号
    3.如果没有持久信号,则从临时信号中根据级别排序找出第一个,作为输出信号
    '''

    @autowired
    def __init__(self, signals: List[RebalanceSignal] = None, hold: PortfoliosHolder = None):
        self._signals = signals
        self._hold = hold
        self._config = get_config(__name__)

    @property
    def disable_period(self):
        result = self._config['disable-period']
        return {PortfoliosType(x[0]): x[1] for x in result.items()}

    def take_next_signal(self, day, risk: PortfoliosRisk):
        last_re = rrs.get_last_one(max_date=day, risk=risk, effective=True)
        if not last_re:
            builder = [x for x in self._signals if x.signal_type is SignalType.INIT][0]
            return builder.get_signal(day, risk)

        long_signals = [x for x in self._signals if x.signal_type.p_type is not PortfoliosType.NORMAL and x.signal_type.level < SignalType(last_re['type']).level]
        for long_signal in sorted(long_signals, key=lambda x: x.signal_type.level):
            workdays = workday_range(next_workday(last_re['date']), day)
            if len(workdays) <= self._hold.interval_days:
                for date in workdays:
                    signal = long_signal.get_signal(date, risk)
                    if signal:
                        return signal
            else:
                signal = long_signal.get_signal(day, risk)
                if signal:
                    return signal
        if SignalType(last_re['type']).p_type in self.disable_period:
            re_date = self._hold.get_last_rebalance_date(risk=risk, max_date=day)
            if re_date:
                workdays = workday_range(re_date, day)
                if len(workdays) < self.disable_period[SignalType(last_re['type']).p_type]:
                    return None
        for temp_signal in sorted([x for x in self._signals if x.signal_type.p_type is PortfoliosType.NORMAL], key=lambda x: x.signal_type.level):
            signal = temp_signal.get_signal(day, risk)
            if signal:
                return signal
        return None

    def get_signal_type(self, sign_id) -> SignalType:
        signal = rrs.get_by_id(sign_id)
        return SignalType(signal['type']) if signal else None

    def commit_signal(self, sign_id):
        rrs.update(sign_id, {'effective': True})