import json
import logging
from datetime import datetime as dt, date
from functools import reduce
from typing import List

import pandas as pd
from py_jftech import (
    component, autowired, get_config, next_workday, format_date, is_workday, prev_workday, workday_range
)

from api import PortfoliosHolder, PortfoliosRisk, Navs, RoboExecutor, PortfoliosType, PortfoliosBuilder, RoboReportor, \
    DatumType, Datum, RebalanceSignal
from portfolios.dao import robo_hold_portfolios as rhp
from portfolios.utils import format_weight

logger = logging.getLogger(__name__)


@component(bean_name='dividend-holder')
class DividendPortfoliosHolder(PortfoliosHolder):

    @autowired(names={'executor': RoboExecutor.use_name()})
    def __init__(self, navs: Navs = None, executor: RoboExecutor = None, builder: PortfoliosBuilder = None,
                 datum: Datum = None, mpt: PortfoliosBuilder = None, signal: RebalanceSignal = None):
        self._navs = navs
        self._executor = executor
        self._builder = builder
        self._config = get_config(__name__)
        self._datum = datum
        self._mpt = mpt
        self._signal = signal

    def get_portfolio_type(self, day, risk: PortfoliosRisk) -> PortfoliosType:
        return PortfoliosType.NORMAL

    def get_last_rebalance_date(self, risk: PortfoliosRisk, max_date=None):
        assert risk, f"get last rebalance date, risk can not be none"
        last = rhp.get_last_one(max_date=max_date, risk=risk, rebalance=True)
        return last['date'] if last else None

    def get_rebalance_date_by_signal(self, signal_id):
        last = rhp.get_last_one(signal_id=signal_id, rebalance=True)
        return last['date'] if last else None

    def get_portfolios_weight(self, day, risk: PortfoliosRisk):
        hold = rhp.get_one(day, risk)
        if hold:
            result = json.loads(hold['portfolios'])['weight']
            return {int(x[0]): x[1] for x in result.items()}
        return None

    def has_hold(self, risk: PortfoliosRisk) -> bool:
        return rhp.get_count(risk=risk) > 0

    def build_hold_portfolio(self, day, risk: PortfoliosRisk, force_mpt=False):
        last_nav = rhp.get_last_one(max_date=day, risk=risk)
        start = next_workday(last_nav['date']) if last_nav else self._executor.start_date
        try:
            while start <= day:
                if force_mpt:
                    logger.info(f'start to get normal portfolio for date[{format_date(start)}]')
                    self._mpt.get_portfolios(day=prev_workday(start), type=PortfoliosType.NORMAL, risk=risk)
                logger.info(f"start to build hold portfolio[{risk.name}] for date[{format_date(start)}]")
                signal = self._signal.get_signal(prev_workday(start), risk)
                if signal:
                    last_re_date = self.get_last_rebalance_date(risk=risk, max_date=start)
                    # 两次实际调仓最小间隔期,单位交易日
                    if last_re_date and len(workday_range(last_re_date, start)) <= self.interval_days:
                        self.no_rebalance(start, risk, last_nav)
                    else:
                        self.do_rebalance(start, risk, signal, last_nav)
                else:
                    self.no_rebalance(start, risk, last_nav)
                start = next_workday(start)
                last_nav = rhp.get_last_one(max_date=day, risk=risk)
        except Exception as e:
            logger.exception(f"build hold portfolio[{risk.name}] for date[{format_date(start)}] failure.", e)

    def do_rebalance(self, day, risk: PortfoliosRisk, signal, last_nav):
        weight = {int(x[0]): x[1] for x in json.loads(signal['portfolio']).items()}
        dividend_acc = 0
        fund_dividend = 0
        if last_nav:
            share = {int(x): y for x, y in json.loads(last_nav['portfolios'])['share'].items()}
            fund_div_tuple = self.get_navs_and_div(fund_ids=tuple(set(weight) | set(share)), day=day)
            navs = fund_div_tuple[0]
            fund_dividend = fund_div_tuple[1]
            fund_dividend = sum(
                map(lambda k: share[k] * fund_dividend[k], filter(lambda k: k in fund_dividend, share.keys())))
            dividend_acc = last_nav['div_acc']
            fund_av = round(sum([navs[x] * y for x, y in share.items()]), 4) + last_nav['fund_div']
            #     每年的首个季度调整配息
            if day.month in self._config.get('dividend-adjust-day'):
                asset_nav = last_nav['asset_nav']
                # 配息率
                div_rate = last_nav['div_forecast'] * 12 / asset_nav
                # 年配息率减去配息率差值超过基准配息率上下10%触发配息率重置
                if self.month_dividend > 0 and abs(
                        (self._config['dividend-rate'] - div_rate) / self._config['dividend-rate']) > \
                        self._config['dividend-drift-rate']:
                    # 以本月前一天的单位净值进行配息计算
                    dividend = last_nav['asset_nav'] * self.month_dividend
                else:
                    dividend = last_nav['div_forecast']
                fund_av = fund_av - dividend
                dividend_acc = dividend + dividend_acc
                div_forecast = dividend
            else:
                # 如果有未配息,则不再配息
                if last_nav['dividend'] > 0:
                    dividend = last_nav['dividend']
                else:
                    dividend = last_nav['div_forecast']
                    fund_av = fund_av - last_nav['div_forecast']
                    dividend_acc = last_nav['div_forecast'] + dividend_acc
            asset_nav = fund_av + fund_dividend + dividend
            nav = last_nav['nav'] * asset_nav / last_nav['asset_nav']
            share = {x: fund_av * w / navs[x] for x, w in weight.items()}
        else:
            fund_av = self.init_nav
            fund_div_tuple = self.get_navs_and_div(fund_ids=tuple(weight), day=day)
            navs = fund_div_tuple[0]
            dividend = fund_av * self.month_dividend
            div_forecast = dividend
            fund_av = fund_av - dividend
            dividend_acc = dividend + dividend_acc
            nav = self.init_nav
            asset_nav = fund_av + fund_dividend + dividend
            funds = self._datum.get_datums(type=DatumType.FUND)
            funds_subscription_rate = {fund['id']: fund.get('subscriptionRate', 0) for fund in funds}
            share = {x: (1 - funds_subscription_rate[x]) * (fund_av * w) / navs[x] for x, w in weight.items()}
            #  初始买入扣手续费
            fee = sum(funds_subscription_rate[x] * (fund_av * w) for x, w in weight.items())
            fund_av = fund_av - fee
        rhp.insert({
            'date': day,
            'risk': risk,
            'signal_id': signal['id'],
            'dividend': dividend,
            'div_forecast': div_forecast if div_forecast else last_nav['div_forecast'] if last_nav else None,
            'fund_div': fund_dividend,
            'div_acc': dividend_acc,
            'rebalance': True,
            'portfolios': {
                'weight': weight,
                'share': share,
            },
            'fund_av': fund_av,
            'nav': nav,
            'port_div': 0,
            'asset_nav': asset_nav,
        })

    def no_rebalance(self, day, risk: PortfoliosRisk, last_nav):
        share = {int(x): y for x, y in json.loads(last_nav['portfolios'])['share'].items()}
        fund_div_tuple = self.get_navs_and_div(fund_ids=tuple(share), day=day)
        navs = fund_div_tuple[0]
        fund_dividend = fund_div_tuple[1]
        fund_av = round(sum([navs[x] * y for x, y in share.items()]), 4)
        weight = {x: round(y * navs[x] / fund_av, 2) for x, y in share.items()}
        weight = format_weight(weight)
        dividend = last_nav['dividend']
        port_div = 0
        fund_dividend = last_nav['fund_div'] + sum(
            map(lambda k: share[k] * fund_dividend[k], filter(lambda k: k in fund_dividend, share.keys())))
        dividend_acc = last_nav['div_acc']
        if self.is_dividend_date(day):
            port_div = dividend
            asset_nav = fund_av + fund_dividend
            dividend = 0
            nav = last_nav['nav'] * (asset_nav + port_div) / last_nav['asset_nav']
        else:
            asset_nav = fund_av + fund_dividend + dividend
            nav = last_nav['nav'] * asset_nav / last_nav['asset_nav']
        rhp.insert({
            'date': day,
            'risk': risk,
            'dividend': dividend,
            'div_forecast': last_nav['div_forecast'],
            'fund_div': fund_dividend,
            'div_acc': dividend_acc,
            'signal_id': last_nav['signal_id'],
            'rebalance': False,
            'portfolios': {
                'weight': weight,
                'share': share,
            },
            'fund_av': fund_av,
            'nav': nav,
            'port_div': port_div,
            'asset_nav': asset_nav,
        })

    def get_navs_and_div(self, day, fund_ids):
        navs = pd.DataFrame(self._navs.get_fund_navs(fund_ids=fund_ids, max_date=day))
        dividend = navs.pivot_table(index='nav_date', columns='fund_id', values='dividend')
        navs = navs.pivot_table(index='nav_date', columns='fund_id', values='av')
        navs.fillna(method='ffill', inplace=True)
        dividend.fillna(method='ffill', inplace=True)
        return dict(navs.iloc[-1]), dict(dividend.iloc[-1])

    def clear(self, day=None, risk: PortfoliosRisk = None):
        rhp.delete(min_date=day, risk=risk)

    def is_dividend_date(self, day):
        div_date = self._config['dividend-date']
        div_date = date(day.year, day.month, div_date)
        if is_workday(div_date):
            return div_date.day == day.day
        else:
            return next_workday(div_date).day == day.day

    @property
    def month_dividend(self):
        return self._config['dividend-rate'] / 12

    @property
    def interval_days(self):
        return self._config['min-interval-days']

    @property
    def init_nav(self):
        return self._config['init-nav']


