import numpy as np
from scipy.signal import cwt, morlet2
import matplotlib.pyplot as plt
from funcs import *

"""#Uncomment for redo simulations
Nnets = 15;
Ncycs = 30;
for fsin in [4,8]:
	v=[]
	for gsin in [0.5,1,2,3,4,5,6,7,8,9,10]:

		dt = .1;

		Ess = [-75.]#[-75.,-55.];

		Ntcyc = int(np.round(125./dt)); # Number of points per theta cycle
		phases = np.linspace(-np.pi,+np.pi,Ntcyc); # Values of phases within ones cycle. -PI and +PI for minima of externl input, 0 for the maxima.
		fs = np.arange(50.,450.1,3.);

		rateses = [ [ np.loadtxt("RatesFig7/rates_Esm%d_cm%d_%dHzgChR%d.dat" % (int(np.round(-Es)),kcm,int(fsin),int(gsin))) for kcm in range(Nnets) ] for Es in Ess ];
		cwtPowses = [ [ compPWT(rateses[kEs][kcm][-Ncycs*Ntcyc:],dt=dt*1.e-3,fs=fs) for kcm in range(Nnets) ] for kEs in range(len(Ess)) ];
		max_powers = [ [ np.max(cwtPowses[kEs][kcm]) for kcm in range(Nnets) ] for kEs in range(len(Ess)) ];
		maxPowses = np.array([ np.array([ np.array([ np.max( cwtPowses[kEs][kcm][:,kcyc*Ntcyc:kcyc*Ntcyc+Ntcyc] ) for kcyc in range(Ncycs) ]) for kcm in range(Nnets) ]) for kEs in range(len(Ess)) ]);
		kcycs_act_hyp = [ [ kcyc for kcyc in range(Ncycs) if np.max(cwtPowses[0][kcm][:,kcyc*Ntcyc:kcyc*Ntcyc+Ntcyc])>.3*max_powers[0][kcm] ] for kcm in range(Nnets) ];


		fpmaxes_hyp = [ [ np.min( fs[np.where( cwtPowses[0][kcm][:,kcyc*Ntcyc:(kcyc+1)*Ntcyc] > (1.-1.e-6)*np.max(cwtPowses[0][kcm][:,kcyc*Ntcyc:(kcyc+1)*Ntcyc])  )[0]] ) for kcyc in kcycs_act_hyp[kcm] ] for kcm in range(Nnets) ];
		#fpmaxes_shu = [ [ np.min( fs[np.where( cwtPowses[1][kcm][:,kcyc*Ntcyc:(kcyc+1)*Ntcyc] > (1.-1.e-6)*np.max(cwtPowses[1][kcm][:,kcyc*Ntcyc:(kcyc+1)*Ntcyc])  )[0]] ) for kcyc in range(Ncycs) ] for kcm in range(Nnets) ];

		meansfs_hyp = [ np.mean([ fpmaxes_hyp[kcm][kcyc] for kcyc in range(len(kcycs_act_hyp[kcm])) ]) for kcm in range(Nnets) ];
		stdsfs_hyp = [ np.std([ fpmaxes_hyp[kcm][kcyc] for kcyc in range(len(kcycs_act_hyp[kcm])) ]) for kcm in range(Nnets) ];

		#meansfs_shu = [ np.mean([ fpmaxes_shu[kcm][kcyc] for kcyc in range(Ncycs) ]) for kcm in range(Nnets) ];
		#stdsfs_shu = [ np.std([ fpmaxes_shu[kcm][kcyc] for kcyc in range(Ncycs) ]) for kcm in range(Nnets) ];

		v.append([gsin,np.mean(maxPowses[0],axis=1)[0], np.std(maxPowses[0],axis=1)[0], meansfs_hyp[0], stdsfs_hyp[0] ])

	np.savetxt("PowFreq_%dHz.dat" % int(fsin), v)
"""

es = 0
for fsin in [8,4]:
	Vector = np.loadtxt("PowFreq_%dHz.dat" % int(fsin))       
	tmp = plt.subplot(2,2,1+2*es);
	tmp = plt.title("Freq %d Hz" % int(fsin))
	tmp = plt.errorbar(Vector[:,0],Vector[:,3],yerr=Vector[:,4],fmt="x-");
	tmp = plt.xlabel("gChr (nS)"); tmp = plt.ylabel("Max Amplitude Network Frequency (Hz)");
	tmp = plt.ylim([0,300])
	tmp = plt.subplot(2,2,2+2*es);
	tmp = plt.title("Freq %d Hz" % int(fsin))
	tmp = plt.errorbar(Vector[:,0],Vector[:,1],yerr=Vector[:,2],fmt="x-");
	tmp = plt.xlabel("gChr (nS)"); tmp = plt.ylabel("Max Amplitude Power (u.a.)");
	tmp = plt.ylim([0,2.5e+6])
	es+=1
tmp = plt.tight_layout()

tmp = plt.savefig("Figures/Fig7BC.png")
tmp = plt.savefig("Figures/Fig7BC.eps")