#
# simulation_fig_model.py
#
# Main simulation run: model description figures
#
# Copyright (C) 2012 Lukas Solanka <l.solanka@sms.ed.ac.uk>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
from matplotlib.backends.backend_pdf import PdfPages
from optparse import OptionParser
from brian import *
from parameters import getOptParser, setOptionDictionary
from grid_cell_network_brian import BrianGridCellNetwork
from custombrian import ExtendedSpikeMonitor
from tools import butterHighPass, spikePhaseTrialRaster, \
phaseCWT
from plotting import phaseFigTemplate, rasterPhasePlot, \
raster_bin_size
import time
import numpy as np
import logging as lg
lg.basicConfig(level=lg.DEBUG)
parser = getOptParser()
parser.add_option("--theta_start_mon_t", type="float", help="theta start monitoring time")
(options, args) = parser.parse_args()
################################################################################
# Network setup
################################################################################
print "Starting network and connections initialization..."
start_time=time.time()
total_start_t = time.time()
options.ndim = 'twisted_torus'
ei_net = BrianGridCellNetwork(options, simulationOpts=None)
ei_net.uniformInhibition()
ei_net.uniformExcitation()
ei_net.setConstantCurrent()
ei_net.setStartCurrent()
ei_net.setThetaCurrentStimulation()
duration=time.time()-start_time
print "Network setup time:",duration,"seconds"
# End Network setup
################################################################################
simulationClock = ei_net._getSimulationClock()
nrecSpike_e = ei_net.Ne_x*ei_net.Ne_y
nrecSpike_i = ei_net.Ni_x*ei_net.Ni_y
state_record_e = [ei_net.Ne_x/2 - 1, ei_net.Ne_y/2*ei_net.Ne_x + ei_net.Ne_x/2 - 1]
state_record_i = [ei_net.Ni_x/2 - 1, ei_net.Ni_y/2*ei_net.Ni_x + ei_net.Ni_x/2 - 1]
spikeMon_e = ExtendedSpikeMonitor(ei_net.E_pop[0:nrecSpike_e])
spikeMon_i = ExtendedSpikeMonitor(ei_net.I_pop[0:nrecSpike_i])
stateMon_e = RecentStateMonitor(ei_net.E_pop, 'vm', duration=options.stateMonDur*ms, record = state_record_e, clock=simulationClock)
stateMon_i = RecentStateMonitor(ei_net.I_pop, 'vm', duration=options.stateMonDur*ms, record = state_record_i, clock=simulationClock)
stateMon_ge_e = RecentStateMonitor(ei_net.E_pop, 'ge', duration=options.stateMonDur*ms, record = state_record_e, clock=simulationClock)
stateMon_Iclamp_e = RecentStateMonitor(ei_net.E_pop, 'Iclamp', duration=options.stateMonDur*ms, record = state_record_e, clock=simulationClock)
stateMon_Iclamp_i = RecentStateMonitor(ei_net.I_pop, 'Iclamp', duration=options.stateMonDur*ms, record = state_record_i, clock=simulationClock)
stateMon_Iext_e = RecentStateMonitor(ei_net.E_pop, 'Iext', duration=options.stateMonDur*ms, record = state_record_e, clock=simulationClock)
stateMon_Iext_i = RecentStateMonitor(ei_net.I_pop, 'Iext', duration=options.stateMonDur*ms, record = state_record_i, clock=simulationClock)
theta_n_it_range = 2
theta_state_record_e = range(state_record_e[1] - theta_n_it_range/2,
state_record_e[1] + theta_n_it_range/2 + 1)
theta_state_record_i = range(state_record_i[1] - theta_n_it_range/2,
state_record_i[1] + theta_n_it_range/2 + 1)
theta_spikeMon_e = ExtendedSpikeMonitor(ei_net.E_pop)
theta_spikeMon_i = ExtendedSpikeMonitor(ei_net.I_pop)
theta_stateMon_e = StateMonitor(ei_net.E_pop, 'vm', record = theta_state_record_e, clock=simulationClock)
theta_stateMon_Iclamp_e = StateMonitor(ei_net.E_pop, 'Iclamp', record = theta_state_record_e, clock=simulationClock)
theta_stateMon_i = StateMonitor(ei_net.I_pop, 'vm', record = theta_state_record_i, clock=simulationClock)
theta_stateMon_Iclamp_i = StateMonitor(ei_net.I_pop, 'Iclamp', record = theta_state_record_i, clock=simulationClock)
theta_stateMon_Iext_e = StateMonitor(ei_net.E_pop, 'Iext', record = theta_state_record_e, clock=simulationClock)
ei_net.net.add(spikeMon_e, spikeMon_i)
ei_net.net.add(stateMon_e, stateMon_i, stateMon_Iclamp_e, stateMon_Iclamp_i)
ei_net.net.add(stateMon_ge_e)
ei_net.net.add(stateMon_Iext_e, stateMon_Iext_i)
ei_net.net.add(theta_spikeMon_e, theta_spikeMon_i, theta_stateMon_Iclamp_e,
theta_stateMon_Iclamp_i, theta_stateMon_Iext_e, theta_stateMon_e, theta_stateMon_i)
x_lim = [options.time/1e3 - 1, options.time/1e3]
################################################################################
# Main cycle
################################################################################
print "Simulation running..."
start_time=time.time()
print " Network initialisation..."
ei_net.net.run(options.theta_start_mon_t*msecond, report='stdout')
theta_spikeMon_e.reinit()
theta_spikeMon_i.reinit()
theta_stateMon_e.reinit()
theta_stateMon_i.reinit()
theta_stateMon_Iclamp_e.reinit()
theta_stateMon_Iclamp_i.reinit()
theta_stateMon_Iext_e.reinit()
print " Theta stimulation..."
ei_net.net.run((options.time - options.theta_start_mon_t)*msecond, report='stdout')
duration=time.time()-start_time
print "Simulation time:",duration,"seconds"
output_fname = "{0}/{1}job{2:04}".format(options.output_dir,
options.fileNamePrefix, options.job_num)
F_tstart = 0
F_tend = options.time*1e-3
F_dt = 0.05
F_winLen = 0.25
Fe, Fe_t = spikeMon_e.getFiringRate(F_tstart, F_tend, F_dt, F_winLen)
Fi, Fi_t = spikeMon_i.getFiringRate(F_tstart, F_tend, F_dt, F_winLen)
# Plot membrane potentials of a cell in the middle of the sheet and at the
# bottom-center
figure()
ax = subplot(211)
plot(stateMon_e.times, stateMon_e.values[:, 0:2]/mV)
ylabel('E cell $V_m$ (mV)')
subplot(212, sharex=ax)
plot(stateMon_i.times, stateMon_i.values[:, 0:2]/mV)
xlabel('Time (s)')
ylabel('I cell $V_m$ (mV)')
xlim(x_lim)
tight_layout()
savefig(output_fname + '_Vm.pdf')
# Plot post-synaptic currents of the same cells
figure()
ax = subplot(211)
plot(stateMon_Iclamp_e.times, stateMon_Iclamp_e.values[:, 0:2]/pA)
ylabel('E cell $I_{syn}$ (pA)')
#ylim([0, 3000])
subplot(212, sharex=ax)
plot(stateMon_Iclamp_i.times, stateMon_Iclamp_i.values[:, 0:2]/pA)
xlabel('Time (s)')
ylabel('I cell $I_{syn}$ (pA)')
xlim(x_lim)
tight_layout()
savefig(output_fname + '_Isyn.pdf')
# Plot external current input
figure()
ax = subplot(211)
plot(stateMon_Iext_e.times, -stateMon_Iext_e.values[:, 1]/pA)
ylabel('E cell $I_{ext}$ (pA)')
subplot(212, sharex=ax)
plot(stateMon_Iext_i.times, -stateMon_Iext_i.values[:, 0]/pA)
xlabel('Time (s)')
ylabel('I cell $I_{ext}$ (pA)')
xlim(x_lim)
tight_layout()
savefig(output_fname + '_Iext.png')
# Plot a snapshot of the population firing rate on the twisted torus, E
# cells
figure()
pcolormesh(np.reshape(Fe[:, len(Fe_t)/2], (ei_net.Ne_y, ei_net.Ne_x)))
xlabel('E neuron #')
ylabel('E neuron #')
colorbar()
axis('equal')
tight_layout()
savefig(output_fname + '_firing_snapshot_e.png')
# Plot a snapshot of the population firing rate on the twisted torus, I
# cells
figure()
pcolormesh(np.reshape(Fi[:, len(Fi_t)/2], (ei_net.Ni_y, ei_net.Ni_x)))
xlabel('I neuron #')
ylabel('I neuron #')
colorbar()
axis('equal')
tight_layout()
savefig(output_fname + '_firing_snapshot_i.png')
######################################################################
# Wavelet and raster analysiS
######################################################################
print "Wavelet analysis..."
wavelet_sig_pp = PdfPages(output_fname + '_phase_sig_e.pdf')
high_pass_freq = 40.
maxF = 200
for ei_it in [0]:
if ei_it == 0:
print ' E neurons...'
wavelet_sig_pp = PdfPages(output_fname + '_phase_sig_e.pdf')
wavelet_sig_fname = output_fname + '_phase_wavelet_e'
sig_epochs_fname = output_fname + '_sig_epochs_e'
tmp_stateMon = theta_stateMon_Iclamp_e
range_n_it = theta_state_record_e
else:
print ' I neurons...'
wavelet_sig_pp = PdfPages(output_fname + '_phase_sig_i.pdf')
wavelet_sig_fname = output_fname + '_phase_wavelet_i'
sig_epochs_fname = output_fname + '_sig_epochs_i'
tmp_stateMon = theta_stateMon_Iclamp_i
range_n_it = theta_state_record_i
for n_it in range(len(range_n_it)):
neuron_no = range_n_it[n_it]
print(' Neuron no. ' + str(neuron_no))
cwt_phases, sig_cwt, freq, sig_ph = \
phaseCWT(butterHighPass(tmp_stateMon[neuron_no].T/pA,
options.sim_dt*msecond, high_pass_freq), 1./options.theta_freq, options.sim_dt*1e-3, maxF)
# Wavelet plot
f = phaseFigTemplate()
PH, F = np.meshgrid(cwt_phases, freq)
pcolormesh(PH, F, sig_cwt, edgecolors='None', cmap=get_cmap('jet'))
ylabel('F (Hz)')
ylim([0, maxF])
savefig(wavelet_sig_fname + '{0}.png'.format(n_it),
dpi=300)
close()
# pcolor of signals over theta epochs
figure(figsize=(12, 6))
PH, T = np.meshgrid(cwt_phases, np.arange(1, len(sig_ph)+1))
pcolormesh(PH, T, sig_ph, edgecolor='None')
xlabel('Theta phase')
ylabel('Theta epoch')
xlim([-np.pi, np.pi])
ylim([1, len(sig_ph)])
xticks([-np.pi, 0, np.pi], ('$-\pi$', '', '$\pi$'), fontsize=25)
yticks([1, len(sig_ph)])
savefig(sig_epochs_fname + '{0}.png'.format(n_it),
dpi=300)
# Average signal plot
f = phaseFigTemplate()
mn = np.mean(sig_ph, 0)
st = np.std(sig_ph, 0)
gca().fill_between(cwt_phases, mn+st, mn-st, facecolor='black', alpha=0.1, zorder=0)
plot(cwt_phases, mn, 'k')
ylabel('I (pA)')
wavelet_sig_pp.savefig()
close()
wavelet_sig_pp.close()
print "Done"
# Raster plots
for ei_it in [0]:
if ei_it == 0:
raster_pp = PdfPages(output_fname + '_phase_raster_e.pdf')
range_n_it = theta_state_record_e
tmp_spikeMon = theta_spikeMon_e
else:
raster_pp = PdfPages(output_fname + '_phase_raster_i.pdf')
range_n_it = theta_state_record_i
tmp_spikeMon = theta_spikeMon_i
for n_it in range(len(range_n_it)):
neuron_no = range_n_it[n_it]
print('Saving rasters for neuron no. ' + str(neuron_no))
phases, times, trials = spikePhaseTrialRaster(tmp_spikeMon[neuron_no],
options.theta_freq, options.theta_start_mon_t*msecond)
phases -= np.pi
# Raster plots (single cell over 'theta' epochs)
ntrials = np.ceil((options.time - options.theta_start_mon_t) * msecond * options.theta_freq)
f = rasterPhasePlot(phases, trials, ntrials)
raster_pp.savefig()
close()
raster_pp.close()
# End main cycle
################################################################################
total_time = time.time()-total_start_t
print "Overall time: ", total_time, " seconds"