import functools import json import threading from enum import Enum import pymysql from pymysql.cursors import DictCursor from framework.date_utils import format_date, datetime from framework.env_config import get_config class DatabaseError(Exception): def __init__(self, msg): self.__msg = msg def __str__(self): return self.__msg class Database: def __init__(self, config): self._config = config or get_config(__name__) if self._config is None: raise DatabaseError("database config is not found.") def __enter__(self): port = 3306 if 'port' in self._config: port = self._config['port'] self.__connect = pymysql.connect(host=self._config['host'], user=self._config['user'], port=port, password=str(self._config['password']), database=self._config['dbname']) self.__cursor = self.connect.cursor(DictCursor) return self def __exit__(self, exc_type, exc_val, exc_tb): self.cursor.close() self.connect.close() @property def connect(self): return self.__connect @property def cursor(self): return self.__cursor __local__ = threading.local() def read(func=None, config=None, one=False): if func is None: return functools.partial(read, config=config, one=one) def execute(db, sql): db.cursor.execute(sql) result = db.cursor.fetchall() if one: return result[0] if result else None else: return result @functools.wraps(func) def wraps(*args, **kwargs): sql = func(*args, **kwargs) if hasattr(__local__, 'db'): return execute(__local__.db, sql) else: with Database(config) as db: return execute(db, sql) return wraps def write(func=None, config=None): if func is None: return functools.partial(write, config=config) def get_result(db, sql, res): if sql.find('insert into') >= 0: return db.connect.insert_id() else: return res def execute(db, sqls): if isinstance(sqls, list): results = [] for sql in sqls: return [x for x in get_result(db, sql, db.cursor.execute(sql))] else: return get_result(db, sqls, db.cursor.execute(sqls)) @functools.wraps(func) def wraps(*args, **kwargs): sqls = func(*args, **kwargs) if hasattr(__local__, 'db'): return execute(__local__.db, sqls) else: with Database(config) as db: try: result = execute(db, sqls) db.connect.commit() return result except Exception as e: db.connect.rollback() raise e return wraps def transaction(func=None, config=None): if func is None: return functools.partial(transaction, config=config) @functools.wraps(func) def wraps(*args, **kwargs): if hasattr(__local__, 'db'): return func(*args, **kwargs) with Database(config) as db: __local__.db = db try: result = func(*args, **kwargs) db.connect.commit() return result except Exception as e: db.connect.rollback() raise e finally: del __local__.db return wraps def where(*args, **kwargs) -> str: result = [] if kwargs: for k, v in kwargs.items(): if isinstance(v, str): result.append(f"{k} = '{v}'") elif isinstance(v, bool): result.append(f"{k} = {1 if v else 0}") elif isinstance(v, datetime): result.append(f"{k} = '{format_date(v)}'") elif isinstance(v, Enum): result.append(f"{k} = '{v.value}'") elif isinstance(v, tuple) or isinstance(v, list): if len(v) > 0: v = tuple([(x.value if isinstance(x, Enum) else x) for x in v]) result.append(f"{k} in {v}" if len(v) > 1 else f"{k} = '{v[0]}'") elif v is not None: result.append(f"{k} = '{v}'") if args: result.extend([x for x in args if x]) return f"where {' and '.join(result)}" if result else '' def mapper_columns(datas: dict, columns: dict) -> dict: datas = {x[0]: datas[x[1]] for x in columns.items() if x[1] in datas and datas[x[1]] is not None} return { **datas, **{x[0]: format_date(x[1]) for x in datas.items() if isinstance(x[1], datetime)}, **{x[0]: x[1].value for x in datas.items() if isinstance(x[1], Enum)}, **{x[0]: json.dumps(x[1]) for x in datas.items() if isinstance(x[1], dict)}, **{x[0]: (1 if x[1] else 0) for x in datas.items() if isinstance(x[1], bool)} }