import functools import pymysql import threading from pymysql.cursors import DictCursor from .__env_config import get_config from .date_utils import format_date, datetime from enum import Enum class DatabaseError(Exception): def __init__(self, msg): self.__msg = msg def __str__(self): return self.__msg class Database: def __init__(self, config): self._config = config or get_config(__name__) if self._config is None: raise DatabaseError("database config is not found.") def __enter__(self): port = 3306 if 'port' in self._config: port = self._config['port'] self.__connect = pymysql.connect(host=self._config['host'], user=self._config['user'], port=port, password=str(self._config['password']), database=self._config['dbname']) self.__cursor = self.connect.cursor(DictCursor) return self def __exit__(self, exc_type, exc_val, exc_tb): self.cursor.close() self.connect.close() @property def connect(self): return self.__connect @property def cursor(self): return self.__cursor __local__ = threading.local() def read(func=None, config=None, one=False): if func is None: return functools.partial(read, config=config, one=one) def execute(db, sql): db.cursor.execute(sql) result = db.cursor.fetchall() if one: return result[0] if result else None else: return result @functools.wraps(func) def wraps(*args, **kwargs): sql = func(*args, **kwargs) if hasattr(__local__, 'db'): return execute(__local__.db, sql) else: with Database(config) as db: return execute(db, sql) return wraps def write(func=None, config=None): if func is None: return functools.partial(write, config=config) def execute(db, sqls): if isinstance(sqls, list): for sql in sqls: db.cursor.execute(sql) else: db.cursor.execute(sqls) @functools.wraps(func) def wraps(*args, **kwargs): sqls = func(*args, **kwargs) if hasattr(__local__, 'db'): execute(__local__.db, sqls) else: with Database(config) as db: try: execute(db, sqls) db.connect.commit() except Exception as e: db.connect.rollback() raise e return wraps def transaction(func=None, config=None): if func is None: return functools.partial(transaction, config=config) @functools.wraps(func) def wraps(*args, **kwargs): if hasattr(__local__, 'db'): return func(*args, **kwargs) with Database(config) as db: __local__.db = db try: result = func(*args, **kwargs) db.connect.commit() return result except Exception as e: db.connect.rollback() raise e finally: del __local__.db return wraps def where(*args, **kwargs) -> str: result = [] if kwargs: for k, v in kwargs.items(): if isinstance(v, str): result.append(f"{k} = '{v}'") elif isinstance(v, datetime): result.append(f"{k} = '{format_date(v)}'") elif isinstance(v, Enum): result.append(f"{k} = '{v.value}'") elif isinstance(v, tuple) or isinstance(v, list): if len(v) > 0: result.append(f"{k} in {tuple(v)}" if len(v) > 1 else f"{k} = {v[0]}") elif v is not None: result.append(f"{k} = {v}") if args: result.extend([x for x in args if x]) return f"where {' and '.join(result)}" if result else ''