Commit 294dd165 authored by jichao's avatar jichao

创建

parent e08c69f0
from utils import * from utils import *
import utils as ut
table = 'robo_fund_info'
columns = 'id, FT_Ticker, Lipper_ID, Name_CH, wind_ticker, Bloomberg_Ticker, FT_TW_RISK, RISK, ROBO, TEST, Asset, index_id'
@read
def getSyncFundInfos():
return f"select {columns} from {table} where TEST = 1 or ROBO = 1 ;"
@transaction
def test():
return getSyncFundInfos()
if __name__ == '__main__': if __name__ == '__main__':
print(dir()) print(test())
print(dir(ut))
import os
import re import re
from functools import partial from functools import partial
from .base import * from .base import *
...@@ -71,24 +72,21 @@ def env_var_constructor(loader, node, raw=False): ...@@ -71,24 +72,21 @@ def env_var_constructor(loader, node, raw=False):
return value if raw else yaml.safe_load(value) return value if raw else yaml.safe_load(value)
def _setup_yaml_parser(): def get_config(config_name=None):
yaml.add_constructor('!env_var', env_var_constructor, yaml.SafeLoader)
yaml.add_constructor(
'!raw_env_var',
partial(env_var_constructor, raw=True),
yaml.SafeLoader
)
yaml.add_implicit_resolver(
'!env_var', IMPLICIT_ENV_VAR_MATCHER, Loader=yaml.SafeLoader
)
def _get_config(config_name=None):
CONFIG_NAME = config_name or 'config.yml' CONFIG_NAME = config_name or 'config.yml'
path = f'{get_project_path()}{os.path.sep}config.yml' path = f'{get_project_path()}{os.path.sep}config.yml'
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
return yaml.safe_load(f) return yaml.safe_load(f)
if __name__ == '__main__': yaml.add_constructor('!env_var', env_var_constructor, yaml.SafeLoader)
_get_config() yaml.add_constructor(
'!raw_env_var',
partial(env_var_constructor, raw=True),
yaml.SafeLoader
)
yaml.add_implicit_resolver(
'!env_var', IMPLICIT_ENV_VAR_MATCHER, Loader=yaml.SafeLoader
)
config = get_config()
from .date import * from .date_utils import *
from .base import * from .base import *
from .__env_config import _setup_yaml_parser, _get_config from .datebase import read, write, transaction
from .__logger import _setup_logger from .__env_config import config, get_config
from .__logger import build_logger, logger
_setup_yaml_parser()
config = _get_config()
if 'logger' in config:
logger = _setup_logger(config['logger'])
import os
from logging import config as cf, getLogger from logging import config as cf, getLogger
from .base import * from .__env_config import config
from .base import get_project_path
def _setup_logger(config, name='base'): def build_logger(config, name='base'):
if 'handlers' in config and 'file' in config['handlers']: if 'handlers' in config and 'file' in config['handlers']:
file = config['handlers']['file'] file = config['handlers']['file']
path = os.path.join(get_project_path(), file["filename"]) path = os.path.join(get_project_path(), file["filename"])
...@@ -11,3 +13,7 @@ def _setup_logger(config, name='base'): ...@@ -11,3 +13,7 @@ def _setup_logger(config, name='base'):
cf.dictConfig(config) cf.dictConfig(config)
return getLogger(name) return getLogger(name)
if 'logger' in config:
logger = build_logger(config['logger'])
import os import os
__all__ = ['get_project_path', 'deep_dict_update']
def get_project_path(): def get_project_path():
for anchor in ['.idea', '.git', 'config.yml', 'requirements.txt']: for anchor in ['.idea', '.git', 'config.yml', 'requirements.txt']:
...@@ -27,8 +29,3 @@ def deep_dict_update(d1, d2): ...@@ -27,8 +29,3 @@ def deep_dict_update(d1, d2):
for key in d2: for key in d2:
if key not in d1: if key not in d1:
d1[key] = d2[key] d1[key] = d2[key]
if __name__ == '__main__':
print(get_project_path())
# print(os.path.dirname('/123123/123123'))
import calendar import calendar
from datetime import timedelta from datetime import timedelta
__all__ = ['filter_weekend', 'next_workday', 'is_workday']
def filter_weekend(day): def filter_weekend(day):
while calendar.weekday(day.year, day.month, day.day) in [5, 6]: while calendar.weekday(day.year, day.month, day.day) in [5, 6]:
......
...@@ -2,12 +2,13 @@ import functools ...@@ -2,12 +2,13 @@ import functools
import pymysql import pymysql
import threading import threading
from pymysql.cursors import DictCursor from pymysql.cursors import DictCursor
from utils import config from .__env_config import config as default_config
from functools import partial
class Database: class Database:
def __init__(self, conf=None): def __init__(self, config):
self.config = conf or config['database'] self.config = config or default_config['database']
def __enter__(self): def __enter__(self):
port = 3306 port = 3306
...@@ -34,9 +35,9 @@ class Database: ...@@ -34,9 +35,9 @@ class Database:
__local__ = threading.local() __local__ = threading.local()
def read(func=None, one=False): def read(func=None, config=None, one=False):
if func is None: if func is None:
return functools.partial(read, one=one) return functools.partial(read, config=config, one=one)
def execute(db, sql): def execute(db, sql):
db.cursor.execute(sql) db.cursor.execute(sql)
...@@ -52,14 +53,14 @@ def read(func=None, one=False): ...@@ -52,14 +53,14 @@ def read(func=None, one=False):
if hasattr(__local__, 'db'): if hasattr(__local__, 'db'):
return execute(__local__.db, sql) return execute(__local__.db, sql)
else: else:
with Database() as db: with Database(config) as db:
return execute(db, sql) return execute(db, sql)
return wraps return wraps
def write(func=None): def write(func=None, config=None):
if func is None: if func is None:
return functools.partial(func) return functools.partial(write, func=func, config=config)
def execute(db, sqls): def execute(db, sqls):
if isinstance(sqls, list): if isinstance(sqls, list):
...@@ -74,7 +75,7 @@ def write(func=None): ...@@ -74,7 +75,7 @@ def write(func=None):
if hasattr(__local__, 'db'): if hasattr(__local__, 'db'):
execute(__local__.db, sqls) execute(__local__.db, sqls)
else: else:
with Database() as db: with Database(config) as db:
try: try:
execute(db, sqls) execute(db, sqls)
db.connect.commit() db.connect.commit()
...@@ -84,15 +85,15 @@ def write(func=None): ...@@ -84,15 +85,15 @@ def write(func=None):
return wraps return wraps
def transaction(func=None): def transaction(func=None, config=None):
if func is None: if func is None:
return functools.partial(func) return functools.partial(transaction, func=func, config=config)
@functools.wraps(func) @functools.wraps(func)
def wraps(*args, **kwargs): def wraps(*args, **kwargs):
if hasattr(__local__, 'db'): if hasattr(__local__, 'db'):
return func(*args, **kwargs) return func(*args, **kwargs)
with Database() as db: with Database(config) as db:
__local__.db = db __local__.db = db
try: try:
result = func(*args, **kwargs) result = func(*args, **kwargs)
...@@ -101,4 +102,8 @@ def transaction(func=None): ...@@ -101,4 +102,8 @@ def transaction(func=None):
except Exception as e: except Exception as e:
db.connect.rollback() db.connect.rollback()
raise e raise e
finally:
del __local__.db
return wraps return wraps
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment