'''

Script to combine M1 experimental long-rang connectivity data from multiple papers into single conn matrix for model

smat - strength matrix; smat = pmat * imat
cmat - convergence matrix
wmat - weight matrix

'''

import numpy as np
from scipy.io import loadmat, savemat
from pprint import pprint
from scipy import interpolate
from pylab import *
from pprint import pprint
from collections import OrderedDict

colorList=[[0,0,0],[0.42,0.67,0.84], [0.90,0.76,0.00], [0.42,0.83,0.59], [0.90,0.32,0.00],
        [0.34,0.67,0.67], [0.90,0.59,0.00], [0.42,0.82,0.83], [1.00,0.85,0.00],
        [0.33,0.67,0.47], [1.00,0.38,0.60], [0.57,0.67,0.33], [0.5,0.2,0.0],
        [0.71,0.82,0.41], [0.0,0.2,0.5], [0.70,0.32,0.10]]

colorList = ['r', 'g', 'b', 'y', 'b']

def plotMats():    
    from mpl_toolkits.mplot3d import Axes3D
    from matplotlib.collections import PolyCollection
    from matplotlib.colors import colorConverter
    from scipy.interpolate import spline,interp1d
    from matplotlib import colors as mcolors
    def cc(arg):
        return mcolors.to_rgba(arg, alpha=0.6)

    longPops = ['TPO', 'TVL', 'S1', 'S2', 'cM1', 'M2', 'OC']
    cellTypes = ['IT', 'PT', 'CT']#, 'PV', 'SOM']
    cellTypes = ['CT', 'PV', 'SOM']
    EorIvals = ['exc', 'inh']

    #ion()
    # cmat
    for EorI in ['exc']:#EorIvals:
        fig=figure(figsize=(20,8))
        fig.suptitle('Convergence matrices', fontsize=16)
        i=0

        zs = np.arange(0, len(longPops)+1, 1.0)
            
        for ct in cellTypes:
            #ax = plt.subplot(111, projection='3d')
            ax = fig.add_subplot(1,len(cellTypes),i+1, projection='3d')
            i=i+1
            title('Long-range ' + EorI + ' inputs -> '+ ct)
            ax.elev=30 # 90 
            ax.azim=-45#-45 # -90
            verts = []
            for longPop in longPops:
            
        #     #
        #     i=i+1
        #     if i==1: ylabel('postsynaptic NCD')
        #     if i==3: xlabel('convergence'); 
        #     title(ct+' '+EorI)
        #     for longPop in longPops:
        #         ybins = [mean(x) for x in bins[(longPop, ct)]]
        #         xconv = cmat[(longPop,ct,EorI)]

        #         #ybins = [x for pair in bins[(longPop, ct)] for x in pair]
        #         #xconv = [x for x in cmat[(longPop,ct,EorI)] for _ in (0,1)]

        #         #print ybins, xconv
        #         plot(xconv, ybins, label=longPop, marker='o')
        #         ax.set_ylim([0.1,1.0])
        #         ax.invert_yaxis()
        #         #plt.xticks(np.arange(0, 250, 50.0))
        #         for tick in ax.xaxis.get_major_ticks():
        #             tick.label.set_fontsize(8) 
        #         tight_layout()
        #         subplots_adjust(top=0.9)
        # lgd = legend(bbox_to_anchor=(0.0, -0.1), loc='right', ncol=7)

                
                xconv = cmat[(longPop,ct,EorI)]
                ybins = [mean(x) for x in bins[(longPop, ct)]]
                if len(xconv) == 1:
                    ybins = [(1.0+0.0*ix)*x for pair in bins[(longPop, ct)] for ix,x in enumerate(pair)]
                    xconv = [(1.0+0.0*pt)*x for x in cmat[(longPop,ct,EorI)] for pt in (0,1)]

                y_smooth = linspace(min(ybins), max(ybins), 50)
                xfunc = interp1d(ybins, xconv, kind='linear')
                x_smooth = list(xfunc(y_smooth))
                y_smooth=list(y_smooth)

                x_smooth.insert(0,0)
                x_smooth.append(0)
                y_smooth.insert(0,min(y_smooth)-0.03)
                y_smooth.append(max(y_smooth)+0.03)

                verts.append(list(zip(y_smooth,x_smooth)))

            verts.insert(0,[(0,0),(0,0)])
            verts.append([(0,0), (0,0)])
            print len(verts)
            poly = PolyCollection(verts, facecolors = colorList*4) #[0:len(verts)])
            poly.set_alpha(0.7)
            ax.add_collection3d(poly, zs=zs, zdir='y')

            popLabels = [' '] + longPops # + [' ']
            ax.set_xlabel('postsynaptic NCD')
            ax.set_xlim3d(0, 1)
            ax.set_zlabel('convergence')
            ax.set_zlim3d(0, max([p[1] for points in verts for p in points]))
            ax.set_ylabel('Long range input')
            ax.set_ylim3d(0, max(zs))
            ax.set_yticklabels(popLabels)
            tight_layout()
            subplots_adjust(top=0.9)
        if fixedNumSyns:
            filename = dataFolder+'conn_long_cmat_fixedNumSyns_3D_'+cellTypes[0]+EorI+'.png'
        else:
            filename = dataFolder+'conn_long_cmat_varNumSyns_'+EorI+'.png'

        plt.show()

        #fig.savefig(filename, dpi=600,  bbox_inches='tight')# bbox_extra_artists=(lgd,),)
        #fig.savefig(filename, dpi=300, bbox_extra_artists=(lgd,), bbox_inches='tight')

    return ax


