Commit 8166842f authored by 纪超's avatar 纪超

完善依赖注入可配置

parent f72a7872
...@@ -5,12 +5,9 @@ framework: ...@@ -5,12 +5,9 @@ framework:
user: ${MYSQL_USER:root} user: ${MYSQL_USER:root}
password: ${MYSQL_PWD:123456} password: ${MYSQL_PWD:123456}
dbname: ${MYSQL_DBNAME:jftech_robo} dbname: ${MYSQL_DBNAME:jftech_robo}
database_from: injectable:
host: ${MYSQL_HOST:127.0.0.1} types:
port: ${MYSQL_PORT:3306} api.PortfoliosBuilder: portfolios.builder.PoemPortfoliosBuilder
user: ${MYSQL_USER:root}
password: ${MYSQL_PWD:123456}
dbname: ${MYSQL_DBNAME:robo_pmpt}
email: email:
server: smtphz.qiye.163.com server: smtphz.qiye.163.com
user: jft-ra@thizgroup.com user: jft-ra@thizgroup.com
...@@ -19,7 +16,7 @@ framework: ...@@ -19,7 +16,7 @@ framework:
max-workers: 8 max-workers: 8
logger: logger:
version: 1 version: 1
use: prod use: ${LOG_NAME}
formatters: formatters:
brief: brief:
format: "%(asctime)s - %(levelname)s - %(message)s" format: "%(asctime)s - %(levelname)s - %(message)s"
...@@ -46,7 +43,7 @@ framework: ...@@ -46,7 +43,7 @@ framework:
level: INFO level: INFO
propagate: no propagate: no
portfolios: portfolios:
level: DEBUG level: INFO
root: root:
level: INFO level: INFO
handlers: [ console ] handlers: [ console ]
......
...@@ -5,11 +5,18 @@ from inspect import signature, Parameter ...@@ -5,11 +5,18 @@ from inspect import signature, Parameter
from typing import get_origin, get_args from typing import get_origin, get_args
from framework.base import get_project_path from framework.base import get_project_path
from framework.env_config import get_config
config = get_config(__name__)
types_config = config['types'] if 'types' in config and config['types'] else {}
names_config = config['names'] if 'names' in config and config['names'] else {}
__COMPONENT_CLASS = [] __COMPONENT_CLASS = []
__NAME_COMPONENT = {} __NAME_COMPONENT = {}
__COMPONENT_INSTANCE = {} __COMPONENT_INSTANCE = {}
class_name = lambda cls: f'{cls.__module__}.{cls.__name__}'
class InjectableError(Exception): class InjectableError(Exception):
def __init__(self, msg): def __init__(self, msg):
...@@ -26,7 +33,10 @@ def component(cls=None, bean_name=None): ...@@ -26,7 +33,10 @@ def component(cls=None, bean_name=None):
__COMPONENT_CLASS.append(cls) __COMPONENT_CLASS.append(cls)
if bean_name: if bean_name:
if bean_name in __NAME_COMPONENT: if bean_name in __NAME_COMPONENT:
raise InjectableError(f"bean name[{bean_name}] is already defined.") if bean_name not in names_config or names_config[bean_name] is None:
raise InjectableError(f"bean name[{bean_name}] is already defined.")
if class_name(cls) != names_config[bean_name]:
return cls
__NAME_COMPONENT[bean_name] = cls __NAME_COMPONENT[bean_name] = cls
return cls return cls
...@@ -55,8 +65,16 @@ def autowired(func=None, names=None): ...@@ -55,8 +65,16 @@ def autowired(func=None, names=None):
kwargs[p_name] = instances kwargs[p_name] = instances
else: else:
components = [x for x in __COMPONENT_CLASS if issubclass(x, p_type.annotation)] components = [x for x in __COMPONENT_CLASS if issubclass(x, p_type.annotation)]
if len(components) > 0: if len(components) == 1:
kwargs[p_name] = get_instance(components[0]) kwargs[p_name] = get_instance(components[0])
elif len(components) > 1:
cls = components[0]
if class_name(p_type.annotation) in types_config:
target_name = types_config[class_name(p_type.annotation)]
find_cls = [x for x in components if class_name(x) == target_name]
if find_cls:
cls = get_instance(find_cls[0])
kwargs[p_name] = get_instance(cls)
func(*args, **kwargs) func(*args, **kwargs)
return wrap return wrap
...@@ -70,10 +88,11 @@ def get_instance(t): ...@@ -70,10 +88,11 @@ def get_instance(t):
return __COMPONENT_INSTANCE[t] return __COMPONENT_INSTANCE[t]
def init_injectable(path=get_project_path()): def init_injectable(root=get_project_path()):
for f in os.listdir(path): for f in os.listdir(root):
if os.path.isdir(f) and os.path.exists(os.path.join(f, '__init__.py')): path = os.path.join(root, f)
init_injectable(path=os.path.join(path, f)) if os.path.isdir(path) and os.path.exists(os.path.join(path, '__init__.py')):
init_injectable(root=path)
if f.endswith('.py') and f != '__init__.py': if f.endswith('.py') and f != '__init__.py':
py = os.path.relpath(os.path.join(path, f), get_project_path())[:-3] py = os.path.relpath(path, get_project_path())[:-3]
import_module('.'.join(py.split(os.path.sep))) import_module('.'.join(py.split(os.path.sep)))
...@@ -16,4 +16,5 @@ if __name__ == '__main__': ...@@ -16,4 +16,5 @@ if __name__ == '__main__':
logger.warning('warning') logger.warning('warning')
logger.error('error') logger.error('error')
logger.critical('critical') logger.critical('critical')
print(start.__module__)
# start() # start()
import unittest
from framework import autowired, parse_date, get_logger
from api import PortfoliosBuilder, PortfoliosType
class PortfoliosTest(unittest.TestCase):
logger = get_logger(__name__)
@autowired(names={'builder': 'mpt'})
def test_portfolio_builder(self, builder: PortfoliosBuilder = None):
result, detail = builder.build_portfolio(parse_date('2011-11-07'), PortfoliosType.NORMAL)
self.logger.info("portfolios: ")
for risk, portfolio in result.items():
self.logger.info(risk.name)
self.logger.info(portfolio)
self.logger.info(detail[risk])
if __name__ == '__main__':
unittest.main()
from api import PortfoliosBuilder, PortfoliosType
from framework import autowired, parse_date, get_logger
logger = get_logger('test')
@autowired(names={'builder': 'poem'})
def test_portfolio_builder(builder: PortfoliosBuilder = None):
result, detail = builder.build_portfolio(parse_date('2011-11-07'), PortfoliosType.NORMAL)
logger.info("portfolios: ")
for risk, portfolio in result.items():
logger.info(risk.name)
logger.info(portfolio)
logger.info(detail[risk])
if __name__ == '__main__':
test_portfolio_builder()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment