#! /usr/bin/env python
#
# test_connectapi.py
#
# This file is part of NEST.
#
# Copyright (C) 2004 The NEST Initiative
#
# NEST is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# NEST is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
"""
UnitTests for the PyNEST connect API.
"""
import unittest
import nest
import sys
class Bench(object):
def __init__(self,g,l):
self.g = g
self.l = l
def __call__(self,cmd,repeat=5):
from time import time
t = 0
for i in range(repeat):
t1 = time()
exec cmd in self.g, self.l
t2 = time()
t += t2-t1
print 'Executed "%s". Elapsed = %f s' % (cmd,t/repeat)
class ConnectAPITestCase(unittest.TestCase):
"""Tests of the Connect API"""
def test_ConnectPrePost(self):
"""Connect pre to post"""
# Connect([pre], [post])
nest.ResetKernel()
pre = nest.Create("iaf_neuron", 2)
post = nest.Create("iaf_neuron", 2)
nest.Connect(pre, post)
connections = nest.FindConnections(pre)
targets = nest.GetStatus(connections, "target")
self.assertEqual(targets, post)
def test_ConnectPrePostParams(self):
"""Connect pre to post with a params dict"""
# Connect([pre], [post], params)
nest.ResetKernel()
pre = nest.Create("iaf_neuron", 2)
post = nest.Create("iaf_neuron", 2)
nest.Connect(pre, post, {"weight" : 2.0})
connections = nest.FindConnections(pre)
weights = nest.GetStatus(connections, "weight")
self.assertEqual(weights, [2.0, 2.0])
# Connect([pre], [post], [params])
nest.ResetKernel()
pre = nest.Create("iaf_neuron", 2)
post = nest.Create("iaf_neuron", 2)
nest.Connect(pre, post, [{"weight" : 2.0}])
connections = nest.FindConnections(pre)
weights = nest.GetStatus(connections, "weight")
self.assertEqual(weights, [2.0, 2.0])
# Connect([pre], [post], [params, params])
nest.ResetKernel()
pre = nest.Create("iaf_neuron", 2)
post = nest.Create("iaf_neuron", 2)
nest.Connect(pre, post, [{"weight" : 2.0}, {"weight" : 3.0}])
connections = nest.FindConnections(pre)
weights = nest.GetStatus(connections, "weight")
self.assertEqual(weights, [2.0, 3.0])
def test_ConnectPrePostWD(self):
"""Connect pre to post with a weight and delay"""
# Connect([pre], [post], w, d)
nest.ResetKernel()
pre = nest.Create("iaf_neuron", 2)
post = nest.Create("iaf_neuron", 2)
nest.Connect(pre, post, 2.0, 2.0)
connections = nest.FindConnections(pre)
weights = nest.GetStatus(connections, "weight")
self.assertEqual(weights, [2.0, 2.0])
# Connect([pre], [post], [w], [d])
nest.ResetKernel()
pre = nest.Create("iaf_neuron", 2)
post = nest.Create("iaf_neuron", 2)
nest.Connect(pre, post, [2.0], [2.0])
connections = nest.FindConnections(pre)
weights = nest.GetStatus(connections, "weight")
delays = nest.GetStatus(connections, "delay")
self.assertEqual(weights, [2.0, 2.0])
self.assertEqual(delays, [2.0, 2.0])
# Connect([pre], [post], [w, w], [d, d])
nest.ResetKernel()
pre = nest.Create("iaf_neuron", 2)
post = nest.Create("iaf_neuron", 2)
nest.Connect(pre, post, [2.0, 3.0], [2.0, 3.0])
connections = nest.FindConnections(pre)
weights = nest.GetStatus(connections, "weight")
delays = nest.GetStatus(connections, "delay")
self.assertEqual(weights, [2.0, 3.0])
self.assertEqual(delays, [2.0, 3.0])
def test_ConvergentConnect(self):
"""ConvergentConnect pre to post"""
nest.ResetKernel()
pre = nest.Create("iaf_neuron", 3)
post = nest.Create("iaf_neuron", 1)
nest.ConvergentConnect(pre, post)
expected_targets = [post[0] for x in range(len(pre))]
connections = nest.FindConnections(pre)
targets = nest.GetStatus(connections, "target")
self.assertEqual(expected_targets, targets)
def test_ConvergentConnectWD(self):
"""ConvergentConnect pre to post with weight and delay"""
nest.ResetKernel()
pre = nest.Create("iaf_neuron", 3)
post = nest.Create("iaf_neuron", 1)
nest.ConvergentConnect(pre, post, weight=[2.0,2.0,2.0], delay=[1.0,2.0,3.0])
connections = nest.FindConnections(pre)
weights = nest.GetStatus(connections, "weight")
delays = nest.GetStatus(connections, "delay")
self.assertEqual(weights, [2.0,2.0,2.0])
self.assertEqual(delays , [1.0,2.0,3.0])
def test_DivergentConnect(self):
"""DivergentConnect pre to post"""
nest.ResetKernel()
pre = nest.Create("iaf_neuron", 1)
post = nest.Create("iaf_neuron", 3)
nest.DivergentConnect(pre, post)
connections = nest.FindConnections(pre)
targets = nest.GetStatus(connections, "target")
self.assertEqual(targets, post)
def test_DivergentConnectWD(self):
"""DivergentConnect pre to post with weight and delay"""
nest.ResetKernel()
pre = nest.Create("iaf_neuron", 1)
post = nest.Create("iaf_neuron", 3)
nest.DivergentConnect(pre, post, weight=[2.0,2.0,2.0], delay=[1.0,2.0,3.0])
connections = nest.FindConnections(pre)
weights = nest.GetStatus(connections, "weight")
delays = nest.GetStatus(connections, "delay")
self.assertEqual(weights, [2.0,2.0,2.0])
self.assertEqual(delays , [1.0,2.0,3.0])
def test_WrongConnection(self):
"""Wrong Connections"""
nest.ResetKernel()
n = nest.Create('iaf_neuron')
vm = nest.Create('voltmeter')
sd = nest.Create('spike_detector')
try:
nest.Connect(n,vm)
self.fail() # should not be reached
except nest.NESTError:
info = sys.exc_info()[1]
if not "IllegalConnection" in info.__str__():
self.fail()
# another error has been thrown, this is wrong
except:
self.fail()
def test_UnexcpectedEvent(self):
"""Unexpected Event"""
nest.ResetKernel()
n = nest.Create('iaf_neuron')
vm = nest.Create('voltmeter')
sd = nest.Create('spike_detector')
try:
nest.Connect(sd,n)
self.fail() # should not be reached
except nest.NESTError:
info = sys.exc_info()[1]
if not "UnexpectedEvent" in info.__str__():
self.fail()
# another error has been thrown, this is wrong
except:
self.fail()
def suite():
suite = unittest.makeSuite(ConnectAPITestCase,'test')
return suite
if __name__ == "__main__":
runner = unittest.TextTestRunner(verbosity=2)
runner.run(suite())