# -*- coding: utf-8 -*-
"""
Created on Tue Dec 10 17:37:26 2024

@author: Caterina
"""
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)
import seaborn as sns


# 2 plots for int and pyr. Save figure on file
def create2Plots(listS,cap,sol,simName,outputFolder):
    S = len(listS)
        
    sol_int=sol.y[0:2*S]
    sol_pyr=sol.y[2*S:]

    fig,ax=plt.subplots(figsize=(9,6))
    
    for h in range(S): 
        plt.plot(sol.t, sol_int[2*h+1], linestyle='-',label="slice"+str(h+1),linewidth=3)
        
    ax.xaxis.set_minor_locator(MultipleLocator(max(sol.t)/25))
    ax.tick_params(axis='x', which='minor', direction='in')
    ax.tick_params(axis='x', length=5, direction='in')
    # Include legend, labels, and title for the plot
    plt.xlabel('time',fontsize=14, weight='bold',fontname='Arial')
    plt.ylabel('counts',fontsize=14, weight='bold',fontname='Arial')
    plt.title(simName+'\n INT',fontsize=14, weight='bold',fontname='Arial')
    plt.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left',fontsize=14)
    plt.tight_layout()
    plt.xticks(fontsize=14, weight='bold')
    plt.yticks(fontsize=14, weight='bold')    

    for pos in ['right', 'top']:
        plt.gca().spines[pos].set_visible(False)    
    plt.savefig(outputFolder+'\\'+simName+'_int.png', bbox_inches='tight')


    fig,ax=plt.subplots(figsize=(9,6))
    for h in range(S): 
        ax.plot(sol.t, sol_pyr[2*h+1], linestyle='-',label="slice"+str(h+1),linewidth=3)
        
    ax.xaxis.set_minor_locator(MultipleLocator(max(sol.t)/25))
    ax.tick_params(axis='x', which='minor', direction='in')
    ax.tick_params(axis='x', length=5, direction='in')
    # Include legend, labels, and title for the plot
    plt.xlabel('time',fontsize=14, weight='bold',fontname='Arial')
    plt.ylabel('counts',fontsize=14, weight='bold',fontname='Arial')
    plt.title(simName+'\n PYR',fontsize=14, weight='bold',fontname='Arial')
    plt.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left',fontsize=14)
    plt.tight_layout()
    plt.xticks(fontsize=14, weight='bold')
    plt.yticks(fontsize=14, weight='bold')

    for pos in ['right', 'top']:
        plt.gca().spines[pos].set_visible(False)

    plt.savefig(outputFolder+'\\'+simName+'_pyr.png', bbox_inches='tight')


    
# 1 plot for both pyr e int 
def create1Plot(listS,scalatura,sol,simName,outputFolder):
    myFontsize = 22 
    S = len(listS)
    colorsInt=['indianred','red','darkorange','firebrick']#
    colorsPyr=['royalblue','cornflowerblue','lightsteelblue','b']#
        
    
    sol_int=sol.y[0:2*S]
    sol_pyr=sol.y[2*S:]

    fig,ax=plt.subplots(figsize=(9,6))
    
    for h in range(S): 
        plt.plot(sol.t/scalatura, sol_int[2*h+1], linestyle='-',label="slice"+str(h+1),linewidth=3,color=colorsInt[h-1])
    for h in range(S): 
        plt.plot(sol.t/scalatura, sol_pyr[2*h+1], linestyle='-',label="slice"+str(h+1),linewidth=3,color=colorsPyr[h-1])

    ax.xaxis.set_minor_locator(AutoMinorLocator(5))
    ax.yaxis.set_minor_locator(AutoMinorLocator(5))
    ax.tick_params(axis='x', which='minor', direction='in')
    ax.tick_params(axis='x', length=5, direction='in')
    ax.tick_params(axis='y', which='minor', direction='in')
    ax.tick_params(axis='y', length=5, direction='in')    
    # Include legend, labels, and title for the plot
    plt.xlabel('time',fontsize=myFontsize, weight='bold',fontname='Arial')
    plt.ylabel('count',fontsize=myFontsize, weight='bold',fontname='Arial')
    plt.title(simName,fontsize=myFontsize, weight='bold',fontname='Arial')
    #plt.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left',fontsize=myFontsize)
    plt.tight_layout()
    
    plt.xticks(fontsize=myFontsize, weight='bold',fontname='Arial')    
    plt.yticks(fontsize=myFontsize, weight='bold',fontname='Arial')

    for pos in ['right', 'top']:
        plt.gca().spines[pos].set_visible(False)    
        
    plt.savefig(outputFolder+'\\'+simName+'_int-pyr.png', bbox_inches='tight')