def plotPies():
    fracs = {}
    fracLabels = {}
    fracLabels['EI'] = ['Exc', 'Inh']
    fracs['EI'] = [fracSyns['M1'][k] for k in ['exc','inh']]
    fracLabels['ELL'] = ['Long (exc)', 'Local (exc)']
    fracs['ELL'] = [fracSyns['M1'][k] for k in ['exc_long','exc_local']]
    fracLabels['ILL'] = ['Long (inh)', 'Local (inh)']
    fracs['ILL'] = [fracSyns['M1'][k] for k in ['inh_long','inh_local']]
    fracLabels['L'] = ['TPO', 'TVL', 'S1', 'S2', 'cM1', 'M2', 'OC', 'Other']
    fracs['L'] = [fracSyns['M1'][k] for k in fracLabels['L']]

    fracsTitles = {'EI':'Exc vs inh inputs', 'ELL': 'Exc long-range vs local inputs', 'ILL': 'Inh long-range vs local inputs', 'L':'Long-range inputs'} 
    
    print fracs

    for k in fracs.keys():
        fracLabel,frac = fracLabels[k],fracs[k]
        # make a square figure and axes
        figure(1, figsize=(9,9))
        ax = axes([0.1, 0.1, 0.8, 0.8])

        # The slices will be ordered and plotted counter-clockwise.
        labels = fracLabel

        ## RED/BLUE for E/I!!
        print k
        if k=='EI':
             colors = [colorList[4], colorList[1]]
        else:
            colors = colorList[1:len(frac)+1]
        mpl.rcParams['font.size'] = 20.0
        mpl.rcParams['font.weight'] = 'bold'

        # DONT CUT LABELS!!
        pie(frac, labels=labels, autopct='%1.0f%%', pctdistance=0.7, labeldistance=1.05, radius=0.8, shadow=True, startangle=0, colors=colors)
        title(fracsTitles[k])
        subplots_adjust(top=0.9)
        savefig(dataFolder+k+'_frac.png')
        show()



# --------------------------------------------------------------------------------------------- #
# MAIN SCRIPT
# --------------------------------------------------------------------------------------------- #

# set folder paths for source data and output data
rootFolder = '../../'  # should point to root repo folder (m1) -- currently in m1/sim/conn  
dataFolder = rootFolder+'data/conn/'
outFolder = rootFolder+'sim/conn/'

# load data
smat = {}  # dict for strength matrices (smat = cmat * wmat) - not used here
cmat = {}  # dict for convergence matrices
wmat = {}  # dict for weight matrices
bins = {}  # dict for bin intervals
numSyns = {} # dict with number of syns per cell
synsPerConn = {}  # dict with syns per connection
numCells = {}  # dict for the number of artificial spike generators (NetStims) per presyn population
delays = {}  # dict for the delays 
rates = {}  # dict for the spontaneous firing rate of each presyn pop

## cell types
cellTypes = ['IT', 'PT', 'CT', 'SOM', 'PV']

## General coordinate system / layer boundaries '''
bins['layerLabels'] =   ['pia', 'L1-L2 border', 'L2/3-L4 border',   'L4-L5A border',   'L5A-L5B border',    'L5B-L6 border',   'L6-WM border']  
bins['layers']      =   [0,     0.12,           0.31,               0.42,               0.52,               0.77,               1.0]

# ------------------------------------------------------------------------------------------------------------------
# 1) Set total num of long range syns/cell:
#  a) Start from Ben's estimate (dend length and 1.5 spines/um) for our M1 PT cell = ~16400
#  b) Estimate num for other cell types / layers based on Markram,2015/Meyer 2012 (rat S1), scaling based on PT = 10400
#  c) Make sure consistent with Shuz 1989 (~7500); not sure what to make of DeFelipe 2002 (21983) - ignore since estimate ?
#
# Note: Maybe set same num syns for all pops so the %-based syn input distributions from papers have the desired effect ?! 

synDensity = 1.5  # spines/um
PTlength = 10933  # total spiny dend length (um)
numSyns['M1'] = {}
numSyns['M1']['PT5B'] = int(round(PTlength * synDensity)) # total number of syns (16400)

## Meyer et al, 2010
numSyns['Meyer'] = {'IT2': 2155, 'IT3': 2973, 'IT4': 1777, 'STL4': 1275, 'IT5': 2887, 'PT5B': 5178, 'IT6': 1954}

## scaling Meyer to M1 PT
numSynsFactor = float(numSyns['M1']['PT5B']) / numSyns['Meyer']['PT5B']  # scale factor between mouse M1 PT cell and Meyer 2010 rat S1

## Use variable or fixed (mean) number of syns for each layers
fixedNumSyns = True  # use diff number of syns for each layer; otherwise use mean for all
if fixedNumSyns:
    meanNumSyns = mean(numSyns['Meyer'].values())
    numSyns['Meyer'] = {k: meanNumSyns for k in numSyns['Meyer']}

## tot num syns exc pops
m1EPops = ['IT2', 'IT4', 'IT5A', 'IT5B', 'PT5B', 'IT6', 'CT6']
s1EPops = ['IT2', 'IT4', 'IT5', 'IT5', 'PT5B', 'IT6', 'IT6']

for m1pop, s1pop in zip(m1EPops, s1EPops):
    numSyns['M1'][m1pop] = int(round(numSyns['Meyer'][s1pop] * numSynsFactor))

## tot num syns inh pops
m1IPops = ['PV2', 'SOM2', 'PV5A', 'SOM5A', 'PV5B', 'SOM5B', 'PV6', 'SOM6']
s1IPops = ['IT2', 'IT2', 'IT5', 'IT5', 'IT5', 'IT5', 'IT6', 'IT6']

numSynsFactorInh =  1.0  # Potjans & Diesman used ~0.8
for m1pop, s1pop in zip(m1IPops, s1IPops):
    numSyns['M1'][m1pop] = int(round(numSyns['Meyer'][s1pop] * numSynsFactor * numSynsFactorInh))

#pprint(numSyns['M1'])
#pprint(mean(numSyns['M1'].values()))

# ------------------------------------------------------------------------------------------------------------------
# 2) Set num of total syns/cell for each input region: 
#  a) Set % of exc vs inh, and long vs local
#  overall long:long = 70-30%
fracSyns = {'M1': {}}  
fracSyns['M1']['exc'] = 0.8  # exc/inh inputs = 84/16% (DeFelipe 2002); 75%/25% (Markram 2015)
fracSyns['M1']['inh'] = 0.2
fracSyns['M1']['exc_long'] = 0.8  # local/long exc = 20-80% (Markram 2015; Stepanyants 2009); 25-75% (Suter 2013) 
fracSyns['M1']['exc_local'] = 0.2
fracSyns['M1']['inh_long'] = 0.3  # local/long inh = 70-30% (Stepanyants 2009)
fracSyns['M1']['inh_local'] = 0.7 

## Thalamus VL and PO
##  mouse TVAL->M1 PT = 18% (Suter, 2013); rat VPM,POm->S1 = 7.4%, 4% (Meyer, 2010b); rat Thal -> L4 S1 rat (Bruno 2006): ~15%; PO=0.7, VAL=3.5 (Oh 2014)
fracSyns['M1']['TVL'] = 0.15  # Thalamus, motor-related, ventro-lateral, core, Cerebellar-relaying
fracSyns['M1']['TPO'] = 0.10  # Thalamus, sensory-related, medial posterior, matrix; less syns thatn VL (Meyer 2011; Zhang 2016; Oh 2014; Bopp 2017)

## S1
## upper-limb=1.16, barrel=0.86 (Oh 2014, norm strength); upper-limb=8.61%, barrel=29% (Zhang 2016, % presyn cells); 
fracSyns['M1']['S1'] = 0.15  

## S2
## 0.83 (Oh 2014, norm strength); ~1/12th long-range syns (Suter 2013 notes)
fracSyns['M1']['S2'] = 0.10 

## M2
## 0.72 (Oh 2014, norm strength)
fracSyns['M1']['M2'] = 0.10 

## contralateral M1
## 1.27 (Oh 2014, norm strength)
fracSyns['M1']['cM1'] = 0.15 

## OC
## 0.32+0.36 (Oh 2014, norm strength)
fracSyns['M1']['OC'] = 0.10 

## Other (not modeled)
fracSyns['M1']['Other'] = 0.15 


# ------------------------------------------------------------------------------------------------------------------
# 3) Set convergence based on layer/yfrac

#############
## Thalamus 
inputPops = ['TPO', 'TVL']
normInput = {'Yama15elife': {}, ('Yama15', 'IT'): {}, ('Yama15', 'PT'): {}, ('Yama15', 'CT'): {}, ('Yama15', 'SOM'): {}, ('Yama15', 'PV'): {}}

