import json
from abc import ABC, abstractmethod
from sys import exception

import numpy as np
import pandas as pd
from dateutil.relativedelta import relativedelta
from empyrical import sortino_ratio, annual_volatility, downside_risk, annual_return, tail_ratio
from py_jftech import filter_weekend, dict_remove, get_config, component, autowired, next_workday, \
    is_workday

from api import AssetOptimize, Navs, Datum, AssetPoolType, DatumType
from asset_pool.dao import robo_assets_pool as rop, robo_indicator


class SortinoAssetOptimize(AssetOptimize, ABC):

    def __init__(self):
        optimize_config = get_config(__name__)
        self._config = [{
            **x,
            'name': [f"sortino_{y[1]}_{y[0]}" for y in x.items() if y[0] != 'weight'][0]
        } for x in optimize_config['sortino-weight']] if 'sortino-weight' in optimize_config else []

    @property
    def delta_kwargs(self):
        result = []
        for item in self._config:
            delta_kwargs = item.copy()
            del delta_kwargs['weight'], delta_kwargs['name']
            result.append(delta_kwargs)
        return result

    def find_optimize(self, fund_ids, day):
        pass

    def get_optimize_pool(self, day):
        pass

    @property
    @abstractmethod
    def nav_min_dates(self) -> dict:
        pass

    @abstractmethod
    def get_groups(self, day=None):
        '''
        :return: 返回待处理的id数组
        '''
        pass

    @abstractmethod
    def get_pct_change(self, fund_ids, day):
        '''
        根据id数组,返回指定日期的收益率
        :param fund_ids: id数组
        :param day: 指定的日期
        :return: 收益率
        '''
        pass

    @abstractmethod
    def has_change(self, day):
        return False


