#Simpilified neuron model fitter
#_withmids: consider not only best parameter sets for each objective but also "midoptimal" parameter sets that
#  perform well in two objectives
#_combfs: Combine certain objective functions to reduce the number of considered parameter set tree branches
#Tuomo Maki-Marttunen, 2015-2016
#
#Emoo code based on scripts from https://projects.g-node.org/emoo/
#  (Bahl et al. 2012, Automated optimization of a reduced layer 5 pyramidal cell model based on experimental data)
#
#
#
#Running the script:
#  python snmf_withmids_combfs.py ITER myseed
#or in parallel with e.g. 100 CPUs:
#  mpirun -np 100 nrniv -python -mpi snmf_withmids_combfs.py ITER myseed
#
#Arguments:
#  ITER:
#    Fitting task number, see below
#  myseed:
#    RNG seed (used by Emoo)
#
#Files needed:
#
#  Python files:
#
#    snmf_protocols.py:
#      Fitting protocols (which quantities are fitted at each step; what stimuli are for each objective function; etc.)
#    mytools.py:
#      General tools for e.g. spike detection
#
#  Data files:
#    originalrun.sav
#      Target data (generated from the full model based on snmf_protocols.py)
#    If 0 < ITER <= 3, one of the following files is needed (the parameter values obtained from the first step fitting)
#      pars_withmids_combfs_0a.sav
#      pars_withmids_combfs_1a.sav
#      pars_withmids_combfs_2a.sav
#    If 3 < ITER <= 12, one of the following files is needed (the parameter values obtained from the first and second step fitting)
#      pars_withmids_combfs_0a_0a.sav
#        ...
#      pars_withmids_combfs_2a_2a.sav
#    If 12 < ITER <= 66, one of the following files is needed (the parameter values obtained from the first, second and third step fitting)
#      pars_withmids_combfs_0a_0a_0a.sav
#       ...
#      pars_withmids_combfs_2a_2a_5a.sav
#
#Output files:
#
#  Data files:
#    If ITER = 0, the following files are saved, each of which contains an (optimal) parameter set for the first step fit:
#      pars_withmids_combfs_0a.sav (Best parameters for fulfilling first objective)
#      pars_withmids_combfs_1a.sav (Best parameters for fulfilling second objective)
#      pars_withmids_combfs_2a.sav (Parameters performing well in both objectives)
#      pars_withmids_combfs_wholepop.sav (All parameter sets and the corresponding objective function values)
#    If 0 < ITER <= 3, the following files are saved, each of which contains an (optimal) parameter set for the second step fit
#      (in addition to the parameters fixed during the first step, "?" referring to number 0-2 depending on the value of ITER):
#      pars_withmids_combfs_?a_0a.sav (Best parameters for fulfilling first objective)
#      pars_withmids_combfs_?a_1a.sav (Best parameters for fulfilling second objective)
#      pars_withmids_combfs_?a_2a.sav (Parameters performing well in both objectives)
#      pars_withmids_combfs_?a_wholepop.sav (All parameter sets and the corresponding objective function values)
#    If 3 < ITER <= 12, the following files are saved, each of which contains an (optimal) parameter set for the third step fit
#      (in addition to the parameters fixed during the first two steps):
#      pars_withmids_combfs_?a_?a_0a.sav (Best parameters for fulfilling first objective)
#      pars_withmids_combfs_?a_?a_1a.sav (Best parameters for fulfilling second objective)
#      pars_withmids_combfs_?a_?a_2a.sav (Best parameters for fulfilling third objective)
#      pars_withmids_combfs_?a_?a_3a.sav (Parameters performing well in first and second objectives)
#      pars_withmids_combfs_?a_?a_4a.sav (Parameters performing well in first and third objectives)
#      pars_withmids_combfs_?a_?a_5a.sav (Parameters performing well in second and third objectives)
#      pars_withmids_combfs_?a_?a_wholepop.sav (All parameter sets and the corresponding objective function values)
#    If 12 < ITER <= 66, the following files are saved, each of which contains an (optimal) parameter set for the fourth step fit
#      (in addition to the parameters fixed during the first three steps):
#      pars_withmids_combfs_?a_?a_?a_0a.sav (Best parameters for fulfilling first objective)
#      pars_withmids_combfs_?a_?a_?a_1a.sav (Best parameters for fulfilling second objective)
#      pars_withmids_combfs_?a_?a_?a_2a.sav (Best parameters for fulfilling third objective)
#      pars_withmids_combfs_?a_?a_?a_3a.sav (Parameters performing well in first and second objectives)
#      pars_withmids_combfs_?a_?a_?a_4a.sav (Parameters performing well in first and third objectives)
#      pars_withmids_combfs_?a_?a_?a_5a.sav (Parameters performing well in second and third objectives)
#      pars_withmids_combfs_?a_?a_?a_wholepop.sav (All parameter sets and the corresponding objective function values)
#
#  Images: Some EPS files illustrating the fit to the objective functions


import matplotlib
matplotlib.use('Agg')
import numpy as np
import emoo
from pylab import *
import pickle
import snmf_protocols
from neuron import h
import mytools
import time
import random

ITER = 0
myseed = 1

#Read arguments:
argv_i_myseed = 2
argv_i_ITER = 1
if len(sys.argv) > 2 and (sys.argv[1].find("mpi") > -1 or sys.argv[2].find("mpi") > -1):
  argv_i_myseed = 5
  argv_i_ITER = 4
if len(sys.argv) > argv_i_ITER:
  ITER = int(float(sys.argv[argv_i_ITER]))
if len(sys.argv) > argv_i_myseed:
  myseed = int(float(sys.argv[argv_i_myseed]))


random.seed(myseed) #random RNG
seed(myseed)        #pylab RNG
maxPerGroup = 1 # How many best ones shall we consider per objective group

# Emoo parameters:
N_samples = [1000, 1000, 1000, 2000] #size of population (for first, second, third and fourth steps)
C_samples = [2000, 2000, 2000, 4000] #population capacity
N_generations = [20, 20, 20, 10]     #number of generations
eta_m_0s = [20, 20, 20, 20]          #mutation strength parameter
eta_c_0s = [20, 20, 20, 20]          #crossover strength parameter
p_ms = [0.5, 0.5, 0.5, 0.5]          #probability of mutation


#Stimulus protocols and quantities recorded at each step:
VARIABLES = snmf_protocols.get_snmf_variables()
STIMULI = snmf_protocols.get_stimuli()
SETUP = snmf_protocols.get_setup()
STIMTYPES = snmf_protocols.get_stimulus_types()
DATASTORTYPES = snmf_protocols.get_data_storage_types()

#Target data:
unpicklefile = open('originalrun.sav', 'r')
unpickledlist = pickle.load(unpicklefile)
unpicklefile.close()
ORIGINALDATA = unpickledlist[0]
ORIGINALDISTS = [unpickledlist[1],unpickledlist[2]]
ORIGINALXS = unpickledlist[1]+[-x for x in unpickledlist[2]]

# Define the list of objectives:
OBJECTIVES = [['f0_0', 'f0_1'],
              ['f1_0', 'f1_1'],
              ['f2_0', 'f2_1', 'f2_2'],
              ['f3_0', 'f3_1', 'f3_2']]

# Originally, these were the objectives, but they are weighted by objective_group_coeffs and summed together to form the real objectives:
OBJECTIVES_SUBFUNC = [['f0_0_0_0', 'f0_1_0_0', 'f0_2_0_0', 'f0_2_1_0', 'f0_2_2_0'],
                      ['f1_0_0_0', 'f1_0_1_0', 'f1_0_2_0', 'f1_1_0_0', 'f1_1_1_0'],
                      ['f2_0_0_0', 'f2_0_1_0', 'f2_1_0_0', 'f2_1_0_1', 'f2_1_1_0', 'f2_1_1_1', 'f2_2_0_0', 'f2_2_0_1', 'f2_2_1_0', 'f2_2_1_1'],
                      ['f3_0_0_0', 'f3_0_1_0', 'f3_1_0_0', 'f3_2_0_0', 'f3_3_0_0', 'f3_3_1_0', 'f3_3_2_0']]
objective_groups = [ [ [0,1], [2,3,4] ], # Group voltage distributions together, and traces together
                     [ [0,1,2], [3,4] ], # Group voltage distributions together, and traces together
                     [ [0,1], [2,4,6,8], [3,5,7,9] ], # Group traces together, voltage distributions together, and calcium concentration distributions together
                     [ [0,1], [2,3], [4,5,6] ] ] # Group voltage traces together, numbers of spikes in long runs together, and BAC firing alone
objective_group_coeffs = [ [ [1,5], [1,1,1] ],
                              [ [1,1,1], [1,1] ],
                              [ [1,1], [1,1,1,1], [1,1,1,1] ],
                              [ [1,1], [1,1], [1,1,1] ] ]
lens_objective_groups = [ len(x) for x in objective_groups ]
groups_with_midpoints = [ [[0]], [[0],[1],[0,1]], [[0],[1],[2],[0,1],[0,2],[1,2]], [[0],[1],[2],[3],[0,1],[0,2],[0,3],[1,2],[1,3],[2,3]] ] 
objective_groups_with_midpoints = [ groups_with_midpoints[i-1] for i in lens_objective_groups ] 
lens_objective_groups_with_midpoints = [ len(x) for x in objective_groups_with_midpoints ]
saved_until = [0]+[maxPerGroup*x for x in cumprod(lens_objective_groups_with_midpoints)] #Number of fittings to be done for each step
saved_until_cums = cumsum(saved_until)                                                   #ITERs corresponding to the first fitting tasks of each step; see comments in the beginning

nspike_restrictions = snmf_protocols.get_nspike_restrictions() 
stop_if_no_spikes = nspike_restrictions[0] #Whether to stop without checking all objectives if for some stimulus a spike was not induced (although induced in the target data)
stop_if_nspikes = nspike_restrictions[1]   #Whether to stop without checking all objectives if for some stimulus too many spikes were induced

istep = next(i for i,x in enumerate(saved_until_cums) if x >= ITER) #Which step we are performing
stepsChosen = []                                                    #Which steps must have been performed previously
divisorNow = 1
for istep2 in range(0,istep):
  stepsChosen.append(((ITER-saved_until_cums[istep-1]-1)/maxPerGroup/divisorNow%lens_objective_groups_with_midpoints[istep2]))
  divisorNow = divisorNow * lens_objective_groups_with_midpoints[istep2]
iord = (ITER-1)%maxPerGroup
print "ITER="+str(ITER)+", steps="+str(stepsChosen)

dists_apical = [] #distances of recording locations along apical dendrite will be saved here
dists_basal = []  #distances of recording locations along apical dendrite will be saved here

FIGUREIND = 0 # A global variable for determining the names of EPS files

nseg = 20            #Number of segments per compartment
nrecsperseg = 20     #Number of recording locations per compartment
if nrecsperseg == 1:
  xsperseg = [0.5]
else:
  xsperseg = [1.0*x/(nrecsperseg+1) for x in range(1,nrecsperseg+1)]

BACdt = 5.0 #ISI between apical and somatic stimulus (needed in one of the fourth step objectives)


###################### Function definitions ###############################


#initialize_model(): a function that defines the morphology, stimuli, recordings and initial values
def initialize_model():
  global dists_apical, dists_basal
  v0 = -80
  ca0 = 0.0001
  fs = 8
  tstop = 15000.0
  icell = 0

  anyRecsRecorded = 0
  for istims in range(0,len(SETUP[istep])):
    stims = SETUP[istep][istims]
    irecs = stims[2]
    if type(irecs) is not list:
      irecs = [irecs]
    if any([x > 0 for x in irecs]):
      anyRecsRecorded = 1

  h("""
load_file("stdlib.hoc")
load_file("stdrun.hoc")
objref cvode
cvode = new CVode()
cvode.active(1)
cvode.atol(0.00005)
load_file("models/fourcompartment.hoc")
objref L5PC
L5PC = new fourcompartment()
access L5PC.soma
distance()
objref st1
st1 = new IClamp(0.5)
L5PC.soma st1
objref vsoma, vdend, vdend2
vsoma = new Vector()
vdend = new Vector()
vdend2 = new Vector()
objref syn1, tvec, sl
tvec = new Vector()
sl = new List()
double siteVec[2]
siteVec[0] = 0
siteVec[1] = 0.5
L5PC.apic[0] cvode.record(&v(siteVec[1]),vdend,tvec)
L5PC.apic[0] syn1 = new epsp(siteVec[1])
syn1.imax = 0
L5PC.apic[0] nseg = """+str(nseg)+"""
L5PC.apic[1] nseg = """+str(nseg)+"""
L5PC.soma cvode.record(&v(0.5),vsoma,tvec)
L5PC.soma nseg = """+str(nseg)+"""
L5PC.dend nseg = """+str(nseg)+"""
objref vrecs_apical["""+str(2*nrecsperseg)+"""], vrecs_basal["""+str(nrecsperseg)+"""], carecs_apical["""+str(2*nrecsperseg)+"""]
tstop = """+str(tstop)+"""
cai0_ca_ion = """+str(ca0)+"""
v_init = """+str(v0)+"""
""")

  dists_apical = []
  dists_basal = []
  if anyRecsRecorded:
    for j in range(0,nrecsperseg):
      dists_apical.append(h.distance(xsperseg[j],sec=h.L5PC.apic[0]))
      h("L5PC.apic[0] vrecs_apical["+str(j)+"] = new Vector()")
      h("L5PC.apic[0] cvode.record(&v("+str(xsperseg[j])+"),vrecs_apical["+str(j)+"],tvec)")
      h("L5PC.apic[0] carecs_apical["+str(j)+"] = new Vector()")
      h("L5PC.apic[0] cvode.record(&cai("+str(xsperseg[j])+"),carecs_apical["+str(j)+"],tvec)")
    for j in range(0,nrecsperseg):
      dists_apical.append(h.distance(xsperseg[j],sec=h.L5PC.apic[1]))
      h("L5PC.apic[1] vrecs_apical["+str(nrecsperseg+j)+"] = new Vector()")
      h("L5PC.apic[1] cvode.record(&v("+str(xsperseg[j])+"),vrecs_apical["+str(nrecsperseg+j)+"],tvec)")
      h("L5PC.apic[1] carecs_apical["+str(nrecsperseg+j)+"] = new Vector()")
      h("L5PC.apic[1] cvode.record(&cai("+str(xsperseg[j])+"),carecs_apical["+str(nrecsperseg+j)+"],tvec)")
    for j in range(0,nrecsperseg):
      dists_basal.append(h.distance(xsperseg[j],sec=h.L5PC.dend))
      h("L5PC.dend vrecs_basal["+str(j)+"] = new Vector()")
      h("L5PC.dend cvode.record(&v("+str(xsperseg[j])+"),vrecs_basal["+str(j)+"],tvec)")
  else:
    for j in range(0,nrecsperseg):
      dists_apical.append(h.distance(xsperseg[j],sec=h.L5PC.apic[0]))
      dists_apical.append(h.distance(xsperseg[j],sec=h.L5PC.apic[1]))
      dists_basal.append(h.distance(xsperseg[j],sec=h.L5PC.dend))
    print "No vrecs and carecs recorded!"

#setparams(params, istep): a function that sets the conductances of different species, axial resistances, capacitances or lengths for each compartment
#Input:
#  params: A dictionary for the parameters. See snmf_protocols for the variable names
#  istep: Which step fit is this? If fourth step (istep = 3), no additional action taken, but if istep < 3, set the conductance parameters corresponding
#    to the following steps to zero
def setparams(params, istep):
  global dists_apical, dists_basal
  
  keys = params.keys()
  #Apply the new parameter values:
  lengthChanged = False
  for ikey in range(0,len(keys)):
    key = keys[ikey]
    if key[0:2] == "L_":
      lengthChanged = True
    underscoreind = key.rfind('_')
    section = key[underscoreind+1:len(key)]
    if section == "*": # If a parameter has to be same in all sections (such as ehcn)
      h("forall "+key[0:underscoreind]+" = "+str(params[key]))
    else:
      h("L5PC."+section+" "+key[0:underscoreind]+" = "+str(params[key]))

  #Assume that if one length changed then every length changed. Change also the compartment
  #diameters such that the corresponding membrane area is conserved.
  if lengthChanged:                    
    h("L5PC.soma diam = "+str(360.132/params['L_soma']))
    h("L5PC.dend diam = "+str(2821.168/params['L_dend']))
    h("L5PC.apic[0] diam = "+str(4244.628/params['L_apic[0]']))
    h("L5PC.apic[1] diam = "+str(2442.848/params['L_apic[1]']))

    dists_apical = []
    dists_basal = []
    for j in range(0,nrecsperseg):
      dists_apical.append(h.distance(xsperseg[j],sec=h.L5PC.apic[0]))
    for j in range(0,nrecsperseg):
      dists_apical.append(h.distance(xsperseg[j],sec=h.L5PC.apic[1]))
    for j in range(0,nrecsperseg):
      dists_basal.append(h.distance(xsperseg[j],sec=h.L5PC.dend))
    if params['L_apic[0]'] > 620:
      h("L5PC.apic[0] syn1.loc("+str(620.0/params['L_apic[0]'])+")")
    elif params['L_apic[0]'] + params['L_apic[1]'] > 620:
      h("L5PC.apic[1] syn1.loc("+str((620.0-params['L_apic[0]'])/params['L_apic[1]'])+")")
    else:
      h("L5PC.apic[1] syn1.loc(1.0)")

  #Set those conductances to zero that will be fitted at the next step:
  for istep2 in range(istep+1,4):
    vars_zero = snmf_protocols.get_variable_params(istep2)
    for iparam in range(0,len(vars_zero)):
      if vars_zero[iparam][0]=='g' and vars_zero[iparam].find('bar') > -1:
        underscoreind = vars_zero[iparam].rfind('_')
        section = vars_zero[iparam][underscoreind+1:len(vars_zero[iparam])]
        h("L5PC."+section+" "+vars_zero[iparam][0:underscoreind]+" = 0")

#run_model(istep,saveFig,stop_if_needed): a function that runs the model using the stimuli of a certain step
#Input:
#  istep (0 to 3): Which step to run
#  saveFig: Name of the EPS file to save (empty if no need to plot the time course)
#  stop_if_needed: Boolean telling whether we can quit the function after the first stimulus condition that produces too many or too
#    few spikes. This is relevant only in the fourth step, where relatively heavy simulations are performed.
#Output:
#  myValsAllAll: A data structure containing all the data needed for determining the objective function values
def run_model(istep,saveFig="",stop_if_needed=True):
  global STIMULI, SETUP, STIMTYPES, DATASTORTYPES, dists_apical, dists_basal
  time_to_quit = False
  
  myValsAllAll = [nan]*len(SETUP[istep])
  for istims in range(0,len(SETUP[istep])):
    stims = SETUP[istep][istims]
    myValsAll = [nan]*len(stims[1])
    for iamp in range(0,len(stims[1])):
      for istim in range(0,len(stims[0])):
        stimulus = STIMULI[stims[0][istim]]
        st = STIMTYPES[stimulus[0]]
        if type(stims[1][iamp]) is list:
          myamp = stims[1][iamp][istim]
        else:
          myamp = stims[1][iamp]
        h(st[0]+"."+st[4]+" = "+str(myamp))
        for iprop in range(0,len(stimulus[1])):
          thisPropVal = stimulus[1][iprop][1]
          if stimulus[1][iprop][0] == "onset" or stimulus[1][iprop][0] == "del":
            thisPropVal = thisPropVal + istim*BACdt
          h(st[0]+"."+stimulus[1][iprop][0]+" = "+str(thisPropVal))
      h.init()
      h.run()

      irecs = stims[2]
      if type(irecs) is not list:
        irecs = [irecs]

      times=np.array(h.tvec)
      Vsoma=np.array(h.vsoma)
      if any([x > 0 for x in irecs]):
        Vdend=np.concatenate((np.array(h.vrecs_apical),np.array(h.vrecs_basal)))
        Cadend=np.array(h.carecs_apical)
      else:
        Vdend=[]
        Cadend=[]

      spikes = mytools.spike_times(times,Vsoma,-35,-45)
      if len(saveFig) > 0:
        close("all")
        f,axarr = subplots(2,2)
        axarr[0,0].plot(times,Vsoma)
        axarr[0,0].set_xlim([9990,11000])
        axarr[0,0].set_title("nSpikes="+str(len(spikes)))
        if len(Vdend) > 0:
          axarr[0,1].plot(dists_apical+[-x for x in dists_basal], [max([x[i] for i,t in enumerate(times) if t >= 9000]) for x in Vdend], 'bx')
        if len(Cadend) > 0:
          axarr[1,0].plot(dists_apical, [max([x[i] for i,t in enumerate(times) if t >= 9000]) for x in Cadend], 'b.')
        f.savefig(saveFig+"_step"+str(istep)+"_istims"+str(istims)+"_iamp"+str(iamp)+".eps")
      if stop_if_needed and ((stop_if_no_spikes[istep][istims][iamp] > 0 and len(spikes) == 0) or (stop_if_nspikes[istep][istims][iamp] > 0 and len(spikes) >= stop_if_nspikes[istep][istims][iamp])):
        time_to_quit = True
        break

      myVals = []
      for iirec in range(0,len(irecs)):
        irec = irecs[iirec]
        if irec == 0:
          myData = [Vsoma]
        elif irec == 1:
          myData = Vdend
        elif irec == 2:
          myData = Cadend
        else:
          print "Unknown recording type: "+str(irec)
          continue
        if DATASTORTYPES[stims[3]][0] == "fixed":
          print "ind="+str(next(i for i,t in enumerate(times) if t >= DATASTORTYPES[stims[3]][1]))+", val="+str(myData[0][next(i for i,t in enumerate(times) if t >= DATASTORTYPES[stims[3]][1])-1])
          myVals.append([x[next(i for i,t in enumerate(times) if t >= DATASTORTYPES[stims[3]][1])-1] for x in myData])
        elif DATASTORTYPES[stims[3]][0] == "max":
          myVals.append([max(x[next(i for i,t in enumerate(times) if t >= DATASTORTYPES[stims[3]][1][0]):next(i for i,t in enumerate(times) if t >= DATASTORTYPES[stims[3]][1][1])]) for x in myData])
        elif DATASTORTYPES[stims[3]][0] == "trace" or DATASTORTYPES[stims[3]][0] == "highrestrace":
          myVals.append([mytools.interpolate_extrapolate_constant(times,x,DATASTORTYPES[stims[3]][1]) for x in myData])
        elif DATASTORTYPES[stims[3]][0] == "highrestraceandspikes":
          myVals.append([[mytools.interpolate_extrapolate_constant(times,x,DATASTORTYPES[stims[3]][1]) for x in myData],spikes])
        elif DATASTORTYPES[stims[3]][0] == "nspikes":
          myVals.append(sum([1 for x in spikes if x >= DATASTORTYPES[stims[3]][1][0] and x < DATASTORTYPES[stims[3]][1][1]]))
        else:
          print "Unknown data storage type: "+DATASTORTYPES[stims[3]][0]
          continue
      myValsAll[iamp] = myVals[:]
      if time_to_quit:
        break
    myValsAllAll[istims] = myValsAll[:]

    for istim in range(0,len(stims[0])):
      stimulus = STIMULI[stims[0][istim]]
      st = STIMTYPES[stimulus[0]]
      h(st[0]+"."+st[4]+" = 0")
    if time_to_quit:
      break

  return myValsAllAll[:]

#distdiff(xs, fs, xs_ref, fs_ref): a function to calculate the difference of membrane potential distribution along the dendrites
#Input:
#  xs: Distances of recording locations from soma, along the dendrites of the reduced model (negative for basal, positive for apical)
#  fs: Values (e.g. maximum membrane potential during or following a stimulus) along the dendrites of the reduced model
#  xs_ref: Distances of recording locations from soma, along the dendrites of the full model (negative for basal, positive for apical)
#  fs_ref: Values (e.g. maximum membrane potential during or following a stimulus) along the dendrites of the full model. The minimal
#    and maximal values of xs_ref and fs_ref are used to scale the quantities so that both dimensions (spave vs. membrane potential)
#    are as relevant a priori
def distdiff(xs, fs, xs_ref, fs_ref):
  dist = 0
  n = 0
  dist_diff_max = max(xs_ref)-min(xs_ref)
  f_diff_max = max(fs_ref)-min(fs_ref)

  for ix in range(0,len(xs)):
    if xs[ix] > max(xs_ref) or xs[ix] < min(xs_ref):
      continue
    dists2 = [(1.0*(xs[ix]-x)/dist_diff_max)**2 + (1.0*(fs[ix]-y)/f_diff_max)**2 for x,y in zip(xs_ref, fs_ref)]
    dist = dist + sqrt(min(dists2))
    n = n+1
  return 1.0*dist/n

