from py_jftech import component, autowired, get_config, workday_range, next_workday

from api import DriftSolver, PortfoliosRisk, PortfoliosBuilder, Datum, RoboExecutor
from rebalance.dao import robo_rebalance_signal as rrs, robo_weight_drift as rwd


@component(bean_name='date-curve')
class DateCurve(DriftSolver):

    def __init__(self):
        self._config = get_config(__name__)['date-curve']

    @property
    def diff_threshold(self):
        return self._config['diff-threshold']

    @property
    def init_factor(self):
        return self._config['init-factor']

    def get_drift(self, day, risk: PortfoliosRisk):
        last_re = rrs.get_last_one(max_date=day, risk=risk, effective=True)
        result = self.diff_threshold - self.init_factor * (day - last_re['date']).days ** 4
        return max(0, result)


@component(bean_name='high-weight')
class PortfolioHighWeight(DriftSolver):

    @autowired(names={'executor': RoboExecutor.use_name()})
    def __init__(self, builder: PortfoliosBuilder = None, datum: Datum = None, executor: RoboExecutor = None):
        self._builder = builder
        self._datum = datum
        self._executor = executor
        self._config = get_config(__name__)['high-weight']

    @property
    def drift_coef(self):
        return self._config['coef']

    def get_drift(self, day, risk: PortfoliosRisk):
        drift = rwd.get_one(day, risk)
        if not drift:
            datum_ids = [x['id'] for x in self._datum.get_high_risk_datums(risk)]
            last_one = rwd.get_last_one(max_date=day, risk=risk)
            start = (next_workday(last_one['date'])) if last_one else self._executor.start_date
            last_drift = last_one['drift'] if last_one else 0
            for date in workday_range(start, day):
                portfolio = self._builder.get_portfolios(date, risk)
                weight = round(sum([x[1] for x in portfolio.items() if x[0] in datum_ids]), 2)
                last_drift = round((weight * self.drift_coef + (1 - self.drift_coef) * last_drift) if last_drift else weight, 2)
                rwd.insert({
                    'date': date,
                    'risk': risk,
                    'weight': weight,
                    'drift': last_drift,
                })
            drift = rwd.get_last_one(day, risk)
        return drift['drift']