import pandas as pd
from py_jftech import autowired, get_config

from api import DatumType, Datum

risk_dict = {}


@autowired
def build_risk_dict(datum: Datum = None):
    global risk_dict
    if risk_dict:
        pass
    else:
        funds = datum.get_datums(type=DatumType.FUND)
        risk_dict = {fund['id']: fund['risk'] for fund in funds}


def format_weight(weight: dict, to=1) -> dict:
    """
    对权重的小数点进行截取,到指定权重
    @param datum:
    @param weight:
    @param to: 指定权重
    @return:
    """

    # funds = datum.get_datums(type=DatumType.FUND)
    # risk_dict = {fund['id']: fund['risk'] for fund in funds}
    # risk = 0
    # for k, v in weight.items():
    #     risk += risk_dict.get(int(k)) * v
    # print(risk)
    build_risk_dict()
    weight_series = pd.Series(weight)
    weight_series = weight_series.fillna(0)
    weight_series = weight_series.apply(lambda x: round(x, 2))
    if weight_series.sum() == to:
        return dict(weight_series)
    id_sort = sorted(weight_series.to_dict().keys(), key=lambda x: risk_dict.get(int(x)))
    low = get_config('portfolios.solver.mpt.low-weight')
    high = get_config('portfolios.solver.mpt.high-weight')[0]
    # 低风险
    minidx = [i for i in id_sort if weight_series[i] < high][0]
    # 高风险
    maxidx = [i for i in id_sort if weight_series[i] > low][-1]
    if weight_series.sum() < to:
        weight_series[minidx] += to - weight_series.sum()
    elif weight_series.sum() > to:
        weight_series[maxidx] += to - weight_series.sum()
    return dict(weight_series.apply(lambda x: round(float(x), 2)))


if __name__ == '__main__':
    format_weight({"5": 0.35, "6": 0.35, "10": 0.1, "11": 0.16, "22": 0.05})