# -*- coding: utf-8 -*-
"""
Created on Wed May 20 02:55:11 2020

@author: Emre

Updated Tue May 26 2022, Matthieu, Curtis
"""


import matplotlib.pyplot as plt
# import seaborn as sns; sns.set()
import seaborn as sns
import pickle
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression, ElasticNet, Lasso, Ridge
from sklearn.svm import SVR, LinearSVR
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import mean_squared_error , mean_absolute_error, r2_score
import os


def calculate_data(data, columns):
    for col in columns:
        data[(f'{col}', 'mean_minus_std')] = data[(f'{col}', 'mean')] - data[(f'{col}', 'std')]
        data[(f'{col}', 'mean_plus_std')] = data[(f'{col}', 'mean')] + data[(f'{col}', 'std')]
    return data



def main():

    # LOAD Features except Brace Height
    data = pickle.load(open("data_machine_learning.pkl", "rb"))
    columns_list = ['delta_F', 'trec', 'range', 'activation_duration', 'tderec', 'saturation_rate', 'brace_height']
    columns_list = ['delta_F', 'trec', 'range', 'activation_duration', 'tderec', 'saturation_rate', 'brace_height', 'accel_slope', 'ratemod_slope', 'pic_angle']
    columns_list = ['delta_F', 'trec', 'range', 'activation_duration', 'tderec', 'brace_height', 'ratemod_slope']





    plt.scatter(data[('SMULTSTRT', '')], data['Weight Ratio'], label='SMULTSTRT')
    plt.scatter(data[('SMULTEND', '')], data['Weight Ratio'], label='SMULTEND')
    plt.xlim((0, 2.6))
    plt.ylim((0, 2.6))
    plt.xlabel('Values (SMULTSTRT or SMULTEND)')
    plt.ylabel('Weight Ratio (SMULTSTRT / SMULTEND)')
    plt.legend()
    plt.title('MN Weight Ratio')
    plt.savefig('Weight Ratio Values.pdf')
    plt.savefig('Weight Ratio Values.svg')
    plt.show()

    stats_list = ['mean', 'mean_plus_std', 'mean_minus_std', 'median']
    stats_list = ['mean']
    # stats_list = ['mean']
    data_name = []
    for col in columns_list:
        for sta in stats_list:
            data_name.append((col, sta))



    X_data = data.loc[:, data_name]
    # y_data_list = [('ginmult', 'Inhibition'), ('nmmult', 'Neuromodulation'), ('SMULTSTRT', 'Weight Start'), ('SMULTEND', 'Weight End')]
    y_data_list = [('ginmult', 'Inhibition'), ('nmmult', 'Neuromodulation'), ('SMULTSTRT', 'Weight Start'), ('SMULTEND', 'Weight End'), ('Weight Ratio', 'Weight Ratio')]
    # y_data_list = [('ginmult', 'Inhibition')]
    df_MAE = pd.DataFrame()
    df_MSE = pd.DataFrame()
    df_R2 = pd.DataFrame()
    for y_data_name, y_data_save in y_data_list:
        y_data = data.loc[:, (y_data_name, '')]

        X_train, X_test, y_train, y_test = train_test_split(X_data, y_data, random_state=25, test_size=0.3, stratify=y_data)

        # Scale input data
        scaler = StandardScaler()
        scaler.fit(X_train)
        X_train_scaled = scaler.transform(X_train)
        X_test_scaled = scaler.transform(X_test)

        # Fits
        linreg = LinearRegression()
        linreg.fit(X_train_scaled, y_train)

        lasso = Lasso(alpha=0.001)
        lasso.fit(X_train_scaled, y_train)

        ridge = Ridge(alpha=1)
        ridge.fit(X_train_scaled, y_train)

        en = ElasticNet(alpha=0.01, l1_ratio=0.1)
        en.fit(X_train_scaled, y_train)

        svm_lin = SVR(kernel='linear', gamma=1, C=10, verbose=True)
        svm_lin.fit(X_train_scaled, y_train)

        svm_rbf = SVR(kernel='rbf', gamma=1, C=10, verbose=True)
        svm_rbf.fit(X_train_scaled, y_train)

        # Predictions and Stats
        y_pred_lr = linreg.predict(X_test_scaled)
        y_pred_lasso = lasso.predict(X_test_scaled)
        y_pred_ridge = ridge.predict(X_test_scaled)
        y_pred_en = en.predict(X_test_scaled)
        y_pred_lin = svm_lin.predict(X_test_scaled)
        y_pred_rbf = svm_rbf.predict(X_test_scaled)

        mae_columns = ['Name', 'Linear Regression', 'Lasso', 'Ridge', 'Elastic Net', 'Linear SVM', 'RBF SVM']
        mae_list = [[y_data_save,
                     mean_absolute_error(y_test, y_pred_lr),
                     mean_absolute_error(y_test, y_pred_lasso),
                     mean_absolute_error(y_test, y_pred_ridge),
                     mean_absolute_error(y_test, y_pred_en),
                     mean_absolute_error(y_test, y_pred_lin),
                     mean_absolute_error(y_test, y_pred_rbf)]]

        df = pd.DataFrame(mae_list, columns=mae_columns)
        df_MAE = df_MAE.append(df, ignore_index=True)

        mse_columns = ['Name', 'Linear Regression', 'Lasso', 'Ridge', 'Elastic Net', 'Linear SVM', 'RBF SVM']
        mse_list = [[y_data_save,
                     mean_squared_error(y_test, y_pred_lr),
                     mean_squared_error(y_test, y_pred_lasso),
                     mean_squared_error(y_test, y_pred_ridge),
                     mean_squared_error(y_test, y_pred_en),
                     mean_squared_error(y_test, y_pred_lin),
                     mean_squared_error(y_test, y_pred_rbf)]]

        df = pd.DataFrame(mse_list, columns=mse_columns)
        df_MSE = df_MSE.append(df, ignore_index=True)

        r2_columns = ['Name', 'Linear Regression', 'Lasso', 'Ridge', 'Elastic Net', 'Linear SVM', 'RBF SVM']
        r2_list = [[y_data_save,
                     r2_score(y_test, y_pred_lr),
                     r2_score(y_test, y_pred_lasso),
                     r2_score(y_test, y_pred_ridge),
                     r2_score(y_test, y_pred_en),
                     r2_score(y_test, y_pred_lin),
                     r2_score(y_test, y_pred_rbf)]]

        df = pd.DataFrame(r2_list, columns=r2_columns)
        df_R2 = df_R2.append(df, ignore_index=True)

        if len(stats_list) > 1:
            stats_path = 'all_stats'
        else:
            stats_path = stats_list[0]



        # Plotting
        fig = plt.figure(figsize=(21, 21))
        # fig = plt.figure()
        fig.subplots_adjust(hspace=0.5, wspace=0.5)
        fig.suptitle(f'Results for {y_data_save} - {stats_path}')

        marker_size = 50
        line_width = 3
        line_color = 'red'

        y_top_pred = 1.20 * np.max([np.max(y_pred_lr), np.max(y_pred_rbf)])
        y_bot_pred = 0.80 * np.min([np.min(y_pred_lr), np.min(y_pred_rbf)])


        y_top = 1.20 * np.max([np.max(y_test - y_pred_lr), np.max(y_test - y_pred_rbf)])
        y_bot = 1.20 * np.min([np.min(y_test - y_pred_lr), np.min(y_test - y_pred_rbf)])

        y_top_pred = 1
        y_bot_pred = -1
        y_top = 1
        y_bot = -1
        x_top = 1
        x_bot = -1




        # Linear Regression
        ax = fig.add_subplot(6, 2, 1)
        sns.scatterplot(y_test, y_pred_lr, ax=ax, s=marker_size).set(title='Lin. Reg.', xlabel='Target Values', ylabel='Predictions')
        axes = plt.gca()
        x_vals = np.array(axes.get_xlim())
        plt.plot(x_vals, x_vals, '--', linewidth=line_width, color=line_color)
        # ax.set_ylim(y_bot, y_top)
        # ax.set_xlim(x_bot, x_top)

        ax = fig.add_subplot(6, 2, 2)
        sns.scatterplot(y_test, y_test - y_pred_lr, ax=ax, s=marker_size).set(title = 'Res vs Target for Lin. Reg.', xlabel = 'Target Values', ylabel = 'Residuals' )
        # ax.set_ylim(y_bot, y_top)
        # ax.set_xlim(x_bot, x_top)

        # Lasso
        ax = fig.add_subplot(6, 2, 3)
        sns.scatterplot(y_test, y_pred_lasso, ax=ax).set(title = 'Pred vs Target for Lasso', xlabel = 'Target Values', ylabel = 'Predictions' )
        axes = plt.gca()
        x_vals = np.array(axes.get_xlim())
        plt.plot(x_vals, x_vals, '--', linewidth=line_width, color=line_color)

        ax = fig.add_subplot(6, 2, 4)
        sns.scatterplot(y_test, np.abs(y_test - y_pred_lasso),ax=ax).set(title = 'Res vs Target for Lasso', xlabel = 'Target Values', ylabel = 'Residuals' )

        # Ridge
        ax = fig.add_subplot(6, 2, 5)
        sns.scatterplot(y_test, y_pred_ridge,ax=ax).set(title = 'Pred vs Target for Ridge', xlabel = 'Target Values', ylabel = 'Predictions' )
        axes = plt.gca()
        x_vals = np.array(axes.get_xlim())
        plt.plot(x_vals, x_vals, '--', linewidth=line_width, color=line_color)

        ax = fig.add_subplot(6, 2, 6)
        sns.scatterplot(y_test,np.abs(y_test - y_pred_ridge),ax=ax).set(title = 'Res vs Target for Ridge', xlabel = 'Target Values', ylabel = 'Residuals' )

        # Elastic Net
        ax = fig.add_subplot(6, 2, 7)
        sns.scatterplot(y_test, y_pred_en,ax=ax).set(title = 'Pred vs Target for ElasticNet', xlabel = 'Target Values', ylabel = 'Predictions' )
        axes = plt.gca()
        x_vals = np.array(axes.get_xlim())
        plt.plot(x_vals, x_vals, '--', linewidth=line_width, color=line_color)

        ax = fig.add_subplot(6, 2, 8)
        sns.scatterplot(y_test, np.abs(y_test - y_pred_en),ax=ax).set(title = 'Res vs Target for ElasticNet', xlabel = 'Target Values', ylabel = 'Residuals' )

        # Linear SVM
        ax = fig.add_subplot(6, 2, 9)
        sns.scatterplot(y_test, y_pred_lin, ax=ax).set(title='Pred vs Target for Linear SVM', xlabel='Target Values',
                                                       ylabel='Predictions')
        axes = plt.gca()
        x_vals = np.array(axes.get_xlim())
        plt.plot(x_vals, x_vals, '--', linewidth=line_width, color=line_color)

        ax = fig.add_subplot(6, 2, 10)
        sns.scatterplot(y_test, np.abs(y_test - y_pred_lin), ax=ax).set(title='Res vs Target for Linear SVM', xlabel='Target Values', ylabel='Residuals')

        # Kernelized SVM (RBF)
        ax = fig.add_subplot(6, 2, 11)
        sns.scatterplot(y_test, y_pred_rbf, ax=ax, s=marker_size).set(title='Pred vs Target for RBF SVM', xlabel='Target Values', ylabel='Predictions')
        axes = plt.gca()
        x_vals = np.array(axes.get_xlim())
        plt.plot(x_vals, x_vals, '--', linewidth=line_width, color=line_color)
        # ax.set_ylim(y_bot_pred, y_top_pred)
        # ax.set_xlim(x_bot, x_top)

        ax = fig.add_subplot(6, 2, 12)
        sns.scatterplot(y_test, y_test - y_pred_rbf, ax=ax, s=marker_size).set(title='Res vs Target for RBF SVM', xlabel='Target Values', ylabel='Residuals')
        # ax.set_ylim(y_bot, y_top)
        # ax.set_xlim(x_bot, x_top)

        plt.show()

        dir_path = f'figure/regression/w_brace_height/{stats_path}'
        os.makedirs(dir_path, exist_ok=True)
        fig.savefig(os.path.join(dir_path, f'{y_data_name}_2_sec_w_brace_height.pdf'))
        fig.savefig(os.path.join(dir_path, f'{y_data_name}_2_sec_w_brace_height.svg'))


    df_MAE.to_excel(f'{dir_path}/MAE_{stats_path}_2_sec_w_brace_height.xlsx')
    df_MSE.to_excel(f'{dir_path}/MSE_{stats_path}_2_sec_w_brace_height.xlsx')
    df_R2.to_excel(f'{dir_path}/R2_{stats_path}_2_sec_w_brace_height.xlsx')

    return

if __name__ == '__main__':
    main()