from datetime import datetime as dt, timedelta
from typing import List

import pandas as pd
from py_jftech import component, autowired, prev_workday, filter_weekend, next_workday, get_config

from api import RoboReportor, PortfoliosRisk, RoboExecutor, Navs, Datum, DatumType
from reports.dao import robo_benckmark as rb


@component(bean_name='benckmark-report')
class BenchmarkReportor(RoboReportor):

    @autowired(names={'executor': RoboExecutor.use_name()})
    def __init__(self, executor: RoboExecutor = None, navs: Navs = None, datum: Datum = None):
        self._exec = executor
        self._navs = navs
        self._datum = datum
        self._config = get_config(__name__)

    @property
    def report_name(self) -> str:
        return 'benchmark'

    @property
    def risks(self):
        return self._config['stock-rate'].keys()

    @property
    def init_amount(self):
        return self._config['init-amount']

    def stock_rate(self, risk):
        return self._config['stock-rate'][risk]

    def load_nav_rtn(self, risk, day):
        last = rb.get_last_one(risk=risk, max_date=day, re=True)
        start_date = last['date'] if last else next_workday(self._exec.start_date)

        datums = {x['id']: x for x in self._datum.get_datums(type=DatumType.FUND)}
        navs = pd.DataFrame(self._navs.get_fund_navs(fund_ids=tuple(datums.keys()), min_date=prev_workday(start_date - timedelta(10)), max_date=day))
        navs = navs.pivot_table(index='nav_date', columns='fund_id', values='nav_cal')
        navs.fillna(method='ffill', inplace=True)
        nav_index = navs.shape[1]
        for i in range(nav_index):
            navs[f'rtn_{navs.columns[i]}'] = navs[navs.columns[i]] / navs[navs.columns[i]].shift() - 1
        navs = navs[navs.index >= start_date]
        return navs, nav_index

    def find_datum_asset(self):
        return {x['id']: x['assetType'] for x in self._datum.get_datums(type=DatumType.FUND)}

    def build_benchmark(self, risk, day=dt.today()):
        nav_rtn, nav_index = self.load_nav_rtn(risk=risk, day=day)
        asset_types = {x['id']: x['assetType'] for x in self._datum.get_datums(type=DatumType.FUND)}

        last = rb.get_last_one(risk=risk, max_date=day, re=True)
        init_amount = last['nav'] if last else self.init_amount
        stock_rate = self.stock_rate(risk)
        other_rate = 1 - stock_rate

        five_rtn = 0
        last_day = None
        fund_ids = None

        for index, row in nav_rtn.iterrows():
            if last_day is None or fund_ids is None:
                fund_ids = list(row.iloc[:nav_index].dropna().index)
                stock_count = len([x for x in fund_ids if asset_types[x] == 'STOCK'])
                stock_average = init_amount * stock_rate / stock_count
                other_average = init_amount * other_rate / (len(fund_ids) - stock_count)
                nav_rtn.loc[index, f'{risk}_result'] = init_amount
                nav_rtn.loc[index, f'{risk}_re'] = 1
                for fund_id in fund_ids:
                    if fund_id and asset_types[fund_id] == 'STOCK':
                        nav_rtn.loc[index, f'stock_{fund_id}'] = stock_average
                    else:
                        nav_rtn.loc[index, f'other_{fund_id}'] = other_average
            else:
                for fund_id in fund_ids:
                    if fund_id and asset_types[fund_id] == 'STOCK':
                        nav_rtn.loc[index, f'stock_{fund_id}'] = nav_rtn.loc[last_day, f'stock_{fund_id}'] * (
                                1 + nav_rtn.loc[index, f'rtn_{fund_id}'])
                    else:
                        nav_rtn.loc[index, f'other_{fund_id}'] = nav_rtn.loc[last_day, f'other_{fund_id}'] * (
                                1 + nav_rtn.loc[index, f'rtn_{fund_id}'])
                nav_rtn.loc[index, f'{risk}_result'] = nav_rtn.loc[index][-len(fund_ids):].sum()
                nav_rtn.loc[index, f'{risk}_re'] = 0
            if five_rtn == 5:
                five_rtn = 0
                fund_ids = list(row.iloc[:nav_index].dropna().index)
                stock_count = len([x for x in fund_ids if asset_types[x] == 'STOCK'])

                stock_average = nav_rtn.loc[index, f'{risk}_result'] * stock_rate / stock_count
                other_average = nav_rtn.loc[index, f'{risk}_result'] * other_rate / (len(fund_ids) - stock_count)
                nav_rtn.loc[index, f'{risk}_re'] = 1
                for fund_id in fund_ids:
                    if fund_id and asset_types[fund_id] == 'STOCK':
                        nav_rtn.loc[index, f'stock_{fund_id}'] = stock_average
                    else:
                        nav_rtn.loc[index, f'other_{fund_id}'] = other_average
            five_rtn += 1
            last_day = index

        result = nav_rtn.reindex(columns=[f'{risk}_result', f'{risk}_re'])
        result.reset_index(inplace=True)
        result['risk'] = risk
        result.rename(columns={f'{risk}_result': 'nav', f'{risk}_re': 're', 'nav_date': 'date'}, inplace=True)
        result['nav'] = round(result['nav'], 4)
        return result.to_dict('records')

    def load_report(self, max_date=dt.today(), min_date=None) -> List[dict]:
        for risk in self.risks:
            last = rb.get_last_one(max_date=max_date, risk=risk)
            if not last or last['date'] < filter_weekend(max_date):
                benchmarks = pd.DataFrame(self.build_benchmark(risk=risk, day=max_date))
                if last:
                    benchmarks = benchmarks[benchmarks.date > last['date']]
                if not benchmarks.empty:
                    rb.batch_insert(benchmarks.to_dict('records'))
        result = pd.DataFrame(rb.get_list(max_date=max_date, min_date=min_date))
        result = result.pivot_table(index='date', columns='risk', values='nav')
        result.reset_index(inplace=True)
        return result.to_dict('records')