import math
import os
import statistics
import sys
from logging import DEBUG, getLogger

import numpy as np
import pandas as pd
from dateutil.relativedelta import relativedelta
from numpy import NAN
from py_jftech import component, autowired, get_config, filter_weekend
from pyomo.environ import *

from api import SolverFactory as Factory, PortfoliosRisk, PortfoliosType, AssetPool, Navs, Solver, Datum, DatumType
from portfolios.utils import format_weight

logger = getLogger(__name__)


def create_solver():
    if sys.platform.find('win') == 0:
        executor = 'bonmin.exe'
    elif sys.platform == 'linux':
        executor = 'bonmin_linux'
    else:
        executor = 'bonmin_mac'
    return SolverFactory('Bonmin', executable=os.path.join(os.path.dirname(__file__), executor))


@component
class DefaultFactory(Factory):

    def __init__(self):
        self._config = get_config(__name__)

    @property
    def solver_model(self):
        return self._config['model'].upper() if 'model' in self._config and self._config['model'] is not None else None

    def create_solver(self, risk: PortfoliosRisk = None, type: PortfoliosType = PortfoliosType.NORMAL) -> Solver:
        if self.solver_model == 'ARC':
            return ARCSolver(type=type, risk=risk)
        if self.solver_model == 'PRR':
            if risk == PortfoliosRisk.FT3:
                return PRRSolver(type=type, risk=risk)
        return DefaultSolver(type=type, risk=risk)


class DefaultSolver(Solver):

    @autowired
    def __init__(self, type: PortfoliosType, risk: PortfoliosRisk, assets: AssetPool = None, navs: Navs = None,
                 datum: Datum = None):
        self._category = None
        self._transfer_type = None
        self.__navs = None
        self.risk = risk
        self.type = type or PortfoliosType.NORMAL
        self._assets = assets
        self._navs = navs
        self._datum = datum
        self._config = get_config(__name__)
        self._solver = create_solver()
        self._solver.options['tol'] = float(self.get_config('tol') or 1E-10)

    @property
    def navs(self):
        return self.__navs

    @property
    def rtn_matrix(self):
        result = self.navs / self.navs.shift(self.get_config('matrix-rtn-days')) - 1
        result.dropna(inplace=True)
        return result

    @property
    def rtn_annualized(self):
        if self.get_config('mpt.short-term-strength'):
            result = self.navs.rolling(window=5).mean() / self.navs.rolling(window=10).mean() - 1
            result.dropna(inplace=True)
            return list(result.iloc[-1])
        else:
            return list(self.rtn_matrix.mean() * 12)
    @property
    def sigma(self):
        rtn = (self.navs / self.navs.shift(1) - 1)[1:]
        return rtn.cov() * 252

    @property
    def risk_parity_sigma(self):
        if self.get_config('mpt.jf_etf'):
            rtn = (self.navs / self.navs.shift(20) - 1)[1:]
            return rtn.cov() * 252
        else:
            return self.navs.cov()

    @property
    def rtn_history(self):
        result = self.rtn_matrix * 12
        return result.values

    @property
    def beta(self):
        return self.get_config('mpt.cvar-beta')

    @property
    def k_beta(self):
        return round(len(self.rtn_history) * self.beta + 0.499999)

    @property
    def quantile(self):
        return self.get_config('mpt.quantile')

    @property
    def category(self):
        return self._category

    @property
    def transfer_type(self):
        self._transfer_type = self.get_config("normal-ratio")
        return self._transfer_type

    def set_navs(self, navs):
        self.__navs = navs

    def set_category(self, category):
        self._category = category

    def solve_max_rtn(self):
        model = self.create_model()
        model.objective = Objective(expr=sum([model.w[i] * self.rtn_annualized[i] for i in model.indices]),
                                    sense=maximize)
        self._solver.solve(model)
        self.debug_solve_result(model)
        max_rtn = self.calc_port_rtn(model)
        max_var = self.calc_port_var(model)
        minCVaR_whenMaxR = self.calc_port_cvar(model)
        logger.debug({
            'max_rtn': max_rtn,
            'max_var': max_var,
            'minCVaR_whenMaxR': minCVaR_whenMaxR,
        })
        return max_rtn, max_var, minCVaR_whenMaxR

    def solve_min_rtn(self):
        model = self.create_model()
        model.objective = Objective(
            expr=sum([model.w[i] * model.w[j] * self.sigma.iloc[i, j] for i in model.indices for j in model.indices]),
            sense=minimize)
        self._solver.solve(model)
        self.debug_solve_result(model)
        min_rtn = self.calc_port_rtn(model)
        min_var = self.calc_port_var(model)
        maxCVaR_whenMinV = self.calc_port_cvar(model)
        logger.debug({
            'min_rtn': min_rtn,
            'min_var': min_var,
            'maxCVaR_whenMinV': maxCVaR_whenMinV,
        })
        return min_rtn, min_var, maxCVaR_whenMinV

    def solve_mpt(self, min_rtn, max_rtn):
        logger.debug(f'...... ...... ...... ...... ...... ...... ...... ...... '
                     f'MPT ... sub risk : pct_value = {self.quantile}')
        big_y = min_rtn + self.quantile * (max_rtn - min_rtn)
        logger.debug(f'big_Y = target_Return = {big_y}')
        model = self.create_model()
        model.cons_rtn = Constraint(expr=sum([model.w[i] * self.rtn_annualized[i] for i in model.indices]) >= big_y)
        model.objective = Objective(
            expr=sum([model.w[i] * model.w[j] * self.sigma.iloc[i, j] for i in model.indices for j in model.indices]),
            sense=minimize)
        result = self._solver.solve(model)
        if result.solver.termination_condition == TerminationCondition.infeasible:
            logger.debug('...... MPT: Infeasible Optimization Problem.')
            return None, None
        logger.debug('...... MPT: Has solution.')
        self.debug_solve_result(model)
        return self.calc_port_weight(model), self.calc_port_cvar(model)

    def solve_poem(self, min_rtn, max_rtn, base_cvar, max_cvar):
        k_history = len(self.rtn_history)
        quantile = self.quantile
        logger.debug(f'...... ...... ...... ...... ...... ...... ...... ...... '
                     f'POEM With CVaR constraints ... sub risk : pct_value = {quantile}')
        big_y = min_rtn + quantile * (max_rtn - min_rtn)
        small_y = base_cvar + (max_cvar - base_cvar) * self.get_config('poem.cvar-scale-factor') * quantile
        logger.debug(f'big_Y = target_Return = {big_y} | small_y = target_cvar = {small_y}')
        model = self.create_model()
        model.alpha = Var(domain=Reals)
        model.x = Var(range(k_history), domain=NonNegativeReals)
        model.cons_cvar_aux = Constraint(range(k_history), rule=lambda m, k: m.x[k] >= m.alpha - sum(
            [m.w[i] * self.rtn_history[k][i] for i in m.indices]))
        model.cons_rtn = Constraint(expr=sum([model.w[i] * self.rtn_annualized[i] for i in model.indices]) >= big_y)
        model.cons_cvar = Constraint(
            expr=model.alpha - (1 / self.k_beta) * sum([model.x[k] for k in range(k_history)]) >= small_y)
        result = self._solver.solve(model)
        if result.solver.termination_condition == TerminationCondition.infeasible:
            logger.debug('...... POEM: Infeasible Optimization Problem.')
            return None, None
        logger.debug('...... POEM: Has solution.')
        self.debug_solve_result(model)
        return self.calc_port_weight(model), self.calc_port_cvar(model)

    def solve_risk_parity(self):
        model = self.create_model()
        model.objective = Objective(expr=sum(
            [(model.z[i] * model.w[i] * (self.risk_parity_sigma.iloc[i] @ model.w) - model.z[j] * model.w[j] * (
                    self.risk_parity_sigma.iloc[j] @ model.w)) ** 2
             for i in model.indices for j in model.indices]), sense=minimize)
        self._solver.solve(model)
        return self.calc_port_weight(model)

    def calc_port_weight(self, model):
        id_list = self.navs.columns
        weight_list = []
        for i in model.indices:
            weight_list.append(model.w[i]._value * model.z[i]._value)
        df_w = pd.DataFrame(data=weight_list, index=id_list, columns=['weight'])
        df_w.replace(0, NAN, inplace=True)
        df_w.dropna(axis=0, inplace=True)
        df_w['weight'] = pd.Series(format_weight(dict(df_w['weight']), self.get_weight()))
        dict_w = df_w.to_dict()['weight']
        return dict_w

    def calc_port_rtn(self, model):
        return sum([model.w[i]._value * self.rtn_annualized[i] for i in model.indices])

    def calc_port_var(self, model):
        return sum([model.w[i]._value * model.w[j]._value * self.sigma.iloc[i, j] for i in model.indices for j in
                    model.indices])

    def calc_port_cvar(self, model):
        port_r_hist = []
        for k in range(len(self.rtn_history)):
            port_r_hist.append(
                sum([model.w[i]._value * model.z[i]._value * self.rtn_history[k][i] for i in model.indices]))
        port_r_hist.sort()
        return sum(port_r_hist[0: self.k_beta]) / self.k_beta

    def get_weight(self):
        # 根据asset-include中的对应key找配置
        return self.transfer_type[self.category][0]

    def create_model(self):
        count = self.get_config('asset-count')
        min_count = count[0] if isinstance(count, list) else count
        max_count = count[1] if isinstance(count, list) else count
        min_count = min(min_count, len(self.rtn_annualized))

        low_weight = self.get_config('mpt.low-weight')
        high_weight = self.get_weight()
        model = ConcreteModel()
        model.indices = range(0, len(self.navs.columns))
        model.w = Var(model.indices, domain=NonNegativeReals)
        model.z = Var(model.indices, domain=Binary)
        model.cons_sum_weight = Constraint(expr=sum([model.w[i] for i in model.indices]) == high_weight)
        model.cons_num_asset = Constraint(
            expr=inequality(min_count, sum([model.z[i] for i in model.indices]), max_count, strict=False))
        model.cons_bounds_low = Constraint(model.indices, rule=lambda m, i: m.z[i] * low_weight <= m.w[i])
        model.cons_bounds_up = Constraint(model.indices, rule=lambda m, i: m.z[i] * high_weight >= m.w[i])
        return model

    def reset_navs(self, day):
        asset_ids = self._assets.get_pool(day)
        datum = self._datum.get_datums(type=DatumType.FUND, datum_ids=asset_ids)
        category = list(get_config('asset-pool')['asset-optimize']['asset-include'].keys())[0]
        asset_ids_group = {k: [d['id'] for d in datum if d[category] == k] for k in set(d[category] for d in datum)}
        navs_group = {}
        for category, asset_ids in asset_ids_group.items():
            min_date = day - relativedelta(**self.get_config('navs.range'))
            navs = pd.DataFrame(self._navs.get_fund_navs(fund_ids=asset_ids, max_date=day, min_date=min_date))
            navs = navs[navs['nav_date'].dt.day_of_week < 5]
            navs['nav_date'] = pd.to_datetime(navs['nav_date'])
            navs = navs.pivot_table(index='nav_date', columns='fund_id', values='nav_cal')
            navs = navs.sort_index()

            navs_nan = navs.isna().sum()
            navs.drop(columns=[x for x in navs_nan.index if navs_nan.loc[x] >= self.get_config('navs.max-nan.asset')],
                      inplace=True)
            navs_nan = navs.apply(lambda r: r.isna().sum() / len(r), axis=1)
            navs.drop(index=[x for x in navs_nan.index if navs_nan.loc[x] >= self.get_config('navs.max-nan.day')],
                      inplace=True)
            navs.fillna(method='ffill', inplace=True)
            if navs.iloc[0].isna().sum() > 0:
                navs.fillna(method='bfill', inplace=True)
            navs_group[category] = navs
        self.__navs = navs_group
        return navs_group

    def get_config(self, name):
        def load_config(config):
            for key in name.split('.'):
                if key in config:
                    config = config[key]
                else:
                    return None
            return config

        value = load_config(self._config[self.type.value] if self.type is not PortfoliosType.NORMAL else self._config)
        if value is None:
            value = load_config(self._config)
        return value[f'ft{self.risk.value}'] if value and isinstance(value,
                                                                     dict) and f'ft{self.risk.value}' in value else value

    def debug_solve_result(self, model):
        if logger.isEnabledFor(DEBUG):
            logger.debug('===============================')
            logger.debug('solution: id  |  w(id)')
            w_sum = 0
            for i in model.indices:
                if model.z[i]._value == 1:
                    logger.debug(f'{self.navs.columns[i]}  |  {model.w[i]._value}')
                    w_sum += model.w[i]._value
            logger.debug(f'w_sum = {w_sum}')
            logger.debug({
                'beta': self.beta,
                'kbeta': self.k_beta,
                'port_R': self.calc_port_rtn(model),
                'port_V': self.calc_port_cvar(model),
                'port_CVaR': self.calc_port_cvar(model)
            })
            logger.debug('-------------------------------')


class ARCSolver(DefaultSolver):

    def __init__(self, type: PortfoliosType, risk: PortfoliosRisk, assets: AssetPool = None, navs: Navs = None,
                 datum: Datum = None):
        super().__init__(type, risk)
        self.__date = None

    @property
    def date(self):
        return self.__date

    def calc_port_weight(self, model):
        id_list = self.navs.columns
        weight_list = [model.w[i]._value * model.z[i]._value for i in model.indices]
        df_w = pd.DataFrame(data=weight_list, index=id_list, columns=['weight'])
        df_w.replace(0, math.nan, inplace=True)
        df_w.dropna(axis=0, inplace=True)
        df_w['weight'] = pd.Series(format_weight(dict(df_w['weight'])))
        dict_w = df_w.to_dict()['weight']
        return dict_w

    @property
    def max_count(self):
        count = self.get_config('asset-count')
        return count[1] if isinstance(count, list) else count

    @property
    def min_count(self):
        count = self.get_config('asset-count')
        return min(count[0] if isinstance(count, list) else count, len(self.rtn_annualized))

    @property
    def LARC(self):
        if self._config.get('larc-index', -1) < 0:
            return self._config['LARC']
        cash_uarc = self.compute_cash_uarc
        LARC = self._config['LARC']
        LARC[self._config['larc-index']] = cash_uarc
        return LARC

    @property
    def compute_cash_uarc(self):
        ecos = self._datum.get_datums(ticker=['CPI YOY Index', 'FDTR Index'])
        cpi_id = [data['id'] for data in ecos if data['bloombergTicker'] == 'CPI YOY Index']
        fdtr_id = [data['id'] for data in ecos if data['bloombergTicker'] == 'FDTR Index']
        cpi = self._navs.get_last_eco_values(max_date=self.date, datum_id=cpi_id, count=2, by_release_date=True)
        cpi = statistics.mean([c['indicator'] for c in cpi])
        fdtr = self._navs.get_last_eco_values(max_date=self.date, datum_id=fdtr_id, by_release_date=True)['indicator']
        cash_uarc = round(
            self._config['fix-w'] + self._config['fdtr-w'] * fdtr + abs(cpi - self._config['cpi-expect']) *
            self._config['cpiyoy-w'], 2)
        cash_uarc = self._config['max-w'] if cash_uarc > self._config['max-w'] else cash_uarc
        return cash_uarc

    @property
    def UARC(self):
        if self._config.get('uarc-index', -1) < 0:
            return self._config['UARC']
        cash_uarc = self.compute_cash_uarc
        cash_uarc = self.get_config('mpt.low-weight') if cash_uarc < self.get_config('mpt.low-weight') else cash_uarc
        UARC = self._config['UARC']
        UARC[self._config['uarc-index']] = cash_uarc
        return UARC

    def create_model(self):
        low_weight = self.get_config('mpt.low-weight')
        high_weight = self.get_config('mpt.high-weight')
        if isinstance(high_weight, list):
            high_weight = high_weight[min(len(self.navs.columns), self.min_count, len(high_weight)) - 1]

        model = ConcreteModel()

        model.indices = range(0, len(self.navs.columns))
        model.w = Var(model.indices, domain=NonNegativeReals)
        model.z = Var(model.indices, domain=Binary)
        model.cons_sum_weight = Constraint(expr=sum([model.w[i] for i in model.indices]) == 1)
        model.cons_num_asset = Constraint(
            expr=inequality(self.min_count, sum([model.z[i] for i in model.indices]), self.max_count, strict=False))
        model.cons_bounds_low = Constraint(model.indices, rule=lambda m, i: m.z[i] * low_weight <= m.w[i])
        model.cons_bounds_up = Constraint(model.indices, rule=lambda m, i: m.z[i] * high_weight >= m.w[i])
        if self._config['arc']:
            LARC = self.LARC
            UARC = self.UARC
            numARC = len(LARC)  # this is the M in the doc
            numAsset = len(self.navs.columns)

            # This should from DB. We just fake value here for developing the code
            datums = self._datum.get_datums(type=DatumType.FUND, datum_ids=list(self.navs.columns))
            AssetARC = np.array([x['customType'] for x in datums], dtype=int)
            # the above are input data from either config file or DB
            # the following are POEM / MPT code

            A = np.zeros((numARC, numAsset), dtype=int)
            for i in range(numAsset):
                A[AssetARC[i] - 1, i] = 1
            model.cons_arc_low = Constraint(range(numARC),
                                            rule=lambda m, i: LARC[i] <= sum([A[i, j] * m.w[j] for j in m.indices]))
            model.cons_arc_up = Constraint(range(numARC),
                                           rule=lambda m, i: UARC[i] >= sum([A[i, j] * m.w[j] for j in m.indices]))
        return model

    def reset_navs(self, day):
        self.__date = filter_weekend(day)
        asset_ids = self._assets.get_pool(self.date)
        asset_risk = self.get_config('navs.risk')
        datum = self._datum.get_datums(type=DatumType.FUND, datum_ids=asset_ids, risk=asset_risk)
        exclude = self.get_config('navs.exclude-asset-type') or []
        asset_ids = list(set(asset_ids) & set([x['id'] for x in datum if x['assetType'] not in exclude]))

        min_date = self.date - relativedelta(**self.get_config('navs.range'))
        navs = pd.DataFrame(self._navs.get_fund_navs(fund_ids=asset_ids, max_date=self.date, min_date=min_date))
        navs = navs[navs['nav_date'].dt.day_of_week < 5]
        navs['nav_date'] = pd.to_datetime(navs['nav_date'])
        navs = navs.pivot_table(index='nav_date', columns='fund_id', values='nav_cal')
        navs = navs.sort_index()

        navs_nan = navs.isna().sum()
        navs.drop(columns=[x for x in navs_nan.index if navs_nan.loc[x] >= self.get_config('navs.max-nan.asset')],
                  inplace=True)
        navs_nan = navs.apply(lambda r: r.isna().sum() / len(r), axis=1)
        navs.drop(index=[x for x in navs_nan.index if navs_nan.loc[x] >= self.get_config('navs.max-nan.day')],
                  inplace=True)
        navs.fillna(method='ffill', inplace=True)
        if navs.iloc[0].isna().sum() > 0:
            navs.fillna(method='bfill', inplace=True)
        self.set_navs(navs)


class PRRSolver(ARCSolver):

    def __init__(self, type: PortfoliosType, risk: PortfoliosRisk, assets: AssetPool = None, navs: Navs = None,
                 datum: Datum = None):
        super().__init__(type, risk)
        self.__risk = None

    def create_model(self):
        model = super(PRRSolver, self).create_model()
        # print(self.risks)
        # 创建一个空列表来存储第二列的值
        RR = []

        # 遍历字典的键值对
        for key, value in self.risks.items():
            # 将值添加到列表中
            RR.append(value)

        # 打印第二列的值
        # print(RR)
        minRRweightWithinTRR = 0.7 + self._config['brr']
        TRR = self._config['trr']
        # RR = np.zeros(len(self.navs.columns), dtype=int)
        # # Please note, RR should come from DB with real values. Here, we just assign fake values for coding
        # for i in range(len(self.navs.columns)):
        #     RR[i] = math.ceil((i + 1) / len(self.navs.columns) * 5)

        # the following code are real model code ::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
        model.cons_TRR = Constraint(expr=sum([model.w[i] * RR[i] for i in model.indices]) <= TRR)

        RR_LE_TRR = np.zeros(len(self.navs.columns), dtype=int)
        RR_in_1_5 = np.zeros(len(self.navs.columns), dtype=int)
        RR_EQ_5 = np.zeros(len(self.navs.columns), dtype=int)
        for i in range(len(self.navs.columns)):
            if RR[i] <= TRR:
                RR_LE_TRR[i] = 1

            if RR[i] > 1 and RR[i] < 5:
                RR_in_1_5[i] = 1
            elif RR[i] == 5:
                RR_EQ_5[i] = 1

        model.cons_RR_LE_TRR = Constraint(
            expr=sum([model.w[i] * RR_LE_TRR[i] for i in model.indices]) >= minRRweightWithinTRR)
        if TRR < 5:
            model.cons_RR_in_1_5 = Constraint(
                expr=sum([model.z[i] * (RR_in_1_5[i] * self.max_count - RR_EQ_5[i]) for i in model.indices]) >= 0)
        return model

    def reset_navs(self, day):
        super(PRRSolver, self).reset_navs(day=day)
        datums = self._datum.get_datums(type=DatumType.FUND, datum_ids=list(self.navs.columns))
        self.__risk = {x['id']: x['risk'] for x in datums}
        # self.__risk = {x['risk'] for x in datums}

    @property
    def risks(self):
        return self.__risk
