#cp sim_mmns_savespikesonly.py sim_fI.py
#cp sim_mmns_savespikesonly.py sim_mmns_sep_savespikesonly.py #Add slowly excited population

from pylab import *
import scipy.io
import scipy.stats
import time
from os.path import exists
import mytools
import pickle
from matplotlib.collections import PatchCollection

taus = sort(unique([0.1, 0.2, 0.5, 1.0, 10.0, 10.0, 10.0, 10.25, 10.5, 10.75, 100.0, 1000.0, 10000.0, 11.0, 11.25, 11.5, 11.75, 12.0, 12.25, 12.5, 12.75, 125.0, 13.0, 13.25, 13.5, 13.75, 14.0, 14.25, 14.5, 14.75, 15.0, 15.25, 15.5, 15.75, 150.0, 16.0, 16.5, 17.0, 17.5, 175.0, 18.0, 18.5, 19.0, 19.5, 2.0, 2.25, 2.5, 2.75, 20.0, 200.0, 2000.0, 225.0, 25.0, 250.0, 275.0, 3.0, 3.25, 3.5, 3.75, 300.0, 325.0, 350.0, 375.0, 4.0, 4.25, 4.5, 4.75, 400.0, 425.0, 450.0, 475.0, 5.0, 5.25, 5.5, 5.75, 50.0, 500.0, 5000.0, 6.0, 6.25, 6.5, 6.75, 600.0, 7.0, 7.25, 7.5, 7.75, 700.0, 75.0, 8.0, 8.25, 8.5, 8.75, 800.0, 9.0, 9.25, 9.5, 9.75, 900.0]))  #ls -l fI_tau*amp200.0.mat|cut -c 57-|cut -f1 -d_|while read F; do printf "$F, ";done

def mybar(ax,x,y,facecolor=[],linewidth=0.1,w=0.4):
  qs = quantile(y, [0,0.25,0.5,0.75,1])
  polygon = Polygon(array([[x-w,x+w,x+w,x-w],[qs[1],qs[1],qs[3],qs[3]]]).T)
  p = PatchCollection([polygon], cmap=matplotlib.cm.jet)
  if type(facecolor) is not list or len(facecolor) > 0:
    p.set_facecolor(facecolor)
  p.set_edgecolor('#000000')
  p.set_linewidth(0.3)
  ax.add_collection(p)
  a2 = ax.plot([x-w,x+w,x,x,x-w,x+w,x,x,x-w,x+w],[qs[0],qs[0],qs[0],qs[2],qs[2],qs[2],qs[2],qs[4],qs[4],qs[4]],'k-',lw=linewidth)
  return [p,a2]
def boxoff(ax,whichxoff='top'):
    ax.spines[whichxoff].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.get_xaxis().tick_bottom()
    ax.get_yaxis().tick_left()


tau_controls = [10.0, 250.0]

areas = ['PFC','ACC']
Is = []
f,axs = subplots(2,2)
axarr = axs.reshape(prod(axs.shape),).tolist()
dimcols = ['#AAAAAA','#FF88FF','#99FF99']
cols = ['#000000','#770077','#009900']
for iax in range(0,len(axarr)):
  axarr[iax].tick_params(axis='both', which='major', labelsize=4, direction='out', width=0.4, length=2)
  boxoff(axarr[iax])
for iarea in range(0,2):
  axarr[iarea].set_position([0.1+0.48*iarea,0.84,0.4,0.12])
  axarr[2+iarea].set_position([0.1+0.48*iarea+0.05,0.84+0.03,0.03,0.08])
  axarr[2+iarea].set_xticks([])
  
  area = areas[iarea]

  #From spineNg5d/drawfig2.py
  if area == 'ACC':
    isubjs_HC = [1, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 29, 30, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 73, 75, 76, 77, 78, 79, 80, 83, 84, 86, 90, 91, 92, 95, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 124, 126, 127, 129, 130, 131, 132, 133, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 152, 154, 159, 161, 163, 164, 167, 169, 170, 172, 174, 175, 176, 177, 182, 199, 208, 235, 240, 243, 249, 253, 256, 257, 260, 261, 262, 263, 265, 266, 267, 268, 271, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 329, 330, 331, 332, 343, 349, 350, 352, 353, 355, 360, 366, 369, 370, 371, 372, 374, 375, 376, 378, 381, 382, 383, 385, 388, 389, 392, 393, 395, 396, 397, 398, 401, 402, 409, 410, 414, 416, 418, 421, 422, 426, 427, 431, 432, 433, 434, 435, 437, 438, 442, 443, 444, 446, 447, 448, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480]
    isubjs_SCZ = [0, 2, 3, 12, 17, 28, 31, 58, 69, 70, 71, 72, 74, 81, 82, 85, 87, 88, 89, 93, 94, 96, 97, 98, 99, 100, 116, 117, 118, 119, 120, 121, 122, 123, 125, 134, 151, 155, 156, 157, 158, 160, 162, 165, 166, 168, 171, 173, 178, 179, 180, 181, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 200, 201, 202, 203, 204, 205, 206, 207, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 236, 237, 238, 239, 241, 242, 244, 245, 246, 247, 248, 250, 251, 252, 254, 255, 258, 259, 264, 269, 270, 272, 273, 274, 275, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 344, 345, 346, 347, 348, 351, 354, 356, 357, 358, 359, 361, 362, 363, 364, 365, 367, 368, 373, 377, 379, 380, 384, 386, 387, 390, 391, 394, 399, 400, 403, 404, 405, 406, 407, 411, 412, 413, 415, 417, 419, 420, 423, 424, 425, 428, 429, 430, 436, 439, 440, 441, 445, 449, 450, 451, 452]
  elif area == 'PFC':
    isubjs_HC = [0, 1, 2, 3, 5, 7, 8, 9, 10, 11, 12, 13, 14, 16, 18, 19, 20, 22, 23, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 39, 41, 42, 43, 47, 51, 52, 53, 54, 63, 64, 66, 69, 73, 80, 81, 82, 83, 84, 87, 88, 89, 90, 92, 93, 94, 96, 97, 100, 105, 107, 111, 112, 118, 119, 120, 125, 127, 128, 135, 139, 140, 142, 143, 145, 146, 152, 154, 156, 165, 168, 176, 181, 182, 187, 192, 194, 201, 203, 204, 205, 207, 210, 211, 217, 220, 222, 223, 225, 226, 232, 235, 236, 240, 241, 242, 243, 244, 247, 248, 249, 251, 252, 253, 254, 257, 258, 262, 263, 265, 266, 267, 268, 269, 270, 271, 272, 283, 290, 291, 293, 294, 295, 296, 297, 298, 299, 300, 301, 305, 306, 308, 313, 314, 315, 317, 319, 320, 321, 323, 324, 325, 326, 329, 330, 332, 333, 335, 336, 339, 340, 341, 342, 343, 344, 345, 347, 351, 352, 354, 355, 356, 359, 361, 362, 368, 371, 373, 374, 375, 376, 377, 380, 381, 382, 383, 384, 386, 387, 388, 389, 390, 393, 396, 397, 398, 399, 400, 402, 403, 405, 406, 407, 408, 415, 416, 417, 420, 421, 422, 423, 424, 425, 426, 427, 428]
    isubjs_SCZ = [4, 6, 15, 17, 21, 24, 26, 29, 30, 38, 40, 44, 45, 46, 48, 49, 50, 55, 56, 57, 58, 59, 60, 61, 62, 65, 67, 68, 70, 71, 72, 74, 75, 76, 77, 78, 79, 85, 86, 91, 95, 98, 99, 101, 102, 103, 104, 106, 108, 109, 110, 113, 114, 115, 116, 117, 121, 122, 123, 124, 126, 129, 130, 131, 132, 133, 134, 136, 137, 138, 141, 144, 147, 148, 149, 150, 151, 153, 155, 157, 158, 159, 160, 161, 162, 163, 164, 166, 167, 169, 170, 171, 172, 173, 174, 175, 177, 178, 179, 180, 183, 184, 185, 186, 188, 189, 190, 191, 193, 195, 196, 197, 198, 199, 200, 202, 206, 208, 209, 212, 213, 214, 215, 216, 218, 219, 221, 224, 227, 228, 229, 230, 231, 233, 234, 237, 238, 239, 245, 246, 250, 255, 256, 259, 260, 261, 264, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 285, 286, 287, 288, 289, 292, 302, 303, 304, 307, 309, 310, 311, 312, 316, 318, 322, 327, 328, 334, 337, 338, 346, 348, 349, 350, 353, 357, 358, 360, 363, 364, 365, 366, 367, 369, 370, 372, 378, 379, 385, 392, 394, 395, 401, 404, 409, 410, 411, 412, 413, 414, 418, 419]

  isubjs = list(range(0,481 if area == 'ACC' else 429))
  AUCdata = []
  spFreqs = []
  for iisubj in range(0,len(isubjs)):
    isubj = isubjs[iisubj]
    print('Loading hay/saves_'+area+'/'+area+'_patientID_'+str(isubj+1)+'.sav', 'rb')
    unpicklefile = open('hay/saves_'+area+'/'+area+'_patientID_'+str(isubj+1)+'.sav', 'rb')
    unpickledlist = pickle.load(unpicklefile,encoding='bytes')
    unpicklefile.close()
    spikfreqsAll = unpickledlist[0]
    if len(Is) > 0:
      if any([Is[i] != unpickledlist[-1][i] for i in range(0,len(unpickledlist[0]))]):
        print('Mismatch len(is)')
    Is = unpickledlist[-1]
    AUCdata.append(sum(unpickledlist[0])*0.1)
    spFreqs.append(unpickledlist[0])

  for isamp in range(0,len(spFreqs)):
    if isamp % 1 == 0:
      axarr[iarea].plot(Is,spFreqs[isamp][0],'-',lw=0.1,color=dimcols[(iarea+1)*(isamp in isubjs_SCZ)])
  axarr[iarea].plot(Is,median(array([spFreqs[isamp][0] for isamp in isubjs_HC]),axis=0),'-',color=cols[0],lw=1.0)
  axarr[iarea].plot(Is,median(array([spFreqs[isamp][0] for isamp in isubjs_SCZ]),axis=0),'-',color=cols[iarea+1],lw=1.0)
  #mybar(axarr[2+iarea],1.5,sum(array([spFreqs[isamp][0] for isamp in isubjs_HC]),axis=1))
  #mybar(axarr[2+iarea],3.5,sum(array([spFreqs[isamp][0] for isamp in isubjs_SCZ]),axis=1))
  mybar(axarr[2+iarea],0.5,[AUCdata[isamp] for isamp in isubjs_HC],facecolor=dimcols[0])
  mybar(axarr[2+iarea],1.5,[AUCdata[isamp] for isamp in isubjs_SCZ],facecolor=dimcols[1+iarea])
  pval = scipy.stats.ranksums([AUCdata[isamp] for isamp in isubjs_HC],[AUCdata[isamp] for isamp in isubjs_SCZ])[1]
  print('pval = '+str(pval))
  if pval < 0.05/2:
    axarr[2+iarea].plot([0.5,0.5,1.5,1.5],[x-6*(iarea==0) for x in [21,22,22,21]],'k-',lw=0.1)
    axarr[2+iarea].text(1,22-6*(iarea==0),'*',ha='center',va='bottom',fontsize=6)
    time.sleep(0.25)
  axarr[2+iarea].set_ylim([0,24-6*(iarea==0)])  
  
for iax in range(0,2):
  pos = axarr[iax].get_position()
  f.text(pos.x0 - 0.05, pos.y1 - 0.01, chr(ord('A')+iax), fontsize=11)

f.savefig("fig_SCZ_allcortical_a.pdf")