# cp ../spineM9l/drawfig6varcontnm.py drawfigstdp.py
#drawfig6.py: Draws the STDP curves
#Tuomo Maki-Marttunen, 2019-2020
#CC BY 4.0
import matplotlib
matplotlib.use('Agg')
from pylab import *
import scipy.io
import sys
import itertools
from os.path import exists
import mytools
import pickle
from matplotlib.collections import PatchCollection
import calcconds
from scipy.stats.stats import pearsonr
TRAINISISALL=[-230., -220., -210., -200., -180., -160., -140., -120., -100., -80., -60., -50., -40., -30., -25., -20., -15., -10., -5., -2.5, 0., 2.5, 5., 10., 15., 20., 25., 30., 40., 50., 60., 80., 100., 120., 140., 160., 180., 200. ]
Econ = 0.00015
pulseamp = 10.0
gNap = 0.03
V_min_peak=-40
V_max_valley=0
givenT = 16
A = scipy.io.loadmat('nrn_tstop25840000_GluR1,GluR1_memb,GluR2,GluR2_memb,Cax0.5,0.5,1.5,1.5,1.0_n120_freq1.0_dur3.0_Lflux0.05_Gluflux0.0_AChflux0.05_Ntrains1_trainT100000.0_pair-100.0_icell1_pulseamp10.0_Nsyn1_Econ0.00015_wNMDA1.0_gNap0.03_Npulses4_apic250-300.mat') #Load this to find the time index
timevec = A['DATA'][0]
indT = argmin(abs(24040000+givenT*60000 - timevec))
print(str(givenT)+" min, indT = "+str(indT))
toBeRemovedIfNecessary = ['_CaCoeff1.0','_tol1e-06','_tstop15000000','3560000_600000','_Ninputs1','_tstop25000000','_onset24040000.0','_pulseamp5.0']
def boxoff(ax,whichxoff='top'):
ax.spines[whichxoff].set_visible(False)
ax.spines['right'].set_visible(False)
ax.get_xaxis().tick_bottom()
ax.get_yaxis().tick_left()
def get_stdp_curve(filename):
if filename.find('paired') > -1:
filenames = [filename.replace('0.0.mat',str(x)+'.mat') for x in TRAINISISALL]
else:
filenames = [filename.replace('pair0.0','pair'+str(x)) for x in TRAINISISALL]
ISIs = []
conds = []
for ifile in range(0,len(filenames)):
if exists(filenames[ifile]):
print('Loading '+filenames[ifile])
cond, times = calcconds.calcconds_nrn(filenames[ifile])
ISIs.append(TRAINISISALL[ifile])
conds.append(cond[indT]/cond[0])
else:
myfilename = filenames[ifile]
for i in range(0,len(toBeRemovedIfNecessary)):
if len(myfilename) > 254:
myfilename = myfilename.replace(toBeRemovedIfNecessary[i],'')
#print("myfilename not short enough previously, now "+myfilename)
if exists(myfilename):
print('Loading '+myfilename)
cond, times = calcconds.calcconds_nrn(myfilename)
ISIs.append(TRAINISISALL[ifile])
conds.append(cond[indT]/cond[0])
else:
print(filenames[ifile]+' does not exist')
if len(ISIs) > 20 and ISIs[20] != 0.0:
print ("Something wrong here")
if len(conds) == 0:
print(filename+": max cond not determined, len=0")
else:
print(filename+": max cond = "+str(max(conds)))
return [conds, ISIs]
def get_fi_curve(filename):
spfreqs = []
amps_saved = []
amps=[0.0, 0.025, 0.05, 0.075, 0.1, 0.125, 0.15, 0.175, 0.2, 0.225, 0.25]
filenames = [filename.replace('stimAmp0.0','stimAmp'+str(x)) for x in amps]
for iamp in range(0,len(amps)):
if exists(filenames[iamp]):
A = scipy.io.loadmat(filenames[iamp])
sps = mytools.spike_times(A['times'][0],A['vsoma'][0],V_min_peak=V_min_peak,V_max_valley=V_max_valley)
spfreqs.append(len(sps)/19.0) #[x/19.0 for x in Nsps]
amps_saved.append(amps[iamp])
else:
print(filenames[iamp]+' does not exist')
if len(amps_saved) > 0 and amps_saved[-1] != 0.25:
print("Something wrong here, spfreqs")
return [spfreqs, amps_saved]
f,axs = subplots(5,2)
axarr = axs.reshape(prod(axs.shape),).tolist()
for iax in range(0,len(axarr)):
boxoff(axarr[iax])
if iax in [0,1,2,3,4]:
axarr[iax].set_position([0.08+0.19*iax,0.6,0.14,0.25])
axarr[iax].set_xlim([-90,170])
axarr[iax].set_ylim([0.5,2.62])
axarr[iax].plot([-250,250],[1.0,1.0],'k--',dashes=[4,2],lw=0.3)
elif iax in [5,6,7,8,9]:
axarr[iax].set_position([0.08+0.19*(iax-5),0.27,0.14,0.25])
axarr[iax].set_xlim([0,0.25])
axarr[iax].set_xticks([0,0.1,0.2])
axarr[iax].set_ylim([0.0,5.0])
for tick in axarr[iax].xaxis.get_major_ticks()+axarr[iax].yaxis.get_major_ticks():
tick.label.set_fontsize(5)
for axis in ['top','bottom','left','right']:
axarr[iax].spines[axis].set_linewidth(0.3)
axarr[iax].tick_params(axis='x',direction='out',width=0.4,length=2)
axarr[iax].tick_params(axis='y',direction='out',width=0.4,length=2)
#for iax in range(0,2):
# axarr[1+iax].set_position([0.27+0.38*iax,0.81,0.32,0.126])
#for iax in range(0,3):
# axarr[3+iax].set_position([0.07+0.32*iax,0.62,0.26,0.126])
# axarr[6+iax].set_position([0.07+0.32*iax,0.43,0.26,0.126])
#for iax in range(0,2):
# axarr[9+iax].set_position([0.07+0.49*iax,0.24,0.42,0.126])
#for iax in range(0,2):
# axarr[11+iax].set_position([0.08+0.21*iax,0.05,0.19,0.126])
# axarr[13+iax].set_position([0.58+0.21*iax,0.01,0.19,0.166])
axinsets = []
for iax in range(0,len(axarr)):
axarr[iax].tick_params(axis='both', which='major', pad=2.0)
if iax < 5:
axinsets.append(f.add_axes([axarr[iax].get_position().x0+0.13,axarr[iax].get_position().y0+0.17,0.02,0.056]))
else:
axinsets.append(f.add_axes([axarr[iax].get_position().x0+0.13,axarr[iax].get_position().y0+0.04,0.02,0.056]))
for tick in axinsets[iax].xaxis.get_major_ticks()+axinsets[iax].yaxis.get_major_ticks():
tick.label.set_fontsize(5)
for axis in ['top','bottom','left','right']:
axinsets[iax].spines[axis].set_linewidth(0.3)
if iax not in [5,6,7,8,9]: #insets for stdp curves
boxoff(axinsets[iax]) #,'bottom')
axinsets[iax].set_ylim([0,160])
axinsets[iax].set_yticks([50,100,150]); axinsets[iax].set_yticklabels(['+50%','+100%','+150%']);
else: #insets for f-I curves
boxoff(axinsets[iax],'top')
axinsets[iax].set_ylim([0,125])
axinsets[iax].set_yticks([0,50,100]); axinsets[iax].set_yticklabels(['0','50%','100%']);
axinsets[iax].set_xticks([])
axinsets[iax].tick_params(axis='y',direction='out',length=1,width=0.3)
axinsets[iax].tick_params(axis='both', which='major', pad=0.1)
#for iax in list(range(0,3))+list(range(11,15)):
# axinsets[iax].set_visible(False)
conds,ISIs = get_stdp_curve('nrn_tstop25840000_GluR1,GluR1_memb,GluR2,GluR2_memb,Cax0.5,0.5,1.5,1.5,1.0_onset24040000.0_n120_freq1.0_dur3.0_Lflux0.05_Gluflux0.0_AChflux0.05_Ntrains1_trainT100000.0_pair0.0_icell1_pulseamp'+str(pulseamp)+'_Nsyn1_Econ'+str(Econ)+'_wNMDA1.0_gNap0.03_Npulses4_apic250-300.mat')
ISIs_control = ISIs[:]
conds_control = conds[:]
spfreqs_control,amps_control = get_fi_curve('../l23pc/somaticDC_icell0_imut0_stimAmp0.0.mat')
if len(conds) > 20:
print("control, max = "+str(max(conds))+", min = "+str(min(conds))+", conds[20] = "+str(conds[20])+", min(conds) = "+str(min(conds))+", imax = "+str(argmax(conds))+", imin = "+str(argmin(conds)))
cols = ['#666666', '#AA0000', '#FF00FF','#888800']
dimcols = ['#AAAAAA','#FF8888','#FF88FF','#AAAA44']
blockedSet = [10,11]
#Mutations (ion channel genes):
mutArray = ['gCa_HVAbar', 'gCa_LVAstbar', 'gIhbar', 'gK_Pstbar', 'gSK_E2bar', 'g', 'gNaTs2_tbar','gImbar','gK_Tstbar','gSKv3_1bar','decay', 'gamma', 'gNap_Et2bar']
mutSuffs = ['Ca_HVA', 'Ca_LVAst', 'Ih', 'K_Pst', 'SK_E2', 'pas', 'NaTs2_t', 'Im', 'K_Tst', 'SKv3_1', 'CaDynamics_E2', 'CaDynamics_E2', 'Nap_Et2']
titles = ['HVA Ca$^{2+}$', 'LVA Ca$^{2+}$', 'HCN', 'Slow K$^+$', 'SK', 'Leak', 'Fast Na$^+$', 'M-type K$^+$', 'Fast K$^+$', 'Kv3.1-type K$^+$', 'Ca$^{2+}$ decay', 'Ca$^{2+}$ fraction', 'Slow Na$^+$']
mutCoeffs = [0.75, 0.8, 0.9, 1.1, 1.2, 1.25]
#imutArray = [0, 1, 2, 3, 7, 6, 5]
imutArray = [7, 5, 6, 3, 1]
data_for_corr = []
data_for_corr_max = []
for iimut in range(0,len(imutArray)):
imut = imutArray[iimut]
for imutcoeff in range(0,2):
imutc = 6*imut+(2+3*imutcoeff)
conds,ISIs = get_stdp_curve('nrn_tstop25840000_GluR1,GluR1_memb,GluR2,GluR2_memb,Cax0.5,0.5,1.5,1.5,1.0_onset24040000.0_n120_freq1.0_dur3.0_Lflux0.05_Gluflux0.0_AChflux0.05_Ntrains1_trainT100000.0_pair0.0_icell1_imutc'+str(imutc)+'_pulseamp'+str(pulseamp)+'_Nsyn1_Econ'+str(Econ)+'_wNMDA1.0_gNap0.03_Npulses4_apic250-300.mat')
spfreqs,amps = get_fi_curve('../l23pc/somaticDC_icell0_imutc'+str(imutc)+'_stimAmp0.0.mat')
if len(conds) > 0:
axarr[iimut].plot([x+30 for x in ISIs],conds,'k.-',mew=0.8,ms=0.8,color=cols[imutcoeff],lw=0.3, label = '-20%' if imutcoeff==1 else '+20%')
axinsets[iimut].bar(imutcoeff*2,(max(conds)-1)*100,color=cols[imutcoeff])
axarr[5+iimut].plot(amps,spfreqs,'k.-',mew=0.8,ms=0.8,color=cols[imutcoeff],lw=0.3, label = '-20%' if imutcoeff==1 else '+20%')
if len(spfreqs) < 2:
print('Something wrong with ../l23pc/somaticDC_icell0_imutc'+str(imutc)+'_stimAmp0.0.mat, skipping...')
continue
axinsets[5+iimut].bar(imutcoeff*2,spfreqs[-1]/spfreqs_control[-1]*100,color=cols[imutcoeff])
tosum = [conds[i] for i in range(0,len(conds)) if [x+30 for x in ISIs][i] >= -90 and [x+30 for x in ISIs][i] <= 170]
ts = [ISIs[i] for i in range(0,len(conds)) if [x+30 for x in ISIs][i] >= -90 and [x+30 for x in ISIs][i] <= 170]
data_for_corr.append([sum([0.5*(tosum[i]+tosum[i+1])*(ts[i+1]-ts[i]) for i in range(0,len(tosum)-1)]),sum(spfreqs)*0.25])
if len(conds) > 0:
data_for_corr_max.append([max(conds),max(spfreqs)*0.25])
else:
data_for_corr_max.append([nan,nan])
axarr[iimut].plot([x+30 for x in ISIs_control],conds_control,'k.-',mew=0.8,ms=0.8,lw=0.3,label='control')
axarr[5+iimut].plot(amps_control,spfreqs_control,'k.-',mew=0.8,ms=0.8,lw=0.3,label='control')
axinsets[5+iimut].bar(1,100,color='#000000')
if len(conds_control) > 0:
axinsets[iimut].bar(1,(max(conds_control)-1)*100,color='#000000')
axarr[0].set_ylabel('Rel. cond. (A.U.)',fontsize=6.5)
axarr[5].set_ylabel('firing rate (spikes/s)',fontsize=6.5)
for iax in [0,1,2,3,4]:
axarr[iax].set_xlabel('ISI (ms)',fontsize=6.5,labelpad=1.0)
for iax in [5,6,7,8,9]:
axarr[iax].set_xlabel('DC amplitude ($\mu$A)',fontsize=6.5,labelpad=1.0)
for i in range(0,5):
axarr[i].set_title(titles[imutArray[i]],fontsize=6.5,pad=1.0)
axarr[5+i].set_title(titles[imutArray[i]],fontsize=6.5,pad=1.0)
for iax in range(0,10):
f.text(axarr[iax].get_position().x0-0.015,axarr[iax].get_position().y1+0.009-0.03*(iax==0),chr(ord('A')+iax),fontsize=8)
f.savefig('figstdp_supp.eps')