#######################################
# membrane_potential_heat_plot.py
#
# Andrew Knox 2014
#
# Python Script to make a heat plot from a data file created by Fspikewave.oc
#
# The first value contains the number of time points, the next the number of cells
# Subsequent lines correspond to the list of membrane poteitnals for a given cell
########################################
from numpy import *
import numpy as np
from matplotlib import *
from matplotlib import pyplot as plt
import matplotlib as mpl
from matplotlib import gridspec
#for 60x40inch image
#tf = 88 #title font
#stf = 44 #subtitle font
#af = 36 #axis font
#atf = 30 #tick font
#satf = 20 #small tick font (for color bar)
#for 18x12cm image
tf = 9 #title font
stf = 6 #subtitle font
af = 6 #axis font
atf = 6 #tick font
satf = 6 #small tick font (for color bar)
def compress_array(inarray,outarray,irows):
subarrays = []
for i in range(irows):
s = slice(i,None,irows)
subarrays.append(inarray[s,:])
outarray = subarrays[0]
for i in range(irows-1):
outarray = maximum(outarray,subarrays[i+1])
return outarray
data = open('membrane_data.txt','r')
ntimepoints = int(data.readline())
ncells = int(data.readline())
print "data:", ntimepoints,",", ncells
compress_factor = 1
mod_ntimepoints = ntimepoints - ntimepoints % compress_factor
tcvdata = np.zeros( (mod_ntimepoints,ncells) )
revdata = np.zeros( (mod_ntimepoints,ncells) )
pyvdata = np.zeros( (mod_ntimepoints,ncells) )
invdata = np.zeros( (mod_ntimepoints,ncells) )
for i in range(ncells):
for j in range(ntimepoints):
if j < mod_ntimepoints:
tcvdata[j,i] = float(data.readline())
else:
data.readline()
data.readline()
for i in range(ncells):
for j in range(ntimepoints):
if j < mod_ntimepoints:
revdata[j,i] = float(data.readline())
else:
data.readline()
data.readline()
for i in range(ncells):
for j in range(ntimepoints):
if j < mod_ntimepoints:
pyvdata[j,i] = float(data.readline())
else:
data.readline()
data.readline()
for i in range(ncells):
for j in range(ntimepoints):
if j < mod_ntimepoints:
invdata[j,i] = float(data.readline())
else:
data.readline()
data.readline()
data.close()
pyvdata = pyvdata[00000:30000,:]
#invdata = invdata[00000:30000,:]
tcvdata = tcvdata[00000:30000,:]
revdata = revdata[00000:30000,:]
#pyvdata = compress_array(pyvdata,pyvdata,compress_factor)
#invdata = compress_array(invdata,invdata,compress_factor)
#revdata = compress_array(revdata,revdata,compress_factor)
#tcvdata = compress_array(tcvdata,tcvdata,compress_factor)
pyvdata = transpose(pyvdata)
#invdata = transpose(invdata)
revdata = transpose(revdata)
tcvdata = transpose(tcvdata)
xaxis = linspace(0,ntimepoints/10000,num=ntimepoints/compress_factor)
yaxis = linspace(1,100,num=100)
#fig = plt.figure(figsize=[60,40])
fig = plt.figure(figsize=[6.7,4.47])
plt.suptitle("Raster Plot of Simulation with Baseline Parameters (Spindle Oscillation)",fontsize=tf)
#leaves space on right for colorbar
#gs = gridspec.GridSpec(1,3,width_ratios=[1,1,1.2])
gs = gridspec.GridSpec(3,2,width_ratios=[98,2])
plt.subplot(gs[0,0])
heatmap = plt.pcolormesh(xaxis,yaxis,pyvdata,cmap=mpl.cm.jet,vmin=-100,vmax=0)
plt.axis([0,ntimepoints/10000,1,100])
plt.title("Cortical Pyramidal Neurons (PY)", fontsize=stf, fontweight = 'bold')
#plt.xlabel("Time (s)",fontsize=af)
plt.ylabel("Neuron Index",fontsize=af)
plt.tick_params(labelsize=atf)
#plt.subplot(2,2,2)
#heatmap = plt.pcolormesh(invdata,cmap=mpl.cm.jet,vmin=-100,vmax=0)
plt.subplot(gs[1,0])
heatmap = plt.pcolormesh(xaxis,yaxis,revdata,cmap=mpl.cm.jet,vmin=-100,vmax=0)
plt.title("Thalamic Reticular Nucleus Neurons (RE)", fontsize=stf,fontweight='bold')
plt.axis([0,ntimepoints/10000,1,100])
#plt.xlabel("Time (s)",fontsize=af)
plt.ylabel("Neuron Index",fontsize=af)
plt.tick_params(labelsize=atf)
plt.subplot(gs[2,0])
heatmap = plt.pcolormesh(xaxis,yaxis,tcvdata,cmap=mpl.cm.jet,vmin=-100,vmax=0)
plt.title("Thalamocortical Neurons (TC)", fontsize=stf, fontweight='bold')
plt.axis([0,ntimepoints/10000,1,100])
plt.xlabel("Time (s)",fontsize=af)
plt.ylabel("Neuron Index",fontsize=af)
plt.tick_params(labelsize=atf)
#gives space for top title
gs.tight_layout(fig,rect=[0,0,1,0.97],h_pad=0.1)
#plt.subplots_adjust(top=0.85)
axes = plt.subplot(gs[:,1])
cb = plt.colorbar(cax=axes)
cb.set_label("mV",fontsize=satf,labelpad=-1)
cb.ax.tick_params(labelsize=satf)
gs.tight_layout(fig,rect=[0,0,1,0.97],h_pad=0.1)
plt.savefig('figure 2b',dpi=1200)