#drawobjective_evolution.py
#A script that illustrates how quickly the model fitting objective functions decrease
#Tuomo Maki-Marttunen, 2015-2016
#                               
#Running the script:            
#  python drawfits_evolution.py ITER
#                               
#Arguments:                     
#  ITER                     
#    The iteration number. The script supposes that snmf_withmids_combfs.py has been run with the same iteration number,
#    and hence the corresponding "pars_withmids_combfs" .sav file and the related _tmpXX files have been saved. See below
#    for details.
#                               
#Files needed:            
#                               
#  Python files:
#    snmf_protocols.py:           
#      Fitting protocols (which quantities are fitted at each step; what stimuli are for each objective function; etc.)     
#    mytools.py:                  
#      General tools for e.g. spike detection
#  Data files:
#    pars_withmids_combfs_final.sav: The final parameter set for which objective functions are plotted for comparison
#    If ITER = 0, the following files should exist:
#      pars_withmids_combfs_1a.sav    #The best parameter set after one generation (in terms of second objective).
#                                     #This file is only needed for ITER=0, as only the second step rewrites parameters.
#                                     #NB: If the final parameterset is not of the branch corresponding to 
#                                     #pars_withmids_combfs_1a_?a_?a_?a.sav, then this should be changed!
#      pars_withmids_combfs_tmp0.sav  #The population after one generation
#       ...
#      pars_withmids_combfs_tmp19.sav #The population after twenty generations
#    If 0 < ITER <= 3, the following files should exist ("?" referring to number 0-2 depending on the value of ITER): 
#      pars_withmids_combfs_?a_tmp0.sav  #The population after one generation
#       ...
#      pars_withmids_combfs_?a_tmp19.sav #The population after twenty generations
#    If 3 < ITER <= 12, the following files should exist ("?"s referring to numbers 0-2 and 0-2 depending on the value of ITER):
#      pars_withmids_combfs_?a_?a_tmp0.sav  #The population after one generation
#       ...
#      pars_withmids_combfs_?a_?a_tmp19.sav #The population after twenty generations
#    If 12 < ITER <= 66, the following files should exist ("?"s referring to numbers 0-2, 0-2 and 0-6 depending on the value of ITER):
#      pars_withmids_combfs_?a_?a_?a_tmp0.sav  #The population after one generation
#       ...
#      pars_withmids_combfs_?a_?a_?a_tmp19.sav #The population after twenty generations
#Output files:                  
#
#    If ITER = 0, the following picture will be saved:
#      pars_withmids_combfs_objective_evolution.eps 
#    If 0 < ITER <= 3, the following picture will be saved:
#      pars_withmids_combfs_?a_objective_evolution.eps #("?" referring to number 0-2 depending on the value of ITER)
#    ... and so on for ITER = 4,...,66; see above
#

import matplotlib
matplotlib.use('Agg')
import numpy as np
import emoo
from pylab import *
import pickle
import snmf_protocols
from neuron import h
import mytools
import time
import random
from os.path import exists


# Draw a single box plot with minimum, maximum, and 25-, 50-, and 75-percentiles.
def plotmybox(ax,ys,x=0,w=0.5):
  ax.plot([x-w,x+w,x,x,x-w,x-w,x+w,x+w,x,nan,x-w,x+w,nan,x,x,x-w,x+w],[ys[0],ys[0],ys[0],ys[1],ys[1],ys[3],ys[3],ys[1],ys[1],nan,ys[2],ys[2],nan,ys[3],ys[4],ys[4],ys[4]],'k-')


ITER = 0
myseed = 1

#Read arguments:
argv_i_ITER = 1

if len(sys.argv) > argv_i_ITER:
  ITER = int(float(sys.argv[argv_i_ITER]))

maxPerGroup = 1 # How many best ones shall we consider per objective group

# Emoo parameters:
N_samples = [1000, 1000, 1000, 2000] #size of population (for first, second, third and fourth steps)
C_samples = [2000, 2000, 2000, 4000] #population capacity
N_generations = [20, 20, 20, 10]     #number of generations
eta_m_0s = [20, 20, 20, 20]          #mutation strength parameter
eta_c_0s = [20, 20, 20, 20]          #crossover strength parameter
p_ms = [0.5, 0.5, 0.5, 0.5]          #probability of mutation


#Stimulus protocols and quantities recorded at each step:
VARIABLES = snmf_protocols.get_snmf_variables()
STIMULI = snmf_protocols.get_stimuli()
SETUP = snmf_protocols.get_setup()
STIMTYPES = snmf_protocols.get_stimulus_types()
DATASTORTYPES = snmf_protocols.get_data_storage_types()

#Target data:
unpicklefile = open('originalrun.sav', 'r')
unpickledlist = pickle.load(unpicklefile)
unpicklefile.close()
ORIGINALDATA = unpickledlist[0]
ORIGINALDISTS = [unpickledlist[1],unpickledlist[2]]
ORIGINALXS = unpickledlist[1]+[-x for x in unpickledlist[2]]

# Define the list of objectives:
OBJECTIVES = [['f0_0', 'f0_1'],
              ['f1_0', 'f1_1'],
              ['f2_0', 'f2_1', 'f2_2'],
              ['f3_0', 'f3_1', 'f3_2']]

