import json
import os
import sys
from logging import DEBUG

import pandas as pd
from dateutil.relativedelta import relativedelta
from numpy import NAN
from pyomo.environ import *

from api import PortfoliosBuilder, PortfoliosRisk, AssetPool, Navs, PortfoliosType, Datum, SolveType
from framework import component, autowired, get_config, format_date, get_logger
from portfolios.dao import robo_mpt_portfolios as rmp

logger = get_logger(__name__)


def create_solver():
    if sys.platform.find('win') == 0:
        executor = 'bonmin.exe'
    elif sys.platform == 'linux':
        executor = 'bonmin_linux'
    else:
        executor = 'bonmin_mac'
    return SolverFactory('Bonmin', executable=os.path.join(os.path.dirname(__file__), executor))


class MptSolver:
    @autowired
    def __init__(self, risk: PortfoliosRisk, type: PortfoliosType, assets: AssetPool = None, navs: Navs = None,
                 datum: Datum = None):
        self.__navs = None
        self.risk = risk
        self.type = type or PortfoliosType.NORMAL
        self._assets = assets
        self._navs = navs
        self._datum = datum
        self._config = get_config(__name__)
        self._solver = create_solver()
        self._solver.options['tol'] = float(self.get_config('tol') or 1E-10)

    @property
    def navs(self):
        return self.__navs

    @property
    def rtn_matrix(self):
        result = self.navs / self.navs.shift(self.get_config('matrix-rtn-days')) - 1
        result.dropna(inplace=True)
        return result

    @property
    def rtn_annualized(self):
        return list(self.rtn_matrix.mean() * 12)

    @property
    def sigma(self):
        rtn = (self.navs / self.navs.shift(1) - 1)[1:]
        return rtn.cov() * 252

    @property
    def rtn_history(self):
        result = self.rtn_matrix * 12
        return result.values

    @property
    def beta(self):
        return self.get_config('mpt.cvar-beta')

    @property
    def k_beta(self):
        return round(len(self.rtn_history) * self.beta + 0.499999)

    @property
    def pct_value(self):
        return self.get_config('mpt.quantile')

    def solve_max_rtn(self):
        model = self.create_model()
        model.objective = Objective(expr=sum([model.w[i] * self.rtn_annualized[i] for i in model.indices]),
                                    sense=maximize)
        self._solver.solve(model)
        self.debug_solve_result(model)
        max_rtn = self.calc_port_rtn(model)
        max_var = self.calc_port_var(model)
        minCVaR_whenMaxR = self.calc_port_cvar(model)
        logger.debug({
            'max_rtn': max_rtn,
            'max_var': max_var,
            'minCVaR_whenMaxR': minCVaR_whenMaxR,
        })
        return max_rtn, max_var, minCVaR_whenMaxR

    def solve_min_rtn(self):
        model = self.create_model()
        model.objective = Objective(
            expr=sum([model.w[i] * model.w[j] * self.sigma.iloc[i, j] for i in model.indices for j in model.indices]),
            sense=minimize)
        self._solver.solve(model)
        self.debug_solve_result(model)
        min_rtn = self.calc_port_rtn(model)
        min_var = self.calc_port_var(model)
        maxCVaR_whenMinV = self.calc_port_cvar(model)
        logger.debug({
            'min_rtn': min_rtn,
            'min_var': min_var,
            'maxCVaR_whenMinV': maxCVaR_whenMinV,
        })
        return min_rtn, min_var, maxCVaR_whenMinV

    def solve_mpt(self, min_rtn, max_rtn):
        logger.debug(f'...... ...... ...... ...... ...... ...... ...... ...... '
                     f'MPT ... sub risk : pct_value = {self.pct_value}')
        big_y = min_rtn + self.pct_value * (max_rtn - min_rtn)
        logger.debug(f'big_Y = target_Return = {big_y}')
        model = self.create_model()
        model.cons_rtn = Constraint(expr=sum([model.w[i] * self.rtn_annualized[i] for i in model.indices]) >= big_y)
        model.objective = Objective(
            expr=sum([model.w[i] * model.w[j] * self.sigma.iloc[i, j] for i in model.indices for j in model.indices]),
            sense=minimize)
        result = self._solver.solve(model)
        if result.solver.termination_condition == TerminationCondition.infeasible:
            logger.debug('...... MPT: Infeasible Optimization Problem.')
            return None, None
        logger.debug('...... MPT: Has solution.')
        self.debug_solve_result(model)
        return self.calc_port_weight(model), self.calc_port_cvar(model)

    def solve_poem(self, min_rtn, max_rtn, base_cvar, max_cvar):
        k_history = len(self.rtn_history)
        quantile = self.pct_value
        logger.debug(f'...... ...... ...... ...... ...... ...... ...... ...... '
                     f'POEM With CVaR constraints ... sub risk : pct_value = {quantile}')
        big_y = min_rtn + quantile * (max_rtn - min_rtn)
        small_y = base_cvar + (max_cvar - base_cvar) * self.get_config('poem.cvar-scale-factor') * quantile
        logger.debug(f'big_Y = target_Return = {big_y} | small_y = target_cvar = {small_y}')
        model = self.create_model()
        model.alpha = Var(domain=Reals)
        model.x = Var(range(k_history), domain=NonNegativeReals)
        model.cons_cvar_aux = Constraint(range(k_history), rule=lambda m, k: m.x[k] >= m.alpha - sum(
            [m.w[i] * self.rtn_history[k][i] for i in m.indices]))
        model.cons_rtn = Constraint(expr=sum([model.w[i] * self.rtn_annualized[i] for i in model.indices]) >= big_y)
        model.cons_cvar = Constraint(
            expr=model.alpha - (1 / self.k_beta) * sum([model.x[k] for k in range(k_history)]) >= small_y)
        result = self._solver.solve(model)
        if result.solver.termination_condition == TerminationCondition.infeasible:
            logger.debug('...... POEM: Infeasible Optimization Problem.')
            return None, None
        logger.debug('...... POEM: Has solution.')
        self.debug_solve_result(model)
        return self.calc_port_weight(model), self.calc_port_cvar(model)

    def calc_port_weight(self, model):
        id_list = self.navs.columns
        weight_list = []
        for i in model.indices:
            weight_list.append(model.w[i]._value * model.z[i]._value)
        df_w = pd.DataFrame(data=weight_list, index=id_list, columns=['weight'])
        df_w.replace(0, NAN, inplace=True)
        df_w.dropna(axis=0, inplace=True)
        df_w['weight'] = self.format_weight(df_w['weight'])
        dict_w = df_w.to_dict()['weight']
        return dict_w

    @staticmethod
    def format_weight(weight_series):
        weight_series = weight_series.fillna(0)
        minidx = weight_series[weight_series > 0].idxmin()
        maxidx = weight_series.idxmax()
        weight_series = weight_series.apply(lambda x: round(x, 2))
        if weight_series.sum() < 1:
            weight_series[minidx] += 1 - weight_series.sum()
        elif weight_series.sum() > 1:
            weight_series[maxidx] += 1 - weight_series.sum()
        return weight_series.apply(lambda x: float(x))

    def calc_port_rtn(self, model):
        return sum([model.w[i]._value * self.rtn_annualized[i] for i in model.indices])

    def calc_port_var(self, model):
        return sum([model.w[i]._value * model.w[j]._value * self.sigma.iloc[i, j] for i in model.indices for j in
                    model.indices])

    def calc_port_cvar(self, model):
        port_r_hist = []
        for k in range(len(self.rtn_history)):
            port_r_hist.append(
                sum([model.w[i]._value * model.z[i]._value * self.rtn_history[k][i] for i in model.indices]))
        port_r_hist.sort()
        return sum(port_r_hist[0: self.k_beta]) / self.k_beta

    def create_model(self):
        count = self.get_config('asset-count')
        min_count = count[0] if isinstance(count, list) else count
        max_count = count[1] if isinstance(count, list) else count

        low_weight = self.get_config('mpt.low-weight')
        high_weight = self.get_config('mpt.high-weight')
        if isinstance(high_weight, list):
            high_weight = high_weight[min(len(self.navs.columns), min_count, len(high_weight)) - 1]

        model = ConcreteModel()

        model.indices = range(0, len(self.navs.columns))
        model.w = Var(model.indices, domain=NonNegativeReals)
        model.z = Var(model.indices, domain=Binary)
        model.cons_sum_weight = Constraint(expr=sum([model.w[i] for i in model.indices]) == 1)
        model.cons_num_asset = Constraint(
            expr=inequality(min_count, sum([model.z[i] for i in model.indices]), max_count, strict=False))
        model.cons_bounds_low = Constraint(model.indices, rule=lambda m, i: m.z[i] * low_weight <= m.w[i])
        model.cons_bounds_up = Constraint(model.indices, rule=lambda m, i: m.z[i] * high_weight >= m.w[i])
        return model

    def reset_navs(self, day):
        asset_ids = self._assets.get_pool(day)
        asset_risk = self.get_config('navs.risk')
        datum = self._datum.get_fund_datums(fund_ids=asset_ids, risk=asset_risk)
        exclude = self.get_config('navs.exclude-asset-type') or []
        asset_ids = list(set(asset_ids) & set([x['id'] for x in datum if x['assetType'] not in exclude]))

        min_date = day - relativedelta(months=self.get_config('navs.months'))
        navs = pd.DataFrame(self._navs.get_fund_navs(fund_ids=asset_ids, max_date=day, min_date=min_date))
        navs['nav_date'] = pd.to_datetime(navs['nav_date'])
        navs = navs.pivot_table(index='nav_date', columns='fund_id', values='nav_cal')
        navs = navs.sort_index()

        navs_nan = navs.isna().sum()
        navs.drop(columns=[x for x in navs_nan.index if navs_nan.loc[x] >= self.get_config('navs.max-nan.asset')],
                  inplace=True)
        navs_nan = navs.apply(lambda r: r.isna().sum() / len(r), axis=1)
        navs.drop(index=[x for x in navs_nan.index if navs_nan.loc[x] >= self.get_config('navs.max-nan.day')],
                  inplace=True)
        navs.fillna(method='ffill', inplace=True)
        self.__navs = navs

    def get_config(self, name):
        def load_config(config):
            for key in name.split('.'):
                if key in config:
                    config = config[key]
                else:
                    return None
            return config

        value = load_config(self._config[self.type.value] if self.type is not PortfoliosType.NORMAL else self._config)
        if value is None:
            value = load_config(self._config)
        return value[f'ft{self.risk.value}'] if value and isinstance(value, dict) else value

    def debug_solve_result(self, model):
        if logger.isEnabledFor(DEBUG):
            logger.debug('===============================')
            logger.debug('solution: id  |  w(id)')
            w_sum = 0
            for i in model.indices:
                if model.z[i]._value == 1:
                    logger.debug(f'{self.navs.columns[i]}  |  {model.w[i]._value}')
                    w_sum += model.w[i]._value
            logger.debug(f'w_sum = {w_sum}')
            logger.debug({
                'beta': self.beta,
                'kbeta': self.k_beta,
                'port_R': self.calc_port_rtn(model),
                'port_V': self.calc_port_cvar(model),
                'port_CVaR': self.calc_port_cvar(model)
            })
            logger.debug('-------------------------------')


