#drawfig9abc.py: Draws the model outputs for the parameters fitted to cortical LTP/LTD data.
#Tuomo Maki-Marttunen, 2019-2020
#CC BY 4.0
import matplotlib
matplotlib.use('Agg')
from pylab import *
import mytools
import pickle
import protocols_many
import protocols_many_78withoutCK
import protocols_many_78withoutCK_1withCK
from os.path import exists
import time
import scipy.stats
def plotmybox(ax,ys,x=0,w=0.5,lw=0.5,col='#000000'): #ys: vector of 5 elements: min, prc-25, median, prc-75, max
ax.plot([x-w,x+w,x,x,x-w,x-w,x+w,x+w,x,nan,x-w,x+w,nan,x,x,x-w,x+w],[ys[0],ys[0],ys[0],ys[1],ys[1],ys[3],ys[3],ys[1],ys[1],nan,ys[2],ys[2],nan,ys[3],ys[4],ys[4],ys[4]],'k-',linewidth=lw,color=col)
VARIABLES = [["Caflux",0,5000], #the upper limit of Caflux will be changed according to imeas
["Lflux",0.0,5.0],
["Gluflux",0,200],
["GluR1_ratio",0.0,1.0],
["IC_MGluRM1GqPLC",0.0,2.0],
["IC_RGsAC1AC8",0.0,2.0],
["IC_CaMCK",0.0,2.0],
["IC_NCX",0.0,2.0],
["IC_PKC",0.0,5.0],
["IC_PKA",0.0,2.0],
["IC_PP1PP2B",0.0,2.0],
["IC_PDE1PDE4",0.0,2.0],
]
Caflux_limits = [20000, 20000, 13000, 13000, 50000, 10000, 40000, 20000, 20000, 16000, 16000] #Planned so that [Ca flux]*T_total_input is around 2e6, but for imeas=7,8, two different protocols used - something in the middle taken
imeass = [0,1,2,3,4,5,6,7,8,9,10,7,8]
captions = ['EC-1','EC-2','PFC-1','PFC-2','BC','ACC','PFC-3','VC-1','VC-2','AC-1','AC-2']
myseeds = [1,1,1,1,1,1,1,30,30,1,1,1,1]
exts = ['fewer','fewer','fewer','fewer','fewer','fewer','fewer','manyb','manyb','fewer','fewer','fewer','fewer',]
rundexts = ['fewer','fewerCK1imeas','fewer','fewer','fewer','fewer','fewer','manyb','manyb','fewer','fewer','fewer','fewer',]
isamps = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Nsamps = [1000,1000,1000,1000,1000,1000,1000,500,500,1000,1000,1000,1000]
Measurement_protocol = protocols_many.get_measurement_protocol()
Measurement_protocol_78withoutCK = protocols_many_78withoutCK.get_measurement_protocol()
Measurement_protocol_78withoutCK_1withCK = protocols_many_78withoutCK_1withCK.get_measurement_protocol()
maxerr = 1.0
maxcaerr = 0.0
def clamp(x):
return max(0, min(int(256*x), 255))
def col2hexcol(rgb,brightness=1.0,dim=0.0):
meanrgb = mean(rgb[0:3])
r = (rgb[0]*(1-dim)+meanrgb*dim)*brightness
g = (rgb[1]*(1-dim)+meanrgb*dim)*brightness
b = (rgb[2]*(1-dim)+meanrgb*dim)*brightness
return "#{0:02x}{1:02x}{2:02x}".format(clamp(r), clamp(g), clamp(b))
try:
cmap = matplotlib.cm.get_cmap('viridis')
colors = [col2hexcol(cmap(0.31*i)) for i in range(0,4)]
except:
colors = ['#440154', '#33628d', '#26ad81', '#d3e21b']
f,axarr = subplots(16,1)
for iax in range(0,16):
for axis in ['top','bottom','left','right']:
axarr[iax].spines[axis].set_linewidth(0.5)
axarr[iax].spines['top'].set_visible(False)
axarr[iax].spines['right'].set_visible(False)
axarr[iax].tick_params(width=0.2,length=2.0,labelsize=5)
for iax in range(0,7):
axarr[iax].set_position([0.05, 0.86-0.1*iax, 0.11, 0.08])
for iax in range(0,2):
axarr[9+iax].set_position([0.05, 0.16-0.1*iax, 0.11, 0.08])
axarr[7+iax].set_position([0.23, 0.75-0.19*iax, 0.11, 0.17])
axarr[11+iax].set_position([0.41, 0.75-0.19*iax, 0.11, 0.17])
labels78 = ['control (HFS)','CaMKII blocked (HFS)', 'control (LFS)','CaMKII blocked (LFS)']
labels78_B = ['control (HFS)','control (LFS)']
for iimeas in range(0,13):
imeas = imeass[iimeas]
MeasurementsAll = Measurement_protocol_78withoutCK[0]
Measurements_stdsAll = Measurement_protocol_78withoutCK[7]
if iimeas == 7 or iimeas == 8:
MeasurementsAll = Measurement_protocol[0]
Measurements_stdsAll = Measurement_protocol[7]
print "iimeas = "+str(iimeas)
Measurements = MeasurementsAll[imeas]
targetTs = Measurements[1]
targetVals = Measurements[2]
Measurement_stds = Measurements_stdsAll[imeas]
OBJECTIVES = ['f'+str(i) for i in range(0,len(Measurements[0])+1)]
VARIABLES[0][2] = Caflux_limits[imeas]
mylw = 0.6
myms = 1.1
finalThrsAbsolute = [maxerr*nansum(Measurement_stds[i]) for i in range(0,len(Measurement_stds))]+[maxcaerr]
goodparams = []
gooddata = []
IDs = []
coeffs = rand(len(VARIABLES),)
Nall = 0
filename = 'fitfiles/'+exts[iimeas]+str(imeas)+'_seed'+str(myseeds[iimeas])+'_N'+str(Nsamps[iimeas])
fitnesses = []
for gen in range(24,0,-1):
if exists(filename+'_tmp'+str(gen)+'.sav'):
gensdone = gen
print 'loading '+filename+'_tmp'+str(gen)+'.sav'
unpicklefile = open(filename+'_tmp'+str(gen)+'.sav', 'r')
unpickledlist = pickle.load(unpicklefile)
unpicklefile.close()
params_all = unpickledlist[0]
columns = unpickledlist[1]
for iparam in range(0,params_all.shape[0]):
Nall = Nall + 1
isbelowMed = True
fitness = 0
for iobj in range(0,len(OBJECTIVES)):
if finalThrsAbsolute[iobj] > 0:
fitness = fitness + params_all[iparam,len(VARIABLES)+iobj]/finalThrsAbsolute[iobj]
if params_all[iparam,len(VARIABLES)+iobj] > finalThrsAbsolute[iobj]:
isbelowMed = False
break
if isbelowMed:
myID = sum([coeffs[i]*params_all[iparam,i] for i in range(0,len(VARIABLES))])
if myID not in IDs:
gooddata.append(params_all[iparam,:])
goodparams.append([(params_all[iparam,i] - VARIABLES[i][1])/(VARIABLES[i][2] - VARIABLES[i][1]) for i in range(0,len(VARIABLES))])
IDs.append(myID)
fitnesses.append(fitness)
myord = [i[0] for i in sorted(enumerate(fitnesses), key=lambda x:x[1])]
filename = 'fitfiles/rungiven_'+rundexts[iimeas]+str(imeas)+'_seed'+str(myseeds[iimeas])+'_N'+str(Nsamps[iimeas])+'_maxerr1.0_maxcaerr0.0_'+str(isamps[iimeas])+'.sav'
print filename
if not exists(filename):
print filename+' does not exist'
time.sleep(0.02)
continue
print 'loading '+filename
unpicklefile = open(filename,'r')
unpickledlist = pickle.load(unpicklefile)
unpicklefile.close()
mydict = unpickledlist[0]
A = unpickledlist[1]
timesAll = A[0]
timeCoursesAll = A[1]
maxValsAll = A[2]
MeasurementsAll = Measurement_protocol_78withoutCK_1withCK[0]
Measurements_stdsAll = Measurement_protocol_78withoutCK_1withCK[7]
if iimeas == 7 or iimeas == 8:
MeasurementsAll = Measurement_protocol[0]
Measurements_stdsAll = Measurement_protocol[7]
if iimeas == 1: # Do not plot the data for the neglected objective function
MeasurementsAll = Measurement_protocol[0]
Measurements_stdsAll = Measurement_protocol[7]
#All data sets have same checking protocol file
Measurements = MeasurementsAll[imeas]
targetTs = Measurements[1]
targetVals = Measurements[2]
Measurement_stds = Measurements_stdsAll[imeas]
OBJECTIVES = ['f'+str(i) for i in range(0,len(Measurements[0])+1)]
errSum = 0
for iobj in range(0,len(targetVals)):
mycolor=colors[iobj] if iimeas != 11 and iimeas != 12 else colors[2*iobj]
errThis = 0
for itarget in range(0,len(targetTs)):
itime = argmin(abs(timesAll[iobj]-targetTs[itarget]))
myval = timeCoursesAll[iobj][itime]/timeCoursesAll[iobj][0]
if isnan(targetVals[iobj][itarget]):
continue
errThis = errThis + abs(targetVals[iobj][itarget] - myval)
errSum = errSum + errThis
print "imeas = "+str(imeas)+", ext="+exts[iimeas]+",iobj = "+str(iobj)+", errSum = "+str(errSum)
if iimeas == 7 or iimeas == 8:
axarr[iimeas].plot([3e6]+timesAll[iobj].tolist(),[1.0]+[timeCoursesAll[iobj][k]/timeCoursesAll[iobj][0] for k in range(0,len(timesAll[iobj]))],'b-',color=mycolor,lw=mylw,label=labels78[iobj])
elif iimeas == 11 or iimeas == 12:
axarr[iimeas].plot([3e6]+timesAll[iobj].tolist(),[1.0]+[timeCoursesAll[iobj][k]/timeCoursesAll[iobj][0] for k in range(0,len(timesAll[iobj]))],'b-',color=mycolor,lw=mylw,label=labels78[iobj])
else:
axarr[iimeas].plot([3e6]+timesAll[iobj].tolist(),[1.0]+[timeCoursesAll[iobj][k]/timeCoursesAll[iobj][0] for k in range(0,len(timesAll[iobj]))],'b-',color=mycolor,lw=mylw,label='{:.3f}'.format(errThis)+', '+'{:.4f}'.format(max(A[2])))
axarr[iimeas].plot([x-10000*(iobj-1) for x in targetTs],Measurements[2][iobj],'r.',color=mycolor,mew=myms,ms=myms)
if imeas == 7 or imeas == 8:
for itarget in range(0,len(targetTs)):
axarr[iimeas].plot([targetTs[itarget]-10000*(iobj-1),targetTs[itarget]-10000*(iobj-1)],[Measurements[2][iobj][itarget]-Measurement_stds[iobj][itarget],Measurements[2][iobj][itarget]+Measurement_stds[iobj][itarget]],'r-',color=mycolor,lw=mylw)
legax = mytools.mylegend(f,[0.04,0.945,0.15,0.037],['b-','b-','b-'],['1st experiment', '2nd experiment', '3rd experiment'],1,2,0.5,0.35,colors[0:3],dashes=[],linewidths=[],myfontsize=4.5)
for axis in ['top','bottom','left','right']:
legax.spines[axis].set_visible(False) #linewidth(0.5)
legax = mytools.mylegend(f,[0.22,0.93,0.15,0.05],['b-','b-','b-','b-'],labels78,1,2,0.5,0.35,colors,dashes=[],linewidths=[],myfontsize=4.5)
for axis in ['top','bottom','left','right']:
legax.spines[axis].set_visible(False) #linewidth(0.5)
legax = mytools.mylegend(f,[0.4,0.93,0.1,0.027],['b-','b-'],labels78_B,1,2,0.5,0.35,[colors[i] for i in [0,2]],dashes=[],linewidths=[],myfontsize=4.5)
f.text(0.005,0.95,'A',fontsize=12)
f.text(0.175,0.95,'B',fontsize=12)
f.text(0.375,0.95,'C',fontsize=12)
for axis in ['top','bottom','left','right']:
legax.spines[axis].set_visible(False) #.set_linewidth(0.5)
ylims = [[0.98,1.7],[0.98,1.7],[0.98,2.05],[0.98,1.74],[0.98,1.6],[0.9,1.6],[0.98,1.44],[0.7,1.65],[0.7,1.45],[0.51,2.15],[0.51,2.15],[0.4,1.4],[0.4,1.4]]
yticks = [[1.0,1.25,1.5],[1.0,1.25,1.5],[1.0,1.4,1.8],[1.0,1.3,1.6],[1.0,1.25,1.5],[1.0,1.25,1.5],[1.0,1.25],[0.75,1.0,1.25],[0.75,1.0,1.25],[0.6,1.0,1.4,1.8],[0.6,1.0,1.4,1.8]]
for iimeas in range(0,13):
imeas = imeass[iimeas]
axarr[iimeas].set_xlim([3.5e6,5.3e6])
axarr[iimeas].set_xticks([3.44e6, 4.04e6, 4.64e6, 5.24e6])
if imeas in [8,10]:
axarr[iimeas].set_xticklabels(['-10', '0', '10', '20'],fontsize=5)
else:
axarr[iimeas].set_xticklabels([])
axarr[iimeas].set_ylim(ylims[iimeas])
axarr[iimeas].set_yticks(yticks[imeas])
axarr[iimeas].set_yticklabels([str(x) for x in yticks[imeas]],fontsize=5)
for tick in axarr[iimeas].xaxis.get_major_ticks() + axarr[iimeas].yaxis.get_major_ticks():
tick.label.set_fontsize(3.5)
axarr[iimeas].text(3.48e6,ylims[iimeas][0]*0.01+ylims[iimeas][1]*0.99,captions[imeas],fontsize=6,va='top')
axarr[10].set_xlabel('time (min)',fontsize=5)
axarr[8].set_xlabel('time (min)',fontsize=5)
axarr[12].set_xlabel('time (min)',fontsize=5)
axarr[4].set_ylabel('relative conductance',fontsize=5)
axarr[8].set_ylabel(' relative conductance',fontsize=5)
axarr[12].set_ylabel(' relative conductance',fontsize=5)
axarr[13].set_visible(False)
axarr[14].set_visible(False)
axarr[15].set_visible(False)
f.savefig("fig9abc.eps")