"""
Copyright 2019 Toshitake Asabuki
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
import os, sys
import matplotlib.pyplot as plt
import pylab as pl
import matplotlib as mpl
import shutil
mpl.rcParams['svg.fonttype'] = 'none'
mpl.rcParams['font.sans-serif'] = 'Arial'
mpl.rcParams['pdf.fonttype'] = 42
params = {'backend': 'ps',
'axes.labelsize': 11,
'text.fontsize': 11,
'legend.fontsize': 11,
'xtick.labelsize': 11,
'ytick.labelsize': 11,
'text.usetex': False,
'figure.figsize': [10 / 2.54, 6 / 2.54]}
beta = 5
def g(x):
alpha = 1
theta= 0.5
ans = 1/(1+alpha*np.exp(beta*(-x+theta)))
return ans
width = 50
mean_rate = 5
r_sig = mean_rate
r_noise = mean_rate - r_sig
gain=30
n_in=500
dt=1
pat1 = np.zeros((n_in,width),dtype=bool)
pat2 = np.zeros((n_in,width),dtype=bool)
pat3 = np.zeros((n_in,width),dtype=bool)
for i in range(n_in):
for j in range(width):
if np.random.rand()<r_sig*dt*10**(-3):
pat1[i,j]=1
if np.random.rand()<r_sig*dt*10**(-3):
pat2[i,j]=1
if np.random.rand()<r_sig*dt*10**(-3):
pat3[i,j]=1
N = 10
nsecs = width *15000
simtime = np.arange(0, nsecs, dt)
simtime_len = len(simtime)
tau =15
tau_syn = 5
n_syn = n_in
PSP = np.zeros(n_in)
I_syn = np.zeros(n_in)
g_L=1/tau
g_d=0.7
w = np.random.randn(n_syn,N)/np.sqrt(n_syn)
eps = 10**(-6)
p_connect = 1
w_inh_max =0.1
w_inh =np.ones((N,N))*w_inh_max
spike_time = np.zeros(N)
mask = np.zeros((N,N))
for i in range(N):
for j in range(N):
if i!=j:
if np.random.rand()<p_connect:
mask[i,j]=1
w_inh*=mask
w_inh[w_inh<0] = 0
w_inh[w_inh>w_inh_max] = w_inh_max
window = 300
V_som_list=np.zeros((N,window*width))
V_dend = np.zeros(N)
V_som = np.zeros(N)
connection_list = np.zeros((N,n_in),dtype=bool)
for i in range(N):
connection_list[i,np.random.choice(np.arange(n_in), n_syn, replace = False)]=1
f = np.zeros(N)
cA_p = -1*0.105*0.1
cA_d = 0.525*0.1*0.1
tau_p=20
tau_d=40
A_pre_p=np.zeros(N)
A_pre_d=np.zeros(N)
A_post_p=np.zeros(N)
A_post_d=np.zeros(N)
spike_mat=np.zeros((n_in,simtime_len),dtype=bool)
random_start=0
print("")
print("***********")
print("Learning... ")
print("***********")
for i in range(simtime_len):
if int(i / simtime_len * 100) % 5 == 0.0:
if int(i / simtime_len * 100) > int((i - 1) / simtime_len * 100):
print(" " + str(int(i / simtime_len * 100)) + "% ")
if i == random_start:
type='random'
random_width=np.random.randint(1*width,3*width)
pat_start=i+random_width
if i==pat_start:
random_start=i+width
dice=np.random.rand()
if dice<1/3:
type='pat1'
elif dice<2/3:
type='pat2'
else:
type='pat3'
if type=='random':
spike_mat = np.zeros(n_in,dtype=bool)
spike_mat[np.random.rand(n_in)<mean_rate*dt*10**(-3)]=1
elif type=='pat1':
spike_mat =pat1[:,i-pat_start]
elif type=='pat2':
spike_mat =pat2[:,i-pat_start]
elif type=='pat3':
spike_mat =pat3[:,i-pat_start]
I_syn = (1.0 - dt / tau_syn) * I_syn
I_syn[spike_mat]+=1/tau/tau_syn
PSP = (1.0 - dt / tau) * PSP + I_syn
PSP_unit=PSP*25
V_dend = np.dot(w.T,PSP_unit)
V_som_list = np.roll(V_som_list, -1,axis=1)
V_som = (1.0-dt*g_L)*V_som +g_d*(V_dend-V_som)+np.dot(-w_inh,f)
V_som_list[:,-1] = V_som
if i>width*window:
f = g((V_som-np.mean(V_som_list,axis=1)) / np.std(V_som_list,axis=1))
w += eps *np.outer((f-g(V_dend*g_d/(g_d+g_L))) , PSP_unit).T*beta*(1-g(V_dend*g_d/(g_d+g_L)))
w-=eps*w*5
A_pre_p = (1.0 - dt / tau_p) * A_pre_p
A_pre_d = (1.0 - dt / tau_d) * A_pre_d
A_post_p = (1.0 - dt / tau_p) * A_post_p
A_post_d = (1.0 - dt / tau_d) * A_post_d
for k in range(N):
if np.random.rand()<dt*f[k]*gain*(10**-3):
A_pre_p[k]+=cA_p
A_pre_d[k]+=cA_d
A_post_p[k]+=cA_p
A_post_d[k]+=cA_d
w_inh[k,:]+=(A_pre_p+A_pre_d)*w_inh_max
w_inh[:,k]+=(A_post_p+A_post_d)*w_inh_max
spike_time[k]=i
w_inh*=mask
w_inh[w_inh<0] = 0
w_inh[w_inh>w_inh_max] = w_inh_max
print("")
print("***********")
print("Testing... ")
print("***********")
test_len=60*width
plot_len=1500
synaptic_input_matrix=np.zeros((n_in*N,test_len))
spike_mat=np.zeros((n_in,test_len),dtype=bool)
PSP = np.zeros(n_in)
I_syn = np.zeros(n_in)
synaptic_input_matrix=np.zeros((n_in*N,test_len))
V_dend_list =np.zeros((N,test_len))
V_dend = np.zeros(N)
V_som = np.zeros(N)
f_list = np.zeros((N,test_len))
pat1_start=[]
pat2_start=[]
pat3_start=[]
random_start=0
random_mat=np.zeros((n_in,test_len))
pat1_mat=np.zeros((n_in,test_len))
pat2_mat=np.zeros((n_in,test_len))
pat3_mat=np.zeros((n_in,test_len))
for i in range(test_len):
if i == random_start:
random_width=np.random.randint(1*width,3*width)
pat_start=i+random_width
spike_mat[:,i:min(i+random_width,test_len)] = np.zeros((n_in,min(test_len-i,random_width)),dtype=bool)
for j in range(n_in):
for k in range(min(random_width,test_len-i)):
if np.random.rand()<mean_rate*dt*10**(-3):
spike_mat[j,k+i]=1
random_mat[:,i:min(i+random_width,test_len)]=spike_mat[:,i:min(i+random_width,test_len)]
if i==pat_start:
random_start=i+width
p_pat1=1/3
p_pat2=2/3
if pat1_start==[]:
p_pat1=1
if pat2_start==[]:
p_pat1=0
p_pat2=1
if pat3_start==[]:
p_pat1=0
p_pat2=0
dice=np.random.rand()
if dice<p_pat1:
pat1_start.append(i)
spike_mat[:,i:min(i+width,test_len)]=pat1[:,0:min(test_len-i,width)]
elif dice<p_pat2:
pat2_start.append(i)
spike_mat[:,i:min(i+width,test_len)]=pat2[:,0:min(test_len-i,width)]
else:
pat3_start.append(i)
spike_mat[:,i:min(i+width,test_len)]=pat3[:,0:min(width,test_len-i)]
for i in range(test_len):
if i in pat1_start:
pat1_mat[:,i:min(i+width,test_len)]=spike_mat[:,i:min(i+width,test_len)]
if i in pat2_start:
pat2_mat[:,i:min(i+width,test_len)]=spike_mat[:,i:min(i+width,test_len)]
if i in pat3_start:
pat3_mat[:,i:min(i+width,test_len)]=spike_mat[:,i:min(i+width,test_len)]
for i in range(test_len):
I_syn = (1.0 - dt / tau_syn) * I_syn
I_syn[spike_mat[:,i]]+=1/tau/tau_syn
PSP = (1.0 - dt / tau) * PSP + I_syn
PSP_unit=PSP*25
for l in range(N):
synaptic_input_matrix[l*n_in:(l+1)*n_in,i]=PSP_unit*w[:,l]
V_dend = np.dot(w.T,PSP_unit)
V_som = (1.0-dt*g_L)*V_som +g_d*(V_dend-V_som)+np.dot(-w_inh,f)
for k in range(N):
f[k] = g(V_som[k])
f_list[:,i]=f
nspk_random,tspk_random = pl.nonzero(random_mat[0:200,:]==1)
nspk1,tspk1 = pl.nonzero(pat1_mat[0:200,:]==1)
nspk2,tspk2 = pl.nonzero(pat2_mat[0:200,:]==1)
nspk3,tspk3 = pl.nonzero(pat3_mat[0:200,:]==1)
fig = plt.figure(figsize=(7, 2))
ax = fig.add_subplot(111)
for i in pat1_start:
plt.vlines([i], 0, n_in, "dodgerblue", linestyles='dashed',lw=1)
if i+width<plot_len:
plt.vlines([i+width], 0, n_in, "dodgerblue", linestyles='dashed',lw=1)
plt.hlines([200+15], i, min(i+width,plot_len), "dodgerblue", linestyles='solid',lw=3)
for i in pat2_start:
plt.vlines([i], 0, n_in, "orangered", linestyles='dashed',lw=1)
if i+width<plot_len:
plt.vlines([i+width], 0, n_in, "orangered", linestyles='dashed',lw=1)
plt.hlines([200+15], i, min(i+width,plot_len), "orangered", linestyles='solid',lw=3)
for i in pat3_start:
plt.vlines([i], 0, n_in, "limegreen", linestyles='dashed',lw=1)
if i+width<plot_len:
plt.vlines([i+width], 0, n_in, "limegreen", linestyles='dashed',lw=1)
plt.hlines([200+15], i, min(i+width,plot_len), "limegreen", linestyles='solid',lw=3)
pl.plot(tspk_random,nspk_random,'k.',markersize=2)
pl.plot(tspk1,nspk1,'b.',markersize=2)
pl.plot(tspk2,nspk2,'r.',markersize=2)
pl.plot(tspk3,nspk3,'g.',markersize=2)
plt.ylabel("Neuron id", fontsize=11)
fig.subplots_adjust(bottom=0.25, left=0.1)
ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('left')
ax.spines['right'].set_color('none')
ax.spines['top'].set_color('none')
pl.xlim([0,plot_len])
pl.ylim([0,200+20])
plt.savefig('raster.pdf', fmt='pdf', dpi=350)
sample_len = test_len
max1 = np.zeros(N)
min1 = np.zeros(N)
for i in range(N):
max1[i] = np.max(f_list[i,0:sample_len])
min1[i] = np.min(f_list[i,0:sample_len])
avg_norm1 = np.zeros((N,sample_len))
for i in range(N):
avg_norm1[i,:] = (f_list[i,0:sample_len]-min1[i])/(max1[i]-min1[i])
t = np.zeros(N)
for j in range(N):
arg = np.angle(np.dot(avg_norm1[j,:],np.exp(np.arange(sample_len)/(sample_len)*2*np.pi*1j))/sum(avg_norm1[j,:]))
if arg<0:
arg += 2*np.pi
t[j] = sample_len/(2*np.pi)*arg
index = np.zeros(N)
index = np.argsort(t)
avg_sorted = np.zeros((N,sample_len))
for i in range(N):
avg_sorted[i,:] = avg_norm1[int(index[i]),:]
fig, ax = plt.subplots(figsize=(4,3))
cax=plt.imshow(avg_sorted, interpolation='nearest', aspect="auto",cmap='jet')
cbar = fig.colorbar(cax, ticks=[0, 1], orientation='vertical')
cbar.ax.set_yticklabels(['min', 'max'],fontsize=10)
plt.xlabel("Time [ms]",fontsize=10)
plt.ylabel("Neurons (sorted)",fontsize=10)
plt.yticks([0,N-1],["1","%d"%N],fontsize=10)
ax.tick_params(length=1.3, width=0.05, labelsize=10)
ax.xaxis.set_ticks_position('none')
ax.yaxis.set_ticks_position('none')
plt.ylim([-0.5,N-0.5])
pl.xlim([0,plot_len])
fig.subplots_adjust(left=0.15,bottom=0.25,right=1)
plt.savefig('activity_map.pdf', fmt='pdf',dpi=350)
weight_sorted_row =np.zeros((N,N))
weight_sorted_column =np.zeros((N,N))
for i in range(N):
weight_sorted_row[i,:] = w_inh[int(index[i]),:]
for i in range(N):
weight_sorted_column[:,i] =weight_sorted_row[:,int(index[i])]
fig, ax = plt.subplots(figsize=(4,3))
cax=plt.imshow(weight_sorted_column, interpolation='nearest', aspect="auto",cmap='jet')
cbar = fig.colorbar(cax, orientation='vertical')
plt.xlabel("Presynaptic neuron",fontsize=10)
plt.ylabel("Postsynaptic neuron",fontsize=10)
ax.tick_params(length=1.3, width=0.05, labelsize=11)
ax.xaxis.set_ticks_position('none')
ax.yaxis.set_ticks_position('none')
plt.xticks([0,1,2,3,4,5,6,7,8,9],['1','2','3','4','5','6','7','8','9','10'],fontsize=11)
plt.yticks([0,1,2,3,4,5,6,7,8,9],['1','2','3','4','5','6','7','8','9','10'],fontsize=11)
fig.subplots_adjust(left=0.15,bottom=0.25,right=0.8)
plt.savefig('Winh_map.pdf', fmt='pdf',dpi=350)
target1=np.zeros(test_len)
target2=np.zeros(test_len)
target3=np.zeros(test_len)
for i in (pat1_start):
target1[i:min(test_len,i+width)]=1
for i in (pat2_start):
target2[i:min(test_len,i+width)]=1
for i in (pat3_start):
target3[i:min(test_len,i+width)]=1
chunk_corr = np.zeros((N,3))
for i in range(N):
chunk_corr[i,0]=np.corrcoef(f_list[i,:],target1)[0,1]
chunk_corr[i,1]=np.corrcoef(f_list[i,:],target2)[0,1]
chunk_corr[i,2]=np.corrcoef(f_list[i,:],target3)[0,1]
performance = 0
for i in range(N):
performance+=max(chunk_corr[i,:])
performance/=N
print(performance)
fig = plt.figure(figsize=(7, 3))
ax = fig.add_subplot(111)
ax.spines['top'].set_color('none')
ax.spines['bottom'].set_color('none')
ax.spines['left'].set_color('none')
ax.spines['right'].set_color('none')
ax.tick_params(labelcolor='w', top='off', bottom='off', left='off', right='off')
ax1 = fig.add_subplot(311)
pl.plot(f_list[np.argmax(chunk_corr[:,0]),:],lw=1.5,c='k')
for i in (pat1_start):
pl.axvspan(i, i + width, facecolor='dodgerblue', alpha=0.3,linewidth=0)
pl.xlim([0,plot_len])
ax1.xaxis.set_ticks_position('bottom')
ax1.yaxis.set_ticks_position('left')
ax1.spines['right'].set_color('none')
ax1.spines['top'].set_color('none')
ax1.xaxis.set_major_locator(pl.NullLocator())
ax2 = fig.add_subplot(312)
pl.plot(f_list[np.argmax(chunk_corr[:,1]),:],lw=1.5,c='k')
for i in (pat2_start):
pl.axvspan(i, i + width, facecolor='orangered', alpha=0.3,linewidth=0)
pl.xlim([0,plot_len])
ax2.xaxis.set_ticks_position('bottom')
ax2.yaxis.set_ticks_position('left')
ax2.spines['right'].set_color('none')
ax2.spines['top'].set_color('none')
ax2.xaxis.set_major_locator(pl.NullLocator())
ax3 = fig.add_subplot(313)
pl.plot(f_list[np.argmax(chunk_corr[:,2]),:],lw=1.5,c='k')
for i in (pat3_start):
pl.axvspan(i, i + width, facecolor='limegreen', alpha=0.3,linewidth=0)
pl.xlim([0,plot_len])
ax3.xaxis.set_ticks_position('bottom')
ax3.yaxis.set_ticks_position('left')
ax3.spines['right'].set_color('none')
ax3.spines['top'].set_color('none')
plt.xlabel("Time [ms]", fontsize=11)
fig.subplots_adjust(bottom=0.15, left=0.1)
plt.savefig('activities.pdf', fmt='pdf', dpi=350)