from datetime import datetime as dt
from typing import List

import pandas as pd
from py_jftech import component, autowired
from empyrical import annual_return, annual_volatility, max_drawdown, sharpe_ratio

from api import RoboReportor, Navs


@component(bean_name='combo-report')
class ComboDatasReport(RoboReportor):

    @autowired(names={'hold_reportor': 'hold-report', 'benchmark': 'benckmark-report'})
    def __init__(self, hold_reportor: RoboReportor = None, benchmark: RoboReportor = None, navs: Navs = None):
        self._hold_reportor = hold_reportor
        self._benchmark = benchmark
        self._navs = navs

    @property
    def report_name(self) -> str:
        return '混合数据'

    def load_report(self, max_date=dt.today(), min_date=None) -> List[dict]:
        holds = pd.DataFrame(self._hold_reportor.load_report(max_date=max_date, min_date=min_date))
        if not holds.empty:
            holds['risk'] = holds.apply(lambda row: row.risk, axis=1)
            datas = holds.pivot_table(index='date', columns='risk', values='nav')

            benchmark = pd.DataFrame(self._benchmark.load_report(max_date=max_date, min_date=min_date))
            datas = datas.join(benchmark.set_index('date'))

            spx = pd.DataFrame(self._navs.get_index_close(ticker='SPX Index', min_date=min_date, max_date=max_date))
            spx = spx.pivot_table(index='date', columns='index_id', values='close')
            spx.columns = ['SPX']
            datas = datas.join(spx)

            datas.fillna(method='ffill', inplace=True)
            datas.reset_index(inplace=True)
            return datas.to_dict('records')
        return []