# Originally, these were the objectives, but they are weighted by objective_group_coeffs and summed together to form the real objectives:
OBJECTIVES_SUBFUNC = [['f0_0_0_0', 'f0_1_0_0', 'f0_2_0_0', 'f0_2_1_0', 'f0_2_2_0'],
                      ['f1_0_0_0', 'f1_0_1_0', 'f1_0_2_0', 'f1_1_0_0', 'f1_1_1_0'],
                      ['f2_0_0_0', 'f2_0_1_0', 'f2_1_0_0', 'f2_1_0_1', 'f2_1_1_0', 'f2_1_1_1', 'f2_2_0_0', 'f2_2_0_1', 'f2_2_1_0', 'f2_2_1_1'],
                      ['f3_0_0_0', 'f3_0_1_0', 'f3_1_0_0', 'f3_2_0_0', 'f3_3_0_0', 'f3_3_1_0', 'f3_3_2_0']]
objective_groups = [ [ [0,1], [2,3,4] ], # Group voltage distributions together, and traces together
                     [ [0,1,2], [3,4] ], # Group voltage distributions together, and traces together
                     [ [0,1], [2,4,6,8], [3,5,7,9] ], # Group traces together, voltage distributions together, and calcium concentration distributions together
                     [ [0,1], [2,3], [4,5,6] ] ] # Group voltage traces together, numbers of spikes in long runs together, and BAC firing alone
objective_group_coeffs = [ [ [1,5], [1,1,1] ],
                              [ [1,1,1], [1,1] ],
                              [ [1,1], [1,1,1,1], [1,1,1,1] ],
                              [ [1,1], [1,1], [1,1,1] ] ]
lens_objective_groups = [ len(x) for x in objective_groups ]
groups_with_midpoints = [ [[0]], [[0],[1],[0,1]], [[0],[1],[2],[0,1],[0,2],[1,2]], [[0],[1],[2],[3],[0,1],[0,2],[0,3],[1,2],[1,3],[2,3]] ] 
objective_groups_with_midpoints = [ groups_with_midpoints[i-1] for i in lens_objective_groups ] 
lens_objective_groups_with_midpoints = [ len(x) for x in objective_groups_with_midpoints ]
saved_until = [0]+[maxPerGroup*x for x in cumprod(lens_objective_groups_with_midpoints)] #Number of fittings to be done for each step
saved_until_cums = cumsum(saved_until)                                                   #ITERs corresponding to the first fitting tasks of each step; see comments in the beginning

nspike_restrictions = snmf_protocols.get_nspike_restrictions() 
stop_if_no_spikes = nspike_restrictions[0] #Whether to stop without checking all objectives if for some stimulus a spike was not induced (although induced in the target data)
stop_if_nspikes = nspike_restrictions[1]   #Whether to stop without checking all objectives if for some stimulus too many spikes were induced

istep = next(i for i,x in enumerate(saved_until_cums) if x >= ITER) #Which step we are performing
stepsChosen = []                                                    #Which steps must have been performed previously
divisorNow = 1
for istep2 in range(0,istep):
  stepsChosen.append(((ITER-saved_until_cums[istep-1]-1)/maxPerGroup/divisorNow%lens_objective_groups_with_midpoints[istep2]))
  divisorNow = divisorNow * lens_objective_groups_with_midpoints[istep2]
iord = (ITER-1)%maxPerGroup
print "ITER="+str(ITER)+", steps="+str(stepsChosen)

dists_apical = [] #distances of recording locations along apical dendrite will be saved here
dists_basal = []  #distances of recording locations along apical dendrite will be saved here

FIGUREIND = 0 # A global variable for determining the names of EPS files

nseg = 20            #Number of segments per compartment
nrecsperseg = 20     #Number of recording locations per compartment
if nrecsperseg == 1:
  xsperseg = [0.5]
else:
  xsperseg = [1.0*x/(nrecsperseg+1) for x in range(1,nrecsperseg+1)]

BACdt = 5.0 #ISI between apical and somatic stimulus (needed in one of the fourth step objectives)


###################### Function definitions ###############################


#initialize_model(): a function that defines the morphology, stimuli, recordings and initial values
def initialize_model():
  global dists_apical, dists_basal
  v0 = -80
  ca0 = 0.0001
  fs = 8
  tstop = 15000.0
  icell = 0

  anyRecsRecorded = 0
  for istims in range(0,len(SETUP[istep])):
    stims = SETUP[istep][istims]
    irecs = stims[2]
    if type(irecs) is not list:
      irecs = [irecs]
    if any([x > 0 for x in irecs]):
      anyRecsRecorded = 1

  h("""
load_file("stdlib.hoc")
load_file("stdrun.hoc")
objref cvode
cvode = new CVode()
cvode.active(1)
cvode.atol(0.00005)
load_file("models/fourcompartment.hoc")
objref L5PC
L5PC = new fourcompartment()
access L5PC.soma
distance()
objref st1
st1 = new IClamp(0.5)
L5PC.soma st1
objref vsoma, vdend, vdend2
vsoma = new Vector()
vdend = new Vector()
vdend2 = new Vector()
objref syn1, tvec, sl
tvec = new Vector()
sl = new List()
double siteVec[2]
siteVec[0] = 0
siteVec[1] = 0.5
L5PC.apic[0] cvode.record(&v(siteVec[1]),vdend,tvec)
L5PC.apic[0] syn1 = new epsp(siteVec[1])
syn1.imax = 0
L5PC.apic[0] nseg = """+str(nseg)+"""
L5PC.apic[1] nseg = """+str(nseg)+"""
L5PC.soma cvode.record(&v(0.5),vsoma,tvec)
L5PC.soma nseg = """+str(nseg)+"""
L5PC.dend nseg = """+str(nseg)+"""
objref vrecs_apical["""+str(2*nrecsperseg)+"""], vrecs_basal["""+str(nrecsperseg)+"""], carecs_apical["""+str(2*nrecsperseg)+"""]
tstop = """+str(tstop)+"""
cai0_ca_ion = """+str(ca0)+"""
v_init = """+str(v0)+"""
""")

  dists_apical = []
  dists_basal = []
  if anyRecsRecorded:
    for j in range(0,nrecsperseg):
      dists_apical.append(h.distance(xsperseg[j],sec=h.L5PC.apic[0]))
      h("L5PC.apic[0] vrecs_apical["+str(j)+"] = new Vector()")
      h("L5PC.apic[0] cvode.record(&v("+str(xsperseg[j])+"),vrecs_apical["+str(j)+"],tvec)")
      h("L5PC.apic[0] carecs_apical["+str(j)+"] = new Vector()")
      h("L5PC.apic[0] cvode.record(&cai("+str(xsperseg[j])+"),carecs_apical["+str(j)+"],tvec)")
    for j in range(0,nrecsperseg):
      dists_apical.append(h.distance(xsperseg[j],sec=h.L5PC.apic[1]))
      h("L5PC.apic[1] vrecs_apical["+str(nrecsperseg+j)+"] = new Vector()")
      h("L5PC.apic[1] cvode.record(&v("+str(xsperseg[j])+"),vrecs_apical["+str(nrecsperseg+j)+"],tvec)")
      h("L5PC.apic[1] carecs_apical["+str(nrecsperseg+j)+"] = new Vector()")
      h("L5PC.apic[1] cvode.record(&cai("+str(xsperseg[j])+"),carecs_apical["+str(nrecsperseg+j)+"],tvec)")
    for j in range(0,nrecsperseg):
      dists_basal.append(h.distance(xsperseg[j],sec=h.L5PC.dend))
      h("L5PC.dend vrecs_basal["+str(j)+"] = new Vector()")
      h("L5PC.dend cvode.record(&v("+str(xsperseg[j])+"),vrecs_basal["+str(j)+"],tvec)")
  else:
    for j in range(0,nrecsperseg):
      dists_apical.append(h.distance(xsperseg[j],sec=h.L5PC.apic[0]))
      dists_apical.append(h.distance(xsperseg[j],sec=h.L5PC.apic[1]))
      dists_basal.append(h.distance(xsperseg[j],sec=h.L5PC.dend))
    print "No vrecs and carecs recorded!"

#setparams(params, istep): a function that sets the conductances of different species, axial resistances, capacitances or lengths for each compartment
#Input:
#  params: A dictionary for the parameters. See snmf_protocols for the variable names
#  istep: Which step fit is this? If fourth step (istep = 3), no additional action taken, but if istep < 3, set the conductance parameters corresponding
#    to the following steps to zero
def setparams(params, istep):
  global dists_apical, dists_basal
  
  keys = params.keys()
  #Apply the new parameter values:
  lengthChanged = False
  for ikey in range(0,len(keys)):
    key = keys[ikey]
    if key[0:2] == "L_":
      lengthChanged = True
    underscoreind = key.rfind('_')
    section = key[underscoreind+1:len(key)]
    if section == "*": # If a parameter has to be same in all sections (such as ehcn)
      h("forall "+key[0:underscoreind]+" = "+str(params[key]))
    else:
      h("L5PC."+section+" "+key[0:underscoreind]+" = "+str(params[key]))

  #Assume that if one length changed then every length changed. Change also the compartment
  #diameters such that the corresponding membrane area is conserved.
  if lengthChanged:                    
    h("L5PC.soma diam = "+str(360.132/params['L_soma']))
    h("L5PC.dend diam = "+str(2821.168/params['L_dend']))
    h("L5PC.apic[0] diam = "+str(4244.628/params['L_apic[0]']))
    h("L5PC.apic[1] diam = "+str(2442.848/params['L_apic[1]']))

    dists_apical = []
    dists_basal = []
    for j in range(0,nrecsperseg):
      dists_apical.append(h.distance(xsperseg[j],sec=h.L5PC.apic[0]))
    for j in range(0,nrecsperseg):
      dists_apical.append(h.distance(xsperseg[j],sec=h.L5PC.apic[1]))
    for j in range(0,nrecsperseg):
      dists_basal.append(h.distance(xsperseg[j],sec=h.L5PC.dend))
    if params['L_apic[0]'] > 620:
      h("L5PC.apic[0] syn1.loc("+str(620.0/params['L_apic[0]'])+")")
    elif params['L_apic[0]'] + params['L_apic[1]'] > 620:
      h("L5PC.apic[1] syn1.loc("+str((620.0-params['L_apic[0]'])/params['L_apic[1]'])+")")
    else:
      h("L5PC.apic[1] syn1.loc(1.0)")

  #Set those conductances to zero that will be fitted at the next step:
  for istep2 in range(istep+1,4):
    vars_zero = snmf_protocols.get_variable_params(istep2)
    for iparam in range(0,len(vars_zero)):
      if vars_zero[iparam][0]=='g' and vars_zero[iparam].find('bar') > -1:
        underscoreind = vars_zero[iparam].rfind('_')
        section = vars_zero[iparam][underscoreind+1:len(vars_zero[iparam])]
        h("L5PC."+section+" "+vars_zero[iparam][0:underscoreind]+" = 0")

