"""
Copyright 2019 Toshitake Asabuki

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
import os, sys
import matplotlib.pyplot as plt
import pylab as pl
import matplotlib as mpl
import shutil
import copy


mpl.rcParams['svg.fonttype'] = 'none'
mpl.rcParams['font.sans-serif'] = 'Arial'
mpl.rcParams['pdf.fonttype'] = 42
params = {'backend': 'ps',
    'axes.labelsize': 11,
    'text.fontsize': 11,
    'legend.fontsize': 11,
    'xtick.labelsize': 11,
    'ytick.labelsize': 11,
    'text.usetex': False,
    'figure.figsize': [10 / 2.54, 6 / 2.54]}

beta = 5
def g(x):
    
    alpha = 1
    
    theta=1.7
    
    ans = 1/(1+alpha*np.exp(beta*(-x+theta)))
    return ans



width = 50

mean_rate = 5
r_sig = mean_rate
r_noise = mean_rate - r_sig
gain=10
n_in=2000
dt=1

pat1 = np.zeros((n_in,width),dtype=bool)
pat2 = np.zeros((n_in,width),dtype=bool)
pat3 = np.zeros((n_in,width),dtype=bool)

for i in range(n_in):
    for j in range(width):
        if np.random.rand()<r_sig*dt*10**(-3):
            pat1[i,j]=1
        if np.random.rand()<r_sig*dt*10**(-3):
            pat2[i,j]=1
        if np.random.rand()<r_sig*dt*10**(-3):
            pat3[i,j]=1

N = 1

nsecs = width *10000
simtime = np.arange(0, nsecs, dt)
simtime_len = len(simtime)

tau =15
tau_syn = 5
n_syn = n_in

PSP = np.zeros(n_in)
I_syn = np.zeros(n_in)
g_L=1/tau
g_d=0.7
w  = np.random.randn(n_syn,N)/np.sqrt(n_syn)

eps = 10**(-6)

p_connect = 1

spike_time = np.zeros(N)

window = 300
V_som_list=np.zeros((N,window*width))

V_dend = np.zeros(N)

V_som = np.zeros(N)

connection_list = np.zeros((N,n_in),dtype=bool)
for i in range(N):
    connection_list[i,np.random.choice(np.arange(n_in), n_syn, replace = False)]=1

f = np.zeros(N)

random_start=0


type='random'

print("")
print("***********")
print("Learning... ")
print("***********")
for i in range(simtime_len):
    if i == random_start:
        type='random'

        random_width=np.random.randint(1*width,3*width)
        pat_start=i+random_width

    if i==pat_start:

        random_start=i+width
        dice=np.random.rand()

        if dice<1/3:

            type='pat1'
        
        elif dice<2/3:

            type='pat2'

        else:

            type='pat3'
    if type=='random':
        spike_mat = np.zeros(n_in,dtype=bool)
        spike_mat[np.random.rand(n_in)<mean_rate*dt*10**(-3)]=1
    elif type=='pat1':
        spike_mat =pat1[:,i-pat_start]
    elif type=='pat2':
        spike_mat =pat2[:,i-pat_start]

    elif type=='pat3':
        spike_mat =pat3[:,i-pat_start]
        
    if int(i / simtime_len * 100) % 5 == 0.0:
        if int(i / simtime_len * 100) > int((i - 1) / simtime_len * 100):
            print(" " + str(int(i / simtime_len * 100)) + "% ")
    I_syn = (1.0 - dt / tau_syn) * I_syn
    I_syn[spike_mat]+=1/tau/tau_syn
    PSP = (1.0 - dt / tau) * PSP + I_syn
    PSP_unit = PSP*25
    V_dend = np.dot(w.T,PSP_unit)
    V_som_list = np.roll(V_som_list, -1,axis=1)

    V_som = (1.0-dt*g_L)*V_som +g_d*(V_dend-V_som)
    V_som_list[:,-1] = V_som

    if i>width*window:

        f = g((V_som-np.mean(V_som_list,axis=1)) / np.std(V_som_list,axis=1))#*gain

        w += eps  *np.outer((f-g(V_dend*g_d/(g_d+g_L))) , PSP_unit).T*beta*(1-g(V_dend*g_d/(g_d+g_L)))
        w -= eps*w*5


print("")
print("***********")
print("Testing... ")
print("***********")

test_len=60*width*5
plot_len=700
synaptic_input_matrix=np.zeros((n_in*N,test_len))
spike_mat=np.zeros((n_in,test_len),dtype=bool)
PSP = np.zeros(n_in)
I_syn = np.zeros(n_in)
synaptic_input_matrix=np.zeros((n_in*N,test_len))

V_dend_list =np.zeros((N,test_len))

V_dend = np.zeros(N)

V_som = np.zeros(N)
f_list = np.zeros((N,test_len))
pat1_start=[]
pat2_start=[]
pat3_start=[]
random_start=0

random_mat=np.zeros((n_in,test_len))
pat1_mat=np.zeros((n_in,test_len))
pat2_mat=np.zeros((n_in,test_len))
pat3_mat=np.zeros((n_in,test_len))
for i in range(test_len):
    if i == random_start:

        random_width=np.random.randint(1*width,3*width)
        pat_start=i+random_width
        spike_mat[:,i:min(i+random_width,test_len)] = np.zeros((n_in,min(test_len-i,random_width)),dtype=bool)
        for j in range(n_in):
            for k in range(min(random_width,test_len-i)):
                if np.random.rand()<mean_rate*dt*10**(-3):
                    spike_mat[j,k+i]=1

        random_mat[:,i:min(i+random_width,test_len)]=spike_mat[:,i:min(i+random_width,test_len)]
    if i==pat_start:

        random_start=i+width
        p_pat1=1/3
        p_pat2=2/3

        if pat1_start==[]:
            p_pat1=1
        if pat2_start==[]:
            p_pat1=0
            p_pat2=1
        if pat3_start==[]:
            p_pat1=0
            p_pat2=0
        dice=np.random.rand()
        if dice<p_pat1:
            pat1_start.append(i)
            spike_mat[:,i:min(i+width,test_len)]=pat1[:,0:min(test_len-i,width)]
            
        elif dice<p_pat2:
            pat2_start.append(i)
            spike_mat[:,i:min(i+width,test_len)]=pat2[:,0:min(test_len-i,width)]

        else:
            pat3_start.append(i)
            
            spike_mat[:,i:min(i+width,test_len)]=pat3[:,0:min(width,test_len-i)]

for i in range(test_len):
    if i in pat1_start:
        pat1_mat[:,i:min(i+width,test_len)]=spike_mat[:,i:min(i+width,test_len)]
    if i in pat2_start:
        pat2_mat[:,i:min(i+width,test_len)]=spike_mat[:,i:min(i+width,test_len)]
    if i in pat3_start:
        pat3_mat[:,i:min(i+width,test_len)]=spike_mat[:,i:min(i+width,test_len)]
for i in range(test_len):
    I_syn = (1.0 - dt / tau_syn) * I_syn
    I_syn[spike_mat[:,i]]+=1/tau/tau_syn
    PSP = (1.0 - dt / tau) * PSP + I_syn
    PSP_unit=PSP*25
    for l in range(N):
        synaptic_input_matrix[l*n_in:(l+1)*n_in,i]=PSP_unit*w[:,l]
    V_dend = np.dot(w.T,PSP_unit)

    V_som = (1.0-dt*g_L)*V_som +g_d*(V_dend-V_som)
    for k in range(N):
        f[k] = g(V_som[k])

    f_list[:,i]=f


nspk_random,tspk_random = pl.nonzero(random_mat==1)
nspk1,tspk1 = pl.nonzero(pat1_mat==1)
nspk2,tspk2 = pl.nonzero(pat2_mat==1)
nspk3,tspk3 = pl.nonzero(pat3_mat==1)



fig = plt.figure(figsize=(7, 2))
ax = plt.subplot(1, 1, 1)
for i in range(N):
    pl.plot(f_list[i,:],lw=1.5,c='k')


for i in (pat1_start):
    
    pl.axvspan(i, i + width, facecolor='orangered', alpha=0.3,linewidth=0)
for i in (pat2_start):
    
    pl.axvspan(i, i + width, facecolor='dodgerblue', alpha=0.3,linewidth=0)
for i in (pat3_start):
    
    pl.axvspan(i, i + width, facecolor='limegreen', alpha=0.3,linewidth=0)

pl.xlim([0,1400])
pl.ylim([-0.1,1.1])
plt.xlabel("Time [ms]", fontsize=11)
plt.ylabel("Activity", fontsize=11)

fig.subplots_adjust(bottom=0.25, left=0.1)
ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('left')
ax.spines['right'].set_color('none')
ax.spines['top'].set_color('none')

plt.savefig('activity.pdf', fmt='pdf', dpi=350)



target1=np.zeros(test_len)
target2=np.zeros(test_len)
target3=np.zeros(test_len)
for i in (pat1_start):
    target1[i:min(test_len,i+width)]=1
for i in (pat2_start):
    target2[i:min(test_len,i+width)]=1
for i in (pat3_start):
    target3[i:min(test_len,i+width)]=1

chunk_corr = np.zeros(3)

for i in range(N):
    chunk_corr[0]=np.corrcoef(f_list[i,:],target1)[0,1]
    chunk_corr[1]=np.corrcoef(f_list[i,:],target2)[0,1]
    chunk_corr[2]=np.corrcoef(f_list[i,:],target3)[0,1]

fig = plt.figure(figsize=(7, 2))
ax = fig.add_subplot(111)
for i in pat1_start:
    plt.vlines([i], 0, 2000+150, "orangered", linestyles='dashed',lw=1)
    if i+width<plot_len:
        plt.vlines([i+width], 0, 2000+150, "orangered", linestyles='dashed',lw=1)
    plt.hlines([2000+150], i, min(i+width,plot_len), "orangered", linestyles='solid',lw=5)
for i in pat2_start:
    plt.vlines([i], 0, 2000+150, "dodgerblue", linestyles='dashed',lw=1)
    if i+width<plot_len:
        plt.vlines([i+width], 0, 2000+150, "dodgerblue", linestyles='dashed',lw=1)
    plt.hlines([2000+150], i, min(i+width,plot_len), "dodgerblue", linestyles='solid',lw=5)
for i in pat3_start:
    plt.vlines([i], 0, 2000+150, "limegreen", linestyles='dashed',lw=1)
    if i+width<plot_len:
        plt.vlines([i+width], 0, 2000+150, "limegreen", linestyles='dashed',lw=1)
    plt.hlines([2000+150], i, min(i+width,plot_len), "limegreen", linestyles='solid',lw=5)
pl.plot(tspk_random,nspk_random,c='k',marker='.',lw=0,markersize=1.5)
pl.plot(tspk1,nspk1,c='orangered',marker='.',lw=0,markersize=1.5)
pl.plot(tspk2,nspk2,c='dodgerblue',marker='.',lw=0,markersize=1.5)
pl.plot(tspk3,nspk3,c='limegreen',marker='.',lw=0,markersize=1.5)
plt.ylabel("Neuron id", fontsize=11)

fig.subplots_adjust(bottom=0.25, left=0.1)
ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('left')
ax.spines['right'].set_color('none')
ax.spines['top'].set_color('none')
pl.xlim([0,plot_len])
pl.ylim([0,2000+150])
plt.savefig('raster.pdf', fmt='pdf', dpi=350)