@component(bean_name='dividend')
class FundDividendSortinoAssetOptimize(SortinoAssetOptimize):
    """
    根据索提诺比率计算基金优选的优选实现
    以美国资产为主:US_STOCK、US_HY_BOND、US_IG_BOND
    Sortino ratio对资产进行排序,选出排名靠前的资产(非一类选一只)
    """

    @autowired
    def __init__(self, navs: Navs = None, datum: Datum = None):
        super().__init__()
        self._navs = navs
        self._datum = datum
        self._conf = get_config(__name__)

    @property
    def annual_volatility_section(self):
        return self._conf['annual-volatility-section']

    @property
    def annual_volatility_filter(self):
        return self._conf['annual-volatility-filter']

    @property
    def asset_include(self):
        return self._conf['asset-include']

    @property
    def asset_filter(self):
        return self._conf.get('asset-filter')

    @property
    def optimize_count(self):
        return self._conf['optimize-count']

    @property
    def nav_min_dates(self) -> dict:
        return self._navs.get_nav_start_date()

    def has_change(self, day):
        return self._datum.update_change(day)

    def find_optimize(self, fund_ids, day):
        assert self._config, "find optimize, but not found sortino config."
        pct_change = pd.DataFrame(self.get_pct_change(fund_ids, day))
        pct_change.set_index('date', inplace=True)
        sortino = pd.DataFrame()
        for item in self._config:
            ratio = dict(sortino_ratio(
                pct_change.truncate(before=(day - relativedelta(**dict_remove(item, ('weight', 'name')))))))
            sortino = pd.concat([sortino, pd.DataFrame([ratio], index=[item['name']])])
        sortino = sortino.T
        sortino['score'] = sortino.apply(lambda r: sum([x['weight'] * r[x['name']] for x in self._config]), axis=1)
        sortino.sort_values('score', ascending=False, inplace=True)
        records = sortino.to_dict(orient='index')
        data = {fund_ids[k]: v for k, v in records.items()}
        self.save_sortino(day, data)
        # 取得分数高的前optimize_count个
        return pct_change.columns[sortino.index[0:self.optimize_count]].values, sortino['score']

    def save_sortino(self, day, datas):
        for key, record in datas.items():
            record = {k: v for k, v in record.items() if not (np.isnan(v) or np.isinf(v))}
            robo_indicator.update_sortino(key, day, json.dumps(record))

    def get_optimize_pool(self, day):
        opt_pool = rop.get_one(day=day, type=AssetPoolType.OPTIMIZE)
        if opt_pool is not None:
            return json.loads(opt_pool['asset_ids'])
        last_one = rop.get_last_one(day=day, type=AssetPoolType.OPTIMIZE)
        if not last_one or day > last_one['date']:
            pool = []
            min_dates = self.nav_min_dates
            max_incept_date = sorted([(day - relativedelta(**x)) for x in self.delta_kwargs])[0]
            max_incept_date = max_incept_date if is_workday(max_incept_date) else next_workday(max_incept_date)
            for fund_group in self.get_groups(day):
                fund_group = [x for x in fund_group if min_dates[x] <= max_incept_date]
                if len(fund_group) > self.optimize_count:
                    pool.extend(self.find_optimize(tuple(fund_group), day)[0])
                elif len(fund_group) <= self.optimize_count:
                    pool.extend(fund_group)
            if len(pool) < get_config('portfolios.solver.asset-count')[0]:
                raise ValueError(f"基金优选个数小于{get_config('portfolios.solver.asset-count')[0]},请调整参数")
            rop.insert(day, AssetPoolType.OPTIMIZE, sorted(pool))
            last_one = rop.get_last_one(day=day, type=AssetPoolType.OPTIMIZE)
        return json.loads(last_one['asset_ids'])

    def do_annual_volatility_filter(self, day, funds):
        """
        年化波动率过滤器
        @return:
        """
        filtered = []
        fund_ids = [fund['id'] for fund in funds]
        pct_change = pd.DataFrame(self.get_pct_change(fund_ids, day))
        pct_change.set_index('date', inplace=True)
        pct_change = pct_change.truncate(before=(day - relativedelta(**self.annual_volatility_section[0])))
        # 时间未够计算年化波动的直接剔除
        funds = [fund for fund in funds if fund['id'] in pct_change.columns]
        ratio = downside_risk(pct_change)  # annual_volatility / downside_risk / tail_ratio
        ratio = pd.Series(ratio).to_dict()
        annual = dict(zip(pct_change.columns, ratio.values()))
        self.save_annual(day, annual)
        filters = self.annual_volatility_filter
        for f in filters:
            customType = f.get('customType')
            exclude = f.get('exclude')
            volatility = f.get('volatility')
            retain = f.get('min-retain')
            records = [fund for fund in funds if fund['customType'] == customType and fund['id'] in annual.keys()]
            records = sorted(records, key=lambda x: annual[x['id']])
            max_exclude = len(records) - retain
            if exclude is not None:
                exclude = exclude if len(records) > exclude else len(records)
                exclude = max_exclude if (len(records) - exclude) < retain else exclude
                if exclude > 0:
                    filtered.extend(records[-exclude:])
                    records = records[:- exclude]
            if volatility is not None and len(records) > retain:
                max_exclude = max_exclude - len(records)
                records = [record for record in records if annual.get(record['id']) > volatility][:max_exclude]
                filtered.extend(records)
        for f in filtered:
            funds.remove(f)
        return funds

    def save_annual(self, day, annual):
        datas = []
        for key, record in annual.items():
            data = {
                "id": key,
                "date": day,
                "annual": record,
            }
            datas.append(data)
        robo_indicator.insert(datas)

    def get_filtered_funds(self, day):
        funds = self._datum.get_datums(type=DatumType.FUND)
        if get_config('portfolios.checker.month-fund-filter'):
            #  如果有按月剔除
            filters = get_config('portfolios.checker.month-fund-filter')
            excludes = filters.get(day.month)
            if excludes:
                for f in funds[:]:
                    if f['bloombergTicker'] in excludes:
                        funds.remove(f)
        if self.asset_filter:
            filters = list(self.asset_filter.keys())[0]
            funds_in = []
            for fund in funds:
                if fund[filters] in self.asset_filter[filters]:
                    funds_in.append(fund)
            return funds_in
        funds = self.do_annual_volatility_filter(day, funds)
        return funds

    def get_groups(self, day=None):
        funds = pd.DataFrame(self.get_filtered_funds(day))
        if len(funds) < get_config('portfolios.solver.asset-count')[0]:
            raise ValueError(f"{day}==基金优选个数小于{get_config('portfolios.solver.asset-count')[0]},请调整参数")
        result = []
        if self.asset_include:
            include = list(self.asset_include.keys())[0]
            for key, fund_group in funds.groupby(by=include):
                if key in self.asset_include[include]:
                    result.append(tuple(fund_group['id']))
        else:
            for (category, asset_type), fund_group in funds.groupby(by=['category', 'assetType']):
                result.append(tuple(fund_group['id']))
        return result

    def get_pct_change(self, fund_ids, day):
        if not self._config:
            raise exception(f"find optimize, but not found sortino config.")
        days = [day - relativedelta(days=7, **dict_remove(x, ('weight', 'name'))) for x in self._config]
        days.append(day - relativedelta(days=7, **self.annual_volatility_section[0]))
        start = filter_weekend(sorted(days)[0])
        fund_navs = pd.DataFrame(self._navs.get_fund_navs(fund_ids=tuple(fund_ids), min_date=start, max_date=day))
        if not fund_navs.empty:
            fund_navs.sort_values('nav_date', inplace=True)
            fund_navs = fund_navs.pivot_table(index='nav_date', columns='fund_id', values='nav_cal')
            fund_navs.fillna(method='ffill', inplace=True)
            fund_navs = fund_navs.loc[fund_navs.index >= start + relativedelta(days=6)]
            fund_navs.dropna(axis=1, inplace=True)
            result = round(fund_navs.pct_change().dropna(), 4)
            result.reset_index(inplace=True)
            result.rename(columns={'nav_date': 'date'}, inplace=True)
            return result.to_dict('records')
        return []