import math
import os
import sys
from logging import DEBUG, getLogger

import numpy as np
import pandas as pd
from dateutil.relativedelta import relativedelta
from numpy import NAN
from py_jftech import component, autowired, get_config, filter_weekend
from pyomo.environ import *

from api import SolverFactory as Factory, PortfoliosRisk, PortfoliosType, AssetPool, Navs, Solver, Datum, DatumType
from portfolios.utils import format_weight

logger = getLogger(__name__)


def create_solver():
    if sys.platform.find('win') == 0:
        executor = 'bonmin.exe'
    elif sys.platform == 'linux':
        executor = 'bonmin_linux'
    else:
        executor = 'bonmin_mac'
    return SolverFactory('Bonmin', executable=os.path.join(os.path.dirname(__file__), executor))


@component
class DefaultFactory(Factory):

    def __init__(self):
        self._config = get_config(__name__)

    @property
    def solver_model(self):
        return self._config['model'].upper() if 'model' in self._config and self._config['model'] is not None else None

    def create_solver(self, risk: PortfoliosRisk = None, type: PortfoliosType = PortfoliosType.NORMAL) -> Solver:
        if self.solver_model == 'ARC':
            return ARCSolver(type=type, risk=risk)
        if self.solver_model == 'PRR':
            if risk == PortfoliosRisk.FT3:
                return PRRSolver(type=type, risk=risk)
        return DefaultSolver(type=type, risk=risk)


class DefaultSolver(Solver):

    @autowired
    def __init__(self, type: PortfoliosType, risk: PortfoliosRisk, assets: AssetPool = None, navs: Navs = None,
                 datum: Datum = None):
        self._category = None
        self._transfer_type = None
        self.__navs = None
        self.risk = risk
        self.type = type or PortfoliosType.NORMAL
        self._assets = assets
        self._navs = navs
        self._datum = datum
        self._config = get_config(__name__)
        self._solver = create_solver()
        self._solver.options['tol'] = float(self.get_config('tol') or 1E-10)

    @property
    def navs(self):
        return self.__navs

    @property
    def rtn_matrix(self):
        result = self.navs / self.navs.shift(self.get_config('matrix-rtn-days')) - 1
        result.dropna(inplace=True)
        return result

    @property
    def rtn_annualized(self):
        return list(self.rtn_matrix.mean() * 12)

    @property
    def sigma(self):
        rtn = (self.navs / self.navs.shift(1) - 1)[1:]
        return rtn.cov() * 252

    @property
    def risk_parity_sigma(self):
        return self.navs.cov()

    @property
    def rtn_history(self):
        result = self.rtn_matrix * 12
        return result.values

    @property
    def beta(self):
        return self.get_config('mpt.cvar-beta')

    @property
    def k_beta(self):
        return round(len(self.rtn_history) * self.beta + 0.499999)

    @property
    def quantile(self):
        return self.get_config('mpt.quantile')

    @property
    def category(self):
        return self._category

    @property
    def transfer_type(self):
        self._transfer_type = self.get_config("normal-ratio")
        return self._transfer_type

    def set_navs(self, navs):
        self.__navs = navs

    def set_category(self, category):
        self._category = category

    def solve_max_rtn(self):
        model = self.create_model()
        model.objective = Objective(expr=sum([model.w[i] * self.rtn_annualized[i] for i in model.indices]),
                                    sense=maximize)
        self._solver.solve(model)
        self.debug_solve_result(model)
        max_rtn = self.calc_port_rtn(model)
        max_var = self.calc_port_var(model)
        minCVaR_whenMaxR = self.calc_port_cvar(model)
        logger.debug({
            'max_rtn': max_rtn,
            'max_var': max_var,
            'minCVaR_whenMaxR': minCVaR_whenMaxR,
        })
        return max_rtn, max_var, minCVaR_whenMaxR

    def solve_min_rtn(self):
        model = self.create_model()
        model.objective = Objective(
            expr=sum([model.w[i] * model.w[j] * self.sigma.iloc[i, j] for i in model.indices for j in model.indices]),
            sense=minimize)
        self._solver.solve(model)
        self.debug_solve_result(model)
        min_rtn = self.calc_port_rtn(model)
        min_var = self.calc_port_var(model)
        maxCVaR_whenMinV = self.calc_port_cvar(model)
        logger.debug({
            'min_rtn': min_rtn,
            'min_var': min_var,
            'maxCVaR_whenMinV': maxCVaR_whenMinV,
        })
        return min_rtn, min_var, maxCVaR_whenMinV

    def solve_mpt(self, min_rtn, max_rtn):
        logger.debug(f'...... ...... ...... ...... ...... ...... ...... ...... '
                     f'MPT ... sub risk : pct_value = {self.quantile}')
        big_y = min_rtn + self.quantile * (max_rtn - min_rtn)
        logger.debug(f'big_Y = target_Return = {big_y}')
        model = self.create_model()
        model.cons_rtn = Constraint(expr=sum([model.w[i] * self.rtn_annualized[i] for i in model.indices]) >= big_y)
        model.objective = Objective(
            expr=sum([model.w[i] * model.w[j] * self.sigma.iloc[i, j] for i in model.indices for j in model.indices]),
            sense=minimize)
        result = self._solver.solve(model)
        if result.solver.termination_condition == TerminationCondition.infeasible:
            logger.debug('...... MPT: Infeasible Optimization Problem.')
            return None, None
        logger.debug('...... MPT: Has solution.')
        self.debug_solve_result(model)
        return self.calc_port_weight(model), self.calc_port_cvar(model)

    def solve_poem(self, min_rtn, max_rtn, base_cvar, max_cvar):
        k_history = len(self.rtn_history)
        quantile = self.quantile
        logger.debug(f'...... ...... ...... ...... ...... ...... ...... ...... '
                     f'POEM With CVaR constraints ... sub risk : pct_value = {quantile}')
        big_y = min_rtn + quantile * (max_rtn - min_rtn)
        small_y = base_cvar + (max_cvar - base_cvar) * self.get_config('poem.cvar-scale-factor') * quantile
        logger.debug(f'big_Y = target_Return = {big_y} | small_y = target_cvar = {small_y}')
        model = self.create_model()
        model.alpha = Var(domain=Reals)
        model.x = Var(range(k_history), domain=NonNegativeReals)
        model.cons_cvar_aux = Constraint(range(k_history), rule=lambda m, k: m.x[k] >= m.alpha - sum(
            [m.w[i] * self.rtn_history[k][i] for i in m.indices]))
        model.cons_rtn = Constraint(expr=sum([model.w[i] * self.rtn_annualized[i] for i in model.indices]) >= big_y)
        model.cons_cvar = Constraint(
            expr=model.alpha - (1 / self.k_beta) * sum([model.x[k] for k in range(k_history)]) >= small_y)
        result = self._solver.solve(model)
        if result.solver.termination_condition == TerminationCondition.infeasible:
            logger.debug('...... POEM: Infeasible Optimization Problem.')
            return None, None
        logger.debug('...... POEM: Has solution.')
        self.debug_solve_result(model)
        return self.calc_port_weight(model), self.calc_port_cvar(model)

    def solve_risk_parity(self):
        model = self.create_model()
        model.objective = Objective(expr=sum(
            [(model.z[i] * model.w[i] * (self.risk_parity_sigma.iloc[i] @ model.w) - model.z[j] * model.w[j] * (
                        self.risk_parity_sigma.iloc[j] @ model.w)) ** 2
             for i in model.indices for j in model.indices]), sense=minimize)
        self._solver.solve(model)
        return self.calc_port_weight(model)

    def calc_port_weight(self, model):
        id_list = self.navs.columns
        weight_list = []
        for i in model.indices:
            weight_list.append(model.w[i]._value * model.z[i]._value)
        df_w = pd.DataFrame(data=weight_list, index=id_list, columns=['weight'])
        df_w.replace(0, NAN, inplace=True)
        df_w.dropna(axis=0, inplace=True)
        df_w['weight'] = pd.Series(format_weight(dict(df_w['weight']), self.get_weight()))
        dict_w = df_w.to_dict()['weight']
        return dict_w

    def calc_port_rtn(self, model):
        return sum([model.w[i]._value * self.rtn_annualized[i] for i in model.indices])

    def calc_port_var(self, model):
        return sum([model.w[i]._value * model.w[j]._value * self.sigma.iloc[i, j] for i in model.indices for j in
                    model.indices])

    def calc_port_cvar(self, model):
        port_r_hist = []
        for k in range(len(self.rtn_history)):
            port_r_hist.append(
                sum([model.w[i]._value * model.z[i]._value * self.rtn_history[k][i] for i in model.indices]))
        port_r_hist.sort()
        return sum(port_r_hist[0: self.k_beta]) / self.k_beta

    def get_weight(self):
        # 根据asset-include中的对应key找配置
        return self.transfer_type[self.category][0]

    def create_model(self):
        count = self.get_config('asset-count')
        min_count = count[0] if isinstance(count, list) else count
        max_count = count[1] if isinstance(count, list) else count
        min_count = min(min_count, len(self.rtn_annualized))

        low_weight = self.get_config('mpt.low-weight')
        high_weight = self.get_weight()
        model = ConcreteModel()
        model.indices = range(0, len(self.navs.columns))
        model.w = Var(model.indices, domain=NonNegativeReals)
        model.z = Var(model.indices, domain=Binary)
        model.cons_sum_weight = Constraint(expr=sum([model.w[i] for i in model.indices]) == high_weight)
        model.cons_num_asset = Constraint(
            expr=inequality(min_count, sum([model.z[i] for i in model.indices]), max_count, strict=False))
        model.cons_bounds_low = Constraint(model.indices, rule=lambda m, i: m.z[i] * low_weight <= m.w[i])
        model.cons_bounds_up = Constraint(model.indices, rule=lambda m, i: m.z[i] * high_weight >= m.w[i])
        return model

    def reset_navs(self, day):
        asset_ids = self._assets.get_pool(day)
        datum = self._datum.get_datums(type=DatumType.FUND, datum_ids=asset_ids)
        category = list(get_config('asset-pool')['asset-optimize']['asset-include'].keys())[0]
        asset_ids_group = {k: [d['id'] for d in datum if d[category] == k] for k in set(d[category] for d in datum)}
        navs_group = {}
        for category, asset_ids in asset_ids_group.items():
            min_date = day - relativedelta(**self.get_config('navs.range'))
            navs = pd.DataFrame(self._navs.get_fund_navs(fund_ids=asset_ids, max_date=day, min_date=min_date))
            navs = navs[navs['nav_date'].dt.day_of_week < 5]
            navs['nav_date'] = pd.to_datetime(navs['nav_date'])
            navs = navs.pivot_table(index='nav_date', columns='fund_id', values='nav_cal')
            navs = navs.sort_index()

            navs_nan = navs.isna().sum()
            navs.drop(columns=[x for x in navs_nan.index if navs_nan.loc[x] >= self.get_config('navs.max-nan.asset')],
                      inplace=True)
            navs_nan = navs.apply(lambda r: r.isna().sum() / len(r), axis=1)
            navs.drop(index=[x for x in navs_nan.index if navs_nan.loc[x] >= self.get_config('navs.max-nan.day')],
                      inplace=True)
            navs.fillna(method='ffill', inplace=True)
            if navs.iloc[0].isna().sum() > 0:
                navs.fillna(method='bfill', inplace=True)
            navs_group[category] = navs
        self.__navs = navs_group
        return navs_group

    def get_config(self, name):
        def load_config(config):
            for key in name.split('.'):
                if key in config:
                    config = config[key]
                else:
                    return None
            return config

        value = load_config(self._config[self.type.value] if self.type is not PortfoliosType.NORMAL else self._config)
        if value is None:
            value = load_config(self._config)
        return value[f'ft{self.risk.value}'] if value and isinstance(value,
                                                                     dict) and f'ft{self.risk.value}' in value else value

    def debug_solve_result(self, model):
        if logger.isEnabledFor(DEBUG):
            logger.debug('===============================')
            logger.debug('solution: id  |  w(id)')
            w_sum = 0
            for i in model.indices:
                if model.z[i]._value == 1:
                    logger.debug(f'{self.navs.columns[i]}  |  {model.w[i]._value}')
                    w_sum += model.w[i]._value
            logger.debug(f'w_sum = {w_sum}')
            logger.debug({
                'beta': self.beta,
                'kbeta': self.k_beta,
                'port_R': self.calc_port_rtn(model),
                'port_V': self.calc_port_cvar(model),
                'port_CVaR': self.calc_port_cvar(model)
            })
            logger.debug('-------------------------------')


class ARCSolver(DefaultSolver):

    def __init__(self, type: PortfoliosType, risk: PortfoliosRisk, assets: AssetPool = None, navs: Navs = None,
                 datum: Datum = None):
        super().__init__(type, risk)
        self.__date = None

    @property
    def date(self):
        return self.__date

    def calc_port_weight(self, model):
        id_list = self.navs.columns
        weight_list = [model.w[i]._value * model.z[i]._value for i in model.indices]
        df_w = pd.DataFrame(data=weight_list, index=id_list, columns=['weight'])
        df_w.replace(0, math.nan, inplace=True)
        df_w.dropna(axis=0, inplace=True)
        df_w['weight'] = pd.Series(format_weight(dict(df_w['weight'])))
        dict_w = df_w.to_dict()['weight']
        return dict_w

    @property
    def max_count(self):
        count = self.get_config('asset-count')
        return count[1] if isinstance(count, list) else count

    @property
    def min_count(self):
        count = self.get_config('asset-count')
        return min(count[0] if isinstance(count, list) else count, len(self.rtn_annualized))

    def create_model(self):
        low_weight = self.get_config('mpt.low-weight')
        high_weight = self.get_config('mpt.high-weight')
        if isinstance(high_weight, list):
            high_weight = high_weight[min(len(self.navs.columns), self.min_count, len(high_weight)) - 1]

        model = ConcreteModel()

        model.indices = range(0, len(self.navs.columns))
        model.w = Var(model.indices, domain=NonNegativeReals)
        model.z = Var(model.indices, domain=Binary)
        model.cons_sum_weight = Constraint(expr=sum([model.w[i] for i in model.indices]) == 1)
        model.cons_num_asset = Constraint(
            expr=inequality(self.min_count, sum([model.z[i] for i in model.indices]), self.max_count, strict=False))
        model.cons_bounds_low = Constraint(model.indices, rule=lambda m, i: m.z[i] * low_weight <= m.w[i])
        model.cons_bounds_up = Constraint(model.indices, rule=lambda m, i: m.z[i] * high_weight >= m.w[i])
        if self._config['arc']:
            LARC = self._config['LARC']
            UARC = self._config['UARC']
            numARC = len(LARC)  # this is the M in the doc
            numAsset = len(self.navs.columns)

            # This should from DB. We just fake value here for developing the code
            datums = self._datum.get_datums(type=DatumType.FUND, datum_ids=list(self.navs.columns))
            AssetARC = np.array([x['customType'] for x in datums], dtype=int)
            # the above are input data from either config file or DB
            # the following are POEM / MPT code

            A = np.zeros((numARC, numAsset), dtype=int)
            for i in range(numAsset):
                A[AssetARC[i] - 1, i] = 1
            model.cons_arc_low = Constraint(range(numARC),
                                            rule=lambda m, i: LARC[i] <= sum([A[i, j] * m.w[j] for j in m.indices]))
            model.cons_arc_up = Constraint(range(numARC),
                                           rule=lambda m, i: UARC[i] >= sum([A[i, j] * m.w[j] for j in m.indices]))
        return model

    def reset_navs(self, day):
        self.__date = filter_weekend(day)
        asset_ids = self._assets.get_pool(self.date)
        asset_risk = self.get_config('navs.risk')
        datum = self._datum.get_datums(type=DatumType.FUND, datum_ids=asset_ids, risk=asset_risk)
        exclude = self.get_config('navs.exclude-asset-type') or []
        asset_ids = list(set(asset_ids) & set([x['id'] for x in datum if x['assetType'] not in exclude]))

        min_date = self.date - relativedelta(**self.get_config('navs.range'))
        navs = pd.DataFrame(self._navs.get_fund_navs(fund_ids=asset_ids, max_date=self.date, min_date=min_date))
        navs = navs[navs['nav_date'].dt.day_of_week < 5]
        navs['nav_date'] = pd.to_datetime(navs['nav_date'])
        navs = navs.pivot_table(index='nav_date', columns='fund_id', values='nav_cal')
        navs = navs.sort_index()

        navs_nan = navs.isna().sum()
        navs.drop(columns=[x for x in navs_nan.index if navs_nan.loc[x] >= self.get_config('navs.max-nan.asset')],
                  inplace=True)
        navs_nan = navs.apply(lambda r: r.isna().sum() / len(r), axis=1)
        navs.drop(index=[x for x in navs_nan.index if navs_nan.loc[x] >= self.get_config('navs.max-nan.day')],
                  inplace=True)
        navs.fillna(method='ffill', inplace=True)
        if navs.iloc[0].isna().sum() > 0:
            navs.fillna(method='bfill', inplace=True)
        self.set_navs(navs)


class PRRSolver(ARCSolver):

    def __init__(self, type: PortfoliosType, risk: PortfoliosRisk, assets: AssetPool = None, navs: Navs = None,
                 datum: Datum = None):
        super().__init__(type, risk)
        self.__risk = None

    def create_model(self):
        model = super(PRRSolver, self).create_model()
        # print(self.risks)
        # 创建一个空列表来存储第二列的值
        RR = []

        # 遍历字典的键值对
        for key, value in self.risks.items():
            # 将值添加到列表中
            RR.append(value)

        # 打印第二列的值
        # print(RR)
        minRRweightWithinTRR = 0.7 + self._config['brr']
        TRR = self._config['trr']
        # RR = np.zeros(len(self.navs.columns), dtype=int)
        # # Please note, RR should come from DB with real values. Here, we just assign fake values for coding
        # for i in range(len(self.navs.columns)):
        #     RR[i] = math.ceil((i + 1) / len(self.navs.columns) * 5)

        # the following code are real model code ::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
        model.cons_TRR = Constraint(expr=sum([model.w[i] * RR[i] for i in model.indices]) <= TRR)

        RR_LE_TRR = np.zeros(len(self.navs.columns), dtype=int)
        RR_in_1_5 = np.zeros(len(self.navs.columns), dtype=int)
        RR_EQ_5 = np.zeros(len(self.navs.columns), dtype=int)
        for i in range(len(self.navs.columns)):
            if RR[i] <= TRR:
                RR_LE_TRR[i] = 1

            if RR[i] > 1 and RR[i] < 5:
                RR_in_1_5[i] = 1
            elif RR[i] == 5:
                RR_EQ_5[i] = 1

        model.cons_RR_LE_TRR = Constraint(
            expr=sum([model.w[i] * RR_LE_TRR[i] for i in model.indices]) >= minRRweightWithinTRR)
        if TRR < 5:
            model.cons_RR_in_1_5 = Constraint(
                expr=sum([model.z[i] * (RR_in_1_5[i] * self.max_count - RR_EQ_5[i]) for i in model.indices]) >= 0)
        return model

    def reset_navs(self, day):
        super(PRRSolver, self).reset_navs(day=day)
        datums = self._datum.get_datums(type=DatumType.FUND, datum_ids=list(self.navs.columns))
        self.__risk = {x['id']: x['risk'] for x in datums}
        # self.__risk = {x['risk'] for x in datums}

    @property
    def risks(self):
        return self.__risk