@component(bean_name='dividend-holder')
class InvTrustPortfoliosHolder(DividendPortfoliosHolder):

    def do_rebalance(self, day, risk: PortfoliosRisk, signal, last_nav):
        weight = {int(x[0]): x[1] for x in json.loads(signal['portfolio']).items()}
        dividend_acc = 0
        fund_dividend = 0
        if last_nav:
            # 若非首次配息
            share = {int(x): y for x, y in json.loads(last_nav['portfolios'])['share'].items()}
            fund_div_tuple = self.get_navs_and_div(fund_ids=tuple(set(weight) | set(share)), day=day)
            navs = fund_div_tuple[0]
            fund_dividend = fund_div_tuple[1]
            fund_dividend = sum(
                map(lambda k: share[k] * fund_dividend[k], filter(lambda k: k in fund_dividend, share.keys())))
            dividend_acc = last_nav['div_acc']
            fund_av = round(sum([navs[x] * y for x, y in share.items()]), 4) + last_nav['fund_div']
            asset_nav = fund_av + fund_dividend
            nav = last_nav['nav'] * asset_nav / last_nav['asset_nav']
            share = {x: fund_av * w / navs[x] for x, w in weight.items()}
            if self.is_first_workday(day):
                div_forecast = asset_nav * self.month_dividend
        else:
            fund_av = self.init_nav
            nav = self.init_nav
            asset_nav = self.init_nav
            fund_div_tuple = self.get_navs_and_div(fund_ids=tuple(weight), day=day)
            navs = fund_div_tuple[0]
            # 首次配息金额,做记录
            div_forecast = 0
            funds = self._datum.get_datums(type=DatumType.FUND)
            funds_subscription_rate = {fund['id']: fund.get('subscriptionRate', 0) for fund in funds}
            share = {x: (1 - funds_subscription_rate[x]) * (fund_av * w) / navs[x] for x, w in weight.items()}
            #  初始买入扣手续费
            fee = sum(funds_subscription_rate[x] * (fund_av * w) for x, w in weight.items())
            fund_av = fund_av - fee
        rhp.insert({
            'date': day,
            'risk': risk,
            'signal_id': signal['id'],
            'dividend': 0,
            'fund_div': fund_dividend,
            'div_forecast': div_forecast if div_forecast else last_nav['div_forecast'] if last_nav else None,
            'div_acc': dividend_acc,
            'rebalance': True,
            'portfolios': {
                'weight': weight,
                'share': share,
            },
            'fund_av': fund_av,
            'nav': nav,
            'port_div': 0,
            'asset_nav': asset_nav,
        })

    def is_first_workday(self, day):
        # 获取当月第一天的日期
        first_day = date(day.year, day.month, 1)
        first_work_day = first_day if is_workday(first_day) else next_workday(first_day)
        return day.day == first_work_day.day

    def no_rebalance(self, day, risk: PortfoliosRisk, last_nav):
        port_div = 0
        dividend_acc = last_nav['div_acc']
        share = {int(x): y for x, y in json.loads(last_nav['portfolios'])['share'].items()}
        fund_div_tuple = self.get_navs_and_div(fund_ids=tuple(share), day=day)
        navs = fund_div_tuple[0]
        fund_dividend = fund_div_tuple[1]
        # 配息日当天取得调仓日计算的应调仓金额,做实际份额赎回,这里的金额(即月初计算的赎回金额)用于转换成“赎回目标的份额”
        need_div = last_nav['div_forecast']
        if self.is_dividend_date(day) and need_div > 0:
            funds = self._datum.get_datums(type=DatumType.FUND, ticker=self._config['redeem-list'])
            # 获取需要配息的金额
            for fund in funds:
                if fund['id'] in share.keys():
                    #  按配息金额依次扣除对应基金份额
                    if share[fund['id']] * navs[fund['id']] <= need_div:
                        share[fund['id']] = 0
                        need_div = need_div - share[fund['id']] * navs[fund['id']]
                    else:
                        share[fund['id']] = (share[fund['id']] * navs[fund['id']] - need_div) / navs[fund['id']]
                        break
            port_div = last_nav['div_forecast']
            dividend_acc = dividend_acc + port_div
        fund_av = round(sum([navs[x] * y for x, y in share.items()]), 4)
        weight = {x: round(y * navs[x] / fund_av, 2) for x, y in share.items()}
        weight = format_weight(weight)
        fund_dividend = last_nav['fund_div'] + sum(
            map(lambda k: share[k] * fund_dividend[k], filter(lambda k: k in fund_dividend, share.keys())))
        asset_nav = fund_av + fund_dividend
        nav = last_nav['nav'] * asset_nav / last_nav['asset_nav']
        div_forecast = last_nav['div_forecast']
        if self.is_first_workday(day):
            div_forecast = asset_nav * self.month_dividend
        rhp.insert({
            'date': day,
            'risk': risk,
            'dividend': 0,
            'fund_div': fund_dividend,
            'div_forecast': div_forecast,
            'div_acc': dividend_acc,
            'signal_id': last_nav['signal_id'],
            'rebalance': False,
            'portfolios': {
                'weight': weight,
                'share': share,
            },
            'fund_av': fund_av,
            'nav': nav,
            'port_div': port_div,
            'asset_nav': asset_nav,
        })


@component(bean_name='hold-report')
class DivHoldReportor(RoboReportor):

    @property
    def report_name(self) -> str:
        return '投组净值'

    def load_report(self, max_date=dt.today(), min_date=None) -> List[dict]:
        holds = pd.DataFrame(rhp.get_list(max_date=max_date, min_date=min_date))
        if not holds.empty:
            holds['signal_type'] = 'INIT'
            holds['cash'] = holds['dividend'] + holds['fund_div']
            holds['real_av'] = holds['asset_nav']
            holds = holds[
                ['date', 'signal_type', 'fund_av', 'fund_div', 'cash', 'real_av', 'port_div', 'acc_av', 'nav']]
            return holds.to_dict('records')
        return []


@component(bean_name='daily-hold-report')
class DailyHoldReportor(RoboReportor):

    @autowired
    def __init__(self, datum: Datum = None):
        self._datum = datum

    @property
    def report_name(self) -> str:
        return '每日持仓信息'

    def load_report(self, max_date=prev_workday(dt.today()), min_date=None) -> List[dict]:

        holds = pd.DataFrame(rhp.get_list(max_date=max_date, min_date=min_date))
        holds = holds[holds['date'].dt.date == max_date.date()]
        if not holds.empty:
            portfolio = rhp.get_last_one(max_date=max_date, rebalance=True)
            datum_ids = reduce(lambda x, y: x | y,
                               holds['portfolios'].apply(lambda x: set(json.loads(x)['weight'].keys())))
            datums = pd.DataFrame(self._datum.get_datums(type=DatumType.FUND, datum_ids=datum_ids))
            datums.set_index('id', inplace=True)

            holds['rebalance_type'] = holds.apply(lambda row: PortfoliosType.NORMAL.name, axis=1)
            holds['rebalance_date'] = holds.apply(lambda row: prev_workday(portfolio['date']), axis=1)
            holds['risk'] = holds.apply(lambda row: PortfoliosRisk(row['risk']).name, axis=1)
            holds['portfolios'] = holds.apply(lambda row: [x for x in json.loads(row['portfolios'])['weight'].items()],
                                              axis=1)
            holds = holds.explode('portfolios', ignore_index=True)
            holds['weight'] = holds.apply(lambda row: format(row['portfolios'][1], '.0%'), axis=1)
            holds['asset_ids'] = holds.apply(lambda row: datums.loc[int(row['portfolios'][0])]['ftTicker'], axis=1)
            holds['name'] = holds.apply(lambda row: datums.loc[int(row['portfolios'][0])]['chineseName'], axis=1)
            holds['lipper_id'] = holds.apply(lambda row: datums.loc[int(row['portfolios'][0])]['lipperKey'], axis=1)
            holds = holds[
                ['lipper_id', 'asset_ids', 'name', 'weight', 'risk', 'date', 'rebalance_type', 'rebalance_date']]
            return holds.to_dict('records')
        return []