######################################################################################
# OS_Figure1 -- Reads and analyzes the results generated from OS_run.py
# and plots Figure 1 in the following:
#
# Reference: Sadeh and Rotter 2015.
# "Orientation selectivity in inhibition-dominated networks of spiking neurons:
# effect of single neuron properties and network dynamics" PLOS Computational Biology.
#
# Author: Sadra Sadeh <s.sadeh@ucl.ac.uk> // Created: 2014-2015
######################################################################################

from imp import reload
import OS_params; reload(OS_params); from OS_params import *

import OS_functions; reload(OS_functions); from OS_functions import *

from mpl_toolkits.axes_grid.inset_locator import inset_axes

cwd = code_path

################################################################################
################################################################################

## for the first time you run the file, you need to make these data (xxx_make = 1)
## for the subsequent runs, you can turn the flag off, the data is now loaded
spike_data_make = 1#0
isi_cv_make = 1#0

#################################

## simulation folder for the results
sim_folder = 'N-5000_pif_delayType-random_g-4'

## figure name
fig_name = 'Figure1'

################################################################################
################################################################################

ti = time.time()   

#######################################
### reading simulation params & results

os.chdir(res_path+sim_folder)

fl = open('info', 'rb')
infor = cPickle.load(fl)
fl.close()

po_init = infor['po_init']
inputs = [infor['b']]

print(infor)

# params
N = infor['N']
NE = int(.8*N)
NI = N - NE

stim_range = infor['stim_range'] 
stim_range_deg = stim_range*180./np.pi 
simtime = infor['simtime']
t_trans = infor['t_trans']
tr_time = simtime - t_trans

CE = infor['eps_exc']
CI = infor['eps_inh']
f = infor['exc_inh']
g = infor['g']

tau_m = infor['tauMem']/1000.
V_th = infor['theta']
V_r = 0.
t_ref = 2./1000

Js = infor['J_ffw']
J_ext = infor['J_ext']
sb = infor['b']
se = infor['e']
J = infor['J'][0]
m = infor['m'] 
sm = m*sb
trial_no = infor['trial_no']
tr_no = infor['trial_no']
contrast = infor['contrast']

n_smpl = infor['n_smpl']

fl = open('results', 'rb')
results = cPickle.load(fl)
fl.close()

# results
tc_trans = results['tc_trans']
tc = results['tc']

tc_mean = np.mean(tc, 1)
tc_std = np.std(tc, 1)

tc_f0 = results['tc_f0']
tc_f1 = results['tc_f1']

vm_tc = results['vm_tc']

VM_fit = results['VM_fit']
err_fit = results['err_fit']
scs_fit = results['scs_fit']
TW_out = results['TW_out']
PO_out = results['PO_out']
OSI_out = results['OSI_out']

################################################################################
os.chdir(code_path)

################################################################################
### spike data all
################################################################################

## spont
os.chdir(res_path+sim_folder)

spd_all_sp = {}
for st in range(len(stim_range)):
    #spd_all[st] = []
    for tr in range(tr_no): 
        spd_all_sp[tr] = []       
        spflno = 2*N+ 1+tr + tr_no
        spike_file = 'spikes-all-spont-tr'+str(tr)+'-'+str(spflno)+'-0.gdf'
        spd_tr = pl.loadtxt(spike_file)
        #spd += spd_tr.tolist()
        spd_all_sp[tr] = spd_tr

os.chdir(cwd)	

if spike_data_make == 1:
    os.chdir(res_path+sim_folder)
    
    spd_all = {}
    for st in range(len(stim_range)):
        spd_all[st] = []
        for tr in range(tr_no):        
            spflno = 2*N + 1+tr + tr_no
            spike_file = 'spikes-all-st'+str(st)+'-tr'+str(tr)+'-'+str(spflno)+'-0.gdf'
            spd_tr = pl.loadtxt(spike_file)
            #spd += spd_tr.tolist()
            spd_all[st].append(spd_tr)
    
    os.chdir(cwd)	
    
    fl = open('spd_all', 'wb')
    cPickle.dump(spd_all, fl)
    fl.close()
else:
    os.chdir(cwd)
    fl = open('spd_all', 'rb')
    spd_all = cPickle.load(fl)
    fl.close()

os.chdir(cwd)

################################################################################
### ISI CV
################################################################################

if isi_cv_make == 1:
    print('### ISI CV ###')
    ### ISI CV
    #ISI, rr = {}, {}
    CV = {}
    for tr in range(trial_no):
        print('tr: '+str(tr))
        CV[tr] = []
        for i in range(1, N+1):
            spt = spd_all[4][tr][:,1][np.where(spd_all[4][tr][:,0] == i)]
            isi = np.diff(spt)
            cv = np.std(isi)/np.mean(isi)
            if not np.isnan(cv) and not np.isinf(cv) and len(spt) > 10:
               CV[tr].append(cv)
        CV[tr] = np.array(CV[tr])

    isi_cv = CV
    
    fl = open('isi_cv', 'wb')
    cPickle.dump(isi_cv, fl)
    fl.close()
else:
    fl = open('isi_cv', 'rb')
    isi_cv = cPickle.load(fl)
    fl.close()

################################################################################
################ raster plots for population
################################################################################

al = [.25, .5, 1.]
fg_lb_sz = 20

pl.figure(figsize= (15,9))

########## ########## ########## ########## ########## 
########## raster plot for the middle contrast (spont)
########## ########## ########## ########## ########## 

tr = 1

zz_sp = spd_all_sp[tr]
    
excid_sp = np.where(zz_sp[:,0] <= NE)[0]
inhid_sp = np.where(zz_sp[:,0] > NE)[0]

ax = pl.subplot(3, 3, 1)
pl.title(r'Spont. spiking activity, contrast C = 2', size=15) 

ax.text(-.15, 1., 'A1', size=fg_lb_sz, transform = ax.transAxes)

ax.plot(zz_sp[excid_sp, 1], zz_sp[excid_sp, 0], 'r.', ms=2)
ax.plot(zz_sp[inhid_sp, 1], zz_sp[inhid_sp, 0], 'b.', ms=2)
#
t1 = tr*(tr_time+t_trans) + 500
dt = 60
t2 = t1 + dt 

xtk = np.array([t1, t1+20, t1+40, t2])
ax.set_xticks(xtk)
ax.set_xticklabels(xtk - simtime)

ax.set_yticks([1, 1000, 2000, 3000, 4000, 5000])
ax.set_yticklabels([1, '', '', '', '4k', '5k'])

ax.set_xlim(t1-5, t2+5)
ax.set_ylim(0-100, N + 100)

#ax.set_xlabel('time (ms)')
ax.set_ylabel('Neuron #')

########## ########## ########## ########## ########## 
ax = pl.subplot(6, 3, 2*3+1)

binw = 5.
bins = simtime/binw

ax.hist(zz_sp[excid_sp, 1], bins = bins, weights=1000./binw*np.ones(len(zz_sp[excid_sp, 1]))/NE, color='r', alpha=.75)
ax.hist(zz_sp[inhid_sp, 1], bins = bins, histtype='step', weights=1000./binw*np.ones(len(zz_sp[inhid_sp, 1]))/NI, color='b', lw=2, alpha=1)

ax.set_xticks(xtk)
ax.set_xticklabels(xtk - simtime)
ax.set_xlim(t1-5, t2+5)

ax.set_ylim(0-1, 40+1)

ax.set_xlabel('Time (ms)')
ax.set_ylabel('Population rate \n (spikes/sec)', size=10)

ax.text(.7, .8, 'binw: '+str(int(binw))+' ms', transform = ax.transAxes)

########## ########## ########## ########## ########## 
########## raster plot for the middle contrast (evoked)
########## ########## ########## ########## ########## 

tr = 1

zz = spd_all[4][tr]
    
excid = np.where(zz[:,0] <= NE)[0]
inhid = np.where(zz[:,0] > NE)[0]

ax = pl.subplot(3, 3, 2)
pl.title(r'Evoked spiking activity, stim. orient. $\theta: 0^\circ$, C = 2', size=12.5) 

ax.text(-.15, 1., 'A2', size=fg_lb_sz, transform = ax.transAxes)

