import sys, os
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)),os.path.pardir))
from graphs.my_graph import *
import matplotlib.pylab as plt
from scipy.stats.stats import pearsonr
def cross_correl_plot(data, FIGSIZE=(7,7), wspace=.5, hspace=.5, right=0.98, left=0.1,\
many_data=False):
"""
'data' should be an array of dictionaries with keys 'vec' and labels 'label'
"""
fig, AX = plt.subplots(len(data)-1, len(data)-1, figsize=FIGSIZE)
plt.subplots_adjust(wspace=wspace, hspace=hspace, right=right, left=left)
mymap = get_linear_colormap(color1='white', color2='gray')
significance = np.array([1e-3, 2e-2, 1e-1, 1])
for i in range(len(data)-1):
for j in range(i+1, len(data)):
AX[j-1,i].plot(data[i]['vec'], data[j]['vec'], 'ko')
if not many_data:
set_plot(AX[j-1,i], xlabel=data[i]['label'], ylabel=data[j]['label'],
num_xticks=4, num_yticks=3)
else:
if ((i==0) and (j==len(data)-1)):
set_plot(AX[j-1,i], xlabel=data[i]['label'], ylabel=data[j]['label'],
num_xticks=4, num_yticks=3)
elif (j==len(data)-1):
set_plot(AX[j-1,i], xlabel=data[i]['label'], yticks_labels=[],
num_xticks=4, num_yticks=3)
elif (i==0):
set_plot(AX[j-1,i], ylabel=data[j]['label'], xticks_labels=[],
num_xticks=4, num_yticks=3)
else:
set_plot(AX[j-1,i], xticks_labels=[], yticks_labels=[],
num_xticks=4, num_yticks=3)
cc, pp = pearsonr(data[i]['vec'], data[j]['vec'])
x = np.linspace(data[i]['vec'].min(), data[i]['vec'].max())
AX[j-1, i].plot(x,\
np.polyval(np.polyfit(np.array(data[i]['vec'], dtype='f8'),\
np.array(data[j]['vec'], dtype='f8'), 1), x),\
'k--', lw=.5)
ii = np.arange(len(significance))[significance-pp>=0][0]
color = -1.*(ii-len(significance)+1)/(len(significance)-1)
xmin, xmax = AX[j-1, i].get_xaxis().get_view_interval()
ymin, ymax = AX[j-1, i].get_yaxis().get_view_interval()
AX[j-1, i].add_patch(plt.Rectangle((xmin, ymin),\
xmax-xmin, ymax-ymin, color=mymap(1.*color,1)))
ax = plt.axes([.7,.7,.02,.2])
build_bar_legend(np.arange(len(significance)+1),\
ax, mymap,\
ticks_labels=['n.s.', '$<$0.1', '$<$0.02', '$<$0.001'],
label='Significance \n \n (p, Pearson correl.)')
for ax in AX.flatten():
if ax.get_xaxis().get_view_interval()[1]==1.:
ax.axis('off')
return fig
if __name__=='__main__':
data = []
import numpy as np
for i in range(5):
data.append({'vec':np.random.randn(10), 'label':'label'+str(i+1)})
cross_correl_plot(data, many_data=True)
plt.show()