Commit 228349ea authored by jichao's avatar jichao

依赖注入实现中

parent 4175cece
from abc import ABCMeta, abstractmethod
class BaseBean(metaclass=ABCMeta):
@abstractmethod
def do_something(self):
pass
class ServiceBean(metaclass=ABCMeta):
@abstractmethod
def do_service(self):
pass
database: framework:
host: ${MYSQL_HOST:127.0.0.1} database:
port: ${MYSQL_PORT:3306} host: ${MYSQL_HOST:127.0.0.1}
user: ${MYSQL_USER:root} port: ${MYSQL_PORT:3306}
password: ${MYSQL_PWD:123456} user: ${MYSQL_USER:root}
dbname: ${MYSQL_DBNAME:jftech_robo} password: ${MYSQL_PWD:123456}
database_from: dbname: ${MYSQL_DBNAME:jftech_robo}
host: ${MYSQL_HOST:127.0.0.1} database_from:
port: ${MYSQL_PORT:3306} host: ${MYSQL_HOST:127.0.0.1}
user: ${MYSQL_USER:root} port: ${MYSQL_PORT:3306}
password: ${MYSQL_PWD:123456} user: ${MYSQL_USER:root}
dbname: ${MYSQL_DBNAME:robo_pmpt} password: ${MYSQL_PWD:123456}
email: dbname: ${MYSQL_DBNAME:robo_pmpt}
server: smtphz.qiye.163.com email:
user: jft-ra@thizgroup.com server: smtphz.qiye.163.com
password: 5dbb#30ec6d3 user: jft-ra@thizgroup.com
logger: password: 5dbb#30ec6d3
version: 1 logger:
use: ${LOG_NAME:root} version: 1
formatters: use: ${LOG_NAME:root}
brief: formatters:
format: "%(asctime)s - %(levelname)s - %(message)s" brief:
simple: format: "%(asctime)s - %(levelname)s - %(message)s"
format: "%(asctime)s - %(filename)s - %(levelname)s - %(message)s" simple:
handlers: format: "%(asctime)s - %(filename)s - %(levelname)s - %(message)s"
console: handlers:
class: logging.StreamHandler console:
formatter: simple class: logging.StreamHandler
formatter: simple
level: INFO
stream: ext://sys.stdout
file:
class: logging.handlers.TimedRotatingFileHandler
level: INFO
formatter: brief
filename: logs/info.log
interval: 1
backupCount: 30
encoding: utf8
when: D
loggers:
prod:
handlers: [console,file]
level: INFO
propagate: 0
root:
level: INFO level: INFO
stream: ext://sys.stdout handlers: [console]
file:
class: logging.handlers.TimedRotatingFileHandler
level: INFO
formatter: brief
filename: logs/info.log
interval: 1
backupCount: 30
encoding: utf8
when: D
loggers:
prod:
handlers: [console,file]
level: INFO
propagate: 0
root:
level: INFO
handlers: [console]
datas: datas:
navs: navs:
exrate: exrate:
......
from utils import read, write, config, format_date, parse_date, where from framework import read, write, config, format_date, parse_date, where
import json import json
from datetime import datetime from datetime import datetime
from datas.datum.enums import DatumType from datas.datum.enums import DatumType
......
from api import BaseBean
from framework import component
@component(bean_name="one")
class OneImplBean(BaseBean):
def do_something(self):
print("one bean")
@component(bean_name="two")
class TwoImplBean(BaseBean):
def do_something(self):
print("two bean")
\ No newline at end of file
import pandas as _pd import pandas as _pd
from datas.navs import robo_exrate as _re, robo_fund_navs as _navs from datas.navs import robo_exrate as _re, robo_fund_navs as _navs
from datas import datum as _datum, DatumType from datas import datum as _datum, DatumType
from utils import config, to_bool from framework import config, to_bool
from datetime import timedelta from datetime import timedelta
navs_config = config['datas']['navs'] if 'datas' in config and 'navs' in config['datas'] else {} navs_config = config['datas']['navs'] if 'datas' in config and 'navs' in config['datas'] else {}
...@@ -27,6 +27,6 @@ def get_navs(fund_id=None, min_date=None, max_date=None): ...@@ -27,6 +27,6 @@ def get_navs(fund_id=None, min_date=None, max_date=None):
if __name__ == '__main__': if __name__ == '__main__':
from utils import parse_date from framework import parse_date
print(get_navs(min_date=parse_date('2022-11-01'))) print(get_navs(min_date=parse_date('2022-11-01')))
from utils import read, where, format_date from framework import read, where, format_date
@read @read
...@@ -21,6 +21,6 @@ def get_exrate(ticker, date): ...@@ -21,6 +21,6 @@ def get_exrate(ticker, date):
if __name__ == '__main__': if __name__ == '__main__':
from utils import parse_date from framework import parse_date
print(get_exrate(date=parse_date('2022-11-01'), ticker='EURUSD BGN Curncy')) print(get_exrate(date=parse_date('2022-11-01'), ticker='EURUSD BGN Curncy'))
from utils import read, where, format_date from framework import read, where, format_date
@read @read
...@@ -16,6 +16,6 @@ def get_navs(fund_id=None, min_date=None, max_date=None): ...@@ -16,6 +16,6 @@ def get_navs(fund_id=None, min_date=None, max_date=None):
if __name__ == '__main__': if __name__ == '__main__':
from utils import parse_date from framework import parse_date
navs = get_navs(fund_id=1, min_date=parse_date('2022-11-01')) navs = get_navs(fund_id=1, min_date=parse_date('2022-11-01'))
print(navs) print(navs)
...@@ -3,6 +3,8 @@ from .base import * ...@@ -3,6 +3,8 @@ from .base import *
from .datebase import read, write, transaction, where from .datebase import read, write, transaction, where
from .__env_config import config, get_config from .__env_config import config, get_config
from .__logger import build_logger, logger from .__logger import build_logger, logger
from .injectable import component from .injectable import component, autowired, get_instance, init_injectable as _init_injectable
del injectable, __logger, __env_config, datebase, base, date_utils _init_injectable()
del injectable, __logger, __env_config, datebase, base, date_utils, _init_injectable
...@@ -14,5 +14,5 @@ def build_logger(config, name='root'): ...@@ -14,5 +14,5 @@ def build_logger(config, name='root'):
return getLogger(name) return getLogger(name)
if 'logger' in config: if 'framework' in config and 'logger' in config['framework']:
logger = build_logger(config['logger'], name=config['logger']['use']) logger = build_logger(config['framework']['logger'], name=config['framework']['logger']['use'])
...@@ -2,14 +2,27 @@ import functools ...@@ -2,14 +2,27 @@ import functools
import pymysql import pymysql
import threading import threading
from pymysql.cursors import DictCursor from pymysql.cursors import DictCursor
from .__env_config import config as default_config from .__env_config import config as global_config
from .date_utils import format_date, datetime from .date_utils import format_date, datetime
from enum import Enum from enum import Enum
_CONFIG = global_config['framework']['database'] if 'framework' in global_config and 'database' in global_config['framework'] else None
class DatabaseError(Exception):
def __init__(self, msg):
self.__msg = msg
def __str__(self):
return self.__msg
class Database: class Database:
def __init__(self, config): def __init__(self, config):
self.config = config or default_config['database'] self._config = config or _CONFIG
if self._config is None:
raise DatabaseError("database config is not found.")
def __enter__(self): def __enter__(self):
port = 3306 port = 3306
......
import abc import os, sys
import inspect from importlib import import_module
import functools from inspect import signature, Parameter
from functools import partial, wraps
from typing import List, get_origin, get_args from typing import List, get_origin, get_args
from types import GenericAlias from types import GenericAlias
from framework.base import get_project_path
__COMPONENT_CLASS = [] __COMPONENT_CLASS = []
__NAME_COMPONENT = {} __NAME_COMPONENT = {}
...@@ -19,7 +21,7 @@ class InjectableError(Exception): ...@@ -19,7 +21,7 @@ class InjectableError(Exception):
def component(cls=None, bean_name=None): def component(cls=None, bean_name=None):
if cls is None: if cls is None:
return functools.partial(component, bean_name=bean_name) return partial(component, bean_name=bean_name)
__COMPONENT_CLASS.append(cls) __COMPONENT_CLASS.append(cls)
if bean_name: if bean_name:
...@@ -29,19 +31,23 @@ def component(cls=None, bean_name=None): ...@@ -29,19 +31,23 @@ def component(cls=None, bean_name=None):
return cls return cls
def autowired(func=None): def autowired(func=None, names=None):
if func is None: if func is None:
return functools.partial(autowired) return partial(autowired, names=names)
@functools.wraps(func) @wraps(func)
def wrap(*args, **kwargs): def wrap(*args, **kwargs):
self_type = type(args[0]) if func.__name__ == '__init__':
if self_type in __COMPONENT_CLASS and self_type not in __COMPONENT_INSTANCE: self_type = type(args[0])
__COMPONENT_INSTANCE[self_type] = args[0] if self_type in __COMPONENT_CLASS and self_type not in __COMPONENT_INSTANCE:
for p_name, p_type in inspect.signature(func).parameters.items(): __COMPONENT_INSTANCE[self_type] = args[0]
if p_name == 'self' or p_type == inspect.Parameter.empty: for p_name, p_type in signature(func).parameters.items():
if p_name == 'self' or p_type == Parameter.empty or p_name in kwargs:
continue continue
if get_origin(p_type.annotation) is list: if names is not None and p_name in names:
if names[p_name] in __NAME_COMPONENT:
kwargs[p_name] = get_instance(__NAME_COMPONENT[names[p_name]])
elif get_origin(p_type.annotation) is list:
inject_types = get_args(p_type.annotation) inject_types = get_args(p_type.annotation)
if len(inject_types) > 0: if len(inject_types) > 0:
instances = [get_instance(x) for x in __COMPONENT_CLASS if issubclass(x, inject_types)] instances = [get_instance(x) for x in __COMPONENT_CLASS if issubclass(x, inject_types)]
...@@ -64,37 +70,10 @@ def get_instance(t): ...@@ -64,37 +70,10 @@ def get_instance(t):
return __COMPONENT_INSTANCE[t] return __COMPONENT_INSTANCE[t]
class BaseBean(metaclass=abc.ABCMeta): def init_injectable(path=get_project_path()):
@abc.abstractmethod for f in os.listdir(path):
def do_something(self): if os.path.isdir(f) and os.path.exists(os.path.join(f, '__init__.py')):
pass init_injectable(path=os.path.join(path, f))
if f.endswith('.py') and f != '__init__.py':
py = os.path.relpath(os.path.join(path, f), get_project_path())[:-3]
@component(bean_name="one") import_module('.'.join(py.split(os.path.sep)))
class OneImplBean(BaseBean):
def do_something(self):
print("one bean")
@component(bean_name="two")
class TwoImplBean(BaseBean):
def do_something(self):
print("two bean")
class ServiceBean:
@autowired
def __init__(self, base: BaseBean = None, bases: List[BaseBean] = None):
base.do_something()
for b in bases:
b.do_something()
print(__COMPONENT_CLASS, __NAME_COMPONENT)
if __name__ == '__main__':
ServiceBean()
import pandas as pd import pandas as pd
from utils import filter_weekend, config, dict_remove from framework import filter_weekend, config, dict_remove
from datas import navs from datas import navs
from dateutil.relativedelta import relativedelta from dateutil.relativedelta import relativedelta
from empyrical import sortino_ratio from empyrical import sortino_ratio
......
from api import BaseBean, ServiceBean
from framework import autowired, component
@component
class ServiceImpl(ServiceBean):
@autowired
def __init__(self, base: BaseBean = None):
self._base = base
def do_service(self):
self._base.do_something()
\ No newline at end of file
from utils import logger from framework import logger, autowired
from datas import datum from api import ServiceBean
@autowired
def test(service: ServiceBean = None):
service.do_service()
if __name__ == '__main__': if __name__ == '__main__':
logger.info(dir()) logger.info(dir())
logger.info(datum.get_fund_datums()) test()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment