######################################################################################
# OS_results -- Analyzes the results of simulations generated by OS_run.py
# Needed for plotting the figures, as in OS_Figure1.py
#
# Reference: Sadeh and Rotter 2015.
# "Orientation selectivity in inhibition-dominated networks of spiking neurons:
# effect of single neuron properties and network dynamics" PLOS Computational Biology.
#
# Author: Sadra Sadeh <s.sadeh@ucl.ac.uk> // Created: 2014-2015
######################################################################################

from imp import reload
import OS_params; reload(OS_params); from OS_params import *

import OS_functions; reload(OS_functions); from OS_functions import *

################################################################################
################################################################################

def sim_results(sim_folder):
        
    ti = time.time()   
    results = {}
    
    ### reading data
    os.chdir(res_path+sim_folder)
    
    fl = open('info', 'rb')
    info = cPickle.load(fl)
    fl.close()
    
    stim_range = info['stim_range']
    trial_no = info['trial_no']
    N = info['N']
    simtime = info['simtime']
    t_trans = info['t_trans']
    trial_no = info['trial_no']
    n_smpl = info['n_smpl']
    contrast = info['contrast']
    
    ### spike data and tuning curves
    print('### spikes ###')
    
    id_no = 1
    
    tc_trans, tc = [], []
    for st in enumerate(stim_range):
        print('#### stim: ', str(st[0]))
        
        t0 = time.time()
        
        tc_st_trans, tc_st = [], []
        for tr in range(trial_no): 
            print(' # trial: ', str(tr))
            
            yyy = str(2*N+id_no + tr)
            n0 ='spikes-all-trans-st'+str(st[0])+'-tr'+str(tr)+'-'+yyy+'-0.gdf'
            z = np.loadtxt(n0)
                        
            fr_trans = np.array([ len(np.where(z[:,0] == n)[0]) / (t_trans/1000) for n in range(1,N+1)])
            tc_st_trans.append(fr_trans) 
            
            yyy = str(2*N+id_no + tr + trial_no)
            n0 ='spikes-all-st'+str(st[0])+'-tr'+str(tr)+'-'+yyy+'-0.gdf'
            z = np.loadtxt(n0)
                        
            fr = np.array([ len(np.where(z[:,0] == n)[0]) / ((simtime-t_trans)/1000) for n in range(1,N+1)])       
            tc_st.append(fr)                            
        
        t1 = time.time()
        
        print('### %.2f s' % (t1-t0))
        
        tc.append(tc_st)
        tc_trans.append(tc_st_trans)
    
    tc = np.array(tc)
    results['tc'] = tc
    tc_trans = np.array(tc_trans)
    results['tc_trans'] = tc_trans
    
    ### membrane potential data and tuning curves
    print('### mem. pot. ###')
    
    vm_tc_trans, vm_tc = [], []
    vm_hist = {}
    exc_excInp, exc_inhInp = [], []
    inh_excInp, inh_inhInp = [], []
    
    exc_excInp_std, exc_inhInp_std = [], []
    inh_excInp_std, inh_inhInp_std = [], []
    
    for st in enumerate(stim_range):
        print('#### stim: ', str(st[0]))
               
        vm_data_no = 2*N+id_no + 2*trial_no       
        
        ### stationary
        # exc
        yyy = str(vm_data_no)
        nex ='vm-exc-st'+str(st[0])+'-'+yyy+'-0.dat'
        zex = np.loadtxt(nex)
        # inh
        yyy = str(vm_data_no+1)
        nin ='vm-inh-st'+str(st[0])+'-'+yyy+'-0.dat'
        zin = np.loadtxt(nin)
        
        vm_tc_st = []
        vm_hist[st[0]] = []
        for i in range(n_smpl):
            zz = zex[i::n_smpl][:,2]
            vm_tc_st.append(np.mean(zz))
            vm_hist[st[0]].append(np.histogram(zz, bins=100, normed=True))
        for i in range(n_smpl):
            zz = zin[i::n_smpl][:,2]
            vm_tc_st.append(np.mean(zz))
            vm_hist[st[0]].append(np.histogram(zz, bins=100, normed=True))
        vm_tc.append(vm_tc_st)
                   
    results['vm_tc'] = np.array(vm_tc)    
    results['vm_hist'] = vm_hist
    
    ### F0 and F2 components
    tc_f0, tc_f1 = [], []
    for ic, ct in enumerate(contrast):
        tc_f0_tmp, tc_f1_tmp = [], []
        #print ic, ct
        for i in range(N):
            f0, f1 = OS_functions._fft_(tc[:, ic, i])
            tc_f0_tmp.append(f0)
            tc_f1_tmp.append(f1)
        
        tc_f0.append(tc_f0_tmp)
        tc_f1.append(tc_f1_tmp)
    tc_f0 = np.array(tc_f0)
    tc_f1 = np.array(tc_f1)
    
    results['tc_f0'] = tc_f0
    results['tc_f1'] = tc_f1
    
    ### osi, vm fit and tw
    
    TW_out_tot, VM_fit_tot, scs_fit_tot, err_fit_tot = [], [], [], []
    PO_out_tot, OSI_out_tot = [], []
    for ic, ct in enumerate(contrast):
        TW_out, VM_fit, scs_fit, err_fit = [], [], [], []
        PO_out, OSI_out = [], []
        for i in range(N):
            zz = tc[:, ic, i]
            if i%1000 == 0: print(i)
            osi_out, po_out = OS_functions._osv_(zz, stim_range)
            OSI_out.append(osi_out)
            PO_out.append(po_out/2.)
            
            vm_fit = OS_functions.vonMises(stim_range, zz)
            VM_fit.append(vm_fit[1])
            scs_fit.append(vm_fit[3])
            TW_out.append(vm_fit[2])
            zer = zz - vm_fit[1]
            zser = np.sqrt(np.sum(zer**2))
            err_fit.append( 100 * zser / np.sqrt(sum(zz**2)) )
        
        TW_out_tot.append(TW_out)
        VM_fit_tot.append(VM_fit)
        scs_fit_tot.append(scs_fit)
        err_fit_tot.append(err_fit)
        PO_out_tot.append(PO_out)
        OSI_out_tot.append(OSI_out)
    
    results['VM_fit'] = np.array(VM_fit_tot)
    results['err_fit'] = np.array(err_fit_tot)
    results['scs_fit'] = np.array(scs_fit_tot)
    results['TW_out'] = np.array(TW_out_tot)
    results['PO_out'] = np.array(PO_out_tot)
    results['OSI_out'] = np.array(OSI_out_tot)
    
 
    ### writing the results
        
    fl = open('results', 'wb')
    cPickle.dump(results, fl, 2)
    fl.close()
    
    tf = time.time()
    
    print('### took : %.2f s' % (tf-ti))
    
    os.chdir(code_path)

###### Do it for all the simulation results 

sim_folder = sim_folder = 'N-5000_pif_delayType-random_g-4'
os.chdir(res_path+sim_folder)

print('########################################')
print('Processing data ...')
print(sim_folder)
sim_results(sim_folder)
print('########################################')
       
os.chdir(code_path)