##################################################################################
# plot_figures.py -- Reads and analyzes the results generated from network_run.py
# and plots Figures 1 and 3 in:
#
# Ref: Sadeh, Clopath and Rotter (PLOS Computational Biology, 2015).
# Emergence of Functional Specificity in Balanced Networks with Synaptic Plasticity.
#
# Author: Sadra Sadeh <s.sadeh@ucl.ac.uk> // Created: 2014-2015
##################################################################################

import numpy as np
import pylab as pl
from imp import reload
import params; reload(params); from params import *
import pickle as cPickle

from mpl_toolkits.axes_grid.inset_locator import inset_axes
from mpl_toolkits.axes_grid1 import make_axes_locatable

################################################################################
# --- read the results

fl = open('results', 'rb')
results = cPickle.load(fl)
fl.close()

W0 = results['W0'] 

spk_bp = results['spk_bp']
spk_wp_tot = results['spk_wp_tot'] 
spk_ap = results['spk_ap'] 
spk_sp_tot = results['spk_sp_tot']
spk_cd_tot = results['spk_cd_tot']

W_blk_tot = results['W_blk_tot']
stim_rng_tot = results['stim_rng_tot']
W_sp_tot = results['W_sp_tot']
W_cd_tot = results['W_cd_tot']
stim_rng_tot_cd = results['stim_rng_tot_cd']

Wf = W_blk_tot[-1]


################################################################################
### Figure 1
#########

mksz = 2.
Ts = sim_time/1000.

def _temp_plot_(spk, ax, stim=0, yy=0):

	exid = np.where(spk[0] < ne)[0]
	inid = np.where(spk[0] >= ne)[0]

	htex = pl.histogram(spk[1][exid], bins=sim_time/10, range=(0, sim_time))
	htin = pl.histogram(spk[1][inid], bins=sim_time/10, range=(0, sim_time))

	hr = pl.histogram(spk[0], bins=n, range=(0, n)) 

	ax.plot(spk[1][exid]*dt, spk[0][exid], 'r.', markersize=mksz, label='Exc: '+ str(np.round(len(exid)/ne)) )
	ax.plot(spk[1][inid]*dt, spk[0][inid], 'b.', markersize=mksz, label='Inh: '+str(np.round(len(inid)/ne)) )

	ax.set_yticks([0, 99, 199, 299, ne-1, n-1])
	ax.set_yticklabels([])
	ax.set_ylim(0-10, n+10)
	ax.set_xlim([0-10, sim_time+10])

	ax.set_xticklabels([])	

	divider = make_axes_locatable(ax)	

	axHisty = divider.append_axes("right", size=.5, pad=0.1)
	#adjust_spines(axHisty,['left', 'bottom'], outward=0)

	axHisty.plot(hr[0]/(Ts), hr[1][0:-1], color='k', lw=2)
	
	if stim == 0: 
	   pl.text(.85, .5, str(np.round(len(exid)/ne / Ts,1)) +' Hz', transform = axHisty.transAxes, color='r')
	   pl.text(.85, .85, str(np.round(len(inid)/ni /Ts,1))+' Hz', transform = axHisty.transAxes, color='b')
	else:
	    pl.text(.9, .5, str(np.round(len(exid)/ne /Ts,1)) +' Hz', transform = axHisty.transAxes, color='r')
	    pl.text(.9, .85, str(np.round(len(inid)/ni /Ts,1))+' Hz', transform = axHisty.transAxes, color='b')

	axHisty.set_yticks([0, 99, 199, 299, ne-1, n-1])
	axHisty.set_yticklabels([])
	axHisty.set_xticks([0, 10])

	axHisty.set_ylim(0-10, n+10)

	axHistx = divider.append_axes("bottom", 1.2, pad=0.3)
	#adjust_spines(axHistx,['left', 'bottom'], outward=0)

	axHistx.plot(htex[1][0:-1], htex[0], color='r', lw=2, label='Exc')
	axHistx.plot(htin[1][0:-1], htin[0], color='b', lw=2, label='Inh')	
	
	axHistx.set_yticks([0, 50, 100, 150])
	axHistx.set_yticklabels([])
	
	if yy == 1:	
	   axHistx.set_xlabel('Time (ms)')
	   axHistx.set_ylabel('Population spike count')
	   axHistx.set_yticklabels([0, 50, 100, 150])
	   pl.legend(loc=1, frameon=False, prop={'size':12.5})
	   
	   axHisty.set_xlabel('Firing rate \n (spikes/s)', size=10)


fig = pl.figure(figsize=(16,8))

mycl = pl.imshow(np.random.uniform(0, 1, (100, 100)), cmap='hsv', vmin=0, vmax=1)
pl.clf()

##
ax1 = pl.subplot(141)
pl.title('Before Plasticity')
_temp_plot_(spk=spk_bp, ax=ax1, yy=1)
ax1.set_ylabel('Neuron #')
#ax1.set_yticks([0, 99, 199, 299, ne, n])
ax1.set_yticklabels([1, 100, 200, 300, ne, n])

for i in range(int(stim_no)): 
    ax1.plot([i*trial_time, (i+1)*trial_time], [-5, -5], '-', color=pl.cm.hsv(th/np.pi), lw=10)

##
ax2 = pl.subplot(142)
pl.title('Beginning of Plasticity')
_temp_plot_(spk=spk_wp_tot[0], ax=ax2, stim=1)

for i in range(int(stim_no)):
    clbr = ax2.plot([i*trial_time, (i+1)*trial_time], [-5, -5], '-', color=pl.cm.hsv(stim_rng_tot[0][i]/np.pi), lw=10)

cax = fig.add_axes([.485, .25, .01, .1])
clbr = pl.colorbar(mycl, cax=cax, orientation='vertical')
clbr.set_ticks([0, .25, .5, .75, 1])
clbr.set_ticklabels([0, 45, 90, 135, 180])

##
ax3 = pl.subplot(143)
pl.title('End of Plasticity')

_temp_plot_(spk=spk_wp_tot[block_no-1], ax=ax3, stim=2)

for i in range(int(stim_no)):
    ax3.plot([i*trial_time, (i+1)*trial_time], [-5, -5], '-', color=pl.cm.hsv(stim_rng_tot[-1][i]/np.pi), lw=10)

##
ax4 = pl.subplot(144)
pl.title('After Plasticity')
_temp_plot_(spk=spk_ap, ax=ax4)

for i in range(int(stim_no)):
    ax4.plot([i*trial_time, (i+1)*trial_time], [-5, -5], '-', color=pl.cm.hsv(th/np.pi), lw=10)

ax4.text(.25, .1, 'Sparser activity', size=15, transform = ax4.transAxes)

pl.subplots_adjust(left=.05, right=.95, bottom=.075, top=.95, wspace=.25)

pl.savefig('Fig1')

################################################################################
### Figure 3 (A-C)
###########

pl.figure(figsize=(14,5))

pl.subplot(131)
pl.title('Initial Weights (W0)')
pl.imshow(W0)
clb = pl.colorbar(shrink=.75)
clb.set_ticks([-4, 0, .5])
clb.set_ticklabels([-4, 0, .5])

pl.xlabel('Post-synaptic #')
pl.ylabel('Pre-synaptic #')

pl.subplot(132)
pl.title('Final Weights (Wf)')
pl.imshow(Wf, vmin=-5, vmax=2)
clb=pl.colorbar(shrink=.75)
clb.set_ticks([-5, -4, -3, -2, -1, 0, 1, 2])
clb.set_ticklabels([-5, -4, -3, -2, -1, 0, 1, 2])

pl.subplot(133)
pl.title('Weight Changes (Wf - W0)')
pl.imshow(Wf - W0, vmin=-1.5, vmax=1.5)
clb = pl.colorbar(shrink=.75)
clb.set_ticks([-1.5, -1, -.5, 0, .5, 1, 1.5])
clb.set_ticklabels([-1.5, -1, -.5, 0, .5, 1, 1.5])

pl.savefig('Fig3A')

###

## bid and fs
dW = Wf - W0
dpo, dw_po, Wf_po, W0_po = [], [], [], []
for ii in range(ne):
    for jj in range(ne):
        Wf_po.append(Wf[ii, jj])
        W0_po.append(W0[ii, jj])   
        dpo.append(po_init[ii] - po_init[jj])

W0_po, Wf_po = np.array(W0_po), np.array(Wf_po)
dw_po = np.array(dw_po)
dpo = np.array(dpo)