#highrestraceandspikesdiff(data, dataref, coeff_mV=1.0/12, coeff_ms=1.0/20): a function to calculate difference between spike trains.
#By default, a mean difference of 12mV memb. pot. is penalized as much as summed distance of 20ms between nearest spikes and further as
#much as a difference of one spike
#Input:
#  data: A list [V_m, spikes] containing the membrane potentials and spike times of the reduced model
#  dataref: A list [V_m, spikes] containing the membrane potentials and spike times of the full model
#  coeff_mV: How much is the mean difference of 1mV between target and model membrane potential penalized with respect to a difference of one spike in spike count
#  coeff_ms: How much is the summed difference of 1ms in spike timings between target and model penalized with respect to a difference of one spike in spike count
#Output:
#  (difference between numbers of spikes) + coeff_ms*(summed difference between spike timings) + coeff_mV*(mean difference between membrane potentials)
def highrestraceandspikesdiff(data, dataref, coeff_mV=1.0/12, coeff_ms=1.0/20):
  trace1 = data[0]
  spikes1 = data[1]
  traceref = dataref[0]
  spikesref = dataref[1]

  meantracediffs = [1.0*mean([abs(x-y) for x,y in zip(thistrace, thistraceref)]) for thistrace, thistraceref in zip(trace1, traceref)]
  sp_N_err = abs(len(spikesref)-len(spikes1))
  sp_t_err = 0
  if len(spikesref) > 0:
    for ispike in range(0,len(spikes1)):
      sp_t_err = sp_t_err + min([abs(spikes1[ispike] - x) for x in spikesref])
  if type(coeff_mV) is list: # Assume that if there are more than one trace lists, then coeff_mV is explicitly given with each element denoting the coefficient for a separate trace
    return sum([x*y for x,y in zip(meantracediffs, coeff_mV)]) + sp_t_err * coeff_ms + sp_N_err
  else:
    return meantracediffs[0] * coeff_mV + sp_t_err * coeff_ms + sp_N_err


#func_to_optimize(parameters,istep, saveFig=False, filename="FIGUREWITHMIDSCOMBFS",stop_if_needed=True): the function which is to be minimized.
#Calls subfunc_to_optimize with the same parameters and groups the objectives according to data in objective_groups and objective_group_coeffs
def func_to_optimize(parameters,istep, saveFig=False, filename="FIGUREWITHMIDSCOMBFS",stop_if_needed=True):
  mydict = subfunc_to_optimize(parameters,istep, saveFig, filename, stop_if_needed)
  mynewdict = {}
  for j in range(0,len(OBJECTIVES[istep])):
    myval = 0
    for k in range(0,len(objective_groups[istep][j])):
      myval = myval + objective_group_coeffs[istep][j][k]*mydict[OBJECTIVES_SUBFUNC[istep][objective_groups[istep][j][k]]]
    mynewdict['f'+str(istep)+'_'+str(j)] = myval
  return mynewdict

#subfunc_to_optimize(parameters,istep, saveFig, filename,stop_if_needed): the function which is to be minimized.
#Input:
#  parameters: A dictionary for the parameters. See snmf_protocols for the variable names
#  istep: The step to perform (0 to 3)
#  saveFig: Boolean telling whether to plot how well the parameters are fitted to the target data
#  filename: The name of the EPS file to save if any
#  stop_if_needed: Boolean telling whether we can quit the function after the first stimulus condition that produces too many or too
#    few spikes. This is relevant only in the fourth step, where relatively heavy simulations are performed.
#Output:
#  mydict: Dictionary of the (sub-)objective function values
def subfunc_to_optimize(parameters,istep, saveFig, filename, stop_if_needed):

  MAXERR = 1e8 # If stop_if_needed is true, values of objective functions corresponding to those stimuli that were skipped will be 10^8
  setparams(parameters,istep)
  global ORIGINALDATA, SETUP, dists_apical, dists_basal #Reload the values that have possibly been changed by setparams
  xs = dists_apical + [-x for x in dists_basal]
  A = run_model(istep,"",stop_if_needed)

  if saveFig:
    close("all")
    f,axarr = subplots(5,2)
    axs = [axarr[0,0],axarr[1,0],axarr[2,0],axarr[3,0],axarr[4,0],axarr[0,1],axarr[1,1],axarr[2,1],axarr[3,1],axarr[4,1]]
    for iplot in range(0,10):
      iplotx = iplot%2
      iploty = iplot/2
      axs[iplot].set_position([0.08+0.5*iplotx,0.86-0.185*iploty,0.4,0.11])
    saveInd = 0
    global FIGUREIND

  mydict = {}
  for ifun in range(0,len(SETUP[istep])):
    for iamp in range(0,len(SETUP[istep][ifun][1])):
      irecs = SETUP[istep][ifun][2]
      istor = SETUP[istep][ifun][3]
      if type(irecs) is not list:
        irecs = [irecs]

      if type(A[ifun]) is not list or (type(A[ifun][iamp]) is not list and isnan(A[ifun][iamp])):
        for iirec in range(0,len(irecs)):
          mydict['f'+str(istep)+'_'+str(ifun)+'_'+str(iamp)+'_'+str(iirec)] = MAXERR
        continue
          
      for iirec in range(0,len(irecs)):
        irec = irecs[iirec]
        if istor == 0 or istor == 1: # Maxima or steady-state-values across the dendrite(s)
          if irec == 1: #Voltage, whole dendritic tree
            mydict['f'+str(istep)+'_'+str(ifun)+'_'+str(iamp)+'_'+str(iirec)] = distdiff(xs, A[ifun][iamp][iirec], ORIGINALXS, ORIGINALDATA[istep][ifun][iamp][iirec])
          if irec == 2: #Calcium, only apical dendrite
            mydict['f'+str(istep)+'_'+str(ifun)+'_'+str(iamp)+'_'+str(iirec)] = distdiff(dists_apical, A[ifun][iamp][iirec], ORIGINALDISTS[0], ORIGINALDATA[istep][ifun][iamp][iirec])
          if saveFig:
            if irec == 1: #Voltage, whole dendritic tree
              axs[saveInd].plot(xs, A[ifun][iamp][iirec],'bx')
              axs[saveInd].plot(ORIGINALXS, ORIGINALDATA[istep][ifun][iamp][iirec],'g.')
            if irec == 2: #Calcium, only apical dendrite
              axs[saveInd].plot(dists_apical, A[ifun][iamp][iirec],'bx')
              axs[saveInd].plot(ORIGINALDISTS[0], ORIGINALDATA[istep][ifun][iamp][iirec],'g.')
            axs[saveInd].set_title('f'+str(istep)+'_'+str(ifun)+'_'+str(iamp)+'_'+str(iirec)+"="+str(mydict['f'+str(istep)+'_'+str(ifun)+'_'+str(iamp)+'_'+str(iirec)]),fontsize=7)
            if saveInd < len(axs)-1:
              saveInd = saveInd+1
        elif istor == 2 or istor == 3: # Time series
          sumval = 0
          for irecloc in range(0,len(A[ifun][iamp][iirec])): # Usually only time course of soma, but techinically allowed for dendrites as well
            sumval = sumval + sum([1.0*abs(x-y) for x,y in zip(A[ifun][iamp][iirec][irecloc], ORIGINALDATA[istep][ifun][iamp][iirec][irecloc])])
            if saveFig:
              axs[saveInd].plot(A[ifun][iamp][iirec][irecloc])
              axs[saveInd].plot(ORIGINALDATA[istep][ifun][iamp][iirec][irecloc])
          mydict['f'+str(istep)+'_'+str(ifun)+'_'+str(iamp)+'_'+str(iirec)] = sumval/250.0
          if istor == 2:
            mydict['f'+str(istep)+'_'+str(ifun)+'_'+str(iamp)+'_'+str(iirec)] = sumval*5/250.0
          if saveFig:
            axs[saveInd].set_title('f'+str(istep)+'_'+str(ifun)+'_'+str(iamp)+'_'+str(iirec)+"="+str(mydict['f'+str(istep)+'_'+str(ifun)+'_'+str(iamp)+'_'+str(iirec)]),fontsize=7)
            if saveInd < len(axs)-1:
              saveInd = saveInd+1
        elif istor == 4: # Time series with spike time precision
          mydict['f'+str(istep)+'_'+str(ifun)+'_'+str(iamp)+'_'+str(iirec)] = highrestraceandspikesdiff(A[ifun][iamp][iirec], ORIGINALDATA[istep][ifun][iamp][iirec])
          if saveFig:
            for irecloc in range(0,len(A[ifun][iamp][iirec][0])): # Usually only time course of soma, but techinically allowed for dendrites as well
              axs[saveInd].plot(A[ifun][iamp][iirec][0][irecloc])
              axs[saveInd].plot(ORIGINALDATA[istep][ifun][iamp][iirec][0][irecloc])
            axs[saveInd].set_title('f'+str(istep)+'_'+str(ifun)+'_'+str(iamp)+'_'+str(iirec)+"="+str(mydict['f'+str(istep)+'_'+str(ifun)+'_'+str(iamp)+'_'+str(iirec)])+
                                   ', nSp='+str(len(A[ifun][iamp][iirec][1]))+', nSpref='+str(ORIGINALDATA[istep][ifun][iamp][iirec][1]),fontsize=8)
            if saveInd < len(axs)-1:
              saveInd = saveInd+1
        else: # Nspikes
          mydict['f'+str(istep)+'_'+str(ifun)+'_'+str(iamp)+'_'+str(iirec)] = 1.0*(A[ifun][iamp][iirec]-ORIGINALDATA[istep][ifun][iamp][iirec])**2
          if saveFig:
            axs[saveInd].set_title('f'+str(istep)+'_'+str(ifun)+'_'+str(iamp)+'_'+str(iirec)+"="+str(mydict['f'+str(istep)+'_'+str(ifun)+'_'+str(iamp)+'_'+str(iirec)]),fontsize=7)
            if saveInd < len(axs)-1:
              saveInd = saveInd+1
  if saveFig:
    for iax in range(0,len(axs)):
      for tick in axs[iax].yaxis.get_major_ticks()+axs[iax].xaxis.get_major_ticks():
        tick.label.set_fontsize(5)
    write_params_here = [6,7,8,9]
    for j in range(0,4):
      myText = ''
      varnames = [x[0] for x in VARIABLES[j]]
      nplaced = 0
      for ivarname in range(0,len(varnames)):
        if parameters.has_key(varnames[ivarname]):
          myText = myText + varnames[ivarname] + " = " + str(parameters[varnames[ivarname]])
        else:
          underscoreind = varnames[ivarname].rfind('_')
          section = varnames[ivarname][underscoreind+1:len(varnames[ivarname])]
          h("tmpvar = 999999")
          if section == "*": # If a parameter has to be same in all sections (such as ehcn)                                                                                                               
            h("tmpvar = L5PC.soma."+varnames[ivarname][0:underscoreind])
          else:
            h("tmpvar = L5PC."+section+"."+varnames[ivarname][0:underscoreind])
          myText = myText + "No " + varnames[ivarname] + " (" + str(h.tmpvar) + ")"
        nplaced = nplaced + 1
        if ivarname < len(varnames)-1:
          myText = myText + ", "
        if nplaced > 4:
          nplaced = 0; myText = myText + '\n'
      myText = myText.replace('CaDynamics_E2','CaD');
      axs[write_params_here[j]].set_xlabel(myText,fontsize=2.2)
    if FIGUREIND==-1:
      f.savefig(filename+".eps")
    else:
      f.savefig(filename+str(istep)+"_"+str(FIGUREIND)+".eps")
    FIGUREIND = FIGUREIND+1
  return mydict


# After each generation this function is called
def checkpopulation(population, columns, gen, istep):
  picklelist = [population, columns]
  file=open(h.myfilename+'_tmp'+str(gen)+'.sav', 'w')
  pickle.dump(picklelist,file)
  file.close()
  print "Generation %d done!"%gen


########################## The main code ##################################

initialize_model()

ext = chr(ord('a')+iord)
filename = "pars_withmids_combfs"
par_names = []
par_values = []
paramdict = {}
if istep > 0:
  #Determine the name of the file to load (to get the parameters from the previous step fittings)
  for ichosen in range(0,len(stepsChosen)):
    filename = filename+"_"+str(stepsChosen[ichosen])+"a"
  filename = filename[0:-1] + ext
  unpicklefile = open(filename+".sav", 'r')
  unpickledlist = pickle.load(unpicklefile)
  unpicklefile.close()
  par_names = unpickledlist[0]
  par_values = unpickledlist[1]

  #Make the dictionary and set the parameters:
  for i in range(0,len(par_names)):
    paramdict[par_names[i]] = par_values[i]  
  setparams(paramdict, istep-1)

print "filename: "+filename

#Initialize Emoo:
FIGUREIND = 0
if type(N_samples) is list:
  myemoo = emoo.Emoo(N = N_samples[istep], C = C_samples[istep], variables = VARIABLES[istep], objectives = OBJECTIVES[istep])
  myemoo.setup(eta_m_0 = eta_m_0s[istep], eta_c_0 = eta_c_0s[istep], p_m = p_ms[istep])
else:
  myemoo = emoo.Emoo(N = N_samples, C = C_samples, variables = VARIABLES[istep], objectives = OBJECTIVES[istep])
  myemoo.setup(eta_m_0 = eta_m_0s, eta_c_0 = eta_c_0s, p_m = p_ms)
myemoo.get_objectives_error = lambda params: func_to_optimize(params,istep,False)
myemoo.checkpopulation = lambda population, columns, gen: checkpopulation(population, columns, gen, istep)
h("strdef myfilename")
h("myfilename=\""+filename+"\"")

#Run Emoo:
if type(N_samples) is list:
  myemoo.evolution(generations = N_generations[istep])
else:
  myemoo.evolution(generations = N_generations)

Results = []

#Save the best parameters for each objective (and the parameters that perform well in two distinct objectives) to a .sav file and plot the corresponding results
if myemoo.master_mode:
  params_all = myemoo.getpopulation_unnormed()

  picklelist = [N_samples, C_samples, N_generations, par_names, par_values, params_all.tolist()]
  file=open(filename+'_wholepop.sav', 'w')
  pickle.dump(picklelist,file)
  file.close()

  param_fdims = range(len(VARIABLES[istep]), len(VARIABLES[istep])+len(OBJECTIVES[istep]))
  medians = [median([y[i] for y in params_all]) for i in param_fdims]
  params_f = [[1.0*y[param_fdims[j]]/medians[j] for j in range(0,len(param_fdims))] for y in params_all]
  fvals = [[sum([y[j] for j in objective_groups_with_midpoints[istep][i]]) for i in range(0,lens_objective_groups_with_midpoints[istep])] for y in params_f]
  objs = objective_groups_with_midpoints[istep]

  for iobj_group in range(0,len(objs)):
    myord = [i[0] for i in sorted(enumerate([y[iobj_group] for y in fvals]), key=lambda x:x[1])]
    j = 1
    while j < len(myord) and len(myord) > maxPerGroup:
      if all([x==y for x,y in zip(params_all[myord[j]][0:len(VARIABLES[istep])],params_all[myord[j-1]][0:len(VARIABLES[istep])])]):
        myord.pop(j)
      else:
        j = j+1
    for isamp in range(0,maxPerGroup):
      print str(iobj_group)+"_"+str(isamp)
      ext = chr(ord('a')+isamp)
      for iparam in range(0,len(VARIABLES[istep])):
        paramdict[VARIABLES[istep][iparam][0]] = params_all[myord[isamp]][iparam]
      par_names = paramdict.keys()
      par_values = [paramdict[key] for key in par_names]
      picklelist = [par_names, par_values, N_samples, C_samples, N_generations, myord]
      file=open(filename+'_'+str(iobj_group)+ext+'.sav', 'w')
      pickle.dump(picklelist,file)
      file.close()

      FIGUREIND = -1
      func_to_optimize(paramdict,istep,True,filename+'_'+str(iobj_group)+ext)