import datetime
import json
import logging
from datetime import datetime as dt, date
from functools import reduce
from typing import List

import pandas as pd
from py_jftech import (
    component, autowired, get_config, next_workday, format_date, is_workday, prev_workday, workday_range
)

from api import PortfoliosHolder, PortfoliosRisk, Navs, RoboExecutor, PortfoliosType, PortfoliosBuilder, RoboReportor, \
    DatumType, Datum, RebalanceSignal, SignalType
from portfolios.dao import robo_hold_portfolios as rhp
from portfolios.utils import format_weight

logger = logging.getLogger(__name__)


@component(bean_name='dividend-holder')
class DividendPortfoliosHolder(PortfoliosHolder):

    @autowired(names={'executor': RoboExecutor.use_name()})
    def __init__(self, navs: Navs = None, executor: RoboExecutor = None, builder: PortfoliosBuilder = None,
                 datum: Datum = None, mpt: PortfoliosBuilder = None, signal: RebalanceSignal = None):
        self._navs = navs
        self._executor = executor
        self._builder = builder
        self._config = get_config(__name__)
        self._datum = datum
        self._mpt = mpt
        self._signal = signal

    def get_portfolio_type(self, day, risk: PortfoliosRisk) -> PortfoliosType:
        return PortfoliosType.NORMAL

    def get_last_rebalance_date(self, risk: PortfoliosRisk, max_date=None):
        assert risk, f"get last rebalance date, risk can not be none"
        last = rhp.get_last_one(max_date=max_date, risk=risk, rebalance=True)
        return last['date'] if last else None

    def get_rebalance_date_by_signal(self, signal_id):
        last = rhp.get_last_one(signal_id=signal_id, rebalance=True)
        return last['date'] if last else None

    def get_portfolios_weight(self, day, risk: PortfoliosRisk):
        hold = rhp.get_one(day, risk)
        if hold:
            result = json.loads(hold['portfolios'])['weight']
            return {int(x[0]): x[1] for x in result.items()}
        return None

    def has_hold(self, risk: PortfoliosRisk) -> bool:
        return rhp.get_count(risk=risk) > 0

    def build_hold_portfolio(self, day, risk: PortfoliosRisk, force_mpt=False):
        last_nav = rhp.get_last_one(max_date=day, risk=risk)
        start = next_workday(last_nav['date']) if last_nav else self._executor.start_date
        try:
            while start <= day:
                if force_mpt:
                    logger.info(f'start to get normal portfolio for date[{format_date(start)}]')
                    self._mpt.get_portfolios(day=prev_workday(start), type=PortfoliosType.NORMAL, risk=risk)
                logger.info(f"start to build hold portfolio[{risk.name}] for date[{format_date(start)}]")
                signal = self._signal.get_signal(prev_workday(start), risk)
                if signal:
                    last_re_date = self.get_last_rebalance_date(risk=risk, max_date=start)
                    # 两次实际调仓最小间隔期，单位交易日
                    if last_re_date and len(workday_range(last_re_date, start)) <= self.interval_days:
                        self.no_rebalance(start, risk, last_nav)
                    else:
                        self.do_rebalance(start, risk, signal, last_nav)
                else:
                    self.no_rebalance(start, risk, last_nav)
                start = next_workday(start)
                last_nav = rhp.get_last_one(max_date=day, risk=risk)
        except Exception as e:
            logger.exception(f"build hold portfolio[{risk.name}] for date[{format_date(start)}] failure.", e)

    def do_rebalance(self, day, risk: PortfoliosRisk, signal, last_nav):
        weight = {int(x[0]): x[1] for x in json.loads(signal['portfolio']).items()}
        if last_nav:
            share = {int(x): y for x, y in json.loads(last_nav['portfolios'])['share'].items()}
            fund_div_tuple = self.get_navs_and_div(fund_ids=tuple(set(weight) | set(share)), day=day)
            navs = fund_div_tuple[0]
            fund_dividend = fund_div_tuple[1]
            fund_dividend = sum(
                map(lambda k: share[k] * fund_dividend[k], filter(lambda k: k in fund_dividend, share.keys())))
            dividend_acc = last_nav['div_acc'] + fund_dividend
            fund_av = round(sum([navs[x] * y for x, y in share.items()]), 4)
            fund_nav = fund_av + dividend_acc
            cash = last_nav['cash'] + fund_dividend
            div_forecast = last_nav['div_forecast']
            #     每年的首个季度调整配息
            if day.month in self._config.get('dividend-adjust-day'):
                asset_nav = last_nav['asset_nav']
                # 配息率
                div_rate = last_nav['div_forecast'] * 12 / asset_nav
                # 年配息率减去配息率差值超过基准配息率上下10%触发配息率重置
                if self.month_dividend > 0 and abs(
                        (self._config['dividend-rate'] - div_rate) / self._config['dividend-rate']) > \
                        self._config['dividend-drift-rate']:
                    # 以本月前一天的单位净值进行配息计算
                    div_forecast = last_nav['asset_nav'] * self.month_dividend
            asset_nav = fund_av + cash
            nav = last_nav['nav'] * asset_nav / last_nav['asset_nav']
            share = {x: fund_av * w / navs[x] for x, w in weight.items()}
            share_nav = {x: fund_nav * w / navs[x] for x, w in weight.items()}
        else:
            fund_av = self.init_nav
            fund_div_tuple = self.get_navs_and_div(fund_ids=tuple(weight), day=day)
            navs = fund_div_tuple[0]
            fund_dividend = 0
            cash = 0
            div_forecast = fund_av * self.month_dividend
            dividend_acc = 0
            nav = self.init_nav
            asset_nav = fund_av + cash
            funds = self._datum.get_datums(type=DatumType.FUND)
            funds_subscription_rate = {fund['id']: fund.get('subscriptionRate', 0) for fund in funds}
            share = {x: (1 - funds_subscription_rate[x]) * (fund_av * w) / navs[x] for x, w in weight.items()}
            share_nav = share
            #  初始买入扣手续费
            fee = sum(funds_subscription_rate[x] * (fund_av * w) for x, w in weight.items())
            fund_av = fund_av - fee
            fund_nav = fund_av
        rhp.insert({
            'date': day,
            'risk': risk,
            'signal_id': signal['id'],
            'div_forecast': div_forecast if div_forecast else last_nav['div_forecast'] if last_nav else None,
            'fund_div': fund_dividend,
            'div_acc': dividend_acc,
            'rebalance': True,
            'portfolios': {
                'weight': weight,
                'weight_nav': weight,
                'share': share,
                'share_nav': share_nav,
            },
            'fund_av': fund_av,
            'fund_nav': fund_nav,
            'nav': nav,
            'port_div': 0,
            'cash': cash,
            'asset_nav': asset_nav,
        })

    def no_rebalance(self, day, risk: PortfoliosRisk, last_nav):
        share = {int(x): y for x, y in json.loads(last_nav['portfolios'])['share'].items()}
        share_nav = {int(x): y for x, y in json.loads(last_nav['portfolios'])['share'].items()}
        fund_div_tuple = self.get_navs_and_div(fund_ids=tuple(share), day=day)
        navs = fund_div_tuple[0]
        fund_dividend = fund_div_tuple[1]
        # 配息当天配股
        for k in share_nav.keys():
            if k in fund_dividend:
                share_nav[k] = (share_nav[k] * fund_dividend[k]) / (share_nav[k] * navs[k]) + share_nav[k]
        fund_av = round(sum([navs[x] * y for x, y in share.items()]), 4)
        fund_nav = round(sum([navs[x] * y for x, y in share_nav.items()]), 4)
        weight = {x: round(y * navs[x] / fund_av, 2) for x, y in share.items()}
        weight_nav = {x: round(y * navs[x] / fund_av, 2) for x, y in share_nav.items()}
        weight = format_weight(weight)
        weight_nav = format_weight(weight_nav)
        port_div = 0
        fund_dividend = sum(
            map(lambda k: share[k] * fund_dividend[k], filter(lambda k: k in fund_dividend, share.keys())))
        dividend_acc = last_nav['div_acc']
        cash = last_nav['cash'] + fund_dividend
        if self.is_dividend_date(day):
            port_div = last_nav['div_forecast']
            cash += port_div
            asset_nav = fund_av + cash
            dividend_acc += port_div
            nav = last_nav['nav'] * (asset_nav + port_div) / last_nav['asset_nav']
        else:
            asset_nav = fund_av + cash
            nav = last_nav['nav'] * asset_nav / last_nav['asset_nav']
        rhp.insert({
            'date': day,
            'risk': risk,
            'div_forecast': last_nav['div_forecast'],
            'fund_div': fund_dividend,
            'div_acc': dividend_acc,
            'signal_id': last_nav['signal_id'],
            'rebalance': False,
            'portfolios': {
                'weight': weight,
                'weight_nav': weight_nav,
                'share': share,
                'share_nav': share_nav
            },
            'fund_av': fund_av,
            'fund_nav': fund_nav,
            'nav': nav,
            'cash': cash,
            'port_div': port_div,
            'asset_nav': asset_nav,
        })

    def get_navs_and_div(self, day, fund_ids):
        navs = pd.DataFrame(
            self._navs.get_fund_navs(fund_ids=fund_ids, max_date=day, min_date=day - datetime.timedelta(22)))
        dividend = navs.pivot_table(index='nav_date', columns='fund_id', values='dividend')
        nav_cal = navs.pivot_table(index='nav_date', columns='fund_id', values='nav_cal')
        navs = navs.pivot_table(index='nav_date', columns='fund_id', values='av')
        navs.fillna(method='ffill', inplace=True)
        nav_cal.fillna(method='ffill', inplace=True)
        dividend.fillna(value=0, inplace=True)
        dividend = dividend.reindex(pd.date_range(start=dividend.index.min(), end=day, freq='D'), fill_value=0)
        return dict(navs.iloc[-1]), dict(dividend.iloc[-1]), dict(nav_cal.iloc[-1])

    def clear(self, day=None, risk: PortfoliosRisk = None):
        rhp.delete(min_date=day, risk=risk)

    def is_dividend_date(self, day):
        div_date = self._config['dividend-date']
        div_date = date(day.year, day.month, div_date)
        if is_workday(div_date):
            return div_date.day == day.day
        else:
            return next_workday(div_date).day == day.day

    @property
    def month_dividend(self):
        return self._config['dividend-rate'] / 12

    @property
    def interval_days(self):
        return self._config['min-interval-days']

    @property
    def init_nav(self):
        return self._config['init-nav']


@component(bean_name='dividend-holder')
class InvTrustPortfoliosHolder(DividendPortfoliosHolder):

    def do_rebalance(self, day, risk: PortfoliosRisk, signal, last_nav):
        weight = {int(x[0]): x[1] for x in json.loads(signal['portfolio']).items()}
        dividend_acc = 0
        fund_dividend = 0
        if last_nav:
            # 若非首次配息
            share = {int(x): y for x, y in json.loads(last_nav['portfolios'])['share'].items()}
            # 参与配息的基金份额
            share_nav = {int(x): y for x, y in json.loads(last_nav['portfolios'])['share_nav'].items()}
            share_nodiv_nav = {int(x): y for x, y in json.loads(last_nav['portfolios'])['share_nodiv_nav'].items()}
            fund_div_tuple = self.get_navs_and_div(fund_ids=tuple(set(weight) | set(share)), day=day)
            navs = fund_div_tuple[0]
            fund_dividend = fund_div_tuple[1]
            nav_cals = fund_div_tuple[2]
            fund_dividend_nav = sum(
                map(lambda k: share_nav[k] * fund_dividend[k], filter(lambda k: k in fund_dividend, share_nav.keys())))
            fund_dividend = sum(
                map(lambda k: share[k] * fund_dividend[k], filter(lambda k: k in fund_dividend, share.keys())))
            dividend_acc = last_nav['div_acc'] + fund_dividend
            fund_av = round(sum([navs[x] * y for x, y in share.items()]), 4)
            fund_nav = round(sum([navs[x] * y for x, y in share_nav.items()]), 4)
            nav = round(sum([nav_cals[x] * y for x, y in share_nodiv_nav.items()]), 4)
            fund_nav += fund_dividend_nav
            asset_nav = fund_av
            share = {x: fund_av * w / navs[x] for x, w in weight.items()}
            #  若调仓当日，有基金产生配息
            share_nav = {x: fund_nav * w / navs[x] for x, w in weight.items()}
            share_nodiv_nav = {x: nav * w / nav_cals[x] for x, w in weight.items()}
            if self.is_transfer_workday(day):
                div_forecast = asset_nav * self.month_dividend
        else:
            fund_av = self.init_nav
            asset_nav = self.init_nav
            nav = self.init_nav
            fund_div_tuple = self.get_navs_and_div(fund_ids=tuple(weight), day=day)
            navs = fund_div_tuple[0]
            # 首次配息金额，做记录
            div_forecast = 0
            funds = self._datum.get_datums(type=DatumType.FUND)
            funds_subscription_rate = {fund['id']: fund.get('subscriptionRate', 0) for fund in funds}
            share = {x: (1 - funds_subscription_rate[x]) * (fund_av * w) / navs[x] for x, w in weight.items()}
            nav_cals = fund_div_tuple[2]
            share_nav = share
            # 不考虑配息
            share_nodiv_nav = {x: (1 - funds_subscription_rate[x]) * (fund_av * w) / nav_cals[x] for x, w in
                               weight.items()}
            #  初始买入扣手续费
            fee = sum(funds_subscription_rate[x] * (fund_av * w) for x, w in weight.items())
            fund_av = fund_av - fee
            fund_nav = fund_av

        rhp.insert({
            'date': day,
            'risk': risk,
            'signal_id': signal['id'],
            'fund_div': fund_dividend,
            'div_forecast': div_forecast if div_forecast else last_nav['div_forecast'] if last_nav else None,
            'div_acc': dividend_acc,
            'rebalance': True,
            'portfolios': {
                'weight': weight,
                'weight_nav': weight,
                'weight_nodiv_nav': weight,
                'share': share,
                'share_nav': share_nav,
                'share_nodiv_nav': share_nodiv_nav
            },
            'fund_av': fund_av,
            'fund_nav': fund_nav,
            'nav': nav,
            'port_div': 0,
            'asset_nav': asset_nav,
        })

    def is_transfer_workday(self, day):
        transfer_date = self._config['warehouse-transfer-date']
        # 获取当月第n天的日期
        transfer_date = date(day.year, day.month, transfer_date)
        first_work_day = transfer_date if is_workday(transfer_date) else next_workday(transfer_date)
        return day.day == first_work_day.day

    def no_rebalance(self, day, risk: PortfoliosRisk, last_nav):
        port_div = 0
        dividend_acc = last_nav['div_acc']
        share = {int(x): y for x, y in json.loads(last_nav['portfolios'])['share'].items()}
        share_nav = {int(x): y for x, y in json.loads(last_nav['portfolios'])['share_nav'].items()}
        share_nodiv_nav = {int(x): y for x, y in json.loads(last_nav['portfolios'])['share_nodiv_nav'].items()}
        fund_div_tuple = self.get_navs_and_div(fund_ids=tuple(share), day=day)
        navs = fund_div_tuple[0]
        fund_dividend = fund_div_tuple[1]
        nav_cals = fund_div_tuple[2]
        # 配息当天配股
        for k in share_nav.keys():
            if k in fund_dividend:
                share_nav[k] = (share_nav[k] * fund_dividend[k]) / navs[k] + share_nav[k]
        # 配息日当天取得调仓日计算的应调仓金额,做实际份额赎回,这里的金额(即月初计算的赎回金额)用于转换成“赎回目标的份额”
        need_div = last_nav['div_forecast']
        if self.is_dividend_date(day) and need_div > 0:
            funds = self._datum.get_datums(type=DatumType.FUND, ticker=self._config['redeem-list'])
            self.exec_redeem(funds, navs, need_div, share)
            self.exec_redeem(funds, navs, need_div, share_nav)
            port_div = last_nav['div_forecast']
        fund_dividend = sum(
            map(lambda k: share[k] * fund_dividend[k], filter(lambda k: k in fund_dividend, share.keys())))
        dividend_acc = dividend_acc + port_div + fund_dividend
        fund_av = round(sum([navs[x] * y for x, y in share.items()]), 4)
        nav = round(sum([nav_cals[x] * y for x, y in share_nodiv_nav.items()]), 4)
        fund_nav = round(sum([navs[x] * y for x, y in share_nav.items()]), 4)
        weight = {x: round(y * navs[x] / fund_av, 2) for x, y in share.items()}
        nodiv_nav = round(sum([nav_cals[x] * y for x, y in share_nav.items()]), 4)
        weight_nodiv_nav = {x: round(y * nav_cals[x] / nodiv_nav, 2) for x, y in share_nav.items()}
        weight_nav = {x: round(y * navs[x] / fund_av, 2) for x, y in share_nav.items()}
        weight = format_weight(weight)
        weight_nav = format_weight(weight_nav)
        weight_nodiv_nav = format_weight(weight_nodiv_nav)
        asset_nav = fund_av
        div_forecast = last_nav['div_forecast']
        if self.is_transfer_workday(day):
            div_forecast = asset_nav * self.month_dividend
        rhp.insert({
            'date': day,
            'risk': risk,
            'fund_div': fund_dividend,
            'div_forecast': div_forecast,
            'div_acc': dividend_acc,
            'signal_id': last_nav['signal_id'],
            'rebalance': False,
            'portfolios': {
                'weight': weight,
                'weight_nav': weight_nav,
                'weight_nodiv_nav': weight_nodiv_nav,
                'share': share,
                'share_nav': share_nav,
                'share_nodiv_nav': share_nodiv_nav
            },
            'fund_av': fund_av,
            'fund_nav': fund_nav,
            'nav': nav,
            'port_div': port_div,
            'asset_nav': asset_nav,
        })

    def exec_redeem(self, funds, navs, need_div, share):
        # 获取需要配息的金额
        for fund in funds:
            if fund['id'] in share.keys():
                #  按配息金额依次扣除对应基金份额
                if share[fund['id']] * navs[fund['id']] <= need_div:
                    share[fund['id']] = 0
                    need_div = need_div - share[fund['id']] * navs[fund['id']]
                else:
                    share[fund['id']] = (share[fund['id']] * navs[fund['id']] - need_div) / navs[fund['id']]
                    break