#run_model(istep,saveFig,stop_if_needed): a function that runs the model using the stimuli of a certain step
#Input:
#  istep (0 to 3): Which step to run
#  saveFig: Name of the EPS file to save (empty if no need to plot the time course)
#  stop_if_needed: Boolean telling whether we can quit the function after the first stimulus condition that produces too many or too
#    few spikes. This is relevant only in the fourth step, where relatively heavy simulations are performed.
#Output:
#  myValsAllAll: A data structure containing all the data needed for determining the objective function values
def run_model(istep,saveFig="",stop_if_needed=True):
  global STIMULI, SETUP, STIMTYPES, DATASTORTYPES, dists_apical, dists_basal
  time_to_quit = False
  
  myValsAllAll = [nan]*len(SETUP[istep])
  for istims in range(0,len(SETUP[istep])):
    stims = SETUP[istep][istims]
    myValsAll = [nan]*len(stims[1])
    for iamp in range(0,len(stims[1])):
      for istim in range(0,len(stims[0])):
        stimulus = STIMULI[stims[0][istim]]
        st = STIMTYPES[stimulus[0]]
        if type(stims[1][iamp]) is list:
          myamp = stims[1][iamp][istim]
        else:
          myamp = stims[1][iamp]
        h(st[0]+"."+st[4]+" = "+str(myamp))
        for iprop in range(0,len(stimulus[1])):
          thisPropVal = stimulus[1][iprop][1]
          if stimulus[1][iprop][0] == "onset" or stimulus[1][iprop][0] == "del":
            thisPropVal = thisPropVal + istim*BACdt
          h(st[0]+"."+stimulus[1][iprop][0]+" = "+str(thisPropVal))
          print(st[0]+"."+stimulus[1][iprop][0]+" = "+str(thisPropVal))
      h.init()
      h.run()

      irecs = stims[2]
      if type(irecs) is not list:
        irecs = [irecs]

      times=np.array(h.tvec)
      Vsoma=np.array(h.vsoma)
      if any([x > 0 for x in irecs]):
        Vdend=np.concatenate((np.array(h.vrecs_apical),np.array(h.vrecs_basal)))
        Cadend=np.array(h.carecs_apical)
      else:
        Vdend=[]
        Cadend=[]

      spikes = mytools.spike_times(times,Vsoma,-35,-45)
      if len(saveFig) > 0:
        close("all")
        f,axarr = subplots(2,2)
        axarr[0,0].plot(times,Vsoma)
        axarr[0,0].set_xlim([9990,11000])
        axarr[0,0].set_title("nSpikes="+str(len(spikes)))
        if len(Vdend) > 0:
          axarr[0,1].plot(dists_apical+[-x for x in dists_basal], [max([x[i] for i,t in enumerate(times) if t >= 9000]) for x in Vdend], 'bx')
        if len(Cadend) > 0:
          axarr[1,0].plot(dists_apical, [max([x[i] for i,t in enumerate(times) if t >= 9000]) for x in Cadend], 'b.')
        f.savefig(saveFig+"_step"+str(istep)+"_istims"+str(istims)+"_iamp"+str(iamp)+".eps")
      if stop_if_needed and ((stop_if_no_spikes[istep][istims][iamp] > 0 and len(spikes) == 0) or (stop_if_nspikes[istep][istims][iamp] > 0 and len(spikes) >= stop_if_nspikes[istep][istims][iamp])):
        time_to_quit = True
        break

      myVals = []
      for iirec in range(0,len(irecs)):
        irec = irecs[iirec]
        if irec == 0:
          myData = [Vsoma]
        elif irec == 1:
          myData = Vdend
        elif irec == 2:
          myData = Cadend
        else:
          print "Unknown recording type: "+str(irec)
          continue
        if DATASTORTYPES[stims[3]][0] == "fixed":
          print "ind="+str(next(i for i,t in enumerate(times) if t >= DATASTORTYPES[stims[3]][1]))+", val="+str(myData[0][next(i for i,t in enumerate(times) if t >= DATASTORTYPES[stims[3]][1])-1])
          myVals.append([x[next(i for i,t in enumerate(times) if t >= DATASTORTYPES[stims[3]][1])-1] for x in myData])
        elif DATASTORTYPES[stims[3]][0] == "max":
          myVals.append([max(x[next(i for i,t in enumerate(times) if t >= DATASTORTYPES[stims[3]][1][0]):next(i for i,t in enumerate(times) if t >= DATASTORTYPES[stims[3]][1][1])]) for x in myData])
        elif DATASTORTYPES[stims[3]][0] == "trace" or DATASTORTYPES[stims[3]][0] == "highrestrace":
          myVals.append([mytools.interpolate_extrapolate_constant(times,x,DATASTORTYPES[stims[3]][1]) for x in myData])
        elif DATASTORTYPES[stims[3]][0] == "highrestraceandspikes":
          myVals.append([[mytools.interpolate_extrapolate_constant(times,x,DATASTORTYPES[stims[3]][1]) for x in myData],spikes])
        elif DATASTORTYPES[stims[3]][0] == "nspikes":
          myVals.append(sum([1 for x in spikes if x >= DATASTORTYPES[stims[3]][1][0] and x < DATASTORTYPES[stims[3]][1][1]]))
        else:
          print "Unknown data storage type: "+DATASTORTYPES[stims[3]][0]
          continue
      myValsAll[iamp] = myVals[:]
      if time_to_quit:
        break
    myValsAllAll[istims] = myValsAll[:]

    for istim in range(0,len(stims[0])):
      stimulus = STIMULI[stims[0][istim]]
      st = STIMTYPES[stimulus[0]]
      h(st[0]+"."+st[4]+" = 0")
    if time_to_quit:
      break

  return myValsAllAll[:]

#distdiff(xs, fs, xs_ref, fs_ref): a function to calculate the difference of membrane potential distribution along the dendrites
#Input:
#  xs: Distances of recording locations from soma, along the dendrites of the reduced model (negative for basal, positive for apical)
#  fs: Values (e.g. maximum membrane potential during or following a stimulus) along the dendrites of the reduced model
#  xs_ref: Distances of recording locations from soma, along the dendrites of the full model (negative for basal, positive for apical)
#  fs_ref: Values (e.g. maximum membrane potential during or following a stimulus) along the dendrites of the full model. The minimal
#    and maximal values of xs_ref and fs_ref are used to scale the quantities so that both dimensions (spave vs. membrane potential)
#    are as relevant a priori
def distdiff(xs, fs, xs_ref, fs_ref):
  dist = 0
  n = 0
  dist_diff_max = max(xs_ref)-min(xs_ref)
  f_diff_max = max(fs_ref)-min(fs_ref)

  for ix in range(0,len(xs)):
    if xs[ix] > max(xs_ref) or xs[ix] < min(xs_ref):
      continue
    dists2 = [(1.0*(xs[ix]-x)/dist_diff_max)**2 + (1.0*(fs[ix]-y)/f_diff_max)**2 for x,y in zip(xs_ref, fs_ref)]
    dist = dist + sqrt(min(dists2))
    n = n+1
  return 1.0*dist/n

#highrestraceandspikesdiff(data, dataref, coeff_mV=1.0/12, coeff_ms=1.0/20): a function to calculate difference between spike trains.
#By default, a mean difference of 12mV memb. pot. is penalized as much as summed distance of 20ms between nearest spikes and further as
#much as a difference of one spike
#Input:
#  data: A list [V_m, spikes] containing the membrane potentials and spike times of the reduced model
#  dataref: A list [V_m, spikes] containing the membrane potentials and spike times of the full model
#  coeff_mV: How much is the mean difference of 1mV between target and model membrane potential penalized with respect to a difference of one spike in spike count
#  coeff_ms: How much is the summed difference of 1ms in spike timings between target and model penalized with respect to a difference of one spike in spike count
#Output:
#  (difference between numbers of spikes) + coeff_ms*(summed difference between spike timings) + coeff_mV*(mean difference between membrane potentials)
def highrestraceandspikesdiff(data, dataref, coeff_mV=1.0/12, coeff_ms=1.0/20):
  trace1 = data[0]
  spikes1 = data[1]
  traceref = dataref[0]
  spikesref = dataref[1]

  meantracediffs = [1.0*mean([abs(x-y) for x,y in zip(thistrace, thistraceref)]) for thistrace, thistraceref in zip(trace1, traceref)]
  sp_N_err = abs(len(spikesref)-len(spikes1))
  sp_t_err = 0
  if len(spikesref) > 0:
    for ispike in range(0,len(spikes1)):
      sp_t_err = sp_t_err + min([abs(spikes1[ispike] - x) for x in spikesref])
  if type(coeff_mV) is list: # Assume that if there are more than one trace lists, then coeff_mV is explicitly given with each element denoting the coefficient for a separate trace
    return sum([x*y for x,y in zip(meantracediffs, coeff_mV)]) + sp_t_err * coeff_ms + sp_N_err
  else:
    return meantracediffs[0] * coeff_mV + sp_t_err * coeff_ms + sp_N_err


#func_to_optimize(parameters,istep, saveFig=False, filename="FIGUREWITHMIDSCOMBFS",stop_if_needed=True): the function which is to be minimized.
#Calls subfunc_to_optimize with the same parameters and groups the objectives according to data in objective_groups and objective_group_coeffs
def func_to_optimize(parameters,istep, saveFig=False, filename="FIGUREWITHMIDSCOMBFS",stop_if_needed=True):
  mydict = subfunc_to_optimize(parameters,istep, saveFig, filename, stop_if_needed)
  mynewdict = {}
  for j in range(0,len(OBJECTIVES[istep])):
    myval = 0
    for k in range(0,len(objective_groups[istep][j])):
      myval = myval + objective_group_coeffs[istep][j][k]*mydict[OBJECTIVES_SUBFUNC[istep][objective_groups[istep][j][k]]]
    mynewdict['f'+str(istep)+'_'+str(j)] = myval
  return mynewdict

#subfunc_to_optimize(parameters,istep, saveFig, filename,stop_if_needed): the function which is to be minimized.
#Input:
#  parameters: A dictionary for the parameters. See snmf_protocols for the variable names
#  istep: The step to perform (0 to 3)
#  saveFig: Boolean telling whether to plot how well the parameters are fitted to the target data
#  filename: The name of the EPS file to save if any
#  stop_if_needed: Boolean telling whether we can quit the function after the first stimulus condition that produces too many or too
#    few spikes. This is relevant only in the fourth step, where relatively heavy simulations are performed.
#Output:
#  mydict: Dictionary of the (sub-)objective function values
def subfunc_to_optimize(parameters,istep, saveFig, filename, stop_if_needed):

  MAXERR = 1e8 # If stop_if_needed is true, values of objective functions corresponding to those stimuli that were skipped will be 10^8
  setparams(parameters,istep)
  global ORIGINALDATA, SETUP, dists_apical, dists_basal #Reload the values that have possibly been changed by setparams
  xs = dists_apical + [-x for x in dists_basal]
  A = run_model(istep,"",stop_if_needed)

  if saveFig:
    close("all")
    f,axarr = subplots(5,2)
    axs = [axarr[0,0],axarr[1,0],axarr[2,0],axarr[3,0],axarr[4,0],axarr[0,1],axarr[1,1],axarr[2,1],axarr[3,1],axarr[4,1]]
    for iplot in range(0,10):
      iplotx = iplot%2
      iploty = iplot/2
      axs[iplot].set_position([0.08+0.5*iplotx,0.86-0.185*iploty,0.4,0.11])
    saveInd = 0
    global FIGUREIND

  mydict = {}
  for ifun in range(0,len(SETUP[istep])):
    for iamp in range(0,len(SETUP[istep][ifun][1])):
      irecs = SETUP[istep][ifun][2]
      istor = SETUP[istep][ifun][3]
      if type(irecs) is not list:
        irecs = [irecs]

      if type(A[ifun]) is not list or (type(A[ifun][iamp]) is not list and isnan(A[ifun][iamp])):
        for iirec in range(0,len(irecs)):
          mydict['f'+str(istep)+'_'+str(ifun)+'_'+str(iamp)+'_'+str(iirec)] = MAXERR
        continue
          
      for iirec in range(0,len(irecs)):
        irec = irecs[iirec]
        if istor == 0 or istor == 1: # Maxima or steady-state-values across the dendrite(s)
          if irec == 1: #Voltage, whole dendritic tree
            mydict['f'+str(istep)+'_'+str(ifun)+'_'+str(iamp)+'_'+str(iirec)] = distdiff(xs, A[ifun][iamp][iirec], ORIGINALXS, ORIGINALDATA[istep][ifun][iamp][iirec])
          if irec == 2: #Calcium, only apical dendrite
            mydict['f'+str(istep)+'_'+str(ifun)+'_'+str(iamp)+'_'+str(iirec)] = distdiff(dists_apical, A[ifun][iamp][iirec], ORIGINALDISTS[0], ORIGINALDATA[istep][ifun][iamp][iirec])
          if saveFig:
            if irec == 1: #Voltage, whole dendritic tree
              axs[saveInd].plot(xs, A[ifun][iamp][iirec],'bx')
              axs[saveInd].plot(ORIGINALXS, ORIGINALDATA[istep][ifun][iamp][iirec],'g.')
            if irec == 2: #Calcium, only apical dendrite
              axs[saveInd].plot(dists_apical, A[ifun][iamp][iirec],'bx')
              axs[saveInd].plot(ORIGINALDISTS[0], ORIGINALDATA[istep][ifun][iamp][iirec],'g.')
            axs[saveInd].set_title('f'+str(istep)+'_'+str(ifun)+'_'+str(iamp)+'_'+str(iirec)+"="+str(mydict['f'+str(istep)+'_'+str(ifun)+'_'+str(iamp)+'_'+str(iirec)]),fontsize=7)
            if saveInd < len(axs)-1:
              saveInd = saveInd+1
        elif istor == 2 or istor == 3: # Time series
          sumval = 0
          for irecloc in range(0,len(A[ifun][iamp][iirec])): # Usually only time course of soma, but techinically allowed for dendrites as well
            sumval = sumval + sum([1.0*abs(x-y) for x,y in zip(A[ifun][iamp][iirec][irecloc], ORIGINALDATA[istep][ifun][iamp][iirec][irecloc])])
            if saveFig:
              axs[saveInd].plot(A[ifun][iamp][iirec][irecloc])
              axs[saveInd].plot(ORIGINALDATA[istep][ifun][iamp][iirec][irecloc])
          mydict['f'+str(istep)+'_'+str(ifun)+'_'+str(iamp)+'_'+str(iirec)] = sumval/250.0
          if istor == 2:
            mydict['f'+str(istep)+'_'+str(ifun)+'_'+str(iamp)+'_'+str(iirec)] = sumval*5/250.0
          if saveFig:
            axs[saveInd].set_title('f'+str(istep)+'_'+str(ifun)+'_'+str(iamp)+'_'+str(iirec)+"="+str(mydict['f'+str(istep)+'_'+str(ifun)+'_'+str(iamp)+'_'+str(iirec)]),fontsize=7)
            if saveInd < len(axs)-1:
              saveInd = saveInd+1
        elif istor == 4: # Time series with spike time precision
          mydict['f'+str(istep)+'_'+str(ifun)+'_'+str(iamp)+'_'+str(iirec)] = highrestraceandspikesdiff(A[ifun][iamp][iirec], ORIGINALDATA[istep][ifun][iamp][iirec])
          if saveFig:
            for irecloc in range(0,len(A[ifun][iamp][iirec][0])): # Usually only time course of soma, but techinically allowed for dendrites as well
              axs[saveInd].plot(A[ifun][iamp][iirec][0][irecloc])
              axs[saveInd].plot(ORIGINALDATA[istep][ifun][iamp][iirec][0][irecloc])
            axs[saveInd].set_title('f'+str(istep)+'_'+str(ifun)+'_'+str(iamp)+'_'+str(iirec)+"="+str(mydict['f'+str(istep)+'_'+str(ifun)+'_'+str(iamp)+'_'+str(iirec)])+
                                   ', nSp='+str(len(A[ifun][iamp][iirec][1]))+', nSpref='+str(ORIGINALDATA[istep][ifun][iamp][iirec][1]),fontsize=8)
            if saveInd < len(axs)-1:
              saveInd = saveInd+1
        else: # Nspikes
          mydict['f'+str(istep)+'_'+str(ifun)+'_'+str(iamp)+'_'+str(iirec)] = 1.0*(A[ifun][iamp][iirec]-ORIGINALDATA[istep][ifun][iamp][iirec])**2
          if saveFig:
            axs[saveInd].set_title('f'+str(istep)+'_'+str(ifun)+'_'+str(iamp)+'_'+str(iirec)+"="+str(mydict['f'+str(istep)+'_'+str(ifun)+'_'+str(iamp)+'_'+str(iirec)]),fontsize=7)
            if saveInd < len(axs)-1:
              saveInd = saveInd+1
  if saveFig:
    for iax in range(0,len(axs)):
      for tick in axs[iax].yaxis.get_major_ticks()+axs[iax].xaxis.get_major_ticks():
        tick.label.set_fontsize(5)
    write_params_here = [6,7,8,9]
    for j in range(0,4):
      myText = ''
      varnames = [x[0] for x in VARIABLES[j]]
      nplaced = 0
      for ivarname in range(0,len(varnames)):
        if parameters.has_key(varnames[ivarname]):
          myText = myText + varnames[ivarname] + " = " + str(parameters[varnames[ivarname]])
        else:
          underscoreind = varnames[ivarname].rfind('_')
          section = varnames[ivarname][underscoreind+1:len(varnames[ivarname])]
          h("tmpvar = 999999")
          if section == "*": # If a parameter has to be same in all sections (such as ehcn)                                                                                                               
            h("tmpvar = L5PC.soma."+varnames[ivarname][0:underscoreind])
          else:
            h("tmpvar = L5PC."+section+"."+varnames[ivarname][0:underscoreind])
          myText = myText + "No " + varnames[ivarname] + " (" + str(h.tmpvar) + ")"
        nplaced = nplaced + 1
        if ivarname < len(varnames)-1:
          myText = myText + ", "
        if nplaced > 4:
          nplaced = 0; myText = myText + '\n'
      myText = myText.replace('CaDynamics_E2','CaD');
      axs[write_params_here[j]].set_xlabel(myText,fontsize=2.2)
    if FIGUREIND==-1:
      f.savefig(filename+".eps")
    else:
      f.savefig(filename+str(istep)+"_"+str(FIGUREIND)+".eps")
    FIGUREIND = FIGUREIND+1
  return mydict


########################## The main code ##################################


initialize_model() # The model initialization is needed in order to be able to calculate the objective functions for the final parameter sets. All other 
                   # objective function calculations are assumed to be done (and the values are assumed to be saved to _tmpXX-files)

ext = chr(ord('a')+iord)
filename = "pars_withmids_combfs"
par_names = []
par_values = []
paramdict = {}

#Final parameter set:
unpicklefile = open(filename+"_final.sav", 'r')
unpickledlist = pickle.load(unpicklefile)
unpicklefile.close()
par_names_final = unpickledlist[0]
par_values_final = unpickledlist[1]
paramdict_final = {}
for i in range(0,len(par_names_final)):
  paramdict_final[par_names_final[i]] = par_values_final[i]

#Calculate the objective functions for the final parameter set:
fvals_final = func_to_optimize(paramdict_final,istep,False,'')
fvals_final_step0 = [] # This is only needed for step 0 where gpas values are replaced by the step 1 fit!

if istep > 0:
  #Determine the name of the file to load (to get the parameters from the previous step fittings)
  for ichosen in range(0,len(stepsChosen)):
    filename = filename+"_"+str(stepsChosen[ichosen])+"a"
  filename = filename[0:-1] + ext
else:
  unpicklefile = open(filename+"_1a.sav", 'r') #Load the parameters of the first step, optimized for the second objective function. These
                                               #should be the same as the first step parameters in the pars_withmids_combfs_final.sav,
                                               #except that the passive conductances are refitted in pars_withmids_combfs_final.sav and
                                               #hence reprsent a worse fit (red dashed line) to the first step objective function data
  unpickledlist = pickle.load(unpicklefile)
  unpicklefile.close()
  par_names_final_step0 = unpickledlist[0]
  par_values_final_step0 = unpickledlist[1]
  paramdict_final_step0 = {}
  for i in range(0,len(par_names_final_step0)):
    paramdict_final_step0[par_names_final_step0[i]] = par_values_final_step0[i]
  fvals_final_step0 = func_to_optimize(paramdict_final_step0,istep,False,'')

print "filename: "+filename
f,axarr = subplots(len(OBJECTIVES[istep]),1)
fvals_all = []
medians_all = []
prc25s_all = []
prc75s_all = []
mins_all = []
maxs_all = []

maxs_ylims = [[1e4]*3, [1e4]*3, [1e5]*3, [1e1, 1e1, 1e4]] #Cut the figures after these values (anyway, the MAXERR=1e8 otherwise shown in some cases is an arbitrary value)

for igen in range(0,N_generations[istep]):
  if exists(filename+"_tmp"+str(igen)+".sav"):
    unpicklefile = open(filename+"_tmp"+str(igen)+".sav", 'r')
  else:
    print filename+"_tmp"+str(igen)+".sav not found. Run the fittings first."
    sys.exit()
  unpickledlist = pickle.load(unpicklefile)
  unpicklefile.close()
  fvals = [[unpickledlist[0][j][i] for i in range(len(VARIABLES[istep]),len(VARIABLES[istep])+len(OBJECTIVES[istep]))] for j in range(0,N_samples[istep])]
  #for ifval in range(0,len(fvals)):                                                                                                                                                                         
  #  axarr[ifval].boxplot                                                                                                                                                                                    
  fvals_all.append(fvals)
  medians_all.append([median([fvals[i][ifval] for i in range(0,N_samples[istep])]) for ifval in range(0,len(OBJECTIVES[istep]))])
  mins_all.append([min([fvals[i][ifval] for i in range(0,N_samples[istep])]) for ifval in range(0,len(OBJECTIVES[istep]))])
  maxs_all.append([max([fvals[i][ifval] for i in range(0,N_samples[istep])]) for ifval in range(0,len(OBJECTIVES[istep]))])
  prc25s_all.append([percentile([fvals[i][ifval] for i in range(0,N_samples[istep])],25.) for ifval in range(0,len(OBJECTIVES[istep]))])
  prc75s_all.append([percentile([fvals[i][ifval] for i in range(0,N_samples[istep])],75.) for ifval in range(0,len(OBJECTIVES[istep]))])
  print str(unpickledlist[1])
  print str(len(fvals))

h("strdef myfilename")
h("myfilename=\""+filename+"\"")

for ifval in range(0,len(fvals[0])):
  axarr[ifval].plot([0,N_generations[istep]+1],[fvals_final[OBJECTIVES[istep][ifval]]]*2,'r--',linewidth=2,dashes=(4,1.8))
  if len(fvals_final_step0) > 0:
    axarr[ifval].plot([0,N_generations[istep]+1],[fvals_final_step0[OBJECTIVES[istep][ifval]]]*2,'b--',linewidth=2,dashes=(4,2))
  for igen in range(0,N_generations[istep]):
    plotmybox(axarr[ifval],[mins_all[igen][ifval],prc25s_all[igen][ifval],medians_all[igen][ifval],prc75s_all[igen][ifval],maxs_all[igen][ifval]],igen+1,0.3)
  axarr[ifval].set_xlim([0.5,N_generations[istep]+0.5])
  #if min([mins_all[igen][ifval] for igen in range(0,N_generations[istep])]) > 0:
  axarr[ifval].set_yscale("log", nonposy='clip')
  if max([maxs_all[igen][ifval] for igen in range(0,N_generations[istep])]) > maxs_ylims[istep][ifval]:
    axarr[ifval].set_ylim([axarr[ifval].get_ylim()[0], maxs_ylims[istep][ifval]])
  axarr[ifval].set_ylabel(OBJECTIVES[istep][ifval])
  if N_generations[istep] < 20:
    pos = axarr[ifval].get_position()
    axarr[ifval].set_position([pos.x0, pos.y0, pos.width*N_generations[istep]/20, pos.height])
axarr[0].set_title("Step "+str(istep+1),fontweight="bold")
axarr[len(fvals[0])-1].set_xlabel("generation")
f.savefig(filename+"_objective_evolution.eps")