Commit 228349ea authored by jichao's avatar jichao

依赖注入实现中

parent 4175cece
from abc import ABCMeta, abstractmethod
class BaseBean(metaclass=ABCMeta):
@abstractmethod
def do_something(self):
pass
class ServiceBean(metaclass=ABCMeta):
@abstractmethod
def do_service(self):
pass
database:
framework:
database:
host: ${MYSQL_HOST:127.0.0.1}
port: ${MYSQL_PORT:3306}
user: ${MYSQL_USER:root}
password: ${MYSQL_PWD:123456}
dbname: ${MYSQL_DBNAME:jftech_robo}
database_from:
database_from:
host: ${MYSQL_HOST:127.0.0.1}
port: ${MYSQL_PORT:3306}
user: ${MYSQL_USER:root}
password: ${MYSQL_PWD:123456}
dbname: ${MYSQL_DBNAME:robo_pmpt}
email:
email:
server: smtphz.qiye.163.com
user: jft-ra@thizgroup.com
password: 5dbb#30ec6d3
logger:
logger:
version: 1
use: ${LOG_NAME:root}
formatters:
......
from utils import read, write, config, format_date, parse_date, where
from framework import read, write, config, format_date, parse_date, where
import json
from datetime import datetime
from datas.datum.enums import DatumType
......
from api import BaseBean
from framework import component
@component(bean_name="one")
class OneImplBean(BaseBean):
def do_something(self):
print("one bean")
@component(bean_name="two")
class TwoImplBean(BaseBean):
def do_something(self):
print("two bean")
\ No newline at end of file
import pandas as _pd
from datas.navs import robo_exrate as _re, robo_fund_navs as _navs
from datas import datum as _datum, DatumType
from utils import config, to_bool
from framework import config, to_bool
from datetime import timedelta
navs_config = config['datas']['navs'] if 'datas' in config and 'navs' in config['datas'] else {}
......@@ -27,6 +27,6 @@ def get_navs(fund_id=None, min_date=None, max_date=None):
if __name__ == '__main__':
from utils import parse_date
from framework import parse_date
print(get_navs(min_date=parse_date('2022-11-01')))
from utils import read, where, format_date
from framework import read, where, format_date
@read
......@@ -21,6 +21,6 @@ def get_exrate(ticker, date):
if __name__ == '__main__':
from utils import parse_date
from framework import parse_date
print(get_exrate(date=parse_date('2022-11-01'), ticker='EURUSD BGN Curncy'))
from utils import read, where, format_date
from framework import read, where, format_date
@read
......@@ -16,6 +16,6 @@ def get_navs(fund_id=None, min_date=None, max_date=None):
if __name__ == '__main__':
from utils import parse_date
from framework import parse_date
navs = get_navs(fund_id=1, min_date=parse_date('2022-11-01'))
print(navs)
......@@ -3,6 +3,8 @@ from .base import *
from .datebase import read, write, transaction, where
from .__env_config import config, get_config
from .__logger import build_logger, logger
from .injectable import component
from .injectable import component, autowired, get_instance, init_injectable as _init_injectable
del injectable, __logger, __env_config, datebase, base, date_utils
_init_injectable()
del injectable, __logger, __env_config, datebase, base, date_utils, _init_injectable
......@@ -14,5 +14,5 @@ def build_logger(config, name='root'):
return getLogger(name)
if 'logger' in config:
logger = build_logger(config['logger'], name=config['logger']['use'])
if 'framework' in config and 'logger' in config['framework']:
logger = build_logger(config['framework']['logger'], name=config['framework']['logger']['use'])
......@@ -2,14 +2,27 @@ import functools
import pymysql
import threading
from pymysql.cursors import DictCursor
from .__env_config import config as default_config
from .__env_config import config as global_config
from .date_utils import format_date, datetime
from enum import Enum
_CONFIG = global_config['framework']['database'] if 'framework' in global_config and 'database' in global_config['framework'] else None
class DatabaseError(Exception):
def __init__(self, msg):
self.__msg = msg
def __str__(self):
return self.__msg
class Database:
def __init__(self, config):
self.config = config or default_config['database']
self._config = config or _CONFIG
if self._config is None:
raise DatabaseError("database config is not found.")
def __enter__(self):
port = 3306
......
import abc
import inspect
import functools
import os, sys
from importlib import import_module
from inspect import signature, Parameter
from functools import partial, wraps
from typing import List, get_origin, get_args
from types import GenericAlias
from framework.base import get_project_path
__COMPONENT_CLASS = []
__NAME_COMPONENT = {}
......@@ -19,7 +21,7 @@ class InjectableError(Exception):
def component(cls=None, bean_name=None):
if cls is None:
return functools.partial(component, bean_name=bean_name)
return partial(component, bean_name=bean_name)
__COMPONENT_CLASS.append(cls)
if bean_name:
......@@ -29,19 +31,23 @@ def component(cls=None, bean_name=None):
return cls
def autowired(func=None):
def autowired(func=None, names=None):
if func is None:
return functools.partial(autowired)
return partial(autowired, names=names)
@functools.wraps(func)
@wraps(func)
def wrap(*args, **kwargs):
if func.__name__ == '__init__':
self_type = type(args[0])
if self_type in __COMPONENT_CLASS and self_type not in __COMPONENT_INSTANCE:
__COMPONENT_INSTANCE[self_type] = args[0]
for p_name, p_type in inspect.signature(func).parameters.items():
if p_name == 'self' or p_type == inspect.Parameter.empty:
for p_name, p_type in signature(func).parameters.items():
if p_name == 'self' or p_type == Parameter.empty or p_name in kwargs:
continue
if get_origin(p_type.annotation) is list:
if names is not None and p_name in names:
if names[p_name] in __NAME_COMPONENT:
kwargs[p_name] = get_instance(__NAME_COMPONENT[names[p_name]])
elif get_origin(p_type.annotation) is list:
inject_types = get_args(p_type.annotation)
if len(inject_types) > 0:
instances = [get_instance(x) for x in __COMPONENT_CLASS if issubclass(x, inject_types)]
......@@ -64,37 +70,10 @@ def get_instance(t):
return __COMPONENT_INSTANCE[t]
class BaseBean(metaclass=abc.ABCMeta):
@abc.abstractmethod
def do_something(self):
pass
@component(bean_name="one")
class OneImplBean(BaseBean):
def do_something(self):
print("one bean")
@component(bean_name="two")
class TwoImplBean(BaseBean):
def do_something(self):
print("two bean")
class ServiceBean:
@autowired
def __init__(self, base: BaseBean = None, bases: List[BaseBean] = None):
base.do_something()
for b in bases:
b.do_something()
print(__COMPONENT_CLASS, __NAME_COMPONENT)
if __name__ == '__main__':
ServiceBean()
def init_injectable(path=get_project_path()):
for f in os.listdir(path):
if os.path.isdir(f) and os.path.exists(os.path.join(f, '__init__.py')):
init_injectable(path=os.path.join(path, f))
if f.endswith('.py') and f != '__init__.py':
py = os.path.relpath(os.path.join(path, f), get_project_path())[:-3]
import_module('.'.join(py.split(os.path.sep)))
import pandas as pd
from utils import filter_weekend, config, dict_remove
from framework import filter_weekend, config, dict_remove
from datas import navs
from dateutil.relativedelta import relativedelta
from empyrical import sortino_ratio
......
from api import BaseBean, ServiceBean
from framework import autowired, component
@component
class ServiceImpl(ServiceBean):
@autowired
def __init__(self, base: BaseBean = None):
self._base = base
def do_service(self):
self._base.do_something()
\ No newline at end of file
from utils import logger
from datas import datum
from framework import logger, autowired
from api import ServiceBean
@autowired
def test(service: ServiceBean = None):
service.do_service()
if __name__ == '__main__':
logger.info(dir())
logger.info(datum.get_fund_datums())
test()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment