import functools
import pymysql
import threading
from pymysql.cursors import DictCursor
from .__env_config import get_config
from .date_utils import format_date, datetime
from enum import Enum


class DatabaseError(Exception):
    def __init__(self, msg):
        self.__msg = msg

    def __str__(self):
        return self.__msg


class Database:
    def __init__(self, config):
        self._config = config or get_config(__name__)
        if self._config is None:
            raise DatabaseError("database config is not found.")

    def __enter__(self):
        port = 3306
        if 'port' in self._config:
            port = self._config['port']
        self.__connect = pymysql.connect(host=self._config['host'], user=self._config['user'], port=port,
                                         password=str(self._config['password']), database=self._config['dbname'])
        self.__cursor = self.connect.cursor(DictCursor)
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.cursor.close()
        self.connect.close()

    @property
    def connect(self):
        return self.__connect

    @property
    def cursor(self):
        return self.__cursor


__local__ = threading.local()


def read(func=None, config=None, one=False):
    if func is None:
        return functools.partial(read, config=config, one=one)

    def execute(db, sql):
        db.cursor.execute(sql)
        result = db.cursor.fetchall()
        if one:
            return result[0] if result else None
        else:
            return result

    @functools.wraps(func)
    def wraps(*args, **kwargs):
        sql = func(*args, **kwargs)
        if hasattr(__local__, 'db'):
            return execute(__local__.db, sql)
        else:
            with Database(config) as db:
                return execute(db, sql)
    return wraps


def write(func=None, config=None):
    if func is None:
        return functools.partial(write, config=config)

    def execute(db, sqls):
        if isinstance(sqls, list):
            for sql in sqls:
                db.cursor.execute(sql)
        else:
            db.cursor.execute(sqls)

    @functools.wraps(func)
    def wraps(*args, **kwargs):
        sqls = func(*args, **kwargs)
        if hasattr(__local__, 'db'):
            execute(__local__.db, sqls)
        else:
            with Database(config) as db:
                try:
                    execute(db, sqls)
                    db.connect.commit()
                except Exception as e:
                    db.connect.rollback()
                    raise e
    return wraps


def transaction(func=None, config=None):
    if func is None:
        return functools.partial(transaction, config=config)

    @functools.wraps(func)
    def wraps(*args, **kwargs):
        if hasattr(__local__, 'db'):
            return func(*args, **kwargs)
        with Database(config) as db:
            __local__.db = db
            try:
                result = func(*args, **kwargs)
                db.connect.commit()
                return result
            except Exception as e:
                db.connect.rollback()
                raise e
            finally:
                del __local__.db
    return wraps


def where(*args, **kwargs) -> str:
    result = []
    if kwargs:
        for k, v in kwargs.items():
            if isinstance(v, str):
                result.append(f"{k} = '{v}'")
            elif isinstance(v, datetime):
                result.append(f"{k} = '{format_date(v)}'")
            elif isinstance(v, Enum):
                result.append(f"{k} = '{v.value}'")
            elif isinstance(v, tuple) or isinstance(v, list):
                if len(v) > 0:
                    result.append(f"{k} in {tuple(v)}" if len(v) > 1 else f"{k} = {v[0]}")
            elif v is not None:
                result.append(f"{k} = {v}")
    if args:
        result.extend([x for x in args if x])
    return f"where {' and '.join(result)}" if result else ''