import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sys import argv

filename1=argv[1]

# Generate a sample dataset with 20 rows and 420 columns
# In practice, replace this with loading your actual dataset
data1 = np.loadtxt(filename1,usecols=(61,62,63,64,65))
pre_data_1 = np.loadtxt(filename1,usecols=(0,1,2,3,4))
#data = np.loadtxt(filename,usecols=(0,1,2,3,4))

# Compute the average across these rows
average_row1 = np.mean(data1, axis=1)
average_base_1 = np.mean(pre_data_1, axis=1)
diff1 = np.subtract(average_base_1,average_row1)


if len(argv)>2:
	filename2=argv[2]
	data2 = np.loadtxt(filename2,usecols=(61,62,63,64,65))
	pre_data_2 = np.loadtxt(filename2,usecols=(0,1,2,3,4))
	average_base_2 = np.mean(pre_data_2, axis=1)	
	average_row2 = np.mean(data2, axis=1)
	diff2 = np.subtract(average_base_2,average_row2)
	diff = np.subtract(diff2,diff1)
	maxv=0.1
	minv=-0.1
else:
	diff = diff1
	maxv=0.4
	minv=0
#print(average_row)
#quit()
# Reshape the average row into a 21x20 grid
#maxv=0.6
heatmap_data = diff.reshape((21, 20))

# Create the heatmap using seaborn
plt.figure(figsize=(10, 8))
colpal = sns.diverging_palette(220, 20, as_cmap=True)
vcol = sns.color_palette('viridis', as_cmap=True)
vcol = sns.color_palette('Spectral_r', as_cmap=True)
ax = sns.heatmap(heatmap_data,cmap=vcol,vmax=maxv,vmin=minv)
ax.invert_yaxis()
plt.show()

