/**
* "FNS" (Firnet NeuroScience), ver.3.x
*
* FNS is an event-driven Spiking Neural Network framework, oriented
* to data-driven neural simulations.
*
* (c) 2020, Gianluca Susi, Emanuele Paracone, Mario Salerno,
* Alessandro Cristini, Fernando Maestú.
*
* CITATION:
* When using FNS for scientific publications, cite us as follows:
*
* Gianluca Susi, Pilar Garcés, Alessandro Cristini, Emanuele Paracone,
* Mario Salerno, Fernando Maestú, Ernesto Pereda (2020).
* "FNS: an event-driven spiking neural network simulator based on the
* LIFL neuron model".
* Laboratory of Cognitive and Computational Neuroscience, UPM-UCM
* Centre for Biomedical Technology, Technical University of Madrid;
* University of Rome "Tor Vergata".
* Paper under review.
*
* FNS is free software: you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 3 as
* published by the Free Software Foundation.
*
* FNS is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
* or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with FNS. If not, see <http://www.gnu.org/licenses/>.
*
* -----------------------------------------------------------
*
* Website: http://www.fnsneuralsimulator.org
*
* Contacts: fnsneuralsimulator (at) gmail.com
* gianluca.susi82 (at) gmail.com
* emanuele.paracone (at) gmail.com
*
*
* -----------------------------------------------------------
* -----------------------------------------------------------
**/
package spiking.node;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Random;
import spiking.controllers.node.NodeThread;
import spiking.internode.InterNodeSpike;
import spiking.simulator.SpikingNeuralSimulator;
import utils.exceptions.BadCurveException;
import utils.statistics.StatisticsCollector;
import utils.tools.IntegerCouple;
import org.apache.commons.math3.distribution.GammaDistribution;
public class NodesManager implements Serializable {
public static final Double MAX_TRACT_LENGTH=100000.0;
private final static long serialVersionUID = -4781040115396333609L;
private final static String TAG = "[Nodes Manager] ";
private final static Boolean verbose = true;
private final static Integer goodCurveGuessThreshold= 5;
private final static Integer badCurveGuessThreshold= 15;
//private StatisticsCollector sc;
private Boolean debug = false;
//ergion threads array list
private ArrayList<NodeThread> nodeThreads = new ArrayList<NodeThread>();
//internodes connection probability
private HashMap<IntegerCouple, NodesInterconnection> nodesConnections =
new HashMap<IntegerCouple, NodesInterconnection>();
//the total number of neuron
private Long n=0l;
//the maximum number of neurons within a single node
private Long maxN=0l;
private double compressionFactor=1.0;
//the total number of excitatory neuron
private Long excitatory=0l;
//the total number of inhibithory neuron
private Long inhibithory=0l;
//the total number of external inputs
private Integer externalInputs=0;
//the HasMap of small worlds with external inputs
private ArrayList<Integer> regsWithExternalInputs = new ArrayList<Integer>();
//the total number of inter-node connections
private Long inter_node_conns_num=0l;
//is set, no more node addition are allowed
private Boolean initialized=false;
private SpikingNeuralSimulator sim;
private Double minTractLength=MAX_TRACT_LENGTH;
private Random randGen = new Random(System.currentTimeMillis());
//setting default to 1 means every x
private Double gammaInverseCumulativeProbX = 1.0;
private Double bop_conservative_p = null;
public NodesManager(
SpikingNeuralSimulator sim,
//StatisticsCollector sc,
Double bop_conservative_p){
this.sim=sim;
//this.sc=sc;
this.bop_conservative_p = bop_conservative_p;
}
//public StatisticsCollector getStatisticsCollector() {
// return sc;
//}
public void addNodeThread(NodeThread regT){
if (initialized)
return;
nodeThreads.add(regT);
n+=regT.getN();
//updating the maximum number of neuron within a same node
if (regT.getN()>maxN)
maxN=regT.getN();
println("adding node:"+regT.getNodeId()+", n:"+regT.getN()+"\n");
excitatory+=regT.getExcitatory();
inhibithory+=regT.getInhibithory();
if (regT.hasExternalInput()){
externalInputs+=regT.getExternalInputs();
regsWithExternalInputs.add(regT.getNodeId());
}
}
public void addInterNodeConnection(
NodeThread node1,
NodeThread node2,
Double Ne_xn_ratio,
Double mu_omega,
Double sigma_omega,
Double mu_lambda,
Double alpha_lambda,
Integer inter_node_conn_type){
if (mu_lambda==null)
println("length null 1...");
NodesInterconnection n_conn = new NodesInterconnection(
node1,
node2,
Ne_xn_ratio,
mu_omega,
sigma_omega,
mu_lambda,
alpha_lambda,
inter_node_conn_type);
nodesConnections.put(
new IntegerCouple(node1.getNodeId(),
node2.getNodeId()),
n_conn);
_addInterNodeConnection(
node1.getNodeId(),
node2.getNodeId(),
Ne_xn_ratio,
mu_omega,
sigma_omega,
mu_lambda,
alpha_lambda,
inter_node_conn_type);
}
public void addInterNodeConnectionParameters(
Integer reg1Id,
Integer reg2Id,
Double weight,
Double amplitude,
Double amplitudeStdVariation,
Double length,
Double lengthShapeParameter,
Integer inter_node_conn_type){
if (length==null)
println("length null 2...");
NodesInterconnection regi = new NodesInterconnection(reg1Id, reg2Id, weight);
regi.setLength(length);
// Integer src = (reg1Id<reg2Id)? reg1Id: reg2Id;
// Integer dst = (reg1Id<reg2Id)? reg2Id: reg1Id;
nodesConnections.put(new IntegerCouple(reg1Id, reg2Id), regi);
_addInterNodeConnection(
reg1Id,
reg2Id,
weight,
amplitude,
amplitudeStdVariation,
length,
lengthShapeParameter,
inter_node_conn_type);
}
private void _addInterNodeConnection(
Integer node1id,
Integer node2id,
Double weight,
Double amplitude,
Double amplitudeStdVariation,
Double length,
Double lengthShapeParameter,
Integer inter_node_conn_type){
if (length==null)
println("length null...");
if (minTractLength==null)
println("min tract length null...");
if (length<minTractLength)
minTractLength=length;
try {
__addInterNodeConnection(
node1id,
node2id,
weight,
amplitude,
amplitudeStdVariation,
length,
lengthShapeParameter,
inter_node_conn_type);
} catch (BadCurveException e) {
e.printStackTrace();
}
}
/**
* adds an inter node connection using the weight as the number of connections
* between neurons of the two nodes
* @throws BadCurveException
*/
private void __addInterNodeConnection(
Integer reg1Id,
Integer reg2Id,
Double Ne_xn_ratio,
Double mu_omega,
Double sigma_omega,
Double mu_lambda,
Double alpha_lambda,
Integer inter_node_conn_type) throws BadCurveException{
/*
* The schema for internode connections
*
* \ | | | |
* \ to | mixed | exc | inh |
* from \ | | | |
* ------------------------------------
* mixed | 0 | 1 | 2 |
* ------------------------------------
* exc | 3 | 4 | 5 |
* ------------------------------------
* inh | 6 | 7 | 8 |
* ------------------------------------
*
*/
long Nsrc = 0;
// case EXC2*
if ((inter_node_conn_type==NodesInterconnection.EXC2MIXED)||
(inter_node_conn_type==NodesInterconnection.EXC2EXC)||
(inter_node_conn_type==NodesInterconnection.EXC2INH)) {
Nsrc = nodeThreads.get(reg1Id).getExcitatory();
}
// case INH2*
else if ((inter_node_conn_type==NodesInterconnection.INH2MIXED)||
(inter_node_conn_type==NodesInterconnection.INH2EXC)||
(inter_node_conn_type==NodesInterconnection.INH2INH)) {
Nsrc = nodeThreads.get(reg1Id).getInhibithory();
}
// case MIXED2*
else
Nsrc = nodeThreads.get(reg1Id).getN();
long tmp = (long)(Nsrc*Ne_xn_ratio);
Long i1, i2;
GammaDistribution gd = (alpha_lambda!=null)?
//new GammaDistribution(alpha_lambda, 1.0/mu_lambda)
new GammaDistribution(alpha_lambda, mu_lambda/alpha_lambda)
:null;
gammaInverseCumulativeProbX =
((alpha_lambda!=null)&&(bop_conservative_p != null))?
gd.inverseCumulativeProbability(1.0 - bop_conservative_p ):
1.0;
for (long i=0; i<tmp;++i){
// case EXC2*
if ((inter_node_conn_type==NodesInterconnection.EXC2MIXED)||
(inter_node_conn_type==NodesInterconnection.EXC2EXC)||
(inter_node_conn_type==NodesInterconnection.EXC2INH)) {
i1 = (long)(Math.random() * nodeThreads.get(reg1Id).getExcitatory());
}
// case INH2*
else if ((inter_node_conn_type==NodesInterconnection.INH2MIXED)||
(inter_node_conn_type==NodesInterconnection.INH2EXC)||
(inter_node_conn_type==NodesInterconnection.INH2INH)) {
i1 = nodeThreads.get(reg1Id).getExcitatory()+
((long)(Math.random() * nodeThreads.get(reg1Id).getInhibithory()));
}
// case MIXED2*
else
i1 = (long)(Math.random() * nodeThreads.get(reg1Id).getN());
// case *2EXC
if ((inter_node_conn_type==NodesInterconnection.MIXED2EXC)||
(inter_node_conn_type==NodesInterconnection.EXC2EXC)||
(inter_node_conn_type==NodesInterconnection.INH2EXC)) {
i2 = (long)(Math.random() * nodeThreads.get(reg2Id).getExcitatory());
}
// case *2INH
else if ((inter_node_conn_type==NodesInterconnection.MIXED2INH)||
(inter_node_conn_type==NodesInterconnection.EXC2INH)||
(inter_node_conn_type==NodesInterconnection.INH2INH)) {
i2 = nodeThreads.get(reg2Id).getExcitatory()+
((long)(Math.random() * nodeThreads.get(reg2Id).getInhibithory()));
}
// case *2MIXED
else
i2 = (long)(Math.random() * nodeThreads.get(reg2Id).getN());
Double tmp_mu_w, tmpLength=-1.0;
tmp_mu_w = (sigma_omega!=null)?
Math.abs(randGen.nextGaussian()*sigma_omega+mu_omega)
:mu_omega;
if (mu_omega<0 && tmp_mu_w >0)
tmp_mu_w=-tmp_mu_w;
int goodCurveGuess=0;
while ((tmpLength<0)&&(goodCurveGuess<badCurveGuessThreshold)){
tmpLength = (gd!=null)? gd.sample() : mu_lambda;
++goodCurveGuess;
}
if (goodCurveGuess>=goodCurveGuessThreshold){
sim.setBadCurve();
if (goodCurveGuess>=badCurveGuessThreshold)
throw new BadCurveException("the gamma curve G("+alpha_lambda+", "+
(alpha_lambda/mu_lambda)+" has a shape which is not compliant with firnet scope.");
}
nodeThreads.get(reg1Id).addInterNodeSynapse(
reg1Id,
i1,
reg2Id,
i2,
nodeThreads.get(reg1Id).getExcitatoryPresynapticWeight(),
tmp_mu_w,
tmpLength);
nodeThreads.get(reg2Id).addInterNodeSynapse(
reg1Id,
i1,
reg2Id,
i2,
nodeThreads.get(reg1Id).getExcitatoryPresynapticWeight(),
tmp_mu_w,
tmpLength);
++ inter_node_conns_num;
}
}
//public NodesInterconnection _getInterworldConnectionProb(Node reg1, Node reg2){
// Integer src= (reg1.getId()<reg2.getId())? reg1.getId(): reg2.getId();
// Integer dst= (reg1.getId()<reg2.getId())? reg2.getId(): reg1.getId();
// return nodesConnections.get(new IntegerCouple(src, dst));
//}
public Long getTotalN(){
return n;
}
public double getCompressionFactor() {
return compressionFactor;
}
public void setCompressionFactor(double compressionFactor) {
this.compressionFactor = compressionFactor;
}
public int getNodeThreadsNum(){
return nodeThreads.size();
}
public Boolean getInitialized() {
return initialized;
}
public Double getMinTractLength(){
//return minTractLength;
return (gammaInverseCumulativeProbX == 1.0)?
minTractLength:
gammaInverseCumulativeProbX;
}
public Double getGammaInverseCumulativeProbX(){
return gammaInverseCumulativeProbX;
}
public Integer getnSms() {
return nodeThreads.size();
}
public NodeThread getNodeThread(int index){
return nodeThreads.get(index);
}
/**
* @return the maximum number of neuron within a single node
*/
public long getMaxN(){
return maxN;
}
/**
* @ return the total number of inter-node connections
*/
public Long getTotalInterNodeConnectionsNumber(){
return inter_node_conns_num;
}
public void setDebug(Boolean debug){
this.debug=debug;
}
private void println(String s){
if (verbose)
System.out.println(TAG+s);
}
private void debprintln(String s){
if (verbose&&debug)
System.out.println(TAG+"[debug] "+s);
}
public void printNodeFields(){
println("n:\t\t"+n);
println("excitatory:\t"+excitatory);
println("inhibithory:\t"+inhibithory);
println("external inputs:\t"+ externalInputs);
}
public void startAll() {
for (int i=0; i<nodeThreads.size();++i)
nodeThreads.get(i).start();
}
public void runNewSplit(double newStopTime) {
updateInterNodeSpikes();
for (int i=0; i<nodeThreads.size();++i)
nodeThreads.get(i).runNewSplit(newStopTime);
}
public void killAll() {
for (int i=0; i<nodeThreads.size();++i)
nodeThreads.get(i).kill();
}
public synchronized Boolean splitComplete(Integer nodeId){
return sim.splitComplete(nodeId);
}
private void updateInterNodeSpikes(){
for (int i=0; i<nodeThreads.size();++i){
ArrayList <InterNodeSpike> tmpInterNodeSpikes=nodeThreads.get(i).pullInternodeFires();
for (int j=0; j<tmpInterNodeSpikes.size();++j)
deliverInterNodeSpike( tmpInterNodeSpikes.get(j));
}
}
private void deliverInterNodeSpike(InterNodeSpike irs){
nodeThreads.get(irs.getSyn().getBurningNodeId()).burnInterNodeSpike(irs);
}
}