ax.plot(zz[excid, 1], zz[excid, 0], 'r.', ms=2)
ax.plot(zz[inhid, 1], zz[inhid, 0], 'b.', ms=2)

xtk = np.array([t1, t1+20, t1+40, t2])
ax.set_xticks(xtk)
ax.set_xticklabels(xtk - simtime)

ax.set_yticks([1, 1000, 2000, 3000, 4000, 5000])
ax.set_yticklabels([1, '', '', '', '4k', '5k'])

ax.set_xlim(t1-5, t2+5)
ax.set_ylim(0-100, N + 100)

#ax.set_xlabel('time (ms)')
ax.set_ylabel('Neuron #')

########## ########## ########## ########## ########## 
ax = pl.subplot(6, 3, 2*3+2)

binw = 5.
bins = simtime/binw

ax.hist(zz[excid, 1], bins = bins, weights=1000./binw*np.ones(len(zz[excid, 1]))/NE, color='r', alpha=.75)
ax.hist(zz[inhid, 1], bins = bins, histtype='step', weights=1000./binw*np.ones(len(zz[inhid, 1]))/NI, color='b', lw=2, alpha=1)

ax.set_xticks(xtk)
ax.set_xticklabels(xtk - simtime)

#ax.set_xlim(0-10, tr_time+10)
ax.set_xlim(t1-5, t2+5)

ax.set_ylim(0-1, 40+1)

ax.set_xlabel('Time (ms)')
ax.set_ylabel('Population rate \n (spikes/sec)', size=10)

ax.text(.7, .8, 'binw: '+str(int(binw))+' ms', transform = ax.transAxes)

########## ########## ########## ########## ########## 
########## sorted raster plots for all contrasts
########## ########## ########## ########## ########## 

for tr in range(trial_no):
    zz = spd_all[4][tr]
    
    excid = np.where(zz[:,0] <= NE)[0]
    inhid = np.where(zz[:,0] > NE)[0]

    #
    t1 = tr*(tr_time+t_trans) + 500
    dt = 60
    t2 = t1 + dt 

    ###
    ax = pl.subplot(6, 3, 3+3*tr)
    #pl.title(r'$C = $'+str(contrast[tr])) 
    ax.text(.05, 1.05, 'C = '+str(int(contrast[tr])), transform = ax.transAxes, fontsize=15)#, fontstyle='bold')
    if tr == 0: ax.text(-.1, 1., 'B', size=fg_lb_sz, transform = ax.transAxes)

    ax.plot(zz[excid, 1], po_init[(zz[excid, 0]-1).astype('int')], 'r.', ms=2.)
    ax.plot(zz[inhid, 1], po_init[(zz[inhid, 0]-1).astype('int')], 'b.', ms=2.)
    
    ax.set_xticks(xtk)
    ax.set_xticklabels([])#xtk - tr*(simtime - t_trans))
    ax.set_yticks([0, np.pi/4, np.pi/2, 3*np.pi/4, np.pi])
    ax.set_yticklabels([])

    ax.set_xlim(t1-5, t2+5)
    ax.set_ylim(0-.1, np.pi+.1)    

    if tr == 1: 
       ax.set_yticklabels([0, 45, 90, 135, 180])
       ax.set_ylabel('Input PO (deg)')

xtk = np.array([t1, t1+20, t1+40, t2])

ax.set_xticks(xtk)
ax.set_xticklabels(xtk - tr*(simtime))

ax.set_xlabel('Time (ms)')

########## ########## ########## ########## ########## 
########## dist. of firing rates
########## ########## ########## ########## ########## 

ax = pl.subplot(2, 4, 5)
#adjust_spines(ax,['left', 'bottom'], outward=0, s=0)
ax.text(-.2, .975, 'C', size=fg_lb_sz, transform = ax.transAxes)

for tr in range(trial_no):
    ax.hist(tc[4, tr, :], 50, histtype = 'step', color='k', lw=2, alpha=al[tr], label=str(int(contrast[tr])))

ax.set_yscale('log')

pl.legend(title = 'C', loc=2, frameon=False, prop={'size':12})

ax.set_ylim(0, 10000)
ax.set_xticklabels([])
ax.set_ylabel('#')

ax.set_xticks([0, 20, 40, 60, 80])
ax.set_xticklabels([0, 20, 40, 60, 80])

ax.set_xlabel('Firing rate (spikes/sec)')
#ia.set_ylabel('#')

########## dist. of ISI CV
########## ########## ########## ########## ########## 
### isi cv in the inset
########## ########## ########## ########## ########## 
#ax = pl.subplot(2, 4, 6)
ia = inset_axes(ax, width="30%", height="30%", loc=1)
#adjust_spines(ia,['left', 'bottom'], outward=0, s=0)

for tr in range(trial_no):
    ia.hist(isi_cv[tr], 50, histtype = 'step', color='k', lw=3, alpha=al[tr], label=str(int(contrast[tr])))

#pl.legend(title = 'C', )

ia.set_yticks([0, 100, 200, 300, 400, 500, 600])
ia.set_yticklabels([0, '', '', '', '', 500, ''])
ia.set_ylim(0-5, 600)

ia.set_xticks([0, .5, 1, 1.5, 2])
ia.set_xticklabels([0, '', 1, '', 2])

ia.set_xlim(0, 2.)

ia.set_xlabel('CV[ISI]')
ia.set_ylabel('#')


########## ########## ########## ########## ########## 
########## population tuning curves
########## ########## ########## ########## ########## 

#al = [.25, .5, 1.]
ax = pl.subplot(2, 4, 6)
#adjust_spines(ax,['left', 'bottom'], outward=0, s=0)
#tr = 1
ax.text(-.15, .975, 'D', size=fg_lb_sz, transform = ax.transAxes)

pl.title('Network output tuning curve')

ax.text(.1, .9, 'C = 2', size=15, transform = ax.transAxes)

ax.plot(po_init[0:NE], tc[4, tr, 0:NE], 'ro', ms=3.5, alpha=1, label='Exc.')
ax.plot(po_init[NE:], tc[4, tr, NE:], 'bo', ms=3.5, alpha=1, label='Inh.')

inp_mean = np.mean(tc[4,tr,:])
po_rng = np.arange(0, np.pi, .1)
ax.plot(po_rng, inp_mean*(1 + m*np.cos(2*(np.pi/2-po_rng))), 'g-', lw=2, label='Inp.')

pl.legend(loc=1, frameon=False, numpoints=1, markerscale=1.5, prop={'size':12.5}) 

ax.set_xlabel('Input PO (deg)')
ax.set_ylabel('Firing rate (spikes/sec)')

ax.set_xticks([0, np.pi/4, np.pi/2, 3*np.pi/4, np.pi])
ax.set_xticklabels([0, 45, 90, 135, 180])

ax.set_xlim(0-.1, np.pi+.1)
ax.set_ylim(0, 70+1)

########## ########## ########## ########## ########## 
########## mean population tuning curves
########## ########## ########## ########## ########## 

ax = pl.subplot(2, 4, 7)
#adjust_spines(ax,['left', 'bottom'], outward=0, s=0)
ax.text(-.15, .975, 'E', size=fg_lb_sz, transform = ax.transAxes)

pl.title('Avg. output tuning curve')

dxx = np.pi/100.
xx = np.arange(0, np.pi, dxx)

for tr in range(trial_no):
    zz = tc[4, tr, :]     
    zz_mean = np.zeros(len(xx))
    zz_std = np.zeros(len(xx))
    for ii, xi in enumerate(xx):
        ids = np.where( (po_init > xi) * (po_init < xi+dxx) == True)
        zz_mean[ii] = np.mean(zz[ids])
        zz_std[ii] = np.std(zz[ids])

    p, ft, tw, sc = OS_functions.vonMises(xx, zz_mean)
    print(tw)

    ax.plot(xx, zz_mean, 'k-', lw=2, alpha = al[tr], label=str(np.round(tw,1))+r'$^\circ$')
    ax.fill_between(xx, zz_mean - zz_std, zz_mean + zz_std, color='k', alpha=.1)

    inp_mean = np.mean(zz_mean)
    ax.plot(xx, inp_mean*(1 + m*np.cos(2*(np.pi/2-xx))), 'g-', lw=2, alpha = al[tr])#, label='Inp.')

pl.legend(frameon=False, title='output TW', prop={'size':12.5})

ax.set_xlabel('Input PO (deg)')

ax.set_xticks([0, np.pi/4, np.pi/2, 3*np.pi/4, np.pi])
ax.set_xticklabels([0, 45, 90, 135, 180])

ax.set_xlim(0-.1, np.pi+.1)
ax.set_ylim(0, 60+1)


########## ########## ########## ########## ########## 
########## linear prediction
########## ########## ########## ########## ########## 

os.chdir(code_path)

inv_make = 1

fl = open('con', 'rb')
con_loc = cPickle.load(fl)
fl.close()

con_exc, con_inh = np.array(con_loc['exc']), np.array(con_loc['inh'])

##############

print('### building weight matrix W ###')

Je = J
Ji = -g*J

W = np.zeros((N,N))
for ne in range(NE):    
    W[ne][con_exc[ne]-1] = Je 
    W[ne][con_inh[ne]-1] = Ji 
for ni in range(NE,N):    
    W[ni][con_exc[ni]-1] = Je 
    W[ni][con_inh[ni]-1] = Ji 

##############

if inv_make == 1:     
   print('### computing the inverse A = (1- W/Vth)^-1 ###')
   
   A = np.linalg.inv(np.eye(N) - W/V_th)
   
   fl = open('A', 'wb')
   cPickle.dump(A, fl)
   fl.close()

elif inv_read == 1:
   print('### reading the inverse A = (1- W/Vth)^-1 ###')
   
   fl = open('A', 'rb')
   A = cPickle.load(fl)
   fl.close()

################################################################################
print('### computing the linear prediction r = A s ###')
################################################################################

def _rect_(zz): return zz*(zz>0)

####
tc_out_L = []
for i in range(len(stim_range)):
    
    tc_out = []
    #tc_out_mod = []
    
    for ic, ct in enumerate(contrast):
        inp_b = ct * sb * Js/V_th
        inp_e = se * J_ext/V_th
        out_b = np.sum(A[0]* (inp_b+inp_e) )
        
        inp_m = ct * sm*np.cos(2*(stim_range[i]-po_init)) * Js/V_th      
        
        inp_tot = inp_b + inp_e + inp_m       
 
        r = np.array( (np.matrix(A) * np.matrix(inp_tot).T).T)[0]
        
        tc_out.append(r)    
    tc_out_L.append(tc_out)
tc_out_L = np.array(tc_out_L)

tc_out_LR = _rect_(tc_out_L)

######### plot it

ax = pl.subplot(2, 4, 8)
ax.text(-.15, .975, 'F', size=fg_lb_sz, transform = ax.transAxes)

pl.title('Predicting network responses')

tc_out_LR_rp = tc_out_LR/(1. + tc_out_LR*t_ref)

ax.plot(tc[4, 1, :], tc_out_LR_rp[4, 1, :], 'k.')#, ms=1)

ax.plot([0, 60], [0, 60], 'k--')

ax.set_xlabel('Simulated firing rate (spikes/sec)')
ax.set_ylabel('Linear prediction (spikes/sec)')

ia = inset_axes(ax, width="30%", height="30%", loc=2)
ia.yaxis.set_label_position('right')
ia.yaxis.tick_right()

for i, ct in enumerate(contrast):
    ia.hist(tc[4, i, :] - tc_out_LR_rp[4, i, :], bins=100, histtype='step', alpha=al[i], color='k')

ia.set_yscale('log')
ia.set_xticks([-10, -5, 0, 5, 10])
ia.set_xticklabels(['', -5, 0, 5, ''])
ia.set_xlim(-10, 10)
ia.set_ylabel('#')
ia.set_xlabel('Rate Diff')

#######################

pl.subplots_adjust(left=.05, right=.975, bottom=.075, top=.95, wspace=.25, hspace=.45)

pl.savefig(fig_name+'.png')


pl.show()

tf = time.time()   

print('time: ', np.round((tf - ti)/60., 1), ' min')