@component(bean_name='hold-report')
class DivHoldReportor(RoboReportor):

    @property
    def report_name(self) -> str:
        return '投組淨值'

    def load_report(self, max_date=dt.today(), min_date=None) -> List[dict]:
        holds = pd.DataFrame(rhp.get_list(max_date=max_date, min_date=min_date))
        if not holds.empty:
            holds['signal_type'] = 'INIT'
            holds['real_av'] = holds['asset_nav']
            holds = holds[
                ['date', 'signal_type', 'fund_av', 'fund_nav', 'fund_div', 'cash', 'real_av', 'port_div', 'div_acc',
                 'acc_av', 'nav']]
            return holds.to_dict('records')
        return []


@component(bean_name='daily-hold-report')
class DailyHoldReportor(RoboReportor):

    @autowired
    def __init__(self, datum: Datum = None):
        self._datum = datum

    @property
    def report_name(self) -> str:
        return '每日持倉信息'

    def load_report(self, max_date=prev_workday(dt.today()), min_date=None) -> List[dict]:
        holds = pd.DataFrame(rhp.get_list(max_date=max_date, min_date=min_date))
        holds = holds[holds['date'].dt.date == max_date.date()]
        if not holds.empty:
            portfolio = rhp.get_last_one(max_date=max_date, rebalance=True)
            datum_ids = reduce(lambda x, y: x | y,
                               holds['portfolios'].apply(lambda x: set(json.loads(x)['weight'].keys())))
            datums = pd.DataFrame(self._datum.get_datums(type=DatumType.FUND, datum_ids=datum_ids))
            datums.set_index('id', inplace=True)

            holds['rebalance_type'] = holds.apply(lambda row: PortfoliosType.NORMAL.name, axis=1)
            holds['rebalance_date'] = holds.apply(lambda row: prev_workday(portfolio['date']), axis=1)
            holds['risk'] = holds.apply(lambda row: PortfoliosRisk(row['risk']).name, axis=1)
            holds['portfolios'] = holds.apply(lambda row: [x for x in json.loads(row['portfolios'])['weight'].items()],
                                              axis=1)
            holds = holds.explode('portfolios', ignore_index=True)
            holds['weight'] = holds.apply(lambda row: format(row['portfolios'][1], '.0%'), axis=1)
            holds['asset_ids'] = holds.apply(lambda row: datums.loc[int(row['portfolios'][0])]['ftTicker'], axis=1)
            holds['name'] = holds.apply(lambda row: datums.loc[int(row['portfolios'][0])]['chineseName'], axis=1)
            holds['lipper_id'] = holds.apply(lambda row: datums.loc[int(row['portfolios'][0])]['lipperKey'], axis=1)
            holds = holds[
                ['lipper_id', 'asset_ids', 'name', 'weight', 'risk', 'date', 'rebalance_type', 'rebalance_date']]
            return holds.to_dict('records')
        return []


