"""
Copyright 2019 Toshitake Asabuki

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
import os, sys
import matplotlib.pyplot as plt
import pylab as pl
import matplotlib as mpl
import shutil


mpl.rcParams['svg.fonttype'] = 'none'
mpl.rcParams['font.sans-serif'] = 'Arial'
mpl.rcParams['pdf.fonttype'] = 42
params = {'backend': 'ps',
    'axes.labelsize': 11,
    'text.fontsize': 11,
    'legend.fontsize': 11,
    'xtick.labelsize': 11,
    'ytick.labelsize': 11,
    'text.usetex': False,
    'figure.figsize': [10 / 2.54, 6 / 2.54]}

beta = 5
def g(x):
    
    alpha = 1
    
    theta= 0.5
    
    ans = 1/(1+alpha*np.exp(beta*(-x+theta)))
    return ans


width = 50


mean_rate = 5
r_sig = mean_rate
r_noise = mean_rate - r_sig
gain=30
n_in=500
dt=1
pat1 = np.zeros((n_in,width),dtype=bool)
pat2 = np.zeros((n_in,width),dtype=bool)
pat3 = np.zeros((n_in,width),dtype=bool)

for i in range(n_in):
    for j in range(width):
        if np.random.rand()<r_sig*dt*10**(-3):
            pat1[i,j]=1
        if np.random.rand()<r_sig*dt*10**(-3):
            pat2[i,j]=1
        if np.random.rand()<r_sig*dt*10**(-3):
            pat3[i,j]=1

N = 10

nsecs = width *15000
simtime = np.arange(0, nsecs, dt)
simtime_len = len(simtime)

tau =15
tau_syn = 5
n_syn = n_in

PSP = np.zeros(n_in)
I_syn = np.zeros(n_in)
g_L=1/tau
g_d=0.7
w  = np.random.randn(n_syn,N)/np.sqrt(n_syn)

eps = 10**(-6)#*0.5

p_connect = 1

w_inh_max =0.1#5/np.sqrt(N)

w_inh =np.ones((N,N))*w_inh_max
spike_time = np.zeros(N)

mask = np.zeros((N,N))
for i in range(N):
    for j in range(N):
        if i!=j:
            if np.random.rand()<p_connect:
                mask[i,j]=1
w_inh*=mask
w_inh[w_inh<0] = 0
w_inh[w_inh>w_inh_max] = w_inh_max

window = 300
V_som_list=np.zeros((N,window*width))

V_dend = np.zeros(N)

V_som = np.zeros(N)

connection_list = np.zeros((N,n_in),dtype=bool)
for i in range(N):
    connection_list[i,np.random.choice(np.arange(n_in), n_syn, replace = False)]=1

f = np.zeros(N)

cA_p = -1*0.105*0.1#*0.3
cA_d = 0.525*0.1*0.1#*0.3
tau_p=20
tau_d=40
A_pre_p=np.zeros(N)
A_pre_d=np.zeros(N)
A_post_p=np.zeros(N)
A_post_d=np.zeros(N)
spike_mat=np.zeros((n_in,simtime_len),dtype=bool)

random_start=0



print("")
print("***********")
print("Learning... ")
print("***********")
for i in range(simtime_len):
    
    if int(i / simtime_len * 100) % 5 == 0.0:
        if int(i / simtime_len * 100) > int((i - 1) / simtime_len * 100):
            print(" " + str(int(i / simtime_len * 100)) + "% ")
    if i == random_start:
        type='random'

        random_width=np.random.randint(1*width,3*width)
        pat_start=i+random_width

    if i==pat_start:

        random_start=i+width
        dice=np.random.rand()

        if dice<1/3:

            type='pat1'
        
        elif dice<2/3:

            type='pat2'

        else:

            type='pat3'
    if type=='random':
        spike_mat = np.zeros(n_in,dtype=bool)
        spike_mat[np.random.rand(n_in)<mean_rate*dt*10**(-3)]=1
    elif type=='pat1':
        spike_mat =pat1[:,i-pat_start]
    elif type=='pat2':
        spike_mat =pat2[:,i-pat_start]

    elif type=='pat3':
        spike_mat =pat3[:,i-pat_start]
    I_syn = (1.0 - dt / tau_syn) * I_syn
    I_syn[spike_mat]+=1/tau/tau_syn
    PSP = (1.0 - dt / tau) * PSP + I_syn
    PSP_unit=PSP*25
    V_dend = np.dot(w.T,PSP_unit)
    V_som_list = np.roll(V_som_list, -1,axis=1)

    V_som = (1.0-dt*g_L)*V_som +g_d*(V_dend-V_som)+np.dot(-w_inh,f)
    V_som_list[:,-1] = V_som

    if i>width*window:

        f = g((V_som-np.mean(V_som_list,axis=1)) / np.std(V_som_list,axis=1))
       
        w += eps  *np.outer((f-g(V_dend*g_d/(g_d+g_L))) , PSP_unit).T*beta*(1-g(V_dend*g_d/(g_d+g_L)))
        w-=eps*w*5

    A_pre_p = (1.0 - dt / tau_p) * A_pre_p
    A_pre_d = (1.0 - dt / tau_d) * A_pre_d
    A_post_p = (1.0 - dt / tau_p) * A_post_p
    A_post_d = (1.0 - dt / tau_d) * A_post_d

    for k in range(N):
        if np.random.rand()<dt*f[k]*gain*(10**-3):

                A_pre_p[k]+=cA_p
                A_pre_d[k]+=cA_d
                A_post_p[k]+=cA_p
                A_post_d[k]+=cA_d
                
                w_inh[k,:]+=(A_pre_p+A_pre_d)*w_inh_max
                w_inh[:,k]+=(A_post_p+A_post_d)*w_inh_max
                spike_time[k]=i



    w_inh*=mask
    w_inh[w_inh<0] = 0
    w_inh[w_inh>w_inh_max] = w_inh_max

print("")
print("***********")
print("Testing... ")
print("***********")

test_len=60*width
plot_len=1500
synaptic_input_matrix=np.zeros((n_in*N,test_len))
spike_mat=np.zeros((n_in,test_len),dtype=bool)
PSP = np.zeros(n_in)
I_syn = np.zeros(n_in)
synaptic_input_matrix=np.zeros((n_in*N,test_len))

V_dend_list =np.zeros((N,test_len))

V_dend = np.zeros(N)

V_som = np.zeros(N)
f_list = np.zeros((N,test_len))
pat1_start=[]
pat2_start=[]
pat3_start=[]
random_start=0

random_mat=np.zeros((n_in,test_len))
pat1_mat=np.zeros((n_in,test_len))
pat2_mat=np.zeros((n_in,test_len))
pat3_mat=np.zeros((n_in,test_len))
for i in range(test_len):
    if i == random_start:

        random_width=np.random.randint(1*width,3*width)
        pat_start=i+random_width
        spike_mat[:,i:min(i+random_width,test_len)] = np.zeros((n_in,min(test_len-i,random_width)),dtype=bool)
        for j in range(n_in):
            for k in range(min(random_width,test_len-i)):
                if np.random.rand()<mean_rate*dt*10**(-3):
                    spike_mat[j,k+i]=1

        random_mat[:,i:min(i+random_width,test_len)]=spike_mat[:,i:min(i+random_width,test_len)]
    if i==pat_start:
        #m=0
        random_start=i+width
        p_pat1=1/3
        p_pat2=2/3

        if pat1_start==[]:
            p_pat1=1
        if pat2_start==[]:
            p_pat1=0
            p_pat2=1
        if pat3_start==[]:
            p_pat1=0
            p_pat2=0
        dice=np.random.rand()
        if dice<p_pat1:
            pat1_start.append(i)
            spike_mat[:,i:min(i+width,test_len)]=pat1[:,0:min(test_len-i,width)]


        elif dice<p_pat2:
            pat2_start.append(i)
            spike_mat[:,i:min(i+width,test_len)]=pat2[:,0:min(test_len-i,width)]


        else:
            pat3_start.append(i)
            
            spike_mat[:,i:min(i+width,test_len)]=pat3[:,0:min(width,test_len-i)]


for i in range(test_len):
    if i in pat1_start:
        pat1_mat[:,i:min(i+width,test_len)]=spike_mat[:,i:min(i+width,test_len)]
    if i in pat2_start:
        pat2_mat[:,i:min(i+width,test_len)]=spike_mat[:,i:min(i+width,test_len)]
    if i in pat3_start:
        pat3_mat[:,i:min(i+width,test_len)]=spike_mat[:,i:min(i+width,test_len)]
for i in range(test_len):
    I_syn = (1.0 - dt / tau_syn) * I_syn
    I_syn[spike_mat[:,i]]+=1/tau/tau_syn
    PSP = (1.0 - dt / tau) * PSP + I_syn
    PSP_unit=PSP*25
    for l in range(N):
        synaptic_input_matrix[l*n_in:(l+1)*n_in,i]=PSP_unit*w[:,l]
    V_dend = np.dot(w.T,PSP_unit)

    V_som = (1.0-dt*g_L)*V_som +g_d*(V_dend-V_som)+np.dot(-w_inh,f)
    for k in range(N):
        f[k] = g(V_som[k])

    f_list[:,i]=f


nspk_random,tspk_random = pl.nonzero(random_mat[0:200,:]==1)
nspk1,tspk1 = pl.nonzero(pat1_mat[0:200,:]==1)
nspk2,tspk2 = pl.nonzero(pat2_mat[0:200,:]==1)
nspk3,tspk3 = pl.nonzero(pat3_mat[0:200,:]==1)


fig = plt.figure(figsize=(7, 2))
ax = fig.add_subplot(111)
for i in pat1_start:
    plt.vlines([i], 0, n_in, "dodgerblue", linestyles='dashed',lw=1)
    if i+width<plot_len:
        plt.vlines([i+width], 0, n_in, "dodgerblue", linestyles='dashed',lw=1)
    plt.hlines([200+15], i, min(i+width,plot_len), "dodgerblue", linestyles='solid',lw=3)
for i in pat2_start:
    plt.vlines([i], 0, n_in, "orangered", linestyles='dashed',lw=1)
    if i+width<plot_len:
        plt.vlines([i+width], 0, n_in, "orangered", linestyles='dashed',lw=1)
    plt.hlines([200+15], i, min(i+width,plot_len), "orangered", linestyles='solid',lw=3)
for i in pat3_start:
    plt.vlines([i], 0, n_in, "limegreen", linestyles='dashed',lw=1)
    if i+width<plot_len:
        plt.vlines([i+width], 0, n_in, "limegreen", linestyles='dashed',lw=1)
    plt.hlines([200+15], i, min(i+width,plot_len), "limegreen", linestyles='solid',lw=3)
pl.plot(tspk_random,nspk_random,'k.',markersize=2)
pl.plot(tspk1,nspk1,'b.',markersize=2)
pl.plot(tspk2,nspk2,'r.',markersize=2)
pl.plot(tspk3,nspk3,'g.',markersize=2)
plt.ylabel("Neuron id", fontsize=11)

fig.subplots_adjust(bottom=0.25, left=0.1)
ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('left')
ax.spines['right'].set_color('none')
ax.spines['top'].set_color('none')
pl.xlim([0,plot_len])
pl.ylim([0,200+20])
plt.savefig('raster.pdf', fmt='pdf', dpi=350)


sample_len = test_len

max1 = np.zeros(N)
min1 = np.zeros(N)
for i in range(N):
    max1[i] = np.max(f_list[i,0:sample_len])
    min1[i] = np.min(f_list[i,0:sample_len])
avg_norm1 = np.zeros((N,sample_len))

for i in range(N):
    avg_norm1[i,:] = (f_list[i,0:sample_len]-min1[i])/(max1[i]-min1[i])

t = np.zeros(N)
for j in range(N):
    arg = np.angle(np.dot(avg_norm1[j,:],np.exp(np.arange(sample_len)/(sample_len)*2*np.pi*1j))/sum(avg_norm1[j,:]))
    if arg<0:
        arg += 2*np.pi
    t[j] = sample_len/(2*np.pi)*arg

index = np.zeros(N)

index = np.argsort(t)
avg_sorted = np.zeros((N,sample_len))
for i in range(N):
    avg_sorted[i,:] = avg_norm1[int(index[i]),:]

fig, ax = plt.subplots(figsize=(4,3))

cax=plt.imshow(avg_sorted, interpolation='nearest', aspect="auto",cmap='jet')

cbar = fig.colorbar(cax, ticks=[0, 1], orientation='vertical')
cbar.ax.set_yticklabels(['min', 'max'],fontsize=10)

plt.xlabel("Time [ms]",fontsize=10)
plt.ylabel("Neurons (sorted)",fontsize=10)
plt.yticks([0,N-1],["1","%d"%N],fontsize=10)
ax.tick_params(length=1.3, width=0.05, labelsize=10)
ax.xaxis.set_ticks_position('none')
ax.yaxis.set_ticks_position('none')
plt.ylim([-0.5,N-0.5])
pl.xlim([0,plot_len])

fig.subplots_adjust(left=0.15,bottom=0.25,right=1)

plt.savefig('activity_map.pdf', fmt='pdf',dpi=350)


weight_sorted_row =np.zeros((N,N))
weight_sorted_column =np.zeros((N,N))
for i in range(N):
    weight_sorted_row[i,:] = w_inh[int(index[i]),:]
for i in range(N):
    weight_sorted_column[:,i] =weight_sorted_row[:,int(index[i])]
fig, ax = plt.subplots(figsize=(4,3))

cax=plt.imshow(weight_sorted_column, interpolation='nearest', aspect="auto",cmap='jet')

cbar = fig.colorbar(cax, orientation='vertical')

plt.xlabel("Presynaptic neuron",fontsize=10)
plt.ylabel("Postsynaptic neuron",fontsize=10)

ax.tick_params(length=1.3, width=0.05, labelsize=11)
ax.xaxis.set_ticks_position('none')
ax.yaxis.set_ticks_position('none')
plt.xticks([0,1,2,3,4,5,6,7,8,9],['1','2','3','4','5','6','7','8','9','10'],fontsize=11)
plt.yticks([0,1,2,3,4,5,6,7,8,9],['1','2','3','4','5','6','7','8','9','10'],fontsize=11)
fig.subplots_adjust(left=0.15,bottom=0.25,right=0.8)

plt.savefig('Winh_map.pdf', fmt='pdf',dpi=350)

target1=np.zeros(test_len)
target2=np.zeros(test_len)
target3=np.zeros(test_len)
for i in (pat1_start):
    target1[i:min(test_len,i+width)]=1
for i in (pat2_start):
    target2[i:min(test_len,i+width)]=1
for i in (pat3_start):
    target3[i:min(test_len,i+width)]=1

chunk_corr = np.zeros((N,3))

for i in range(N):
    chunk_corr[i,0]=np.corrcoef(f_list[i,:],target1)[0,1]
    chunk_corr[i,1]=np.corrcoef(f_list[i,:],target2)[0,1]
    chunk_corr[i,2]=np.corrcoef(f_list[i,:],target3)[0,1]
performance = 0
for i in range(N):
    performance+=max(chunk_corr[i,:])
performance/=N
print(performance)

fig = plt.figure(figsize=(7, 3))
ax = fig.add_subplot(111)
ax.spines['top'].set_color('none')
ax.spines['bottom'].set_color('none')
ax.spines['left'].set_color('none')
ax.spines['right'].set_color('none')
ax.tick_params(labelcolor='w', top='off', bottom='off', left='off', right='off')

ax1 = fig.add_subplot(311)
pl.plot(f_list[np.argmax(chunk_corr[:,0]),:],lw=1.5,c='k')
for i in (pat1_start):
    pl.axvspan(i, i + width, facecolor='dodgerblue', alpha=0.3,linewidth=0)

pl.xlim([0,plot_len])
ax1.xaxis.set_ticks_position('bottom')
ax1.yaxis.set_ticks_position('left')
ax1.spines['right'].set_color('none')
ax1.spines['top'].set_color('none')
#plt.yticks( np.arange(0, 1.1, 0.5) )
ax1.xaxis.set_major_locator(pl.NullLocator())

ax2 = fig.add_subplot(312)
pl.plot(f_list[np.argmax(chunk_corr[:,1]),:],lw=1.5,c='k')

for i in (pat2_start):
    pl.axvspan(i, i + width, facecolor='orangered', alpha=0.3,linewidth=0)

pl.xlim([0,plot_len])
ax2.xaxis.set_ticks_position('bottom')
ax2.yaxis.set_ticks_position('left')
ax2.spines['right'].set_color('none')
ax2.spines['top'].set_color('none')
#plt.yticks( np.arange(0, 1.1, 0.5) )
ax2.xaxis.set_major_locator(pl.NullLocator())

ax3 = fig.add_subplot(313)
pl.plot(f_list[np.argmax(chunk_corr[:,2]),:],lw=1.5,c='k')

for i in (pat3_start):
    pl.axvspan(i, i + width, facecolor='limegreen', alpha=0.3,linewidth=0)
pl.xlim([0,plot_len])
ax3.xaxis.set_ticks_position('bottom')
ax3.yaxis.set_ticks_position('left')
ax3.spines['right'].set_color('none')
ax3.spines['top'].set_color('none')
#plt.yticks( np.arange(0, 1.1, 0.5) )
plt.xlabel("Time [ms]", fontsize=11)

fig.subplots_adjust(bottom=0.15, left=0.1)
plt.savefig('activities.pdf', fmt='pdf', dpi=350)