import json
from abc import ABC
from datetime import datetime as dt
from datetime import timedelta
from functools import reduce
from typing import List

import pandas as pd
from py_jftech import component, autowired, get_config, prev_workday
from py_jftech import is_workday

from api import PortfoliosBuilder
from api import (
    PortfoliosRisk, RebalanceSignal, SignalType, PortfoliosType, PortfoliosHolder,
    RoboReportor, Datum, DatumType
)
from rebalance.dao import robo_rebalance_signal as rrs


@component(bean_name='base-signal')
class BaseRebalanceSignal(RebalanceSignal, ABC):

    @autowired
    def __init__(self, builder: PortfoliosBuilder = None):
        self._builder = builder

    def get_signal(self, day, risk: PortfoliosRisk):
        signal = rrs.get_one(type=self.signal_type, risk=risk, date=day)
        if signal:
            return signal
        trigger = self.need_rebalance(day, risk)
        if trigger:
            portfolio_type = self.portfolio_type
            portfolio = self._builder.get_portfolios(day, risk, portfolio_type)
            id = rrs.insert({
                'date': day,
                'type': self.signal_type,
                'risk': risk,
                'portfolio_type': portfolio_type,
                'portfolio': portfolio,
                'effective': 1
            })
            return rrs.get_by_id(id)
        return None

    def need_rebalance(self, day, risk: PortfoliosRisk) -> bool:
        # 若记录为空则，将传入日期作为初始日期，进行build
        signal = rrs.get_last_one(day, risk, SignalType.NORMAL, effective=None)
        if signal:
            frequency = get_config('portfolios')['holder']['warehouse-frequency']
            transfer_date = get_config('portfolios')['holder']['warehouse-transfer-date']
            date = pd.to_datetime(signal['date'].replace(day=transfer_date))
            # 说明发生了跨月份问题
            if signal['date'].day > transfer_date:
                if rrs.get_count(risk=PortfoliosRisk.FT3, effective=True) > 0:
                    date = date + pd.DateOffset(months=1)
            date = date + pd.DateOffset(months=frequency)
            date = date - timedelta(days=1)
            # 指定周期末的工作日
            date = date if is_workday(date) else prev_workday(date)
            if date == day:
                return True
            elif signal['date'] == day:
                return True
            else:
                return False
        else:
            return True

    @property
    def portfolio_type(self):
        return self.signal_type.p_type

    @property
    def signal_type(self) -> SignalType:
        return SignalType.NORMAL

    def get_last_signal(self, day, risk: PortfoliosRisk):
        last_re = rrs.get_last_one(max_date=day, risk=risk, effective=True)
        return last_re

    def clear(self, min_date=None, risk: PortfoliosRisk = None):
        rrs.delete(min_date=min_date, risk=risk)


@component(bean_name='signal-report')
class SignalReportor(RoboReportor):

    @autowired
    def __init__(self, hold: PortfoliosHolder = None, datum: Datum = None):
        self._hold = hold
        self._datum = datum

    @property
    def report_name(self) -> str:
        return '調倉信號'

    def load_report(self, max_date=dt.today(), min_date=None) -> List[dict]:
        result = []
        datums = {str(x['id']): x for x in self._datum.get_datums(type=DatumType.FUND, exclude=False)}
        for signal in rrs.get_list(max_date=max_date, min_date=prev_workday(min_date), effective=True):
            rebalance_date = self._hold.get_rebalance_date_by_signal(signal['id'])
            if rebalance_date:
                for fund_id, weight in json.loads(signal['portfolio']).items():
                    result.append({
                        'risk': PortfoliosRisk(signal['risk']).name,
                        'type': SignalType(signal['type']).name,
                        'signal_date': signal['date'],
                        'rebalance_date': rebalance_date,
                        'portfolio_type': PortfoliosType(signal['portfolio_type']).name,
                        'ft_ticker': datums[fund_id]['ftTicker'],
                        'bloomberg_ticker': datums[fund_id]['bloombergTicker'],
                        'fund_name': datums[fund_id]['chineseName'],
                        'weight': weight
                    })
        return result


@component(bean_name='daily-signal-report')
class DailySignalReportor(RoboReportor):

    @autowired
    def __init__(self, hold: PortfoliosHolder = None, datum: Datum = None):
        self._hold = hold
        self._datum = datum

    @property
    def report_name(self) -> str:
        return '每月調倉信號'

    def load_report(self, max_date=prev_workday(dt.today()), min_date=None) -> List[dict]:
        signals = pd.DataFrame(rrs.get_list(max_date=max_date, min_date=min_date))
        # signals = signals[(signals['date'].dt.date == max_date.date())]
        signals = signals.tail(1)
        if not signals.empty:
            datum_ids = reduce(lambda x, y: x | y, signals['portfolio'].apply(lambda x: set(json.loads(x).keys())))
            datums = pd.DataFrame(self._datum.get_datums(type=DatumType.FUND, datum_ids=datum_ids))
            datums.set_index('id', inplace=True)

            signals['risk'] = signals.apply(lambda row: PortfoliosRisk(row['risk']).name, axis=1)
            signals['rebalance_type'] = signals.apply(lambda row: SignalType(row['type']).name, axis=1)
            signals['portfolio_type'] = signals.apply(lambda row: PortfoliosType(row['portfolio_type']).name, axis=1)
            signals['portfolio'] = signals.apply(lambda row: [x for x in json.loads(row['portfolio']).items()], axis=1)
            signals = signals.explode('portfolio', ignore_index=True)

            signals['weight'] = signals.apply(lambda row: format(row['portfolio'][1], '.0%'), axis=1)
            signals['asset_ids'] = signals.apply(lambda row: datums.loc[int(row['portfolio'][0])]['ftTicker'], axis=1)
            signals['name'] = signals.apply(lambda row: datums.loc[int(row['portfolio'][0])]['chineseName'], axis=1)
            signals['lipper_id'] = signals.apply(lambda row: datums.loc[int(row['portfolio'][0])]['lipperKey'], axis=1)

            signals = signals[['lipper_id', 'asset_ids', 'name', 'weight', 'risk', 'date', 'rebalance_type']]
            return signals.to_dict('records')

        return []
