



import numpy as np
from sys import argv
import matplotlib.pyplot as plt
import seaborn as sns
from math import ceil

if len(argv) < 3:
	print 'expects fname, zindex'
	quit()
SKIP = 0
MASK=0
fname = argv[1]
col1 = 1
col2 = 0
col3 = 7
col4=col3+1
xpr=3
ypr=3
SCALE=0
NULLSCALE=0
CENTER=0
PRINT=0
NSIG=2

remove_cont = 0
col3 = int(argv[2])
MAXROWS =50

temp = fname.split('/')
temp = temp[-1].split('_')

output = 'temp.eps'
if len(temp)>2:
	try:
		xlabel = temp[1]
		ylabel = temp[3] # output format is a_vs_b_prefix.dat
	except:
		xlabel ='xlabel'
		ylabel ='ylabel'

else:
	xlabel ='xlabel'
	ylabel ='ylabel'

outname=fname.split('.')[0]

outname+='_%d' % col3

title=''
zlabel = ''

ysc = 1
xsc = 1
offset = 0
for things in argv:
	if len(things.split('='))>1:
		temp = things.split('=')
		outname+='_%s=%s' % (temp[0],temp[1])
		try:
			exec(things)
		except:
			temp[1] = str(temp[1])
			exec('%s=\'%s\'' % (temp[0],temp[1]))

if MASK:
	col4=col3+1

outname += '.eps'

print outname
if xlabel == 'iamp':
	xsc = 1000
	xlabel = 'Applied Current (pA)'

if xlabel == 'kchip':
	xsc = 1
	xlabel = 'KChIP4a Expression'
	
if ylabel == 'fchip':
	ysc = 1
	ylabel = 'Fast Kv4 Expression'

if xlabel == 'offset':
	xsc = 1
	xlabel = 'Pulse Offset (ms)'
	
if ylabel == 'i2amp':
	ysc = 1000
	ylabel = 'Pulse Amplitude (pA)'

if ylabel == 'slow':
	ysc = 20
	ylabel = 'Peak I1-I2 Rate (1/s)'
	
if xlabel == 'fon':
	xsc = 1
	xlabel = 'Avg. Pulse IPSC Freq. (hz)'
	
if ylabel == 'won':
	ysc = 1
	ylabel = 'Pulse Duration (ms)'

"""
if col3 ==3:
	pass
	#vmx = 1
	
if col3 ==2:
	pass
	#vmx=5

if col3==6:
	vmx=50

if col3==9:
	vmn=0
	vmx=1
	zlabel = 'Pause Accuracy'
	
if col3==7:
	vmn = 0
	vmx =6
	zlabel = 'Average Freq. (hz)'

if col3==13:
	vmn=0
	#vmx=3000
	#zlabel = 'Rebound Delay (ms)'
"""

	

xd,yd,zd, qd = np.loadtxt(fname, usecols=(col1,col2,col3,col4),unpack=True)

if SCALE:
	for i in range(len(zd)):
		zd[i] = zd[i]*qd[i]
elif NULLSCALE:
	for i in range(len(zd)):
		zd[i] = zd[i]*qd[0]
		
if MASK:
	m1 = zd.copy() # sets size
	for i in range(len(m1)):
		if abs(zd[i]) > NSIG*qd[i]:
			m1[i]=0
		else:
			m1[i]=1
	
ny=0
startx = xd[0]
xscale = 1
yscale = yd[1]-yd[0]
for things in xd:
	if things == startx:
		ny+=1
	else:
		xscale = abs(things-startx)
		break

nx = len(xd)/ny
#print nx, xscale, ny, yscale
"""if col3 == 9:
	for i in range(len(zd)):
		if qd[i] > 0:
			zd[i] = 1-zd[i]
"""
zp = np.resize(zd,(nx,ny))

zp = zp.transpose()

if MASK:
	mask = np.resize(m1,(nx,ny))
	mask= mask.transpose()

print MASK
if SKIP>0:
	zp = np.delete(zp,range(SKIP),0)
ny -=SKIP

#zp=np.delete(zp,[len(zp),len(zp)-1,len(zp)-2,len(zp)-3,len(zp)-4],0)

try:
	vmx
except:
	vmx = max(zd)
	print vmx
try:
	vmn
except:
	vmn = min(zd)

print xd[0],xd[ny]
print yd[0],yd[1]

print nx, ny
xt = np.linspace(0.5,nx-0.5,5)
yt = np.linspace(0.5,ny-0.5,5)

xp = [xsc*xd[int((idx-0.5)*ny)] for idx in xt]
yp = [ysc*yd[int((idx-0.5))] for idx in yt]

print xp, yp
if not MASK:
	mask = None
ax = sns.heatmap(zp,linewidth=0,vmin=vmn,vmax=vmx,xticklabels=xp,yticklabels=yp,cbar_kws={'label':zlabel},cmap="PiYG",center=CENTER,mask=mask)
ax.set_xticks(xt)
ax.set_yticks(yt)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_title(title)
cb = ax.collections[0].colorbar
cb.ax.tick_params(labelsize=15)
cb.set_label(zlabel, size=20)


for item in ([ax.title, ax.xaxis.label, ax.yaxis.label]):
	item.set_fontsize(20)
for item in (ax.get_xticklabels()+ax.get_yticklabels()):
	item.set_fontsize(15)

ax.invert_yaxis()
#print(output)
plt.subplots_adjust(left=0.2,bottom=0.3,right=0.8,top=0.9)
if PRINT:
	plt.savefig(outname, format='eps')
plt.show()
