import numpy as np
from scipy.signal import cwt, morlet2
import matplotlib.pyplot as plt
from funcs import *


Nnets = 30;
Ncycs = 30;
fsin = 8
gsin=7.


dt = .1;

Ess = [-55.]#[-75.,-55.];

for factor in [1,2,3,5]:
	color =["b","g","r","y","c"]

	Ntcyc = int(np.round(125./dt)); # Number of points per theta cycle
	phases = np.linspace(-np.pi,+np.pi,Ntcyc); # Values of phases within ones cycle. -PI and +PI for minima of externl input, 0 for the maxima.
	fs = np.arange(50.,450.1,3.);

	rateses = [ [ np.loadtxt("RatesFig7/NonPhysrates_Esm%d_cm%d_%dHzgChR%d_Fact%d.dat" % (int(np.round(-Es)),kcm,int(fsin),int(gsin), int(factor))) for kcm in range(Nnets) ] for Es in Ess ];
	cwtPowses = [ [ compPWT(rateses[kEs][kcm][-Ncycs*Ntcyc:],dt=dt*1.e-3,fs=fs) for kcm in range(Nnets) ] for kEs in range(len(Ess)) ];
	max_powers = [ [ np.max(cwtPowses[kEs][kcm]) for kcm in range(Nnets) ] for kEs in range(len(Ess)) ];
	maxPowses = np.array([ np.array([ np.array([ np.max( cwtPowses[kEs][kcm][:,kcyc*Ntcyc:kcyc*Ntcyc+Ntcyc] ) for kcyc in range(Ncycs) ]) for kcm in range(Nnets) ]) for kEs in range(len(Ess)) ]);
	kcycs_act_hyp = [ [ kcyc for kcyc in range(Ncycs) if np.max(cwtPowses[0][kcm][:,kcyc*Ntcyc:kcyc*Ntcyc+Ntcyc])>.3*max_powers[0][kcm] ] for kcm in range(Nnets) ];


	fpmaxes_hyp = [ [ np.min( fs[np.where( cwtPowses[0][kcm][:,kcyc*Ntcyc:(kcyc+1)*Ntcyc] > (1.-1.e-6)*np.max(cwtPowses[0][kcm][:,kcyc*Ntcyc:(kcyc+1)*Ntcyc])  )[0]] ) for kcyc in kcycs_act_hyp[kcm] ] for kcm in range(Nnets) ];
	#fpmaxes_shu = [ [ np.min( fs[np.where( cwtPowses[1][kcm][:,kcyc*Ntcyc:(kcyc+1)*Ntcyc] > (1.-1.e-6)*np.max(cwtPowses[1][kcm][:,kcyc*Ntcyc:(kcyc+1)*Ntcyc])  )[0]] ) for kcyc in range(Ncycs) ] for kcm in range(Nnets) ];

	meansfs_hyp = [ np.mean([ fpmaxes_hyp[kcm][kcyc] for kcyc in range(len(kcycs_act_hyp[kcm])) ]) for kcm in range(Nnets) ];
	stdsfs_hyp = [ np.std([ fpmaxes_hyp[kcm][kcyc] for kcyc in range(len(kcycs_act_hyp[kcm])) ]) for kcm in range(Nnets) ];

	#meansfs_shu = [ np.mean([ fpmaxes_shu[kcm][kcyc] for kcyc in range(Ncycs) ]) for kcm in range(Nnets) ];
	#stdsfs_shu = [ np.std([ fpmaxes_shu[kcm][kcyc] for kcyc in range(Ncycs) ]) for kcm in range(Nnets) ];

	width = 5.2 # in inches
	height = 7.# in inches
	#tmp = plt.figure(figsize=[width,height]);

	# Plot the maximum wavelet powers
	for kEs in range(len(Ess)):
		tmp1 = plt.subplot(2,1,kEs+1);
		tmp = plt.errorbar(range(Nnets), np.mean(maxPowses[kEs],axis=1), yerr=np. std(maxPowses[kEs],axis=1), marker="o",color=color[factor-1], ls="");
		#tmp = plt.ylim(0.,4.e+6);
		tmp1.spines['top'].set_visible(False)
		tmp1.spines['right'].set_visible(False)
		tmp1.spines['bottom'].set_visible(True)
		tmp1.spines['left'].set_visible(True)
		tmp = plt.ylabel("$P_{max}$ (spike$^2$/s$^2$)")
		tmp1.set_xticks([10,20,30])
		tmp1.set_xticklabels([])
		tmp1.set_yticks([0.,1.e+6,2.e+6,3.e+6,4.e+6])
		tmp1.set_yticklabels(["0","","","","4"])

	# Plot the frequencies with maximum wavelet power

	# For hyperpolarizing synapses
	tmp1 = plt.subplot(2,1,2);
	tmp = plt.errorbar(range(len(meansfs_hyp)), meansfs_hyp, yerr=stdsfs_hyp, marker="o", color=color[factor-1], ls="");
	tmp = plt.ylim(0.,450.);
	tmp1.spines['top'].set_visible(False)
	tmp1.spines['right'].set_visible(False)
	tmp1.spines['bottom'].set_visible(True)
	tmp1.spines['left'].set_visible(True)
	tmp = plt.ylabel("$F_{max}$ (Hz)")
	tmp = plt.xlabel("network index")
	tmp1.set_xticks([10,20,30])
	tmp1.set_xticklabels([])
	tmp1.set_yticks([0,200,400])
	# For shunting synapses
	#tmp1 = plt.subplot(2,2,4);
	#tmp = plt.errorbar(range(len(meansfs_shu)), meansfs_shu, yerr=stdsfs_shu, marker="o", color="g", ls="");
	#tmp = plt.ylim(0.,450.);
	#tmp1.spines['top'].set_visible(False)
	#tmp1.spines['right'].set_visible(False)
	#tmp1.spines['bottom'].set_visible(True)
	#tmp1.spines['left'].set_visible(True)
	#tmp = plt.ylabel("$F_{max}$ (Hz)")
	#tmp = plt.xlabel("network index")
	#tmp1.set_xticks([10,20,30])
	#tmp1.set_xticklabels([])
	#tmp1.set_yticks([0,200,400])
plt.legend(["x1","x2","x3","x5"])
plt.savefig("Figures/Supp4_%dHzgChR%dFactor%d.eps" % (int(fsin),int(gsin), int(factor)) );plt.savefig("Figures/Supp4_%dHzgChR%dFactor%d.png" % (int(fsin),int(gsin),int(factor)) );