# coding: utf-8
# In[1]:
get_ipython().magic('matplotlib inline')
from fns import *
from fns.functionsTF import *
# In[2]:
config = load_config()
res = []
T = 5000
# run two simulations, first without shared gap junctions, then with 40 shared gap junctions
for sG in [0,40]:
gpu = TfConnEvolveNet(config=config, T=T)
# number of cross-network gap junctions
gpu.sG = sG
NE = 800
NI = 200
## input network
# number of excitatory neurons
gpu.NE1=NE
# number of inhibitory neurons
gpu.NI1=NI
## gamma network
# number of excitatory neurons
gpu.NE2=NE
# number of inhibitory neurons
gpu.NI2=NI
### input to the input-network
seed = 0
# input amplitude
A = 400
x = generateInput2(seed, T, tau=100)[np.newaxis, :]
w0 = np.random.rand(1, 2*(NE+NI)) * 2
w0 = w0 * A
w0 = w0 * np.concatenate([np.ones((1, NE+NI)), np.zeros((1, NE+NI))], axis=1)
INP = w0.T @ x + 200
### small constant drive to the output-network
# k = np.ones((1, T)) * 200
# w1 = np.concatenate([np.zeros((1, NE+NI)), np.ones((1, NE+NI))], axis=1)
# INP2 = w1.T @ k
# feed input to network
gpu.input = INP
# choose hardware
gpu.device = '/gpu:0'
# mean initial gap junction coupling
gpu.g1 = 5.5
gpu.g2 = 5.5
# do not save the spikes
gpu.spikeMonitor = False
# do not save the individual voltages, currents, etc.
gpu.monitor_single = False
# iteration
gpu.stabTime = np.inf
# rule: g0 = 0 for no bound rule, g0 = 10 for softbound rule
gpu.g0 = 10
gpu.runTFSimul()
res.append(gpu)
del gpu
gc.collect()
# ## Reconstruction input
# In[3]:
s = 0
e = T
for i in range(2):
f(5,4)
# input signal to input-network
inp = res[i].input[0,s:e]/np.max(res[i].input[0,s:e])
# population activity of inhibitory neurons of the output-network
ifr = res[i].vvmI2[s:e]
spikes, xdec, ydec, corr_predict = decode(inp, ifr )
# rescale signals for plotting
inp -= np.min(inp)
inp /= np.max(inp)
ydec -= np.min(ydec)
ydec /= np.max(ydec)
# plot input
plt.plot(xdec, inp, color='r', label='input')
# plot decoded input
plt.plot(xdec,ydec, color=snCol, label='decoded input')
plt.xticks([])
plt.yticks([])
plt.legend(fontsize=18, loc='best', handlelength=1)
plt.xlabel('Time [%d ms]'%int((e-s)*res[i].dt))
plt.title('Input vs Decoded Input')
if i==0:
plt.suptitle(r'\textbf{No cross-network GJs}', y=1.05, fontsize=22)
else:
plt.suptitle(r'\textbf{40 cross-network GJs}', y=1.05, fontsize=22)
print(np.corrcoef(ydec,res[i].input[0,s:e])[0,1])
# In[ ]: