"""
Copyright 2019 Toshitake Asabuki

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# coding: UTF-8
from __future__ import division
import numpy as np
import pylab as pl
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy.matlib
import os
import shutil
import sys
import matplotlib.cm as cm
import sklearn.decomposition
from mpl_toolkits.mplot3d import Axes3D
from scipy import stats

mpl.rcParams['svg.fonttype'] = 'none'
mpl.rcParams['font.sans-serif'] = 'Arial'
mpl.rcParams['pdf.fonttype'] = 42
params = {'backend': 'ps',
    'axes.labelsize': 11,
    'text.fontsize': 11,
    'legend.fontsize': 11,
    'xtick.labelsize': 11,
    'ytick.labelsize': 11,
    'text.usetex': False,
    'figure.figsize': [10 / 2.54, 6 / 2.54]}



beta = 5
def g(x):

    alpha = 1
    
    theta= 0.5

    ans = 1/(1+alpha*np.exp(beta*(-x+theta)))
    return ans

N = 10
width = 30
dt = 1

nsecs = width *10000
simtime = np.arange(0, nsecs, dt)
simtime_len = len(simtime)

tau =15
tau_syn = 5
n_syn = 1000
n_in =  n_syn
PSP = np.zeros(n_in)
I_syn = np.zeros(n_in)
g_L=1/tau
g_d=0.7
w  = np.random.randn(n_syn,N)/np.sqrt(n_syn)*1

poisson_signal =10
poisson_noise = 0.
eps = 10**(-4)#/poisson_signal

cA_p = -1*0.105#*0.1
cA_d = 0.525*0.1#*0.1
tau_p=20
tau_d=40
A_pre_p=np.zeros(N)
A_pre_d=np.zeros(N)
A_post_p=np.zeros(N)
A_post_d=np.zeros(N)

p_connect = 1
max_rate = poisson_signal
w_inh_max = 0.1#5/np.sqrt(N)#/max_rate

spike_time = np.zeros(N)

w_inh =np.ones((N,N))*w_inh_max

mask = np.zeros((N,N))
for i in range(N):
    for j in range(N):
        if i!=j:
            if np.random.rand()<p_connect:
                mask[i,j]=1
w_inh*=mask
w_inh[w_inh<0] = 0
w_inh[w_inh>w_inh_max] = w_inh_max



window = 300
V_som_list=np.zeros((N,window*width))

V_dend = np.zeros(N)

V_som = np.zeros(N)

connection_list = np.zeros((N,n_in),dtype=bool)
for i in range(N):
    connection_list[i,np.random.choice(np.arange(n_in), n_syn, replace = False)]=1

f = np.zeros(N)


chunk_list = [['a', 'b', 'c', 'd'],['e', 'f', 'g', 'h'],['i', 'j', 'k', 'l']]
n_chunk = len(chunk_list)


sym_list = ['a', 'b', 'c', 'd','e', 'f', 'g', 'h','i', 'j', 'k', 'l']

chunk = chunk_list[np.random.randint(n_chunk)]

sample_num=10
sample_len = width*len(chunk_list[0])*sample_num
test_len=sample_len

m = 0

input_pref = np.zeros(n_in)

for i in range(n_in):
    input_pref[i] = np.random.randint(len(sym_list))

PSP_mat = np.zeros((N,n_syn))
symbol_pat=np.zeros((n_in,len(sym_list)),dtype=bool)
for i in range(n_in):
    symbol_pat[i,np.random.choice(np.arange(len(sym_list)), 1, replace = False)]=1
print("")
print("***********")
print("Learning... ")
print("***********")

for i in range(simtime_len):
    if (i % (width/dt) == 0 and i > 0):

        if m == len(chunk) - 1:
            chunk = chunk_list[np.random.randint(n_chunk)]
        
            m = 0
        
        else:
            
            m += 1

    rate_in = np.ones(n_in)*poisson_noise

    input_id =symbol_pat[:,sym_list.index(chunk[m])]
    
    rate_in[input_id] = poisson_signal
    prate = dt*rate_in*(10**-3)
    
    if int(i / simtime_len * 100) % 5 == 0.0:
        if int(i / simtime_len * 100) > int((i - 1) / simtime_len * 100):
            print(" " + str(int(i / simtime_len * 100)) + "% ")

    id = np.random.rand(n_in)<prate
    I_syn = (1.0 - dt / tau_syn) * I_syn
    I_syn[id]+=1/tau/tau_syn
    PSP = (1.0 - dt / tau) * PSP + I_syn
    PSP_unit=PSP*25
    
    for k in range(N):
        PSP_mat[k,:] = PSP_unit[connection_list[k]]
        
    V_dend = np.diag(np.dot(PSP_mat,w))

    V_som_list = np.roll(V_som_list, -1,axis=1)
    for k in range(N):
        V_som[k] = (1.0-dt/tau)*V_som[k] +g_d*(V_dend[k]-V_som[k])+np.dot(-w_inh[k,:],f)

    V_som_list[:,-1] = V_som
    
    if i>width*window:
        
        for k in range(N):
            f[k]=g((V_som[k]-np.mean(V_som_list[k,:])) / np.std(V_som_list[k,:]))#*max_rate

   
            w[:,k] += eps  *(f[k]-g(V_dend[k]*g_d/(g_d+g_L))*1) * PSP_unit[connection_list[k]]*beta*(1-g(V_dend[k]*g_d/(g_d+g_L)))

        w-=eps*w*0.05
    A_pre_p = (1.0 - dt / tau_p) * A_pre_p
    A_pre_d = (1.0 - dt / tau_d) * A_pre_d
    A_post_p = (1.0 - dt / tau_p) * A_post_p
    A_post_d = (1.0 - dt / tau_d) * A_post_d

    for k in range(N):
        if np.random.rand()<dt*f[k]*max_rate*(10**-3):
     
                A_pre_p[k]+=cA_p
                A_pre_d[k]+=cA_d
                A_post_p[k]+=cA_p
                A_post_d[k]+=cA_d
                
                w_inh[k,:]+=(A_pre_p+A_pre_d)*w_inh_max
                w_inh[:,k]+=(A_post_p+A_post_d)*w_inh_max
                spike_time[k]=i

    w_inh*=mask
    w_inh[w_inh<0] = 0
    w_inh[w_inh>w_inh_max] = w_inh_max


print("")
print("***********")
print("Testing... ")
print("***********")

PSP = np.zeros(n_in)
I_syn = np.zeros(n_in)
synaptic_input_matrix=np.zeros((n_in*N,test_len))

V_dend_list =np.zeros((N,test_len))

V_dend = np.zeros(N)

V_som = np.zeros(N)
chunk_id = np.random.randint(n_chunk)
chunk = chunk_list[chunk_id]

m = 0

f_list = np.zeros((N,test_len))

chunk_start = [[] for k in range(n_chunk)]

chunk_start[chunk_id].append(0)

id = np.zeros((test_len,n_in),dtype=bool)

for i in range(test_len):
    if (i % (width/dt) == 0 and i > 0):
        
        if m == len(chunk) - 1:
            chunk_id = np.random.randint(n_chunk)
            chunk = chunk_list[chunk_id]
            chunk_start[chunk_id].append(i)
            m = 0
        
        else:
            
            m += 1

    rate_in = np.ones(n_in)*poisson_noise

    input_id =symbol_pat[:,sym_list.index(chunk[m])]
    rate_in[input_id] = poisson_signal
    prate = dt*rate_in*(10**-3)
    

    id[i,:] = np.random.rand(n_in)<prate

    I_syn = (1.0 - dt / tau_syn) * I_syn
    I_syn[id[i,:]]+=1/tau/tau_syn
    PSP = (1.0 - dt / tau) * PSP + I_syn
    PSP_unit=PSP*25
    for l in range(N):
        synaptic_input_matrix[l*n_in:(l+1)*n_in,i]=PSP*w[:,l]
    for k in range(N):
        PSP_mat[k,:] = PSP_unit[connection_list[k]]
        
    V_dend = np.diag(np.dot(PSP_mat,w))
    
    for k in range(N):
        V_som[k] = (1.0-dt/tau)*V_som[k] +g_d*(V_dend[k]-1*V_som[k])+np.dot(-w_inh[k,:],f)
    V_dend_list[:,i] = V_dend
    for k in range(N):
        f[k] = g(V_som[k])#*max_rate

    f_list[:,i]=f#/max_rate

chunk1_start=np.array(chunk_start[0])
chunk2_start=np.array(chunk_start[1])
chunk3_start=np.array(chunk_start[2])
tspk,nspk = pl.nonzero(id==1)



###################
##
##  Plotting
##
###################



max1 = np.zeros(N)
min1 = np.zeros(N)
for i in range(N):
    max1[i] = np.max(f_list[i,0:sample_len])
    min1[i] = np.min(f_list[i,0:sample_len])
avg_norm1 = np.zeros((N,sample_len))

for i in range(N):
    avg_norm1[i,:] = (f_list[i,0:sample_len]-min1[i])/(max1[i]-min1[i])

t = np.zeros(N)
for j in range(N):
    arg = np.angle(np.dot(avg_norm1[j,:],np.exp(np.arange(sample_len)/(sample_len)*2*np.pi*1j))/sum(avg_norm1[j,:]))
    if arg<0:
        arg += 2*np.pi
    t[j] = sample_len/(2*np.pi)*arg

index = np.zeros(N)

index = np.argsort(t)
avg_sorted = np.zeros((N,sample_len))
for i in range(N):
    avg_sorted[i,:] = avg_norm1[int(index[i]),:]

fig, ax = plt.subplots(figsize=(4,3))

cax=plt.imshow(avg_sorted, interpolation='nearest', aspect="auto",cmap='jet')

cbar = fig.colorbar(cax, ticks=[0, 1], orientation='vertical')
cbar.ax.set_yticklabels(['min', 'max'],fontsize=10)

plt.xlabel("Time [ms]",fontsize=10)
plt.ylabel("Neurons (sorted)",fontsize=10)
plt.yticks([0,N-1],["1","%d"%N],fontsize=10)
ax.tick_params(length=1.3, width=0.05, labelsize=10)
ax.xaxis.set_ticks_position('none')
ax.yaxis.set_ticks_position('none')
plt.ylim([-0.5,N-0.5])
pl.xlim([0,sample_len])

fig.subplots_adjust(left=0.15,bottom=0.25,right=1)
for l in range(sample_num):
    ax.axvline(x=width*len(chunk)*(l+1), ymin=0, ymax=N, color='gray', linewidth=1)

plt.savefig('activity_map.pdf', fmt='pdf',dpi=350)


weight_sorted_row =np.zeros((N,N))
weight_sorted_column =np.zeros((N,N))
for i in range(N):
    weight_sorted_row[i,:] = w_inh[int(index[i]),:]
for i in range(N):
    weight_sorted_column[:,i] =weight_sorted_row[:,int(index[i])]
fig, ax = plt.subplots(figsize=(4,3))

cax=plt.imshow(weight_sorted_column, interpolation='nearest', aspect="auto",cmap='jet')

cbar = fig.colorbar(cax, orientation='vertical')

plt.xlabel("Presynaptic neuron",fontsize=10)
plt.ylabel("Postsynaptic neuron",fontsize=10)

ax.tick_params(length=1.3, width=0.05, labelsize=11)
ax.xaxis.set_ticks_position('none')
ax.yaxis.set_ticks_position('none')
plt.xticks([0,1,2,3,4,5,6,7,8,9],['1','2','3','4','5','6','7','8','9','10'],fontsize=11)
plt.yticks([0,1,2,3,4,5,6,7,8,9],['1','2','3','4','5','6','7','8','9','10'],fontsize=11)
fig.subplots_adjust(left=0.15,bottom=0.25,right=0.8)

plt.savefig('Winh_map.pdf', fmt='pdf',dpi=350)



fig = plt.figure(figsize=(8, 3))
ax = fig.add_subplot(111)
ax.spines['top'].set_color('none')
ax.spines['bottom'].set_color('none')
ax.spines['left'].set_color('none')
ax.spines['right'].set_color('none')
ax.tick_params(labelcolor='w', top='off', bottom='off', left='off', right='off')
ax2 = fig.add_subplot(311)
pl.plot(f_list[int(index[0]),:],lw=1.5,c='k')
for i in (chunk1_start):
    
    pl.axvspan(i, i + width*4 , facecolor='orangered', alpha=0.3,linewidth=0)
for i in (chunk2_start):
    
    pl.axvspan(i, i + width*4 , facecolor='limegreen', alpha=0.3,linewidth=0)
for i in (chunk3_start):
    
    pl.axvspan(i, i + width*4 , facecolor='dodgerblue', alpha=0.3,linewidth=0)
pl.xlim([0,test_len])
pl.ylim([-0.01,1.01])
plt.yticks( np.arange(0,1.01 , 0.5) )
ax2.xaxis.set_ticks_position('bottom')
ax2.yaxis.set_ticks_position('left')
ax2.spines['right'].set_color('none')
ax2.spines['top'].set_color('none')
ax2.xaxis.set_major_locator(pl.NullLocator())
ax2 = fig.add_subplot(312)
pl.plot(f_list[int(index[4]),:],lw=1.5,c='k')
for i in (chunk1_start):
    
    pl.axvspan(i, i + width*4 , facecolor='orangered', alpha=0.3,linewidth=0)
for i in (chunk2_start):
    
    pl.axvspan(i, i + width*4 , facecolor='limegreen', alpha=0.3,linewidth=0)
for i in (chunk3_start):
    
    pl.axvspan(i, i + width*4 , facecolor='dodgerblue', alpha=0.3,linewidth=0)
pl.xlim([0,test_len])
pl.ylim([-0.01,1.01])
plt.yticks( np.arange(0,1.01 , 0.5) )
ax2.xaxis.set_ticks_position('bottom')
ax2.yaxis.set_ticks_position('left')
ax2.spines['right'].set_color('none')
ax2.spines['top'].set_color('none')
ax2.xaxis.set_major_locator(pl.NullLocator())
ax3 = fig.add_subplot(313)
pl.plot(f_list[int(index[9]),:],lw=1.5,c='k')
for i in (chunk1_start):
    
    pl.axvspan(i, i + width*4 , facecolor='orangered', alpha=0.3,linewidth=0)
for i in (chunk2_start):
    
    pl.axvspan(i, i + width*4 , facecolor='limegreen', alpha=0.3,linewidth=0)
for i in (chunk3_start):
    
    pl.axvspan(i, i + width*4 , facecolor='dodgerblue', alpha=0.3,linewidth=0)
pl.xlim([0,test_len])
pl.ylim([-0.01,1.01])
plt.yticks( np.arange(0,1.01 , 0.5) )
ax3.xaxis.set_ticks_position('bottom')
ax3.yaxis.set_ticks_position('left')
ax3.spines['right'].set_color('none')
ax3.spines['top'].set_color('none')

plt.xlabel("Time [ms]", fontsize=11)
ax.set_ylabel("Activity", fontsize=11)

fig.subplots_adjust(bottom=0.15, left=0.1, right=0.95,hspace=0.3)

plt.savefig('activities.pdf', fmt='pdf', dpi=350)