#! /usr/bin/env python
#
# test_threads.py
#
# This file is part of NEST.
#
# Copyright (C) 2004 The NEST Initiative
#
# NEST is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# NEST is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
"""
UnitTests for multithreaded pynest
"""
import unittest
import nest
import sys
class ThreadTestCase(unittest.TestCase):
"""Multiple threads """
def nest_multithreaded(self):
"""Return True, if we have a thread-enabled NEST, False otherwise"""
nest.sr("statusdict/threading :: (no) eq not")
return nest.spp()
def test_Threads(self):
"""Multiple threads"""
if not self.nest_multithreaded(): return
nest.ResetKernel()
self.assertEqual(nest.GetKernelStatus()['local_num_threads'],1)
nest.SetKernelStatus({'local_num_threads':8})
n=nest.Create('iaf_neuron',8)
st = nest.GetStatus(n,'vp')
st.sort()
self.assertEqual(st,[0, 1, 2, 3, 4, 5, 6, 7])
def test_ThreadsFindConnections(self):
"""FindConnections with threads"""
if not self.nest_multithreaded(): return
nest.ResetKernel()
nest.SetKernelStatus({'local_num_threads':8})
pre = nest.Create("iaf_neuron")
post = nest.Create("iaf_neuron", 6)
nest.DivergentConnect(pre, post)
conn = nest.FindConnections(pre)
# Because of threading, targets may be in a different order than
# in post, so we sort the vector.
targets = nest.GetStatus(conn, "target")
targets.sort()
self.assertEqual(targets, post)
def test_ThreadsGetEvents(self):
""" Gathering events across threads """
if not self.nest_multithreaded(): return
threads = [1,2,4,8]
n_events_sd = []
n_events_vm = []
N = 128
Simtime = 1000.
for t in threads:
nest.ResetKernel()
nest.SetKernelStatus({'local_num_threads': t})
n = nest.Create('iaf_psc_alpha', N, {'I_e':2000.}) # force a lot of spike events
sd = nest.Create('spike_detector')
vm = nest.Create('voltmeter')
nest.ConvergentConnect(n,sd)
nest.DivergentConnect(vm,n)
nest.Simulate(Simtime)
n_events_sd.append(nest.GetStatus(sd, 'n_events')[0])
n_events_vm.append(nest.GetStatus(vm, 'n_events')[0])
ref_vm = N*(Simtime-1)
ref_sd = n_events_sd[0]
# could be done more elegantly with any(), ravel(),
# but we dont want to be dependent on numpy et al
[ self.assertEqual(x,ref_vm) for x in n_events_vm]
[ self.assertEqual(x,ref_sd) for x in n_events_sd]
def suite():
suite = unittest.makeSuite(ThreadTestCase,'test')
return suite
if __name__ == "__main__":
runner = unittest.TextTestRunner(verbosity=2)
runner.run(suite())