



import numpy as np
from sys import argv
import matplotlib.pyplot as plt
import seaborn as sns
from math import ceil

if len(argv) < 3:
	print 'expects fname, nx'
	quit()
SKIP = 0
fname = argv[1]
col1 = 0
col2 = 1
col3 = 7
xpr=2
ypr=2
nx = int(argv[2])
remove_cont = 0
col3 = int(argv[3])
MAXROWS =50

temp = fname.split('/')
temp = temp[-1].split('_')

print temp

if len(temp)>2:
	xlabel = temp[0]
	ylabel = temp[2] # output format is a_vs_b_prefix.dat
	
else:
	xlabel ='xlabel'
	ylabel ='ylabel'

title=''
zlabel = ''

ysc = 1
xsc = 1
offset = 0
EDGE =4000
for things in argv:
	if len(things.split('='))>1:
		temp = things.split('=')
		try:
			exec(things)
		except:
			temp[1] = str(temp[1])
			exec('%s=\'%s\'' % (temp[0],temp[1]))

if xlabel == 'iamp':
	xsc = 1000
	xlabel = 'Applied Current (pA)'

if ylabel == 'slow':
	ysc = 20
	ylabel = 'Peak I1-I2 Rate (1/s)'

if col3 == 4:
	vmn = 0
	vmx = 4000
	offset = -2000
	zlabel = 'Time of Last Spike (ms)'
	
if col3 == 5:
	vmn = 0
	vmx = 1
	#remove_cont =1
	zlabel = 'Spike Amp. Ratio (Last/First)'

if col3 == 2:
	vmn = 0
	vmx = 10
	zlabel = 'Burst Duration (Spikes)'
	
if col3 == 3:
	vmn = 0
	remove_cont=-1
	#vmx = 20
	zlabel = 'Peak Freq.(Downramp) (Hz)'
	
if col3 == 8:
	vmn = -50
	vmx = -30
	zlabel = 'Dep. Block Voltage (mV)'


yd,xd,zd, qd= np.loadtxt(fname, usecols=(col1,col2,col3,4),unpack=True)
if len(zd)%nx != 0:
	print 'problem with data length'
	quit()
	
zt = zd + offset
zd = zt
zt = None
for i in range(len(qd)):
	if qd[i] > EDGE and remove_cont == 1:
		zd[i] = 0
	if qd[i] < EDGE and remove_cont == -1:
		zd[i] = 0





ny = len(zd)/nx
zp = np.resize(zd,(nx,ny))

zp = zp.transpose()

if SKIP>0:
	zp = np.delete(zp,range(SKIP),0)
ny -=SKIP

#zp=np.delete(zp,[len(zp),len(zp)-1,len(zp)-2,len(zp)-3,len(zp)-4],0)

try:
	vmx
except:
	vmx = max(zd)
try:
	vmn
except:
	vmn = min(zd)



xt = np.linspace(0,nx-1,11)
yt = np.linspace(0,ny-1,11)

yp = [round(ysc*xd[int(idx)],ypr) for idx in yt]

if xpr==0:
	xp = [int(xsc*yd[int(idx*ny)]) for idx in xt]
else:
	xp = [round(xsc*yd[int(idx*ny)],xpr) for idx in xt]

ax = sns.heatmap(zp,linewidth=0,vmin=vmn,vmax=vmx,xticklabels=xp,yticklabels=yp,cbar_kws={'label':zlabel})
ax.set_xticks(xt)
ax.set_yticks(yt)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_title(title)
cb = ax.collections[0].colorbar
cb.ax.tick_params(labelsize=15)
cb.set_label(zlabel, size=20)
for item in ([ax.title, ax.xaxis.label, ax.yaxis.label]):
	item.set_fontsize(20)
for item in (ax.get_xticklabels()+ax.get_yticklabels()):
	item.set_fontsize(15)

plt.show()