## Yamawaki,2015 elife (forelimb)
normInput['Yama15elife']['TPO'] = [0.26, 0.72, 0.72, 0.58, 0.25, 0.090, 0.044, 0.022]
normInput['Yama15elife']['TVL'] = [0.30, 0.43, 0.64, 0.24, 0.34, 0.77, 0.69, 0.061] 
bins['Yama15elife'] = [[0.1,0.2], [0.2,0.3], [0.3,0.4], [0.4,0.5], [0.5,0.6], [0.6,0.7], [0.7,0.8], [0.8,0.9]]

## yfrac based bins but corrected so each bin is within a single layer (L2/3 has 2 bins; L5B has 3 bins; L6 has 2 bins; rest 1 bin)
## corresponds approximately to original bin ranges
bins['Yama15_labels'] = ['L23u', 'L23l', 'L4', 'L5A', 'L5Bu', 'L5Bm', 'L5Bl', 'L6u', 'L6l']
bins['Yama15'] = [[0.12, 0.12+(0.31-0.12)/2], [0.12+(0.31-0.12)/2, 0.31], [0.31, 0.42], [0.42, 0.52], \
            [0.52, 0.52+(0.77-0.52)/3], [0.52+(0.77-0.52)/3, 0.52+2*(0.77-0.52)/3], [0.52+2*(0.77-0.52)/3, 0.77], [0.77, 0.77+(1.0-0.77)/2], [0.77+(1.0-0.77)/2, 1.0]]

## Extend based on Yamawaki,2015 jns (TVL)
## CT/PT = 0.05; CT/IT5B = 0.16; CT/IT6 (yfrac<0.85) = 0.33; CT=IT6 (yfrac>0.85) = 0.2
## 3 eqn system: CT/PT5B = 0.05; CT/IT5B = 0.16; (PT5B+IT5B)/2 = 0.6 -> solutions: CT=0.0457, PT=0.914, IT5B=0.286
## To match Yamawaki jns ratios for IT/PT/CT, Yamawaki elife values were adapted as follows: 
## PT = (0.914/0.6)*L5B; IT5B = (0.286/0.6)*L5B; CT=0.0457; IT6 upper = 3*CT; IT6 lower = CT  
## TPO keep same values but adapted to use same structure as VL 
## Naka16 confirmed in Table 1: VL/PO->IT=low,high; VL/PO->PT=high,low 
normInput[('Yama15', 'PT')]['TPO'] = normInput['Yama15elife']['TPO'][4:6+1]  
normInput[('Yama15', 'PT')]['TVL'] = [x*(0.914/0.6) for x in normInput['Yama15elife']['TVL'][4:6+1]]  # PT = (0.914/0.6)*L5B
bins[('TVL', 'PT')] = bins[('TPO', 'PT')] = bins['Yama15'][4:6+1]

normInput[('Yama15', 'CT')]['TPO'] = [normInput['Yama15elife']['TPO'][7]] * 2  # add extra L6 
normInput[('Yama15', 'CT')]['TVL'] = [0.0457, 0.0457] # CT=0.0457 (derived from 3 eqn system above based on yamawaki jns ratios)
bins[('TVL', 'CT')] = bins[('TPO', 'CT')] = bins['Yama15'][7:8+1]

normInput[('Yama15', 'IT')]['TPO'] = normInput['Yama15elife']['TPO'] + [normInput['Yama15elife']['TPO'][-1]]  # add extra L6 
normInput[('Yama15', 'IT')]['TVL'] = normInput['Yama15elife']['TVL'] + [normInput['Yama15elife']['TVL'][-1]]  # add extra L6 bin 
normInput[('Yama15', 'IT')]['TVL'][4:6+1] = [x*0.286/0.6 for x in normInput[('Yama15', 'IT')]['TVL'][4:6+1]] # IT5B = (0.286/0.6)*L5B
normInput[('Yama15', 'IT')]['TVL'][7] = normInput[('Yama15', 'CT')]['TVL'][0]*3 # IT6 upper = 3*CT;
normInput[('Yama15', 'IT')]['TVL'][8] = normInput[('Yama15', 'CT')]['TVL'][1] # IT6 lower = CT 
bins[('TVL', 'IT')] = bins[('TPO', 'IT')] = bins['Yama15'][0:8+1]

bins[('TVL', 'PV')] = bins[('TVL', 'SOM')] = bins[('TVL', 'IT')]
bins[('TPO', 'PV')] = bins[('TPO', 'SOM')] =  bins[('TPO', 'IT')]

## For inhib cells assume original Yamawaki elife yfrac-based values
normInput[('Yama15', 'SOM')]['TPO'] = normInput['Yama15elife']['TPO'] + [normInput['Yama15elife']['TPO'][-1]] # add extra L6 
normInput[('Yama15', 'PV')]['TPO'] = normInput['Yama15elife']['TPO'] + [normInput['Yama15elife']['TPO'][-1]]
normInput[('Yama15', 'SOM')]['TVL'] = normInput['Yama15elife']['TVL'] + [normInput['Yama15elife']['TVL'][-1]]
normInput[('Yama15', 'PV')]['TVL'] = normInput['Yama15elife']['TVL'] + [normInput['Yama15elife']['TVL'][-1]]

## normalize so max=1.0
for inputPop in inputPops:
    maxValue = max([x for ct in cellTypes for x in normInput[('Yama15', ct)][inputPop]])
    for ct in cellTypes:
        normInput[('Yama15', ct)][inputPop] = [x/maxValue for x in normInput[('Yama15', ct)][inputPop]]

## syns per conn (= synaptic contacts per connection)
## (Bruno 2006, in vivo rat S1) = 7 syns/con; (Markram 2015, Rat S1)  = 8.1+-4.2 syns/con
synsPerConn['TPO'] = 5
synsPerConn['TVL'] = 5

## calculate convergence
M1BinPops = {}  # pop corresponding to each bin (thalamus) to obtain syns per cell
M1BinPops['IT'] = ['IT2', 'IT2', 'IT4', 'IT5A', 'IT5B',  'IT5B',  'IT5B', 'IT6', 'IT6']
M1BinPops['PT'] = ['PT5B', 'PT5B', 'PT5B']
M1BinPops['CT'] = ['CT6', 'CT6']
M1BinPops['PV'] = ['PV2', 'PV2', 'PV5A', 'PV5A', 'PV5B', 'PV5B', 'PV5B', 'PV6', 'PV6']
M1BinPops['SOM'] = ['SOM2', 'SOM2', 'SOM5A', 'SOM5A', 'SOM5B', 'SOM5B', 'SOM5B', 'SOM6', 'SOM6']

for inputPop in inputPops:
    for ct in cellTypes:
        for EorI in ['exc', 'inh']:
            # num of syns = total * exc * exc_long_range * fraction for this thalamic input 
            nsyns = [int(round(numSyns['M1'][M1BinPop] * fracSyns['M1'][EorI] * fracSyns['M1'][EorI+'_long'] * fracSyns['M1'][inputPop])) for M1BinPop in M1BinPops[ct]]
            # convergence = num syns / syns per conn * normalized input
            cmat[(inputPop, ct, EorI)] = [float(round(nsyn / synsPerConn[inputPop] * ninput)) for nsyn,ninput in zip(nsyns, normInput[('Yama15', ct)][inputPop])]


## Check Yamawaki 2015 jns ratios in terms of the number of syns per cell type
# print cmat 
# print 'Ratio of num of syns for comparison with Yamawaki 2015 jns (note tot syns different for each cell type):'
# print 'CT/PT (Yamawaki=0.05) = ', cmat[('TVL','CT')][0] / mean(cmat[('TVL','PT')])
# print 'CT/IT5B (Yamawaki=0.16) ', cmat[('TVL','CT')][0] / mean(cmat[('TVL','IT')][4:7])
# print 'CT/IT6upper (Yamawaki=0.33) ', cmat[('TVL','CT')][0] / cmat[('TVL','IT')][7]
# print 'CT/IT6lower (Yamawaki=1) ', cmat[('TVL','CT')][0] / cmat[('TVL','IT')][8]


#############
## S1
## L23, 0.7, L5A, 1.0, L5B, 0.14, L6, 0.1 (Mao 2011)
normInput['Mao11'] = [0.7, 1.0, 0.14, 0.1]
bins['Mao11'] =  [[0.12,0.31], [0.31,0.52], [0.52, 0.77], [0.77,1.0]]

## Adapt to use in model by setting specific bins for each cell type
normInput[('Mao11', 'IT')] = [0.7, 1.0, 1.0, 0.14, 0.1]
bins[('S1', 'IT')] =  [[0.12,0.31], [0.31,0.42], [0.42, 0.52], [0.52, 0.77], [0.77,1.0]]

normInput[('Mao11', 'PT')] = [normInput[('Mao11', 'IT')][3]]
bins[('S1', 'PT')] =  [bins[('S1', 'IT')][3]]

normInput[('Mao11', 'CT')] = [normInput[('Mao11', 'IT')][4]]
bins[('S1', 'CT')] =  [bins[('S1', 'IT')][4]]

normInput[('Mao11', 'PV')] = normInput[('Mao11', 'IT')]
bins[('S1', 'PV')] = bins[('S1', 'IT')] 

normInput[('Mao11', 'SOM')] = normInput[('Mao11', 'IT')] # input? yes (Wall 2016) vs no (Harris & Shepeherd 2015) 
bins[('S1', 'SOM')] =  bins[('S1', 'IT')] 

## syns per conn
synsPerConn['S1'] = 5

## calculate convergence
M1BinPops = {}  # pop corresponding to each bin (S1) to obtain syns per cell
M1BinPops['IT'] = ['IT2', 'IT4', 'IT5A', 'IT5B', 'IT6']
M1BinPops['PT'] = ['PT5B']
M1BinPops['CT'] = ['CT6']
M1BinPops['PV'] = ['PV2', 'PV5A', 'PV5A', 'PV5B',  'PV6']
M1BinPops['SOM'] = ['SOM2', 'SOM5A', 'SOM5A', 'SOM5B', 'SOM6']

inputPop = 'S1'
for ct in cellTypes:
    for EorI in ['exc', 'inh']:
        # num of syns = total * exc * exc_long_range * fraction for this thalamic input 
        nsyns = [int(round(numSyns['M1'][M1BinPop] * fracSyns['M1'][EorI] * fracSyns['M1'][EorI+'_long'] * fracSyns['M1'][inputPop])) for M1BinPop in M1BinPops[ct]]
        
        # convergence = num syns / syns per conn * normalized input
        cmat[(inputPop, ct, EorI)] = [float(round(nsyn / synsPerConn[inputPop] * ninput)) for nsyn,ninput in zip(nsyns, normInput[('Mao11', ct)])]


#############
## S2
## L23=1.0, L5A=0.46, L5B=0.35, L6=0.20; in L5B decreases with slope=-5.6 (Suter 2015)
## divide L5B into 3 bins and calculate strength of each bin based on slope
## bins: (0.77-0.52)/3 = 0.083 -> [0.52, 0.603], [0.603, 0.687], [0.687, 0.77]
## values: x=0.645, y=0.35; slope = -5.6*0.35 (rescaled) = 1.96
## -> x=0.5615, y=0.35 + (-0.083)*-1.96 = 0.35+0.1626 = 0.51; 
## -> x=0.7285, y=0.35 + (+0.083)*-1.96 = 0.35+0.1626 = 0.189; 
## use values from L5A (0.46 and 0.20) and L6 since consistent with calculated ones.
## Also add layer 4 as subset of L5A
normInput['Suter15'] = [1.0, 0.46, 0.46, 0.46, 0.35, 0.20, 0.20]
bins['Suter15'] =  [[0.12,0.31], [0.31,0.42], [0.42,0.52], [0.52, 0.603], [0.603, 0.687], [0.687, 0.77], [0.77,1.0]]

## Adapt to use in model by setting specific bins for each cell type
normInput[('Suter15', 'IT')] = normInput['Suter15'] 
bins[('S2', 'IT')] =  bins['Suter15'] 

normInput[('Suter15', 'PT')] = normInput['Suter15'][3:5+1]
bins[('S2', 'PT')] =  bins['Suter15'] [3:5+1]

normInput[('Suter15', 'CT')] = [normInput['Suter15'][6]]
bins[('S2', 'CT')] =  [bins['Suter15'][6]]

normInput[('Suter15', 'PV')] = normInput['Suter15'] 
bins[('S2', 'PV')] = bins['Suter15'] 

normInput[('Suter15', 'SOM')] = normInput['Suter15'] # input? yes (Wall 2016) vs no (Harris & Shepeherd 2015) 
bins[('S2', 'SOM')] = bins['Suter15'] 

## syns per conn
synsPerConn['S2'] = 5

## calculate convergence
M1BinPops = {}  # pop corresponding to each bin (S2) to obtain syns per cell
M1BinPops['IT'] = ['IT2', 'IT4', 'IT5A', 'IT5B', 'IT5B', 'IT5B', 'IT6']
M1BinPops['PT'] = ['PT5B', 'PT5B', 'PT5B']
M1BinPops['CT'] = ['CT6']
M1BinPops['PV'] = ['PV2', 'PV5A', 'PV5A', 'PV5B', 'PV5B', 'PV5B', 'PV6']
M1BinPops['SOM'] = ['SOM2', 'SOM5A', 'SOM5A', 'SOM5B', 'SOM5B', 'SOM5B', 'SOM6']

inputPop = 'S2'
for ct in cellTypes:
    for EorI in ['exc', 'inh']:
        # num of syns = total * exc * exc_long_range * fraction for this thalamic input 
        nsyns = [int(round(numSyns['M1'][M1BinPop] * fracSyns['M1'][EorI] * fracSyns['M1'][EorI+'_long'] * fracSyns['M1'][inputPop])) for M1BinPop in M1BinPops[ct]]
        
        # convergence = num syns / syns per conn * normalized input
        cmat[(inputPop, ct, EorI)] = [float(round(nsyn / synsPerConn[inputPop] * ninput)) for nsyn,ninput in zip(nsyns, normInput[('Suter15', ct)])]


#############
## M2 and contralateral M1
## L23: 0.17, L5A: 0.11, L5B: 1.0, L6: 0.44; L5Blower significantly higher than upper (Hooks 2013)
## divide L5B in upper vs lower: L23: 0.17, L5A: 0.11, L5Bu: 0.8, L5Bl: 1.2, L6: 0.44
## Also add layer 4 as subset of L5A
normInput['Hooks13_M2'] = [x/1.2 for x in [0.17, 0.11, 0.11, 0.8, 1.2, 0.44]] # rescale so max=1
bins['Hooks13_M2'] =  [[0.12,0.31], [0.31,0.42], [0.43,0.52], [0.52, 0.645], [0.645, 0.77], [0.77,1.0]]

## Adapt to use in model by setting specific bins for each cell type
normInput[('Hooks13_M2', 'IT')] = normInput['Hooks13_M2'] 
bins[('M2', 'IT')] = bins['Hooks13_M2'] 

normInput[('Hooks13_M2', 'PT')] = normInput['Hooks13_M2'][3:4+1]
bins[('M2', 'PT')] = bins['Hooks13_M2'] [3:4+1]

normInput[('Hooks13_M2', 'CT')] = [normInput['Hooks13_M2'][5]]
bins[('M2', 'CT')] = [bins['Hooks13_M2'][5]]

normInput[('Hooks13_M2', 'PV')] = normInput['Hooks13_M2'] 
bins[('M2', 'PV')] = bins['Hooks13_M2'] 

normInput[('Hooks13_M2', 'SOM')] = normInput['Hooks13_M2'] # input? yes (Wall 2016) vs no (Harris & Shepeherd 2015) 
bins[('M2', 'SOM')] = bins['Hooks13_M2'] 

for ct in cellTypes:
    bins[('cM1', ct)] = bins[('M2', ct)]

## syns per conn
synsPerConn['cM1'] = 5
synsPerConn['M2'] = 5

## calculate convergence
M1BinPops = {}  # pop corresponding to each bin (S2) to obtain syns per cell
M1BinPops['IT'] = ['IT2', 'IT4', 'IT5A', 'IT5B', 'IT5B', 'IT6']
M1BinPops['PT'] = ['PT5B', 'PT5B']
M1BinPops['CT'] = ['CT6']
M1BinPops['PV'] = ['PV2', 'PV5A', 'PV5A', 'PV5B', 'PV5B', 'PV6']
M1BinPops['SOM'] = ['SOM2', 'SOM5A', 'SOM5A', 'SOM5B', 'SOM5B', 'SOM6']


for pop in ['cM1', 'M2']:
    for ct in cellTypes:
        for EorI in ['exc', 'inh']:
            # num of syns = total * exc * exc_long_range * fraction for this thalamic input 
            nsyns = [int(round(numSyns['M1'][M1BinPop] * fracSyns['M1'][EorI] * fracSyns['M1'][EorI+'_long'] * fracSyns['M1'][pop])) for M1BinPop in M1BinPops[ct]]
            
            # convergence = num syns / syns per conn * normalized input
            cmat[(pop, ct, EorI)] = [float(round(nsyn / synsPerConn[pop] * ninput)) for nsyn,ninput in zip(nsyns, normInput[('Hooks13_M2', ct)])]


#############
## OC
## L23: 0.19, L5A: 0.30, L5B, 0.31, L6, 1.0 (Hooks 2013)
## Add layer 4
normInput['Hooks13_OC'] = [0.19, 0.30, 0.30, 0.31, 1.0]
bins['Hooks13_OC'] =  [[0.12,0.31], [0.31,0.42], [0.42,0.52], [0.52, 0.77], [0.77,1.0]]

## Adapt to use in model by setting specific bins for each cell type
normInput[('Hooks13_OC', 'IT')] = normInput['Hooks13_OC'] 
bins[('OC', 'IT')] = bins['Hooks13_OC'] 

normInput[('Hooks13_OC', 'PT')] = [normInput['Hooks13_OC'][3]]
bins[('OC', 'PT')] = [bins['Hooks13_OC'][3]]

normInput[('Hooks13_OC', 'CT')] = [normInput['Hooks13_OC'][4]]
bins[('OC', 'CT')] = [bins['Hooks13_OC'][4]]

normInput[('Hooks13_OC', 'PV')] = normInput['Hooks13_OC'] 
bins[('OC', 'PV')] = bins['Hooks13_OC']

normInput[('Hooks13_OC', 'SOM')] = normInput[('Hooks13_OC', 'IT')] # input? yes (Wall 2016) vs no (Harris & Shepeherd 2015) 
bins[('OC', 'SOM')] = bins['Hooks13_OC']

## syns per conn
synsPerConn['OC'] = 5

## calculate convergence
M1BinPops = {}  # pop corresponding to each bin (S1) to obtain syns per cell
M1BinPops['IT'] = ['IT2', 'IT4', 'IT5A', 'IT5B', 'IT6']
M1BinPops['PT'] = ['PT5B']
M1BinPops['CT'] = ['CT6']
M1BinPops['PV'] = ['PV2', 'PV5A', 'PV5A', 'PV5B',  'PV6']
M1BinPops['SOM'] = ['SOM2', 'SOM5A', 'SOM5A', 'SOM5B', 'SOM6']

inputPop = 'OC'
for ct in cellTypes:
    for EorI in ['exc', 'inh']:
        # num of syns = total * exc * exc_long_range * fraction for this thalamic input 
        nsyns = [int(round(numSyns['M1'][M1BinPop] * fracSyns['M1'][EorI] * fracSyns['M1'][EorI+'_long'] * fracSyns['M1'][inputPop])) for M1BinPop in M1BinPops[ct]]
        
        # convergence = num syns / syns per conn * normalized input
        cmat[(inputPop, ct, EorI)] = [float(round(nsyn / synsPerConn[inputPop] * ninput)) for nsyn,ninput in zip(nsyns, normInput[('Hooks13_OC', ct)])]


# pprint(normInput)
# pprint(cmat)#[('OC', 'PT')])
# pprint(bins)

# ------------------------------------------------------------------------------------------------------------------
# 4) Weights 
## Thalamus; Hu et al 2016 ~0.6mV, Constantinople et al 2013 E+FS= ~0.57mV, LTS=1.44mV
wmat['TPO'] = 0.6
wmat['TVL'] = 0.6

## S1, S2, contraM1, M2, OC
wmat['S1'] = 0.5
wmat['S2'] = 0.5
wmat['cM1'] = 0.5
wmat['M2'] = 0.5
wmat['OC'] = 0.5

# ------------------------------------------------------------------------------------------------------------------
# 5) delays 
## (Hu, 2016, mouse S1) = 2.2ms; (Constantinople, 2013, in vivo rat S1) = 11,8,10,10,11 (range 10-30) ms
delays['TPO'] = 5
delays['TVL'] = 5

## S1, S2, contraM1, M2, OC
delays['S1'] = 5
delays['S2'] = 5
delays['cM1'] = 5
delays['M2'] = 5
delays['OC'] = 5

# ------------------------------------------------------------------------------------------------------------------
# 6) Num cells in each regions
## TPO and TVL: L5B PT ~ 44% probability (Constantinople 2013) and 42.5% (Bruno 2006)
## set to ~2x max convergence = 600; tradeoff between reducing correlation vs computation time 
maxConv = 300

numCells['TPO'] = 2*maxConv  
numCells['TVL'] = 2*maxConv

## S1, S2, contraM1, M2, OC
numCells['S1'] = 2*maxConv
numCells['S2'] = 2*maxConv
numCells['cM1'] = 2*maxConv
numCells['M2'] = 2*maxConv
numCells['OC'] = 2*maxConv

# ------------------------------------------------------------------------------------------------------------------
# 7) avg spontaneous firing rates 
## TPO and TVL
## LGN ~= 1-10 Hz (rat); VB ~= 2.9 +- 0.7 Hz (Hirata, 2006, rat); Thal = 8 Hz (Potjans 2013, rat/cat S1/V1);  
rates['TPO'] = [0,5]
rates['TVL'] = [0,5]

## S1, S2, contraM1, M2, OC
## in vivo mouse S1->M1 L23 = 0.1 Hz (Yamashita 2013); M1 IT = 0-10 Hz (Isomura, 2009; Jacob 2012; Li 2016)
rates['S1'] = [0,5]
rates['S2'] = [0,5]
rates['cM1'] = [0,5]
rates['M2'] = [0,5]
rates['OC'] = [0,5]

# ------------------------------------------------------------------------------------------------------------------
# save matrices
savePickle = 0
saveMat = 0

data = {'cmat': cmat, 'wmat': wmat, 'bins': bins, 'numSyns': numSyns, 'synsPerConn': synsPerConn, 
            'delays': delays, 'numCells': numCells, 'rates': rates}

if savePickle:
    import pickle
    with open(outFolder+'conn_long.pkl', 'wb') as fileObj:        
        pickle.dump(data, fileObj)

if saveMat:
    savemat(outFolder+'conn.mat', data)



# plot matrices
plotMat = 1
if plotMat: ax=plotMats()

plotPie = 0
if plotPie: ax=plotPies()