Commit 8166842f authored by 纪超's avatar 纪超

完善依赖注入可配置

parent f72a7872
......@@ -5,12 +5,9 @@ framework:
user: ${MYSQL_USER:root}
password: ${MYSQL_PWD:123456}
dbname: ${MYSQL_DBNAME:jftech_robo}
database_from:
host: ${MYSQL_HOST:127.0.0.1}
port: ${MYSQL_PORT:3306}
user: ${MYSQL_USER:root}
password: ${MYSQL_PWD:123456}
dbname: ${MYSQL_DBNAME:robo_pmpt}
injectable:
types:
api.PortfoliosBuilder: portfolios.builder.PoemPortfoliosBuilder
email:
server: smtphz.qiye.163.com
user: jft-ra@thizgroup.com
......@@ -19,7 +16,7 @@ framework:
max-workers: 8
logger:
version: 1
use: prod
use: ${LOG_NAME}
formatters:
brief:
format: "%(asctime)s - %(levelname)s - %(message)s"
......@@ -46,7 +43,7 @@ framework:
level: INFO
propagate: no
portfolios:
level: DEBUG
level: INFO
root:
level: INFO
handlers: [ console ]
......
......@@ -5,11 +5,18 @@ from inspect import signature, Parameter
from typing import get_origin, get_args
from framework.base import get_project_path
from framework.env_config import get_config
config = get_config(__name__)
types_config = config['types'] if 'types' in config and config['types'] else {}
names_config = config['names'] if 'names' in config and config['names'] else {}
__COMPONENT_CLASS = []
__NAME_COMPONENT = {}
__COMPONENT_INSTANCE = {}
class_name = lambda cls: f'{cls.__module__}.{cls.__name__}'
class InjectableError(Exception):
def __init__(self, msg):
......@@ -26,7 +33,10 @@ def component(cls=None, bean_name=None):
__COMPONENT_CLASS.append(cls)
if bean_name:
if bean_name in __NAME_COMPONENT:
raise InjectableError(f"bean name[{bean_name}] is already defined.")
if bean_name not in names_config or names_config[bean_name] is None:
raise InjectableError(f"bean name[{bean_name}] is already defined.")
if class_name(cls) != names_config[bean_name]:
return cls
__NAME_COMPONENT[bean_name] = cls
return cls
......@@ -55,8 +65,16 @@ def autowired(func=None, names=None):
kwargs[p_name] = instances
else:
components = [x for x in __COMPONENT_CLASS if issubclass(x, p_type.annotation)]
if len(components) > 0:
if len(components) == 1:
kwargs[p_name] = get_instance(components[0])
elif len(components) > 1:
cls = components[0]
if class_name(p_type.annotation) in types_config:
target_name = types_config[class_name(p_type.annotation)]
find_cls = [x for x in components if class_name(x) == target_name]
if find_cls:
cls = get_instance(find_cls[0])
kwargs[p_name] = get_instance(cls)
func(*args, **kwargs)
return wrap
......@@ -70,10 +88,11 @@ def get_instance(t):
return __COMPONENT_INSTANCE[t]
def init_injectable(path=get_project_path()):
for f in os.listdir(path):
if os.path.isdir(f) and os.path.exists(os.path.join(f, '__init__.py')):
init_injectable(path=os.path.join(path, f))
def init_injectable(root=get_project_path()):
for f in os.listdir(root):
path = os.path.join(root, f)
if os.path.isdir(path) and os.path.exists(os.path.join(path, '__init__.py')):
init_injectable(root=path)
if f.endswith('.py') and f != '__init__.py':
py = os.path.relpath(os.path.join(path, f), get_project_path())[:-3]
py = os.path.relpath(path, get_project_path())[:-3]
import_module('.'.join(py.split(os.path.sep)))
......@@ -16,4 +16,5 @@ if __name__ == '__main__':
logger.warning('warning')
logger.error('error')
logger.critical('critical')
print(start.__module__)
# start()
import unittest
from framework import autowired, parse_date, get_logger
from api import PortfoliosBuilder, PortfoliosType
class PortfoliosTest(unittest.TestCase):
logger = get_logger(__name__)
@autowired(names={'builder': 'mpt'})
def test_portfolio_builder(self, builder: PortfoliosBuilder = None):
result, detail = builder.build_portfolio(parse_date('2011-11-07'), PortfoliosType.NORMAL)
self.logger.info("portfolios: ")
for risk, portfolio in result.items():
self.logger.info(risk.name)
self.logger.info(portfolio)
self.logger.info(detail[risk])
if __name__ == '__main__':
unittest.main()
from api import PortfoliosBuilder, PortfoliosType
from framework import autowired, parse_date, get_logger
logger = get_logger('test')
@autowired(names={'builder': 'poem'})
def test_portfolio_builder(builder: PortfoliosBuilder = None):
result, detail = builder.build_portfolio(parse_date('2011-11-07'), PortfoliosType.NORMAL)
logger.info("portfolios: ")
for risk, portfolio in result.items():
logger.info(risk.name)
logger.info(portfolio)
logger.info(detail[risk])
if __name__ == '__main__':
test_portfolio_builder()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment