import numpy as np
import sys
sys.path.append('../')
from transfer_functions.tf_simulation import single_experiment_2,single_experiment
from transfer_functions.load_config import load_transfer_functions
from scipy.integrate import odeint
from single_cell_models.cell_library import get_neuron_params
from synapses_and_connectivity.syn_and_connec_library import get_connectivity_and_synapses_matrix
from transfer_functions.tf_simulation import reformat_syn_parameters
import matplotlib.pylab as plt

from scipy.optimize import fsolve
from scipy.optimize import minimize
import math


NTWK='CONFIG1'


NRN1='HH_RS'
NRN2='HH_RS'


TF1, TF2 = load_transfer_functions(NRN1, NRN2, NTWK)




M = get_connectivity_and_synapses_matrix(NTWK, SI_units=True)
params = get_neuron_params(NRN1, SI_units=True)
reformat_syn_parameters(params, M)
paramsinh = get_neuron_params(NRN2, SI_units=True)
reformat_syn_parameters(paramsinh, M)

Qe, Te, Ee = params['Qe'], params['Te'], params['Ee']
Qi, Ti, Ei = params['Qi'], params['Ti'], params['Ei']
Gl, Cm , El = params['Gl'], params['Cm'] , params['El']
pconnec,Ntot,gei,ext_drive=params['pconnec'], params['Ntot'] , params['gei'],M[0,0]['ext_drive']






N = 100
freqs = np.linspace(3, 7, N)




frespthEx=0*freqs
muVexcth=0*freqs
stdexcth=0*freqs




conta=0
seeds = np.arange(len(freqs))

ddt=5e-6

finhib=8.

t = np.arange(int(10./ddt))*ddt



for freq, seed in zip(freqs, seeds):
   
   
    fetrue=freqs[conta]

                                                
    
    
    
    
    
    frespthEx[conta]=TF1(fetrue, finhib)
    




    muGe, muGi = Qe*Te*fetrue*(1.-gei)*pconnec*Ntot, Qi*Ti*finhib*gei*pconnec*Ntot
    muG = Gl+muGe+muGi
    muVV=(muGe*Ee+muGi*Ei+Gl*El)/muG
    
    
    muGn, Tm = muG/Gl, Cm/muG

    Ue, Ui = Qe/muG*(Ee-muVV), Qi/muG*(Ei-muVV)

    sV = np.sqrt(\
             (1.-gei)*pconnec*Ntot*fetrue*(Ue*Te)**2/2./(Te+Tm)+\
             gei*pconnec*Ntot*finhib*(Ti*Ui)**2/2./(Ti+Tm))
    

    muVexcth[conta]=muVV
    stdexcth[conta]=sV




    conta+=1



np.save("TF_B_HHRS_th.npy",[freqs,frespthEx])
freqsexp,frespEx,muVexcexp,stdexcexp=np.load("TF_B_HHRS.npy")

plt.plot(freqsexp,frespEx,'go',freqs,frespthEx,'g-')
plt.show()


plt.plot(freqsexp,muVexcexp,'go',freqs,muVexcth,'g-')
plt.show()

plt.plot(freqsexp,stdexcexp,'go',freqs,stdexcth,'g-')
plt.show()