"""
Copyright 2019 Toshitake Asabuki
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from __future__ import division
import numpy as np
import pylab as pl
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy.matlib
import os
import shutil
import sys
import matplotlib.cm as cm
mpl.rcParams['svg.fonttype'] = 'none'
mpl.rcParams['font.sans-serif'] = 'Arial'
mpl.rcParams['pdf.fonttype'] = 42
params = {'backend': 'ps',
'axes.labelsize': 11,
'text.fontsize': 11,
'legend.fontsize': 11,
'xtick.labelsize': 11,
'ytick.labelsize': 11,
'text.usetex': False,
'figure.figsize': [10 / 2.54, 6 / 2.54]}
"""
In this demo code, the true signals are periodic, but you can train with much complex signals.
You can introduce inhibitory STDP for lateral inhibition to train with unknown number of sources.
"""
"""
activation function
"""
beta = 5
def g(x):
alpha = 1
theta= 0.5
ans = 1/(1+alpha*np.exp(beta*(-x+theta)))
return ans
"""
parameters
"""
N = 2
dt = 1
nsecs = 5000000
simtime = np.arange(0, nsecs, dt)
simtime_len = len(simtime)
eps = 10**(-5)*0.5
tau =15
tau_syn = 5
n_in = 500
PSP = np.zeros(n_in)
I_syn = np.zeros(n_in)
g_L = 1/tau
g_d = 0.7
w = np.random.randn(n_in,N)/np.sqrt(n_in)
t0 = 9000
gain = 10
noise=0.5
"""
mixing matrix
"""
q_cross = 0.5
mixing_mat = np.zeros((n_in,N))
Q = np.zeros((N,N))
denom=np.sqrt(1+q_cross**2)
Q[0,:] = [1/denom,q_cross/denom]
Q[1,:] = [q_cross/denom,1/denom]
for i in range(n_in):
mixing_mat[i,:] = Q[np.random.randint(N),:]
"""
fixed lateral inhibition.
"""
w_inh_max =0.4
w_inh =np.ones((N,N))*w_inh_max
w_inh[w_inh<0] = 0
w_inh[w_inh>w_inh_max] = w_inh_max
mask = np.zeros((N,N))
for i in range(N):
for j in range(N):
if i!=j:
mask[i,j]=1
w_inh *= mask
V_som_list=np.zeros((N,t0))
V_dend = np.zeros(N)
V_som = np.zeros(N)
f = np.zeros(N)
"""
calculating min and max of true signals.
"""
freq = 0.4
time = np.arange(10000)
source = np.zeros((2,len(time)))
source[0,:] = (-np.sin(2*np.pi*time/1000*1.2*freq+100)-np.sin(2*np.pi*time/500*1.2*freq+100)*0.3)*2
source[1,:] = np.sin(2*np.pi*time/1000*freq+2000)+np.sin(2*np.pi*time/500*freq+2000)*2
rate_in = np.dot(mixing_mat,source)
rate_min = np.amin(rate_in,axis=1)
rate_max = np.amax(rate_in,axis=1)
print("")
print("***********")
print("Learning... ")
print("***********")
source=np.zeros(2)
for i in range(simtime_len):
if int(i / simtime_len * 100) % 5 == 0.0:
if int(i / simtime_len * 100) > int((i - 1) / simtime_len * 100):
print(" " + str(int(i / simtime_len * 100)) + "% ")
source[0] = (-np.sin(2*np.pi*i/1000*1.2*freq+100)-np.sin(2*np.pi*i/500*1.2*freq+100)*0.3)*2
source[1] = np.sin(2.*np.pi*i/1000*freq+2000)+np.sin(2.*np.pi*i/500*freq+2000)*2
rate_in = (np.dot(mixing_mat,source)-rate_min)/(rate_max-rate_min)*gain
prate = dt*rate_in*(10**-3)
id = np.random.rand(n_in)<prate
I_syn = (1.0 - dt / tau_syn) * I_syn
I_syn[id]+=1/tau/tau_syn
PSP = (1.0 - dt / tau) * PSP + I_syn
PSP_unit=PSP*25
V_dend = np.dot(w.T,PSP_unit)
V_som_list = np.roll(V_som_list, -1,axis=1)
V_som = (1.0-dt*g_L)*V_som +g_d*(V_dend-V_som)+np.dot(-w_inh,f)
V_som_list[:,-1] = V_som
if i>t0:
f = np.clip(g((V_som-np.mean(V_som_list,axis=1)) / np.std(V_som_list,axis=1))+np.random.randn(N)*noise,0,1)
w += eps *np.outer((f-g(V_dend*g_d/(g_d+g_L))) , PSP_unit).T*beta*(1-g(V_dend*g_d/(g_d+g_L)))
w-=eps*w*0.5
print("")
print("***********")
print("Testing... ")
print("***********")
test_len = 10000
loop = 20
f_list = np.zeros((N,test_len*loop))
out1_lists = np.zeros((loop,test_len))
out2_lists = np.zeros((loop,test_len))
for j in range(loop):
PSP = np.zeros(n_in)
I_syn = np.zeros(n_in)
V_dend = np.zeros(N)
V_som = np.zeros(N)
id = np.zeros((test_len,n_in),dtype=bool)
for i in range(test_len):
source[0] = (-np.sin(2*np.pi*i/1000*1.2*freq+100)-np.sin(2*np.pi*i/500*1.2*freq+100)*0.3)*2
source[1] = np.sin(2.*np.pi*i/1000*freq+2000)+np.sin(2.*np.pi*i/500*freq+2000)*2
rate_in = rate_in = (np.dot(mixing_mat,source)-rate_min)/(rate_max-rate_min)*gain
prate = dt*rate_in*(10**-3)
id[i,:] = np.random.rand(n_in)<prate
I_syn = (1.0 - dt / tau_syn) * I_syn
I_syn[id[i,:]]+=1/tau/tau_syn
PSP = (1.0 - dt / tau) * PSP + I_syn
PSP_unit=PSP*25
V_dend = np.dot(w.T,PSP_unit)
V_som = (1.0-dt*g_L)*V_som +g_d*(V_dend-V_som)+np.dot(-w_inh,f)
f = g(V_som)
out1_lists[j,i]=f[0]
out2_lists[j,i]=f[1]
"""
plotting the mized, true, and output signals.
"""
true_sources=np.zeros((2,test_len))
true_sources[0,:]=(-np.sin(2*np.pi*np.arange(test_len)/1000*1.2*freq+100)-np.sin(2*np.pi*np.arange(test_len)/500*1.2*freq+100)*0.3)*2
true_sources[1,:]=np.sin(2.*np.pi*np.arange(test_len)/1000*freq+2000)+np.sin(2.*np.pi*np.arange(test_len)/500*freq+2000)*2
mixed_sources=np.dot(Q,true_sources)
fig = plt.figure(figsize=(8, 3))
ax = fig.add_subplot(111)
ax.spines['top'].set_color('none')
ax.spines['bottom'].set_color('none')
ax.spines['left'].set_color('none')
ax.spines['right'].set_color('none')
ax.tick_params(labelcolor='w', top='off', bottom='off', left='off', right='off')
ax2 = fig.add_subplot(211)
pl.plot(true_sources[0,:],lw=1.5,c='orangered')
ax2.xaxis.set_ticks_position('bottom')
ax2.yaxis.set_ticks_position('left')
ax2.spines['right'].set_color('none')
ax2.spines['top'].set_color('none')
ax2.xaxis.set_major_locator(pl.NullLocator())
ax3 = fig.add_subplot(212)
pl.plot(true_sources[1,:],lw=1.5,c='limegreen')
ax3.xaxis.set_ticks_position('bottom')
ax3.yaxis.set_ticks_position('left')
ax3.spines['right'].set_color('none')
ax3.spines['top'].set_color('none')
plt.xlabel("Time [ms]", fontsize=11)
ax.set_ylabel("Activity", fontsize=11)
fig.subplots_adjust(bottom=0.15, left=0.1, right=0.95,hspace=0.3)
plt.savefig('true.pdf', fmt='pdf', dpi=350)
fig = plt.figure(figsize=(8, 3))
ax = fig.add_subplot(111)
ax.spines['top'].set_color('none')
ax.spines['bottom'].set_color('none')
ax.spines['left'].set_color('none')
ax.spines['right'].set_color('none')
ax.tick_params(labelcolor='w', top='off', bottom='off', left='off', right='off')
ax2 = fig.add_subplot(211)
pl.plot(mixed_sources[0,:],lw=1.5,c='orangered')
ax2.xaxis.set_ticks_position('bottom')
ax2.yaxis.set_ticks_position('left')
ax2.spines['right'].set_color('none')
ax2.spines['top'].set_color('none')
ax2.xaxis.set_major_locator(pl.NullLocator())
ax3 = fig.add_subplot(212)
pl.plot(mixed_sources[1,:],lw=1.5,c='limegreen')
ax3.xaxis.set_ticks_position('bottom')
ax3.yaxis.set_ticks_position('left')
ax3.spines['right'].set_color('none')
ax3.spines['top'].set_color('none')
plt.xlabel("Time [ms]", fontsize=11)
ax.set_ylabel("Activity", fontsize=11)
fig.subplots_adjust(bottom=0.15, left=0.1, right=0.95,hspace=0.3)
plt.savefig('mixed.pdf', fmt='pdf', dpi=350)
fig = plt.figure(figsize=(8, 3))
ax = fig.add_subplot(111)
ax.spines['top'].set_color('none')
ax.spines['bottom'].set_color('none')
ax.spines['left'].set_color('none')
ax.spines['right'].set_color('none')
ax.tick_params(labelcolor='w', top='off', bottom='off', left='off', right='off')
ax2 = fig.add_subplot(211)
pl.plot(np.mean(out1_lists,axis=0),lw=1.5,c='orangered')
ax2.xaxis.set_ticks_position('bottom')
ax2.yaxis.set_ticks_position('left')
ax2.spines['right'].set_color('none')
ax2.spines['top'].set_color('none')
ax2.xaxis.set_major_locator(pl.NullLocator())
ax3 = fig.add_subplot(212)
pl.plot(np.mean(out2_lists,axis=0),lw=1.5,c='limegreen')
ax3.xaxis.set_ticks_position('bottom')
ax3.yaxis.set_ticks_position('left')
ax3.spines['right'].set_color('none')
ax3.spines['top'].set_color('none')
plt.xlabel("Time [ms]", fontsize=11)
ax.set_ylabel("Activity", fontsize=11)
fig.subplots_adjust(bottom=0.15, left=0.1, right=0.95,hspace=0.3)
plt.savefig('outputs.pdf', fmt='pdf', dpi=350)