/*
 *  communicator_impl.h
 *
 *  This file is part of NEST.
 *
 *  Copyright (C) 2004 The NEST Initiative
 *
 *  NEST is free software: you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation, either version 2 of the License, or
 *  (at your option) any later version.
 *
 *  NEST is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with NEST.  If not, see <http://www.gnu.org/licenses/>.
 *
 */

#include "communicator.h"
#include "network.h"

#include "config.h"

/* To avoid problems on BlueGene/L, mpi.h MUST be the
   first included file after config.h.
 */
#ifdef HAVE_MPI
#include <mpi.h>
#endif /* #ifdef HAVE_MPI */

#ifdef HAVE_MPI

// Variable to hold the MPI communicator to use.
#ifdef HAVE_MUSIC
extern MPI::Intracomm comm;
#else /* #ifdef HAVE_MUSIC */
extern MPI_Comm comm;
#endif /* #ifdef HAVE_MUSIC */


/* ------------------------------------------------------
   The following datatypes are defined here in communicator_impl.h
   file instead of as static class members, to avoid inclusion
   of mpi.h in the .h file. This is necessary, because on
   BlueGene/L mpi.h MUST be included FIRST. Having mpi.h in
   the .h file would lead to requirements on include-order
   throughout the NEST code base and is not acceptable.
   Reported by Mikael Djurfeldt.
   Hans Ekkehard Plesser, 2010-01-28
 */
template <typename T>
struct MPI_Type { static MPI_Datatype type; };

template <typename T>
void nest::Communicator::communicate_Allgatherv(std::vector<T>& send_buffer,
                                                std::vector<T>& recv_buffer,
                                                std::vector<int>& displacements,
                                                std::vector<int>& recv_counts)
{
  //attempt Allgather
  MPI_Allgatherv(&send_buffer[0], send_buffer.size(), MPI_Type<T>::type,
      &recv_buffer[0], &recv_counts[0], &displacements[0], MPI_Type<T>::type, comm);
}

template <typename NodeListType>
void nest::Communicator::communicate(const NodeListType& local_nodes,
                                     vector<NodeAddressingData>& all_nodes,
                                     bool remote)
{
  size_t np = Communicator::num_processes_;
  if (np > 1 && remote)
  {
    vector<long_t> localnodes;
    for ( typename NodeListType::iterator n = local_nodes.begin(); n != local_nodes.end(); ++n )
    {
      localnodes.push_back((*n)->get_gid());
      localnodes.push_back(((*n)->get_parent())->get_gid());
      localnodes.push_back((*n)->get_vp());
    }
    //get size of buffers
    std::vector<nest::int_t> n_nodes(np);
    n_nodes[Communicator::rank_] = localnodes.size();
    communicate(n_nodes);
    // Set up displacements vector.
    std::vector<int> displacements(np,0);

    for ( size_t i = 1; i < np; ++i )
      displacements.at(i) = displacements.at(i-1)+n_nodes.at(i-1);

    // Calculate total number of node data items to be gathered.
    size_t n_globals =
        std::accumulate(n_nodes.begin(),n_nodes.end(), 0);
    assert(n_globals % 3 == 0);
    vector<long_t> globalnodes;
    if (n_globals != 0)
    {
      globalnodes.resize(n_globals,0L);
      communicate_Allgatherv<nest::long_t>(localnodes, globalnodes, displacements, n_nodes);
    
    //Create unflattened vector
    for ( size_t i = 0; i < n_globals -2; i +=3)
      all_nodes.push_back(NodeAddressingData(globalnodes[i],globalnodes[i+1],globalnodes[i+2]));

    //get rid of any multiple entries
    std::sort(all_nodes.begin(), all_nodes.end());
    vector<NodeAddressingData>::iterator it;
    it = std::unique(all_nodes.begin(), all_nodes.end());
    all_nodes.resize(it - all_nodes.begin());
    }
  }
  else   //on one proc or not including remote nodes
  {
    for ( typename NodeListType::iterator n = local_nodes.begin(); n != local_nodes.end(); ++n )
      all_nodes.push_back(NodeAddressingData((*n)->get_gid(),
                          ((*n)->get_parent())->get_gid(),
                          (*n)->get_vp()));
    std::sort(all_nodes.begin(), all_nodes.end());
  }
}


template <typename NodeListType>
void nest::Communicator::communicate(const NodeListType& local_nodes,
                                     vector<NodeAddressingData>& all_nodes,
                                     Network& net, DictionaryDatum params, 
				     bool remote)
{
  size_t np = Communicator::num_processes_;

  if ( np > 1 && remote)
  {
    vector<long_t> localnodes;
    if (params->empty())
    {
      for ( typename NodeListType::iterator n = local_nodes.begin(); n != local_nodes.end(); ++n )
      {
        localnodes.push_back((*n)->get_gid());
        localnodes.push_back(((*n)->get_parent())->get_gid());
        localnodes.push_back((*n)->get_vp());
      }
    } else {
      for ( typename NodeListType::iterator n = local_nodes.begin(); n != local_nodes.end(); ++n )
      {
        //select those nodes fulfilling the key/value pairs of the dictionary
        bool match = true;
        index gid = (*n)->get_gid();
        DictionaryDatum node_status = net.get_status(gid);
        for (Dictionary::iterator i = params->begin(); i != params->end(); ++i)
        {
          if (node_status->known(i->first))
          {
            const Token token = node_status->lookup(i->first);
            if (not ( token == i->second || token.matches_as_string(i->second) ))
            {
              match = false;
              break;
            }
          }
        }
        if (match)
        {
          localnodes.push_back(gid);
          localnodes.push_back(((*n)->get_parent())->get_gid());
          localnodes.push_back((*n)->get_vp());
        }
      }
    }

    //get size of buffers
    std::vector<nest::int_t> n_nodes(np);
    n_nodes[Communicator::rank_] = localnodes.size();
    communicate(n_nodes);

    // Set up displacements vector.
    std::vector<int> displacements(np,0);

    for ( size_t i = 1; i < np; ++i )
      displacements.at(i) = displacements.at(i-1)+n_nodes.at(i-1);

    // Calculate sum of global connections.
    size_t n_globals =
        std::accumulate(n_nodes.begin(),n_nodes.end(), 0);
    assert(n_globals % 3 == 0);
    vector<long_t> globalnodes;
    if (n_globals != 0)
    {
      globalnodes.resize(n_globals,0L);
      communicate_Allgatherv<nest::long_t>(localnodes, globalnodes, displacements, n_nodes);
  
      //Create unflattened vector
      for ( size_t i = 0; i < n_globals -2; i +=3)
        all_nodes.push_back(NodeAddressingData(globalnodes[i],globalnodes[i+1],globalnodes[i+2]));

      //get rid of any multiple entries
      std::sort(all_nodes.begin(), all_nodes.end());
      vector<NodeAddressingData>::iterator it;
      it = std::unique(all_nodes.begin(), all_nodes.end());
      all_nodes.resize(it - all_nodes.begin());
    }
  }
  else   //on one proc or not including remote nodes
  {
    if (params->empty())
    {
      for ( typename NodeListType::iterator n = local_nodes.begin(); n != local_nodes.end(); ++n )
        all_nodes.push_back(NodeAddressingData((*n)->get_gid(), ((*n)->get_parent())->get_gid(), (*n)->get_vp()));
    }
    else {
      //select those nodes fulfilling the key/value pairs of the dictionary
      for ( typename NodeListType::iterator n = local_nodes.begin(); n != local_nodes.end(); ++n )
      {
        bool match = true;
        index gid = (*n)->get_gid();
        DictionaryDatum node_status = net.get_status(gid);
        for (Dictionary::iterator i = params->begin(); i != params->end(); ++i)
        {
          if (node_status->known(i->first))
          {
            const Token token = node_status->lookup(i->first);
            if (not ( token == i->second || token.matches_as_string(i->second) ))
            {
              match = false;
              break;
            }
          }
        }
        if (match)
          all_nodes.push_back(NodeAddressingData((*n)->get_gid(), ((*n)->get_parent())->get_gid(), (*n)->get_vp()));
      }
    }
    std::sort(all_nodes.begin(),all_nodes.end());
  }
}



#else //HAVE_MPI

template <typename NodeListType>
void nest::Communicator::communicate(const NodeListType& local_nodes, vector<NodeAddressingData>& all_nodes, bool)
{
  for ( typename NodeListType::iterator n = local_nodes.begin(); n != local_nodes.end(); ++n )
        all_nodes.push_back(NodeAddressingData((*n)->get_gid(), ((*n)->get_parent())->get_gid(), (*n)->get_vp()));
  std::sort(all_nodes.begin(),all_nodes.end());
}

template <typename NodeListType>
void nest::Communicator::communicate(const NodeListType& local_nodes, vector<NodeAddressingData>& all_nodes,
                                     Network& net, DictionaryDatum params, bool)
{

  if (params->empty())
  {
    for ( typename NodeListType::iterator n = local_nodes.begin(); n != local_nodes.end(); ++n )
      all_nodes.push_back(NodeAddressingData((*n)->get_gid(), ((*n)->get_parent())->get_gid(), (*n)->get_vp()));
  }
  else {
    //select those nodes fulfilling the key/value pairs of the dictionary
    for ( typename NodeListType::iterator n = local_nodes.begin(); n != local_nodes.end(); ++n )
    {
      bool match = true;
      index gid = (*n)->get_gid();
      DictionaryDatum node_status = net.get_status(gid);
      node_status->info(std::cout);
      for (Dictionary::iterator i = params->begin(); i != params->end(); ++i)
      {
        if (node_status->known(i->first))
        {
          const Token token = node_status->lookup(i->first);
          if (not ( token == i->second || token.matches_as_string(i->second) ))
          {
            match = false;
            break;
          }
        }
      }
      if (match)
        all_nodes.push_back(NodeAddressingData((*n)->get_gid(), ((*n)->get_parent())->get_gid(), (*n)->get_vp()));
    }
  }
  std::sort(all_nodes.begin(),all_nodes.end());
}

#endif