import json
import logging
from datetime import datetime as dt
from typing import List

import pandas as pd
from py_jftech import (
    component, autowired, get_config, next_workday, prev_workday, transaction, workday_range, format_date
)

from api import PortfoliosHolder, PortfoliosRisk, RebalanceRuler, Navs, SignalType, RoboExecutor, PortfoliosType, RoboReportor
from portfolios.dao import robo_hold_portfolios as rhp
from portfolios.utils import format_weight

logger = logging.getLogger(__name__)


@component(bean_name='next-re')
class NextReblanceHolder(PortfoliosHolder):

    @autowired(names={'executor': RoboExecutor.use_name()})
    def __init__(self, rule: RebalanceRuler, navs: Navs = None, executor: RoboExecutor = None):
        self._rule = rule
        self._navs = navs
        self._executor = executor
        self._config = get_config(__name__)

    def get_portfolio_type(self, day, risk: PortfoliosRisk) -> PortfoliosType:
        hold = rhp.get_one(day, risk)
        if hold:
            signal_type = self._rule.get_signal_type(hold['signal_id'])
            return signal_type.p_type if signal_type else PortfoliosType.NORMAL
        return PortfoliosType.NORMAL

    def get_last_rebalance_date(self, risk: PortfoliosRisk, max_date=None):
        assert risk, f"get last rebalance date, risk can not be none"
        last = rhp.get_last_one(max_date=max_date, risk=risk, rebalance=True)
        return last['date'] if last else None

    def get_rebalance_date_by_signal(self, signal_id):
        last = rhp.get_last_one(signal_id=signal_id, rebalance=True)
        return last['date'] if last else None

    def get_portfolios_weight(self, day, risk: PortfoliosRisk):
        hold = rhp.get_one(day, risk)
        if hold:
            result = json.loads(hold['portfolios'])['weight']
            return {int(x[0]): x[1] for x in result.items()}
        return None

    def has_hold(self, risk: PortfoliosRisk) -> bool:
        return rhp.get_count(risk=risk) > 0

    def build_hold_portfolio(self, day, risk: PortfoliosRisk):
        last_nav = rhp.get_last_one(max_date=day, risk=risk)
        start = next_workday(last_nav['date'] if last_nav else self._executor.start_date)
        try:
            while start <= day:
                logger.info(f"start to build hold portfolio[{risk.name}] for date[{format_date(start)}]")
                signal = None
                if last_nav:
                    last_re_date = self.get_last_rebalance_date(risk=risk, max_date=start)
                    if len(workday_range(last_re_date, start)) > self.interval_days:
                        signal = self._rule.take_next_signal(prev_workday(start), risk)
                else:
                    signal = self._rule.take_next_signal(prev_workday(start), risk)
                if signal and not signal['effective']:
                    logger.info(f"start to rebalance hold portfolio[{risk.name}] for date[{format_date(start)}] "
                                f"with signal[{SignalType(signal['type']).name}]")
                    self.do_rebalance(start, risk, signal, last_nav)
                elif last_nav and signal is None:
                    self.no_rebalance(start, risk, last_nav)
                start = next_workday(start)
                last_nav = rhp.get_last_one(max_date=day, risk=risk)
        except Exception as e:
            logger.exception(f"build hold portfolio[{risk.name}] for date[{format_date(start)}] failure.", e)

    @transaction
    def do_rebalance(self, day, risk: PortfoliosRisk, signal, last_nav):
        weight = {int(x[0]): x[1] for x in json.loads(signal['portfolio']).items()}
        if last_nav:
            share = {int(x): y for x, y in json.loads(last_nav['portfolios'])['share'].items()}
            navs = self.get_navs(fund_ids=tuple(set(weight) | set(share)), day=day)
            nav = round(sum([navs[x] * y for x, y in share.items()]), 4)
        else:
            nav = self.init_nav
            navs = self.get_navs(fund_ids=tuple(weight), day=day)
        share = {x: nav * w / navs[x] for x, w in weight.items()}
        rhp.insert({
            'date': day,
            'risk': risk,
            'signal_id': signal['id'],
            'rebalance': True,
            'portfolios': {
                'weight': weight,
                'share': share,
            },
            'nav': nav,
        })
        self._rule.commit_signal(signal['id'])

    def no_rebalance(self, day, risk: PortfoliosRisk, last_nav):
        share = {int(x): y for x, y in json.loads(last_nav['portfolios'])['share'].items()}
        navs = self.get_navs(fund_ids=tuple(share), day=day)
        nav = round(sum([navs[x] * y for x, y in share.items()]), 4)
        weight = {x: round(y * navs[x] / nav, 2) for x, y in share.items()}
        weight = format_weight(weight)
        rhp.insert({
            'date': day,
            'risk': risk,
            'signal_id': last_nav['signal_id'],
            'rebalance': False,
            'portfolios': {
                'weight': weight,
                'share': share,
            },
            'nav': nav,
        })

    def get_navs(self, day, fund_ids):
        navs = pd.DataFrame(self._navs.get_fund_navs(fund_ids=fund_ids, max_date=day))
        navs = navs.pivot_table(index='nav_date', columns='fund_id', values='nav_cal')
        navs.fillna(method='ffill', inplace=True)
        return dict(navs.iloc[-1])

    def clear(self, day=None, risk: PortfoliosRisk = None):
        rhp.delete(min_date=day, risk=risk)

    @property
    def interval_days(self):
        return self._config['min-interval-days']

    @property
    def init_nav(self):
        return self._config['init-nav']


@component(bean_name='hold-report')
class HoldReportor(RoboReportor):

    @autowired
    def __init__(self, rule: RebalanceRuler = None):
        self._rule = rule

    @property
    def report_name(self) -> str:
        return '投组净值'

    def load_report(self, max_date=dt.today(), min_date=None) -> List[dict]:
        holds = pd.DataFrame(rhp.get_list(max_date=max_date, min_date=min_date))
        if not holds.empty:
            signal_types = self._rule.get_signal_type(tuple(set(holds['signal_id'])))
            holds['signal_type'] = holds.apply(lambda row: signal_types[row['signal_id']].name, axis=1)
            holds['risk'] = holds.apply(lambda row: PortfoliosRisk(row['risk']).name, axis=1)
            holds = holds[['risk', 'date', 'nav', 'signal_type']]
            return holds.to_dict('records')
        return []