from abc import ABC
from datetime import datetime

from lightgbm import LGBMClassifier
from sklearn import svm
from sklearn.ensemble import AdaBoostClassifier
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.neighbors import KNeighborsClassifier

from ai import reporter
from ai.config import LABEL_RANGE


class ModelTrainer(ABC):
    """
    模型训练类
    """

    def __init__(self, toForecast, pid) -> None:
        super().__init__()
        self._toForecast = toForecast
        self._pid = pid

    ###################
    # Step 3: Train the model

    def test_model(self, strMethod, classifier, X_test, y_test, date_index):
        print(strMethod + " ====== test results ======")
        y_pred = classifier.predict(X_test)

        labels = list(LABEL_RANGE.keys())[::-1]
        result0 = confusion_matrix(y_test, y_pred, labels=labels)
        print(strMethod + " Confusion Matrix:", result0)

        result1 = classification_report(y_test, y_pred, zero_division=1.0)
        print(strMethod + " Classification Report:", result1)
        result2 = accuracy_score(y_test, y_pred)
        print(strMethod + " Accuracy:", result2)
        from ai.EstimateMarketTrendV20 import doReport
        if doReport:
            if strMethod == "Ensemble Model":
                datas = []
                for predict, date in zip(y_pred, date_index):
                    datas.append(
                        {'predict': predict, 'date': date, 'rbd_id': self._pid, 'create_time': datetime.now()})
                reporter.do_reporter2(records=datas, excel_name='Backtest_Report_chu.xlsx')

        # cm_display = ConfusionMatrixDisplay(confusion_matrix=result0, display_labels=labels)
        # cm_display.plot()
        # plt.title(strMethod + ' Accuracy: ' + f'{result2:.0%}')
        # plt.show()

    def train_random_forest(self, X_train, y_train, X_test, y_test, date_index):
        classifier = RandomForestClassifier()
        classifier.fit(X_train, y_train)
        if not self._toForecast:
            self.test_model('Random Forest', classifier, X_test, y_test, date_index)
        return classifier

    def train_GBT(self, X_train, y_train, X_test, y_test, date_index):
        # Gradient Boosted Tree
        classifierGBT = LGBMClassifier()
        classifierGBT.fit(X_train, y_train)
        if not self._toForecast:
            self.test_model('Gradient Boosted Tree', classifierGBT, X_test, y_test, date_index)
        return classifierGBT

    def train_SVC(self, X_train, y_train, X_test, y_test, date_index):
        # Support Vector Machines
        classifierSVC = svm.SVC()
        classifierSVC.fit(X_train, y_train)
        if not self._toForecast:
            self.test_model('Support Vector Machines', classifierSVC, X_test, y_test, date_index)
        return classifierSVC

    def train_nearest_neighbors(self, X_train, y_train, X_test, y_test, date_index):
        classifier = KNeighborsClassifier()
        classifier.fit(X_train, y_train)
        if not self._toForecast:
            self.test_model('K-Nearest Neighbors', classifier, X_test, y_test, date_index)
        return classifier

    def train_AdaBoost(self, X_train, y_train, X_test, y_test, date_index):
        classifier = AdaBoostClassifier()
        classifier.fit(X_train, y_train)
        if not self._toForecast:
            self.test_model('AdaBoost', classifier, X_test, y_test, date_index)
        return classifier

    def ensemble_model(self, rf_model, gbt_model, svc_model, knn_model,
                       ada_model, X_train, y_train, X_test, y_test, date_index):
        # Create a dictionary of our models
        estimators = [('rf', rf_model), ('gbt', gbt_model), ('svc', svc_model),
                      ('knn', knn_model), ('AdaBoost', ada_model)]
        # Create our voting classifier, inputting our models
        ensemble = VotingClassifier(estimators, voting='hard')
        # fit model to training data
        ensemble.fit(X_train, y_train)
        if not self._toForecast:
            self.test_model('Ensemble Model', ensemble, X_test, y_test, date_index)
        return ensemble