dpo_id1 = np.where( (abs(dpo) < np.pi/6) + (abs(dpo) > np.pi-np.pi/6) == True)
dpo_id2 = np.where( ((abs(dpo) < 2*np.pi/6) * (abs(dpo) > np.pi/6)) +  ((abs(dpo) > np.pi-2*np.pi/6) * (abs(dpo) < np.pi-np.pi/6)) )
dpo_id3 = np.where( ((abs(dpo) < 3*np.pi/6) * (abs(dpo) > 2*np.pi/6)) +  ((abs(dpo) > np.pi-3*np.pi/6) * (abs(dpo) < np.pi-2
*np.pi/6)) )

################################################################################
### Figure 3 (D-G)
###########

pl.figure(figsize=(14,5))

#pl.title('aligend weights')
def _plot_alignw_(rng1=range(0,ne), rng2=range(ne,n)):	
    for i in rng1:
        ax.plot(-po_init[rng2] + po_init[i], Wf[i][rng2], 'k.', ms=mksz, alpha=.5)
    for i in rng1:
        if i == rng1[0]: ax.plot(-po_init[rng2] + po_init[i], W0[i][rng2], 'r.', ms=1, alpha=1.)
        else: ax.plot(-po_init[rng2] + po_init[i], W0[i][rng2], 'r.', ms=mksz, alpha=.5)


ax = pl.subplot(141)
pl.title('Exc to Exc')
_plot_alignw_(rng1=range(0,ne), rng2=range(0,ne))
ax.set_xticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi])
ax.set_xticklabels(['', -90, 0, 90, ''])
ax.set_yticks([0, .5, 1, 1.5, 2])
ax.set_yticklabels([0, .5, 1, 1.5, 2])
ax.set_ylim([0-.1, 2+.1])
ax.set_ylabel('Final Weights (mV)')
ax.set_xlabel('Pre-syn. PO - Post-syn. PO (deg)')

ax = pl.subplot(142)
pl.title('Exc to Inh')
_plot_alignw_(rng1=range(0,ne), rng2=range(ne,n))
ax.set_xticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi])
ax.set_xticklabels(['', -90, 0, 90, ''])
ax.set_yticks([0, .5, 1, 1.5, 2])
ax.set_yticklabels([0, .5, 1, 1.5, 2])
ax.set_ylim([0-.1, 2+.1])
ax.set_ylabel('Final Weights (mV)')

ax = pl.subplot(143)
pl.title('Inh to Exc')
_plot_alignw_(rng1=range(ne,n), rng2=range(0,ne))
ax.set_xticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi])
ax.set_xticklabels(['', -90, 0, 90, ''])
ax.set_yticks([-6, -5.5, -5, -4.5, -4, -3.5, -3])
ax.set_yticklabels([-6, -5.5, -5, -4.5, -4, -3.5, -3])
ax.set_ylim([-5.5-.1, -3+.1])
ax.set_ylabel('Final Weights (mV)')

ax=pl.subplot(144)
#adjust_spines(ax,['left', 'bottom'], outward=0, s=0)
pl.title('Connection Specificity')

xrng = [1, 2, 3]

def _part_dpo_(www, cl='k', lbl=[]):
	ym1, ys1 = np.mean(www[dpo_id1]), np.std(www[dpo_id1])
	ym2, ys2 = np.mean(www[dpo_id2]), np.std(www[dpo_id2])
	ym3, ys3 = np.mean(www[dpo_id3]), np.std(www[dpo_id3])

	ymrng = np.array([ym1, ym2, ym3])
	ysrng = np.array([ys1, ys2, ys3])

	ax.plot(xrng, ymrng, '-o', lw=2, color=cl, label=lbl)

_part_dpo_(W0_po, cl='r', lbl='Initial')
_part_dpo_(Wf_po, lbl='Final')

pl.legend(title='Exc to Exc', frameon=False, numpoints=1)

ax.set_xlim(0, 4)
ax.set_xticks([1, 2, 3])
ax.set_xticklabels(['0-30', '30-60', '60-90'], rotation=0)

ax.set_yticks([0, .2, .4, .6, .8, 1])
ax.set_yticklabels([0, .2, .4, .6, .8, 1])

ax.set_ylim(0, 1)


pl.xlabel('dpo range (deg)')
pl.ylabel('Average Weight (mV)')

pl.subplots_adjust(left=.05, right=.97, bottom=.15, top=.925, wspace=.45)

pl.savefig('Fig3B')

pl.show()

# ----------