@component(bean_name='daily-mpt-report')
class DailyMptReportor(RoboReportor):

    @autowired
    def __init__(self, hold: PortfoliosHolder = None, datum: Datum = None):
        self._hold = hold
        self._datum = datum

    @property
    def report_name(self) -> str:
        return '每日mpt'

    def load_report(self, max_date=prev_workday(dt.today()), min_date=None) -> List[dict]:
        from portfolios.dao import robo_mpt_portfolios as rmp
        signals = pd.DataFrame([rmp.get_last_one(type=PortfoliosType.NORMAL,risk=PortfoliosRisk.FT3)])
        if not signals.empty:
            datum_ids = reduce(lambda x, y: x | y, signals['portfolio'].apply(lambda x: set(json.loads(x).keys())))
            datums = pd.DataFrame(self._datum.get_datums(type=DatumType.FUND, datum_ids=datum_ids))
            datums.set_index('id', inplace=True)

            signals['risk'] = signals.apply(lambda row: PortfoliosRisk(row['risk']).name, axis=1)
            signals['portfolio'] = signals.apply(lambda row: [x for x in json.loads(row['portfolio']).items()], axis=1)
            signals = signals.explode('portfolio', ignore_index=True)

            signals['weight'] = signals.apply(lambda row: format(row['portfolio'][1], '.0%'), axis=1)
            signals['asset_ids'] = signals.apply(lambda row: datums.loc[int(row['portfolio'][0])]['ftTicker'], axis=1)
            signals['name'] = signals.apply(lambda row: datums.loc[int(row['portfolio'][0])]['chineseName'], axis=1)
            signals['lipper_id'] = signals.apply(lambda row: datums.loc[int(row['portfolio'][0])]['lipperKey'], axis=1)

            signals = signals[['lipper_id', 'asset_ids', 'name', 'weight', 'risk', 'date']]
            return signals.to_dict('records')

        return []
