import os
import sys
from logging import DEBUG, getLogger

import pandas as pd
from dateutil.relativedelta import relativedelta
from numpy import NAN
from py_jftech import component, autowired, get_config
from pyomo.environ import *

from api import SolverFactory as Factory, PortfoliosRisk, PortfoliosType, AssetPool, Navs, Solver, Datum, DatumType
from portfolios.utils import format_weight

logger = getLogger(__name__)


def create_solver():
    if sys.platform.find('win') == 0:
        executor = 'bonmin.exe'
    elif sys.platform == 'linux':
        executor = 'bonmin_linux'
    else:
        executor = 'bonmin_mac'
    return SolverFactory('Bonmin', executable=os.path.join(os.path.dirname(__file__), executor))


@component
class DefaultFactory(Factory):

    def create_solver(self, risk: PortfoliosRisk, type: PortfoliosType = PortfoliosType.NORMAL) -> Solver:
        return DefaultSolver(risk, type)


class DefaultSolver(Solver):

    @autowired
    def __init__(self, risk: PortfoliosRisk, type: PortfoliosType, assets: AssetPool = None, navs: Navs = None,
                 datum: Datum = None):
        self._category = None
        self._transfer_type = None
        self.__navs = None
        self.risk = risk
        self.type = type or PortfoliosType.NORMAL
        self._assets = assets
        self._navs = navs
        self._datum = datum
        self._config = get_config(__name__)
        self._solver = create_solver()
        self._solver.options['tol'] = float(self.get_config('tol') or 1E-10)

    @property
    def navs(self):
        return self.__navs

    @property
    def rtn_matrix(self):
        result = self.navs / self.navs.shift(self.get_config('matrix-rtn-days')) - 1
        result.dropna(inplace=True)
        return result

    @property
    def rtn_annualized(self):
        return list(self.rtn_matrix.mean() * 12)

    @property
    def sigma(self):
        rtn = (self.navs / self.navs.shift(1) - 1)[1:]
        return rtn.cov() * 252

    @property
    def rtn_history(self):
        result = self.rtn_matrix * 12
        return result.values

    @property
    def beta(self):
        return self.get_config('mpt.cvar-beta')

    @property
    def k_beta(self):
        return round(len(self.rtn_history) * self.beta + 0.499999)

    @property
    def quantile(self):
        return self.get_config('mpt.quantile')

    @property
    def category(self):
        return self._category

    @property
    def transfer_type(self):
        self._transfer_type = self.get_config("normal-ratio")
        return self._transfer_type

    def set_navs(self, navs):
        self.__navs = navs

    def set_category(self, category):
        self._category = category

    def solve_max_rtn(self):
        model = self.create_model()
        model.objective = Objective(expr=sum([model.w[i] * self.rtn_annualized[i] for i in model.indices]),
                                    sense=maximize)
        self._solver.solve(model)
        self.debug_solve_result(model)
        max_rtn = self.calc_port_rtn(model)
        max_var = self.calc_port_var(model)
        minCVaR_whenMaxR = self.calc_port_cvar(model)
        logger.debug({
            'max_rtn': max_rtn,
            'max_var': max_var,
            'minCVaR_whenMaxR': minCVaR_whenMaxR,
        })
        return max_rtn, max_var, minCVaR_whenMaxR

    def solve_min_rtn(self):
        model = self.create_model()
        model.objective = Objective(
            expr=sum([model.w[i] * model.w[j] * self.sigma.iloc[i, j] for i in model.indices for j in model.indices]),
            sense=minimize)
        self._solver.solve(model)
        self.debug_solve_result(model)
        min_rtn = self.calc_port_rtn(model)
        min_var = self.calc_port_var(model)
        maxCVaR_whenMinV = self.calc_port_cvar(model)
        logger.debug({
            'min_rtn': min_rtn,
            'min_var': min_var,
            'maxCVaR_whenMinV': maxCVaR_whenMinV,
        })
        return min_rtn, min_var, maxCVaR_whenMinV

    def solve_mpt(self, min_rtn, max_rtn):
        logger.debug(f'...... ...... ...... ...... ...... ...... ...... ...... '
                     f'MPT ... sub risk : pct_value = {self.quantile}')
        big_y = min_rtn + self.quantile * (max_rtn - min_rtn)
        logger.debug(f'big_Y = target_Return = {big_y}')
        model = self.create_model()
        model.cons_rtn = Constraint(expr=sum([model.w[i] * self.rtn_annualized[i] for i in model.indices]) >= big_y)
        model.objective = Objective(
            expr=sum([model.w[i] * model.w[j] * self.sigma.iloc[i, j] for i in model.indices for j in model.indices]),
            sense=minimize)
        result = self._solver.solve(model)
        if result.solver.termination_condition == TerminationCondition.infeasible:
            logger.debug('...... MPT: Infeasible Optimization Problem.')
            return None, None
        logger.debug('...... MPT: Has solution.')
        self.debug_solve_result(model)
        return self.calc_port_weight(model), self.calc_port_cvar(model)

    def solve_poem(self, min_rtn, max_rtn, base_cvar, max_cvar):
        k_history = len(self.rtn_history)
        quantile = self.quantile
        logger.debug(f'...... ...... ...... ...... ...... ...... ...... ...... '
                     f'POEM With CVaR constraints ... sub risk : pct_value = {quantile}')
        big_y = min_rtn + quantile * (max_rtn - min_rtn)
        small_y = base_cvar + (max_cvar - base_cvar) * self.get_config('poem.cvar-scale-factor') * quantile
        logger.debug(f'big_Y = target_Return = {big_y} | small_y = target_cvar = {small_y}')
        model = self.create_model()
        model.alpha = Var(domain=Reals)
        model.x = Var(range(k_history), domain=NonNegativeReals)
        model.cons_cvar_aux = Constraint(range(k_history), rule=lambda m, k: m.x[k] >= m.alpha - sum(
            [m.w[i] * self.rtn_history[k][i] for i in m.indices]))
        model.cons_rtn = Constraint(expr=sum([model.w[i] * self.rtn_annualized[i] for i in model.indices]) >= big_y)
        model.cons_cvar = Constraint(
            expr=model.alpha - (1 / self.k_beta) * sum([model.x[k] for k in range(k_history)]) >= small_y)
        result = self._solver.solve(model)
        if result.solver.termination_condition == TerminationCondition.infeasible:
            logger.debug('...... POEM: Infeasible Optimization Problem.')
            return None, None
        logger.debug('...... POEM: Has solution.')
        self.debug_solve_result(model)
        return self.calc_port_weight(model), self.calc_port_cvar(model)

    def calc_port_weight(self, model):
        id_list = self.navs.columns
        weight_list = []
        for i in model.indices:
            weight_list.append(model.w[i]._value * model.z[i]._value)
        df_w = pd.DataFrame(data=weight_list, index=id_list, columns=['weight'])
        df_w.replace(0, NAN, inplace=True)
        df_w.dropna(axis=0, inplace=True)
        df_w['weight'] = pd.Series(format_weight(dict(df_w['weight']), self.get_weight()))
        dict_w = df_w.to_dict()['weight']
        return dict_w

    def calc_port_rtn(self, model):
        return sum([model.w[i]._value * self.rtn_annualized[i] for i in model.indices])

    def calc_port_var(self, model):
        return sum([model.w[i]._value * model.w[j]._value * self.sigma.iloc[i, j] for i in model.indices for j in
                    model.indices])

    def calc_port_cvar(self, model):
        port_r_hist = []
        for k in range(len(self.rtn_history)):
            port_r_hist.append(
                sum([model.w[i]._value * model.z[i]._value * self.rtn_history[k][i] for i in model.indices]))
        port_r_hist.sort()
        return sum(port_r_hist[0: self.k_beta]) / self.k_beta

    def get_weight(self):
        # todo 根据self.risk找配置
        return self.transfer_type[self.category][0]

    def create_model(self):
        count = self.get_config('asset-count')
        min_count = count[0] if isinstance(count, list) else count
        max_count = count[1] if isinstance(count, list) else count
        min_count = min(min_count, len(self.rtn_annualized))

        low_weight = self.get_config('mpt.low-weight')
        high_weight = self.get_config('mpt.high-weight')
        if isinstance(high_weight, list):
            high_weight = high_weight[min(len(self.navs.columns), min_count, len(high_weight)) - 1]
        model = ConcreteModel()

        model.indices = range(0, len(self.navs.columns))
        model.w = Var(model.indices, domain=NonNegativeReals)
        model.z = Var(model.indices, domain=Binary)
        model.cons_sum_weight = Constraint(expr=sum([model.w[i] for i in model.indices]) == high_weight)
        model.cons_num_asset = Constraint(
            expr=inequality(min_count, sum([model.z[i] for i in model.indices]), max_count, strict=False))
        model.cons_bounds_low = Constraint(model.indices, rule=lambda m, i: m.z[i] * low_weight <= m.w[i])
        model.cons_bounds_up = Constraint(model.indices, rule=lambda m, i: m.z[i] * high_weight >= m.w[i])
        return model

    def reset_navs(self, day):
        asset_ids = self._assets.get_pool(day)
        datum = self._datum.get_datums(type=DatumType.FUND, datum_ids=asset_ids)
        asset_ids_group = {k: [d['id'] for d in datum if d['category'] == k] for k in set(d['category'] for d in datum)}
        navs_group = {}
        for category, asset_ids in asset_ids_group.items():
            min_date = day - relativedelta(**self.get_config('navs.range'))
            navs = pd.DataFrame(self._navs.get_fund_navs(fund_ids=asset_ids, max_date=day, min_date=min_date))
            navs = navs[navs['nav_date'].dt.day_of_week < 5]
            navs['nav_date'] = pd.to_datetime(navs['nav_date'])
            navs = navs.pivot_table(index='nav_date', columns='fund_id', values='nav_cal')
            navs = navs.sort_index()

            navs_nan = navs.isna().sum()
            navs.drop(columns=[x for x in navs_nan.index if navs_nan.loc[x] >= self.get_config('navs.max-nan.asset')],
                      inplace=True)
            navs_nan = navs.apply(lambda r: r.isna().sum() / len(r), axis=1)
            navs.drop(index=[x for x in navs_nan.index if navs_nan.loc[x] >= self.get_config('navs.max-nan.day')],
                      inplace=True)
            navs.fillna(method='ffill', inplace=True)
            if navs.iloc[0].isna().sum() > 0:
                navs.fillna(method='bfill', inplace=True)
            navs_group[category] = navs
        self.__navs = navs_group
        return navs_group

    def get_config(self, name):
        def load_config(config):
            for key in name.split('.'):
                if key in config:
                    config = config[key]
                else:
                    return None
            return config

        value = load_config(self._config[self.type.value] if self.type is not PortfoliosType.NORMAL else self._config)
        if value is None:
            value = load_config(self._config)
        return value[f'ft{self.risk.value}'] if value and isinstance(value,
                                                                     dict) and f'ft{self.risk.value}' in value else value

    def debug_solve_result(self, model):
        if logger.isEnabledFor(DEBUG):
            logger.debug('===============================')
            logger.debug('solution: id  |  w(id)')
            w_sum = 0
            for i in model.indices:
                if model.z[i]._value == 1:
                    logger.debug(f'{self.navs.columns[i]}  |  {model.w[i]._value}')
                    w_sum += model.w[i]._value
            logger.debug(f'w_sum = {w_sum}')
            logger.debug({
                'beta': self.beta,
                'kbeta': self.k_beta,
                'port_R': self.calc_port_rtn(model),
                'port_V': self.calc_port_cvar(model),
                'port_CVaR': self.calc_port_cvar(model)
            })
            logger.debug('-------------------------------')