Commit 4fa60d2f authored by jichao's avatar jichao

依赖注入实现中

parent 27cd5b86
...@@ -43,12 +43,6 @@ class SolveType(Enum): ...@@ -43,12 +43,6 @@ class SolveType(Enum):
POEM = 2 POEM = 2
@unique
class DriftType(Enum):
PORTFOLIO_HIGH_WEIGHT = 1
DATE_CURVE = 2
@unique @unique
class SignalType(Enum): class SignalType(Enum):
INIT = 0 INIT = 0
...@@ -363,26 +357,12 @@ class DriftSolver(ABC): ...@@ -363,26 +357,12 @@ class DriftSolver(ABC):
''' '''
@abstractmethod @abstractmethod
def get_drift(self, day): def get_drift(self, day, risk: PortfoliosRisk):
'''
获取指定日期得漂移计算结果
:param day:
:return:
'''
pass
class DriftFactory(ABC):
'''
漂移解算器工厂
'''
@abstractmethod
def create_drift_solver(self, type: DriftType, **kwargs) -> DriftSolver:
''' '''
创建指定类型的漂移解算器 获取指定日期,指定风险等级的漂移计算结果
:param type: 指定的解算器类型 :param day: 指定日期
:param kwargs: 解算器需要的参数字典 :param risk: 指定风险等级
:return: 漂移解算器 :return: 漂移计算结果
''' '''
pass pass
......
...@@ -113,10 +113,12 @@ portfolios: ...@@ -113,10 +113,12 @@ portfolios:
mpt: mpt:
quantile: 0.1 quantile: 0.1
rebalance: rebalance:
drift: drift-solver:
date-curve: date-curve:
diff-threshold: 0.4 diff-threshold: 0.4
init-factor: 0.000000002 init-factor: 0.000000002
high-weight:
coef: 0.2
builder: builder:
disable-period: #自然日 disable-period: #自然日
normal: 10 normal: 10
...@@ -146,7 +148,6 @@ rebalance: ...@@ -146,7 +148,6 @@ rebalance:
diff-threshold: 0.4 diff-threshold: 0.4
init-factor: 0.000000002 init-factor: 0.000000002
high-low-buy: high-low-buy:
drift-coef: 0.2
threshold: [ 0.5, 0.8 ] threshold: [ 0.5, 0.8 ]
......
from api import DriftSolver, DriftFactory, DriftType from api import DriftSolver, PortfoliosRisk
from framework import component, autowired, get_config from framework import component, autowired, get_config
from rebalance.dao import robo_rebalance_signal as rrs from rebalance.dao import robo_rebalance_signal as rrs
@component def get_start_date():
class DefaultDriftFactory(DriftFactory): config = get_config('main')
return filter_weekend(config['start-date'])
def create_drift_solver(self, type: DriftType, **kwargs) -> DriftSolver:
if type is DriftType.DATE_CURVE:
return DateCurveDriftSolver()
elif type is DriftType.PORTFOLIO_HIGH_WEIGHT:
pass
return None
@component(bean_name='date-curve')
class DateCurve(DriftSolver): class DateCurve(DriftSolver):
def __init__(self): def __init__(self):
...@@ -27,13 +22,41 @@ class DateCurve(DriftSolver): ...@@ -27,13 +22,41 @@ class DateCurve(DriftSolver):
def init_factor(self): def init_factor(self):
return self._config['init-factor'] return self._config['init-factor']
def get_drift(self, day): def get_drift(self, day, risk: PortfoliosRisk):
last_re = rrs.get_last_one(max_date=day, risk=risk, effective=True) last_re = rrs.get_last_one(max_date=day, risk=risk, effective=True)
result = self.diff_threshold - self.init_factor * (day - last_re['date']).days ** 4 result = self.diff_threshold - self.init_factor * (day - last_re['date']).days ** 4
return max(0, result) return max(0, result)
@component(bean_name='high-weight')
class PortfolioHighWeight(DriftSolver): class PortfolioHighWeight(DriftSolver):
def get_drift(self, day): @autowired
pass def __init__(self, builder: PortfoliosBuilder = None, datum: Datum = None):
self._builder = builder
self._datum = datum
self._config = get_config(__name__)['high-weight']
@property
def drift_coef(self):
return self._config['coef']
def get_drift(self, day, risk: PortfoliosRisk):
drift = rwd.get_one(day, risk)
if not drift:
datum_ids = [x['id'] for x in self._datum.get_high_risk_datums(risk)]
last_one = rwd.get_last_one(max_date=day, risk=risk)
start = (next_workday(last_one['date'])) if last_one else get_start_date()
last_drift = last_one['drift'] if last_one else 0
for date in [x for x in pd.date_range(start, day, freq='D') if is_workday(x)]:
portfolio = self._builder.get_portfolios(date, risk)
weight = round(sum([x[1] for x in portfolio.items() if x[0] in datum_ids]), 2)
last_drift = (weight * self.drift_coef + (1 - self.drift_coef) * last_drift) if last_drift else weight
rwd.insert({
'date': date,
'risk': risk,
'weight': weight,
'drift': last_drift,
})
drift = rwd.get_last_one(day, risk)
return drift['drift']
from api import PortfoliosRisk, SignalType, Datum, PortfoliosHolder from api import PortfoliosRisk, SignalType, Datum, PortfoliosHolder, DriftSolver
from framework import component, autowired, get_config from framework import component, autowired, get_config
from rebalance.base_signal import BaseSignalBuilder from rebalance.base_signal import BaseSignalBuilder
from rebalance.dao import robo_rebalance_signal as rrs from rebalance.dao import robo_rebalance_signal as rrs
...@@ -7,11 +7,12 @@ from rebalance.dao import robo_rebalance_signal as rrs ...@@ -7,11 +7,12 @@ from rebalance.dao import robo_rebalance_signal as rrs
@component(bean_name='curve-drift') @component(bean_name='curve-drift')
class CurveDrift(BaseSignalBuilder): class CurveDrift(BaseSignalBuilder):
@autowired @autowired(names={'solver': 'date-curve'})
def __init__(self, datum: Datum = None, hold: PortfoliosHolder = None): def __init__(self, datum: Datum = None, hold: PortfoliosHolder = None, solver: DriftSolver = None):
super().__init__() super().__init__()
self._datum = datum self._datum = datum
self._hold = hold self._hold = hold
self._solver = solver
self._config = get_config(__name__) self._config = get_config(__name__)
@property @property
...@@ -34,8 +35,7 @@ class CurveDrift(BaseSignalBuilder): ...@@ -34,8 +35,7 @@ class CurveDrift(BaseSignalBuilder):
normal_weight = round(sum([x[1] for x in normal_portfolio.items() if x[0] in datum_ids]), 2) normal_weight = round(sum([x[1] for x in normal_portfolio.items() if x[0] in datum_ids]), 2)
hold_portfolio = self._hold.get_portfolios_weight(day, risk) hold_portfolio = self._hold.get_portfolios_weight(day, risk)
hold_weight = round(sum([x[1] for x in normal_portfolio.items() if x[0] in datum_ids]), 2) hold_weight = round(sum([x[1] for x in normal_portfolio.items() if x[0] in datum_ids]), 2)
threshold = self.diff_threshold - self.init_factor * (day - last_re['date']).days ** 4 return normal_weight - hold_weight >= self._solver.get_drift(day, risk)
return normal_weight - hold_weight >= max(0, threshold)
@property @property
def diff_threshold(self): def diff_threshold(self):
......
import pandas as pd import pandas as pd
from api import PortfoliosBuilder, SignalType, PortfoliosRisk, Datum from api import PortfoliosBuilder, SignalType, PortfoliosRisk, Datum, DriftSolver
from framework import component, autowired, get_config, filter_weekend, next_workday, is_workday from framework import component, autowired, get_config, filter_weekend, next_workday, is_workday
from rebalance.base_signal import BaseSignalBuilder from rebalance.base_signal import BaseSignalBuilder
from rebalance.dao import robo_weight_drift as rwd, robo_rebalance_signal as rrs from rebalance.dao import robo_weight_drift as rwd, robo_rebalance_signal as rrs
def get_start_date(): @component(bean_name='high-buy')
config = get_config('main') class HighBuySignal(BaseSignalBuilder):
return filter_weekend(config['start-date'])
class DriftSupport:
@autowired @autowired(names={'solver': 'high-weight'})
def __init__(self, builder: PortfoliosBuilder = None, datum: Datum = None): def __init__(self, solver: DriftSolver = None):
self._builder = builder super().__init__()
self._datum = datum
self._config = get_config(__name__) self._config = get_config(__name__)
self._solver = solver
@property
def drift_coef(self):
return self._config['drift-coef']
def get_threshold(self, risk: PortfoliosRisk):
threshold = self._config['threshold']
if isinstance(threshold, dict):
threshold = threshold[f'ft{risk.value}']
return threshold
def get_drift(self, day, risk: PortfoliosRisk):
drift = rwd.get_one(day, risk)
if not drift:
datum_ids = [x['id'] for x in self._datum.get_high_risk_datums(risk)]
last_one = rwd.get_last_one(max_date=day, risk=risk)
start = (next_workday(last_one['date'])) if last_one else get_start_date()
last_drift = last_one['drift'] if last_one else 0
for date in [x for x in pd.date_range(start, day, freq='D') if is_workday(x)]:
portfolio = self._builder.get_portfolios(date, risk)
weight = round(sum([x[1] for x in portfolio.items() if x[0] in datum_ids]), 2)
last_drift = (weight * self.drift_coef + (1 - self.drift_coef) * last_drift) if last_drift else weight
rwd.insert({
'date': date,
'risk': risk,
'weight': weight,
'drift': last_drift,
})
drift = rwd.get_last_one(day, risk)
return drift['drift']
@component(bean_name='high-buy')
class HighBuySignal(BaseSignalBuilder, DriftSupport):
@property @property
def include_last_type(self): def include_last_type(self):
...@@ -66,17 +28,23 @@ class HighBuySignal(BaseSignalBuilder, DriftSupport): ...@@ -66,17 +28,23 @@ class HighBuySignal(BaseSignalBuilder, DriftSupport):
def signal_type(self) -> SignalType: def signal_type(self) -> SignalType:
return SignalType.HIGH_BUY return SignalType.HIGH_BUY
def get_threshold(self, risk: PortfoliosRisk):
threshold = self._config['threshold']
if isinstance(threshold, dict):
threshold = threshold[f'ft{risk.value}']
return threshold
def is_trigger(self, day, risk: PortfoliosRisk) -> bool: def is_trigger(self, day, risk: PortfoliosRisk) -> bool:
last_re = rrs.get_last_one(max_date=day, risk=risk, effective=True) last_re = rrs.get_last_one(max_date=day, risk=risk, effective=True)
if last_re is None or SignalType(last_re['type']) not in self.include_last_type: if last_re is None or SignalType(last_re['type']) not in self.include_last_type:
return False return False
drift = self.get_drift(day, risk) drift = self._solver.get_drift(day, risk)
threshold = self.get_threshold(risk) threshold = self.get_threshold(risk)
return drift > threshold[1] return drift > threshold[1]
@component(bean_name='low-buy') @component(bean_name='low-buy')
class LowBuySignal(BaseSignalBuilder, DriftSupport): class LowBuySignal(HighBuySignal):
@property @property
def include_last_type(self): def include_last_type(self):
...@@ -94,6 +62,6 @@ class LowBuySignal(BaseSignalBuilder, DriftSupport): ...@@ -94,6 +62,6 @@ class LowBuySignal(BaseSignalBuilder, DriftSupport):
last_re = rrs.get_last_one(max_date=day, risk=risk, effective=True) last_re = rrs.get_last_one(max_date=day, risk=risk, effective=True)
if last_re is None or SignalType(last_re['type']) not in self.include_last_type: if last_re is None or SignalType(last_re['type']) not in self.include_last_type:
return False return False
drift = self.get_drift(day, risk) drift = self._solver.get_drift(day, risk)
threshold = self.get_threshold(risk) threshold = self.get_threshold(risk)
return threshold[0] < drift < threshold[1] return threshold[0] < drift < threshold[1]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment