import json
from datetime import datetime as dt
from typing import List

import pandas as pd
from dateutil.relativedelta import relativedelta
from empyrical import downside_risk, sortino_ratio
from py_jftech import component, autowired, get_config, dict_remove, prev_workday, filter_weekend

from api import RoboReportor, Datum, DatumType, PortfoliosRisk, Navs
from asset_pool.asset_optimize import SortinoAssetOptimize
from portfolios.solver import DefaultFactory
from rebalance.dao import robo_rebalance_signal as rrs


@component(bean_name='fund-selection-process-report')
class FundSelectionProcessReportor(RoboReportor):

    @autowired
    def __init__(self, datum: Datum = None, optimize: SortinoAssetOptimize = None, navs: Navs = None):
        self._datum = datum
        self._optimize = optimize
        self._navs = navs

    @property
    def report_name(self) -> str:
        return '中間數據报告'

    def load_report(self, max_date=dt.today(), min_date=None) -> List[dict]:
        day = prev_workday(filter_weekend(max_date))
        funds = self._datum.get_datums(type=DatumType.FUND)
        funds_dict = {fund['id']: fund['chineseName'] for fund in funds}
        fund_ids = list(funds_dict.keys())
        pct_change = pd.DataFrame(self._optimize.get_pct_change(fund_ids, day))
        pct_change.set_index('date', inplace=True)
        pct_change = pct_change.truncate(before=(day - relativedelta(
            **get_config('asset-pool')['asset-optimize']['annual-volatility-section'][0])))
        # 时间未够计算年化波动的直接剔除
        funds = [fund for fund in funds if fund['id'] in pct_change.columns]
        ratio = downside_risk(pct_change)
        d_risk = ratio.to_dict()
        # 索提诺
        sortino_conf = [{
            **x, 'name': [f"sortino_{y[1]}_{y[0]}" for y in x.items() if y[0] != 'weight'][0]
        } for x in get_config('asset-pool')['asset-optimize']['sortino-weight']] if 'sortino-weight' in \
                                                                                    get_config('asset-pool')[
                                                                                        'asset-optimize'] else []
        sortino = pd.DataFrame()
        for item in sortino_conf:
            ratio = dict(sortino_ratio(
                pct_change.truncate(before=(day - relativedelta(**dict_remove(item, ('weight', 'name')))))))
            sortino = pd.concat([sortino, pd.DataFrame([ratio], index=[item['name']])])
        sortino = sortino.T
        sortino['score'] = sortino.apply(lambda r: sum([x['weight'] * r[x['name']] for x in sortino_conf]), axis=1)
        sortino.sort_values('score', ascending=False, inplace=True)
        records = sortino.to_dict(orient='index')
        sortino = {fund_ids[k]: v for k, v in records.items()}
        last_re = rrs.get_last_one(max_date=day, risk=PortfoliosRisk.FT3, effective=True)
        portfolios = json.loads(last_re['portfolio'])
        datas = []

        navs = self.get_navs(fund_ids, day)
        rtn_annualized = self.get_rtn_annualized(navs)
        for fund in funds:
            data = {
                'customType': fund['customType'],
                'chineseName': fund['chineseName'],
                'ftRisk': fund['ftRisk'],
                '下行風險': d_risk.get(fund['id']),
                'downside_risk_rank': None,
                '短期sortino': sortino.get(fund['id']).get('sortino_3_months'),
                '中期sortino': sortino.get(fund['id']).get('sortino_6_months'),
                '長期sortino': sortino.get(fund['id']).get('sortino_1_years'),
                'score': sortino.get(fund['id']).get('score'),
                'sortino_rank': None,
                '加權後各類排名前3名': None,
                '短期報酬': rtn_annualized[fund['id']],
                'MPT計算': f"{portfolios.get(str(fund['id'])) * 100:.2f}%" if portfolios.get(str(fund['id'])) else None,
            }
            datas.append(data)

        # 将数据转换为 DataFrame
        df = pd.DataFrame(datas)
        # 按 customType 分组，并对每个组内的 score 进行排名
        df['sortino_rank'] = df.groupby('customType')['score'].rank(method='dense', ascending=False)
        # 将排名结果合并回原始数据
        for i, row in df.iterrows():
            datas[i]['sortino_rank'] = int(row['sortino_rank'])
            if int(row['sortino_rank']) <= 3:
                datas[i]['加權後各類排名前3名'] = True
        # 按 customType 分组，并对每个组内的 下行風險 进行排名
        df['downside_risk_rank'] = df.groupby('customType')['下行風險'].rank(method='dense', ascending=False)
        for i, row in df.iterrows():
            datas[i]['downside_risk_rank'] = int(row['downside_risk_rank'])
        return datas

    def get_navs(self, fund_ids, day):
        navs_conf = get_config('portfolios')['solver']['navs']
        min_date = day - relativedelta(**navs_conf['range'])
        navs = pd.DataFrame(self._navs.get_fund_navs(fund_ids=fund_ids, max_date=day, min_date=min_date))
        navs = navs[navs['nav_date'].dt.day_of_week < 5]
        navs['nav_date'] = pd.to_datetime(navs['nav_date'])
        navs = navs.pivot_table(index='nav_date', columns='fund_id', values='nav_cal')
        navs = navs.sort_index()
        navs_nan = navs.isna().sum()
        navs.drop(columns=[x for x in navs_nan.index if navs_nan.loc[x] >= navs_conf['max-nan']['asset']],
                  inplace=True)
        navs_nan = navs.apply(lambda r: r.isna().sum() / len(r), axis=1)
        navs.drop(index=[x for x in navs_nan.index if navs_nan.loc[x] >= navs_conf['max-nan']['day']],
                  inplace=True)
        navs.fillna(method='ffill', inplace=True)
        if navs.iloc[0].isna().sum() > 0:
            navs.fillna(method='bfill', inplace=True)
        return navs

    def get_rtn_annualized(self, navs):
        result = navs / navs.shift(get_config('portfolios')['solver']['matrix-rtn-days']) - 1
        result.dropna(inplace=True)
        return result.mean() * 12


@component(bean_name='sigma-report')
class SigmaReportor(FundSelectionProcessReportor):

    @autowired
    def __init__(self):
        super().__init__()

    @property
    def report_name(self) -> str:
        return 'sigma'

    def load_report(self, max_date=dt.today(), min_date=None) -> List[dict]:
        day = prev_workday(filter_weekend(max_date))
        funds = self._datum.get_datums(type=DatumType.FUND)
        fund_ids = [fund['id'] for fund in funds]
        navs = self.get_navs(fund_ids, day)
        rtn = (navs / navs.shift(1) - 1)[1:]
        return rtn.cov() * 252