@component(bean_name='mpt')
class MptPortfoliosBuilder(PortfoliosBuilder):

    @autowired
    def __init__(self, assets: AssetPool = None, navs: Navs = None, datum: Datum = None):
        self._assets = assets
        self._navs = navs
        self._datum = datum

    def get_portfolios(self, day, risk: PortfoliosRisk, type: PortfoliosType = PortfoliosType.NORMAL):
        portfolio = rmp.get_one(day, type, risk)
        if not portfolio:
            result, detail = self.build_portfolio(day, type)
            for build_risk, datas in result.items():
                rmp.insert({
                    **datas,
                    'risk': build_risk,
                    'type': type,
                    'date': day
                })
            portfolio = rmp.get_one(day, type, risk)
        return json.loads(portfolio['portfolio']) if SolveType(portfolio['solve']) is not SolveType.INFEASIBLE else None

    def build_portfolio(self, day, type: PortfoliosType):
        result = {}
        detail = {}
        for risk in PortfoliosRisk:
            logger.info(
                f"start build protfolio of type[{type.name}] and risk[{risk.name}] with date[{format_date(day)}]")
            solver = MptSolver(risk, type)
            solver.reset_navs(day)
            logger.debug({
                'Khist': len(solver.rtn_history),
                'beta': solver.get_config('mpt.cvar-beta'),
                'Kbeta': solver.k_beta,
            })
            max_rtn, max_var, minCVaR_whenMaxR = solver.solve_max_rtn()
            min_rtn, min_var, maxCVaR_whenMinV = solver.solve_min_rtn()
            portfolio, cvar = solver.solve_mpt(min_rtn, max_rtn)
            result[risk] = {
                'solve': SolveType.MPT,
                'portfolio': json.dumps(portfolio),
                'cvar': cvar
            } if portfolio else {
                'solve': SolveType.INFEASIBLE
            }
            detail[risk] = {
                'max_rtn': max_rtn,
                'max_var': max_var,
                'minCVaR_whenMaxR': minCVaR_whenMaxR,
                'min_rtn': min_rtn,
                'min_var': min_var,
                'maxCVaR_whenMinV': maxCVaR_whenMinV,
            }
        return result, detail


@component(bean_name='poem')
class PoemPortfoliosBuilder(MptPortfoliosBuilder):

    def build_portfolio(self, day, type: PortfoliosType):
        result, detail = super(PoemPortfoliosBuilder, self).build_portfolio(day, type)
        for risk in PortfoliosRisk:
            if result[risk]['solve'] is SolveType.INFEASIBLE:
                continue
            solver = MptSolver(risk, type)
            solver.reset_navs(day)
            min_rtn = detail[risk]['min_rtn']
            max_rtn = detail[risk]['max_rtn']
            mpt_cvar = result[risk]['cvar']
            maxCVaR_whenMinV = detail[risk]['maxCVaR_whenMinV']
            portfolio, cvar = solver.solve_poem(min_rtn, max_rtn, mpt_cvar, maxCVaR_whenMinV)
            if portfolio:
                result[risk] = {
                    'solve': SolveType.POEM,
                    'portfolio': json.dumps(portfolio),
                    'cvar': cvar
                }
                detail[risk]['mpt_cvar'] = mpt_cvar
        return result, detail