import numpy as np
import matplotlib.pyplot as plt
import FSIN as sim
import os
import brian2 as b2
TigerFish = True
if TigerFish: # For Tigerfish use only
cache_dir = os.path.expanduser('~/.cython/brian-pid-{}'.format(os.getpid()))
b2.prefs.codegen.runtime.cython.cache_dir = cache_dir
b2.prefs.codegen.runtime.cython.multiprocess_safe = False
method = "rk4"; # methods "rk4" and "gsl_rk8pd" give similar results
dt = 4.e-3;
factor=1;
gm_scale = factor;
tau_fall=2.0
tau_rise=.3
FactorScaleKV3=1.0
connectivity_matrix=0; # Index for the connectivity matrix to be used among the ones predefined (loaded from file in the params folder)
Ess = [-75.,-55.];
fsin=8.;
gsin=7.;
std=True
# Total simulated time. Theta frequency is 8 Hz, i.e. period of 125 ms.
# sim_time given as multiple of period, i.e. number of simulated cycles.
# Only the last two are plotted. To remove effects of transients.
sim_time = 6.*125.;
dt_rec = 0.01
datas_nogjs = [ sim.gewnet(std=std,FactorScaleKV3=FactorScaleKV3,tau_fall=tau_fall,read_gms=False,scale_gm=False,gm_scale = gm_scale,read_delays=False,dmin=1.6,dmax=1.6,gsin=gsin,fsin=fsin,sim_time=sim_time,mod_gL=False,gjs=False,method=method,dt=dt,connectivity_matrix=connectivity_matrix,Es=Es) for Es in Ess ];
datas_physiogjs = [ sim.gewnet(std=std, FactorScaleKV3=FactorScaleKV3,dt_rec=dt_rec,return_state=True,tau_fall=tau_fall,tau_rise=tau_rise,read_gms=False,scale_gm=False,gm_scale = gm_scale,read_delays=False,dmin=1.6,dmax=1.6,gsin=gsin,fsin=fsin,sim_time=sim_time,mod_gL=True,gjs=True,method=method,dt=dt,connectivity_matrix=connectivity_matrix,Es=Es) for Es in Ess ];
datas_unphysiogjs = [ sim.gewnet(std=std,FactorScaleKV3=FactorScaleKV3,tau_fall=tau_fall,tau_rise=tau_rise,read_gms=False,scale_gm=False,gm_scale = gm_scale,read_delays=False,dmin=1.6,dmax=1.6,gsin=gsin,fsin=fsin,sim_time=sim_time,mod_gL=False,gjs=True,read_ggaps=False,ggaps=2.,method=method,dt=dt,connectivity_matrix=connectivity_matrix,Es=Es) for Es in Ess ];
spikeses_nogjs = [ db[1].t/sim.ms for db in datas_nogjs ];
spikeses_physiogjs = [ db[1].t/sim.ms for db in datas_physiogjs ];
spikeses_unphysiogjs = [ db[1].t/sim.ms for db in datas_unphysiogjs ];
Title = ["A","B"]
width = 5.2 # in inches
height = 7.# in inches
tmp = plt.figure(figsize=[width,height]);
for kEs in range(len(Ess)):
tmp1 = plt.subplot(3,2,kEs+1);
tmp = plt.hist(spikeses_nogjs[kEs], np.linspace(sim_time-2.*125.,sim_time,2*125+1),color = 'k');
tmp = plt.xlim(sim_time-2.*125.,sim_time); tmp = plt.ylim(0.,55.);
tmp1.spines['top'].set_visible(False)
tmp1.spines['right'].set_visible(False)
tmp1.spines['bottom'].set_visible(False)
tmp1.spines['left'].set_visible(True)
tmp1.set_xticks([])
tmp1.set_xticklabels([])
tmp1.set_yticks([0,50])
tmp = plt.ylabel("Counts",fontsize=11)
tmp = plt.title("%s1. No Gap Junctions ($E_{syn}$=%d)" % (Title[kEs],Ess[kEs]),fontsize=10)
tmp2 = plt.subplot(3,2,kEs+3);
tmp = plt.hist(spikeses_physiogjs[kEs], np.linspace(sim_time-2.*125.,sim_time,2*125+1),color = 'k');
tmp = plt.xlim(sim_time-2.*125.,sim_time); tmp = plt.ylim(0.,85.);
tmp2.spines['top'].set_visible(False)
tmp2.spines['right'].set_visible(False)
tmp2.spines['bottom'].set_visible(False)
tmp2.spines['left'].set_visible(True)
tmp2.set_xticks([])
tmp2.set_xticklabels([])
tmp2.set_yticks([0,50])
tmp = plt.ylabel("Counts",fontsize=11)
tmp = plt.title("%s1. Physiological $g_{GAP}$ ($E_{syn}$=%d)" % (Title[kEs],Ess[kEs]),fontsize=10)
tmp3 = plt.subplot(3,2,kEs+5);
tmp3.spines['top'].set_visible(False)
tmp3.spines['right'].set_visible(False)
tmp3.spines['bottom'].set_visible(False)
tmp3.spines['left'].set_visible(True)
tmp3.set_yticks([0,50,100])
tmp = plt.hist(spikeses_unphysiogjs[kEs], np.linspace(sim_time-2.*125.,sim_time,2*125+1),color = 'k');
tmp = plt.xlim(sim_time-2.*125.,sim_time); tmp = plt.ylim(0.,105.);
tmp = plt.ylabel("Counts",fontsize=11)
tmp = plt.title("%s3. Non-physiological $g_{GAP}$ ($E_{syn}$=%d)" % (Title[kEs],Ess[kEs]),fontsize=10)
tmp1.set_ylabel("")
tmp2.set_ylabel("")
tmp3.set_ylabel("")
tmp1.set_yticks([0,50])
tmp1.set_yticklabels([])
tmp2.set_yticks([0,50])
tmp2.set_yticklabels([])
tmp3.set_yticks([0,50,100])
tmp3.set_yticklabels([])
tmp = plt.tight_layout(pad=0.0, w_pad=2.0, h_pad=1.0)
tmp = plt.savefig("Figures/FigS3_f{f}gChR{g}.eps".format(f=fsin,g=gsin), dpi=300); #tmp = plt.clf(); # To save plot in eps file
tmp = plt.savefig("Figures/FigS3_f{f}gChR{g}.png".format(f=fsin,g=gsin), dpi=300); tmp = plt.clf(); # To save plot in png file