import json
from abc import ABC, abstractmethod
from datetime import datetime as dt, timedelta

import pandas as pd
from dateutil.relativedelta import relativedelta
from empyrical import sortino_ratio
from py_jftech import filter_weekend, dict_remove, get_config, component, autowired, get_quarter_start, next_workday, is_workday

from api import AssetOptimize, Navs, Datum, AssetPoolType
from asset_pool.dao import robo_assets_pool as rop


class SortinoAssetOptimize(AssetOptimize, ABC):

    def __init__(self):
        optimize_config = get_config(__name__)
        self._config = [{
            **x,
            'name': [f"sortino_{y[1]}_{y[0]}" for y in x.items() if y[0] != 'weight'][0]
        } for x in optimize_config['sortino-weight']] if 'sortino-weight' in optimize_config else []

    @property
    def delta_kwargs(self):
        result = []
        for item in self._config:
            delta_kwargs = item.copy()
            del delta_kwargs['weight'], delta_kwargs['name']
            result.append(delta_kwargs)
        return result

    def find_optimize(self, fund_ids, day):
        assert self._config, "find optimize, but not found sortino config."
        pct_change = pd.DataFrame(self.get_pct_change(fund_ids, day))
        pct_change.set_index('date', inplace=True)
        sortino = pd.DataFrame()
        for item in self._config:
            ratio = dict(sortino_ratio(pct_change.truncate(before=(day - relativedelta(**dict_remove(item, ('weight', 'name')))))))
            sortino = pd.concat([sortino, pd.DataFrame([ratio], index=[item['name']])])
        sortino = sortino.T
        sortino['score'] = sortino.apply(lambda r: sum([x['weight'] * r[x['name']] for x in self._config]), axis=1)
        sortino.sort_values('score', ascending=False, inplace=True)
        return pct_change.columns[sortino.index[0]]

    def get_optimize_pool(self, day):
        last_one = rop.get_last_one(day=day, type=AssetPoolType.OPTIMIZE)
        start = get_quarter_start(day or dt.today())
        if not last_one or start > last_one['date'] or self.has_incept_asset(last_one['date'] + timedelta(1), day):
            pool = []
            min_dates = self.nav_min_dates
            max_incept_date = sorted([(day - relativedelta(**x)) for x in self.delta_kwargs])[0]
            max_incept_date = max_incept_date if is_workday(max_incept_date) else next_workday(max_incept_date)
            for fund_group in self.get_groups():
                fund_group = [x for x in fund_group if min_dates[x] <= max_incept_date]
                if len(fund_group) > 1:
                    pool.append(self.find_optimize(tuple(fund_group), day))
                elif len(fund_group) == 1:
                    pool.append(fund_group[0])
            rop.insert(day, AssetPoolType.OPTIMIZE, sorted(pool))
            last_one = rop.get_last_one(day=day, type=AssetPoolType.OPTIMIZE)
        return json.loads(last_one['asset_ids'])

    @abstractmethod
    def has_incept_asset(self, start_date, end_date):
        pass

    @property
    @abstractmethod
    def nav_min_dates(self) -> dict:
        pass

    @abstractmethod
    def get_groups(self):
        '''
        :return: 返回待处理的id数组
        '''
        pass

    @abstractmethod
    def get_pct_change(self, fund_ids, day):
        '''
        根据id数组,返回指定日期的收益率
        :param fund_ids: id数组
        :param day: 指定的日期
        :return: 收益率
        '''
        pass


@component
class FundSortinoAssetOptimize(SortinoAssetOptimize):
    '''
    根据索提诺比率计算基金优选的优选实现
    '''

    @autowired
    def __init__(self, navs: Navs = None, datum: Datum = None):
        super().__init__()
        self._navs = navs
        self._datum = datum

    @property
    def nav_min_dates(self) -> dict:
        return self._navs.get_nav_start_date()

    def has_incept_asset(self, start_date, end_date):
        start_date = sorted([(start_date - relativedelta(**x)) for x in self.delta_kwargs])[0]
        end_date = sorted([(end_date - relativedelta(**x)) for x in self.delta_kwargs])[0]
        return len([x for x in self.nav_min_dates.items() if start_date <= x[1] <= end_date]) > 0

    def get_groups(self):
        funds = pd.DataFrame(self._datum.get_fund_datums())
        min_dates = self._navs.get_nav_start_date()
        result = []
        for (category, asset_type), fund_group in funds.groupby(by=['category', 'assetType']):
            result.append(tuple(fund_group['id']))
        return result

    def get_pct_change(self, fund_ids, day):
        if not self._config:
            raise BusinessException(f"find optimize, but not found sortino config.")
        start = filter_weekend(
            sorted([day - relativedelta(days=1, **dict_remove(x, ('weight', 'name'))) for x in self._config])[0])
        fund_navs = pd.DataFrame(self._navs.get_fund_navs(fund_ids=tuple(fund_ids), min_date=start, max_date=day))
        if not fund_navs.empty:
            fund_navs.sort_values('nav_date', inplace=True)
            fund_navs = fund_navs.pivot_table(index='nav_date', columns='fund_id', values='nav_cal')
            fund_navs.fillna(method='ffill', inplace=True)
            result = round(fund_navs.pct_change().dropna(), 4)
            result.reset_index(inplace=True)
            result.rename(columns={'nav_date': 'date'}, inplace=True)
            return result.to_dict('records')
        return []