import json
import logging
from datetime import datetime as dt

import pandas as pd
from dateutil.relativedelta import relativedelta
from py_jftech import component, autowired, get_config, format_date, transaction, asynchronized
from scipy.stats import norm

from api import AssetRisk, Navs, AssetRiskDateType as DateType, Datum, AssetPoolType, RoboExecutor, DatumType
from asset_pool.dao import asset_risk_dates as ard, asset_ewma_value as aev, robo_assets_pool as rap

logger = logging.getLogger(__name__)


@component
class CvarEwmaAssetRisk(AssetRisk):
    '''
    CVAR方式决定风控开始。风控开始后,开始计算ewma寻找风控结束日期,也就是ewma的起始日期
    EWMA方式决定风控结束:风控结束后,就可以找到风控期的最低点日期,该日期作为下一轮cvar计算的起始日期
    '''

    @autowired(names={'executor': RoboExecutor.use_name()})
    def __init__(self, navs: Navs = None, datum: Datum = None, executor: RoboExecutor = None):
        self._navs = navs
        self._datum = datum
        self._executor = executor
        self._config = get_config(__name__)

    @property
    def risk_start_date(self):
        return self._executor.start_date - relativedelta(months=self._config['advance-months'])

    def get_risk_pool(self, day):
        asset_pool = rap.get_one(day, AssetPoolType.RISK)
        if not asset_pool:
            result = {x['id']: self.async_is_risk(x['id'], day) for x in self._datum.get_datums(type=DatumType.FUND, risk=(3, 4, 5))}
            risk_ids = [x[0] for x in result.items() if x[1].result()]
            rap.insert(day, AssetPoolType.RISK, risk_ids)
            asset_pool = rap.get_one(day, AssetPoolType.RISK)
        return json.loads(asset_pool['asset_ids'])

    @asynchronized
    def async_is_risk(self, id, day):
        return self.is_risk(id, day)

    def is_risk(self, id, day) -> bool:
        asset_pool = rap.get_one(day, AssetPoolType.RISK)
        if asset_pool:
            return id in json.loads(asset_pool['asset_ids'])
        last = ard.get_last_one(fund_id=id)
        if last and last['date'] < day:
            self.build_risk_date(id, day)
        result = ard.get_last_one(id, day)
        return DateType(result['type']) is DateType.START_DATE if result else True

    def build_risk_date(self, asset_id, day=dt.today()):
        risk_date = not None
        try:
            logger.debug(f"start build risk date for asset[{asset_id}] to date[{format_date(day)}]")
            while risk_date is not None:
                risk_date = self.get_next_date(asset_id, day=day)
        except Exception as e:
            logger.exception(f"build risk date for asset[{asset_id}] after date[{risk_date}] to date[{format_date(day)}] error", e)

    @transaction
    def clear(self, day=None):
        ard.delete(day)
        aev.delete(day)

    def get_next_date(self, asset_id, day=dt.today()):
        last = ard.get_last_one(asset_id, day)
        if not last or DateType(last['type']) is DateType.START_DATE:
            start_date = last['date'] if last else self.risk_start_date
            ewma = pd.DataFrame(self.get_ewma_value(asset_id, min_date=start_date, max_date=day))
            total = self._config['ewma']['condition-total']
            meet = self._config['ewma']['condition-meet']
            threshold = self._config['ewma']['threshold']
            if len(ewma) < total:
                return None
            for index in range(total, len(ewma) - 1):
                sub_ewma = ewma[index - total:index]
                if len(sub_ewma[sub_ewma['ewma'] >= threshold]) >= meet:
                    stop_date = sub_ewma.iloc[-1]['date']
                    ard.insert(asset_id, DateType.STOP_DATE, stop_date)
                    return {'date': stop_date, 'type': DateType.STOP_DATE}
        elif DateType(last['type']) is DateType.STOP_DATE:
            last_start = ard.get_last_one(asset_id, last['date'], type=DateType.START_DATE)
            start_date = last_start['date'] if last_start else self.risk_start_date
            rtns = pd.DataFrame(self.get_income_return(asset_id, min_date=start_date, max_date=day))
            risk_rtns = rtns[rtns.date <= last['date']]
            cvar_start_date = risk_rtns.loc[risk_rtns.nav.idxmin()].date
            for index, row in rtns[rtns.date >= cvar_start_date].iterrows():
                tigger = False
                cvar_rtns = rtns[(rtns.date >= cvar_start_date) & (rtns.date <= row['date'])]
                if row.nav < rtns[rtns.date == cvar_start_date].iloc[0].nav:
                    tigger = True
                elif row['rtn'] <= self._config['cvar']['threshold'] and len(cvar_rtns) >= self._config['cvar']['min-volume']:
                    # 当日回报率小于等于阀值并且有足够cvar累计计算数据,则计算cvar判断
                    alpha = 1 - self._config['cvar']['coef']
                    mean = cvar_rtns['rtn'].mean()
                    std = cvar_rtns['rtn'].std()
                    cvar = mean - std * norm.pdf(norm.ppf(alpha)) / alpha
                    tigger = row['rtn'] < cvar
                if tigger:
                    ard.insert(asset_id, DateType.START_DATE, row['date'])
                    return {'date': row['date'], 'type': DateType.START_DATE}
        return None

    def get_ewma_value(self, id, min_date=None, max_date=None):
        rtn = pd.DataFrame(self.get_income_return(id, min_date=min_date or self.risk_start_date, max_date=max_date))
        if rtn.empty:
            return []
        rtn.sort_values('date', inplace=True)

        last_one = aev.get_last_one(id, max_date=max_date)
        if not last_one:
            aev.insert(asset_id=id, date=rtn.iloc[0].date, value=rtn.iloc[0].rtn)
            last_one = aev.get_last_one(id, max_date=max_date)
        last_day = last_one['date']
        if last_day < max_date:
            ewma = last_one['value']
            factor = self._config['ewma']['factor']
            for index, row in rtn[rtn['date'] > last_day].iterrows():
                ewma = factor * row['rtn'] + (1 - factor) * ewma
                aev.insert(id, row['date'], ewma)

        result = aev.get_list(id, min_date=min_date, max_date=max_date)
        return [{'date': x['date'], 'ewma': x['value']} for x in result]

    def get_income_return(self, asset_id, min_date=None, max_date=None):
        fund_navs = pd.DataFrame(self._navs.get_fund_navs(fund_ids=asset_id, max_date=max_date))
        if not fund_navs.empty:
            fund_navs['rtn'] = fund_navs['nav_cal'] / fund_navs['nav_cal'].shift(self._config['rtn-days']) - 1
            fund_navs.dropna(inplace=True)
            if min_date:
                fund_navs = fund_navs[fund_navs.nav_date >= pd.to_datetime(min_date)]
            fund_navs.rename(columns={'nav_date': 'date', 'nav_cal': 'nav'}, inplace=True)
            fund_navs = fund_navs[['date', 'nav', 'rtn']]
            return fund_navs.to_dict('records')
        return []