load_file("CrossingFinder.hoc")
load_file("Integrator.hoc")
load_file("bpap-dendburst.hoc")
load_file("bpap-record.hoc")
load_file("bpap-data.hoc")

objref r

objref vsynmax, vsynint, vsyndel
objref vtreemax, vtreeint
objref casynmax, casyndel, casynint
objref catreemax
objref vsynwidth

// Means
objref vsrimax_mean, vsriint_mean, vsridel_mean
objref vtreemax_mean, vtreeint_mean
objref casrimax_mean, casriint_mean, casridel_mean
objref catreemax_mean
objref vsriwidth_mean

// Stderrs
objref vsrimax_stderr, vsriint_stderr, vsridel_stderr
objref vtreemax_stderr, vtreeint_stderr
objref casrimax_stderr, casriint_stderr, casridel_stderr
objref catreemax_stderr
objref vsriwidth_stderr

// Mapping from syns to sris
objref syn_sri
objref apc

datalist.appendMatrix("vtreemax")
//datalist.appendMatrix("casynmax")
//datalist.appendMatrix("vtreeint")
//datalist.appendMatrix("casynint")
datalist.appendMatrix("vsyndel")
//datalist.appendMatrix("casyndel")
//datalist.appendMatrix("catreemax")

// Record of mean quantities at synapses located in a particular sri
datalist.appendVector("vsrimax_mean")
datalist.appendVector("vsriint_mean")
datalist.appendVector("vsridel_mean")
datalist.appendVector("vtreemax_mean")
datalist.appendVector("vtreeint_mean")
datalist.appendVector("vsriwidth_mean")

datalist.appendVector("casrimax_mean") 
datalist.appendVector("casriint_mean")
datalist.appendVector("casridel_mean")
datalist.appendVector("catreemax_mean")

// Record of stderrs at synapses located in a particular sri
datalist.appendVector("vsrimax_stderr")
datalist.appendVector("vsriint_stderr")
datalist.appendVector("vsridel_stderr")
datalist.appendVector("vtreemax_stderr")
datalist.appendVector("vtreeint_stderr")
datalist.appendVector("vsriwidth_stderr")

datalist.appendVector("casrimax_stderr") 
datalist.appendVector("casriint_stderr")
datalist.appendVector("casridel_stderr")
datalist.appendVector("catreemax_stderr")


datalist.appendMatrix("syn_sri") // Record of where the synapses were located on each run

//datalist.appendVectorList("catree")

// Reordings of the time course of various quantities of the "screc" (Shaffer Collateral
// recorded) synapses, and at the soma 

objref screc_vsyn, screc_casyn, screc_ica_nmdasyn, screc_ica_ampasyn, screc_icasyn, screc_catree

datalist.appendVectorList("screc_vsyn")
datalist.appendVectorList("screc_casyn")
datalist.appendVectorList("screc_icasyn")
datalist.appendVectorList("screc_ica_ampasyn")
datalist.appendVectorList("screc_ica_nmdasyn")
datalist.appendVectorList("screc_catree")

datalist.appendVector("tsoma")
datalist.appendVector("vsoma")
datalist.appendVector("apc")

// Recordings in full detail
objref vtreerecinds
vtreerecinds = new Vector()
objref vtreerec
objref tsomarec
// datalist.appendList("vtreerec")
// datalist.appendList("tsomarec")
// datalist.appendVector("vtreerecinds")

//
// The main function to run a series of simulations, and extract
// statistics from them.
//
objref tpos, tneg
proc run_bpap() { local tmax, n, ind, i localobj nsri, sumcasrimax, syninds, camax
    
    // Record the impedance as a measure of synaptic size
    cell_measure_impedance()
    
    // New random oject (for Jitter)
    r = new Random()
    if (jitter_type==1) {
        r.uniform(start,start+jitter)
    } else {
        r.normal(start+jitter,jitter)
    }
    
    // Record voltage and calcium from tree
    record_vtree()
    record_vsoma()
    record_catree()
    record_apcount_soma()
    
    // Matrix to store values at each segment or synapse for each run
    vsynmax   = new Matrix(nruns, nsyn)
    vsynint   = new Matrix(nruns, nsyn)
    vsyndel   = new Matrix(nruns, nsyn)
    vtreemax  = new Matrix(nruns, vtree.count())
    vtreeint  = new Matrix(nruns, vtree.count())
    vsynwidth = new Matrix(nruns, nsyn)
    catreemax = new Matrix(nruns, catree.count())
    casynmax  = new Matrix(nruns, nsyn)
    casynint  = new Matrix(nruns, nsyn)
    casyndel  = new Matrix(nruns, nsyn)
    syn_sri   = new Matrix(nruns, nsyn)
    apc       = new Vector(nruns)
    
    vtreerec = new List()
    tsomarec = new List()
    
    // Do nruns runs
    for n=0,nruns-1 { 
        // Create and distribute spines with synapses on randomly
        create_and_distribute_inputs()
        
        // Setup the soma iclamp - this is only active is soma_iclamp_state=1
        setup_soma_iclamp()
        
        // Record from the spines
        record_casyn()
        record_vsyn()
        record_icasyn(screc_inputs)
        record_ica_nmdasyn(screc_inputs)
        record_ica_ampasyn(screc_inputs)
        record_tprespike()
        
        print "Run ", n+1, "of " , nruns
        print "Running simulation..."
        
        // Make the random presynaptic spike times
        tprespike = new Vector(scall_inputs.count())
        if (dendburst_state) {
            tprespike.setrand(r)
            for i=0, nsyn-1 {
                scall_inputs.object(i).netstim.start = tprespike.x(i)
            } // should only be one spike
        }
        
        // We initialise before reading in the spike times -- 
        // otherwise they aren't read in since t > the spiketimes we're adding
        stdinit()
        if (spiketime_state) {
            bpap_spiketrain_set_spiketrain(scstim_file)
        }
        
        // Run the simulation
        running_ = 1
        continuerun(tstop)
        
        // Analysis
        //    vsomamax = new Vector(vsoma.count())
        print "Computing maximum voltage..."
        for i=0, vtree.count()-1 {
            vtreemax.x[n][i] = vtree.object(i).max()
        }
        
        print "Computing integral of voltage..."
        for i=0, vtree.count()-1 {
            if (cvode.active()) {
                vtreeint.x[n][i] = integrator.integrate(vtree.object(i),tsyn.object(i))
            } else {
                vtreeint.x[n][i] = vtree.object(i).sum() * dt
            }
        }
        
        print "Computing maximum voltage in each spine..."
        for i=0, vsyn.count()-1 {
            vsynmax.x[n][i] = vsyn.object(i).max()
        }
        
        print "Computing integral of voltage in each spine..."
        for i=0, vsyn.count()-1 {
            if (cvode.active()) {
                vsynint.x[n][i] = integrator.integrate(vsyn.object(i),tsyn.object(0))
            } else {
                vsynint.x[n][i] = vsyn.object(i).sum() * dt
            }
        }
        
        print "Computing delay of voltage..."
        for i=0, vsyndel.ncol()-1 { 
            ind = scall_inputs_sri.x(i)
            tmax = tsyn.object(ind).x(vtree.object(ind).max_ind())
            vsyndel.x[n][i] = tmax - tprespike.x(i)
        }
        
        print "Computing maximum Ca in each spine..."
        for i=0, casyn.count()-1 {
            casynmax.x[n][i] = casyn.object(i).max()
        }
        
        print "Computing integral of Ca in each spine..."
        for i=0, casyn.count()-1 {
            if (cvode.active()) {
                casynint.x[n][i] = integrator.integrate(casyn.object(i),tsyn.object(0))
            } else {
                casynint.x[n][i] = casyn.object(i).sum() * dt
            }
        }
        
        print "Computing delay to peak Ca in each spine..."
        for i=0, casyn.count()-1 {
            tmax = tsyn.object(0).x(casyn.object(i).max_ind())
            casyndel.x[n][i] = tmax - tprespike.x(i)
            syn_sri.x[n][i]  = scall_inputs_sri.x(i)
        }
        
        print "Computing maximum Ca in tree..."
        for i=0, catree.count()-1 {
            catreemax.x[n][i] = catree.object(i).max()
        }
        
        print "Recording voltage traces in tree..."
        if (vtreerecinds.size() > 0) {
            vtreerec.append(new Matrix(vtreerecinds.size(), vtree.o(0).size()))
            tsomarec.append(tsoma.c)
            for i=0, vtreerecinds.size()-1 {
                vtreerec.o(vtreerec.count()-1).setrow(i, vtree.o(vtreerecinds.x[i]))
            }
        }
        
        print "Computing half width in each spine..."
        for i=0, vsyn.count()-1 {
            vhalf = (vsyn.object(i).max() + erest)/2
            // vhalf = mg_NmdaSyn/K0_NmdaSyn)/z_NmdaSyn/delta_NmdaSyn/FARADAY/0.001*R*(celsius+273)
            // vhalf = -13 // Half-height of NMDA block function
            //vhalf = -25
            tpos = find_crossings(vsyn.object(i), vhalf, tsyn.object(0), 1)
            tneg = find_crossings(vsyn.object(i), vhalf, tsyn.object(0), 2)
            if (tpos.size() > 1) {
                utils.warning("More than one crossing")
                tpos.printf
                tneg.printf
                print tsyn.object(i).x(vsyn.object(i).max_ind())
            }
            if (tpos.size() == 0) {
                utils.warning("No crossings")
            } else {
                tneg.sub(tpos)
                vsynwidth.x[n][i] = tneg.sum
            }
        }

        
        apc.x(n) = record_apc.n 
        
        // Tidy up the recordings
        record_remove(casyn)
        record_remove(vsyn)
        record_remove(icasyn)
        record_remove(ica_nmdasyn)
        record_remove(ica_ampasyn)
    }
    
    // Compute means and standard deviations of quantities in the tree
    print "Computing means and standard errors..."
    utils.mean(vtreemax_mean, vtreemax)
    utils.mean(vtreeint_mean, vtreeint)
    utils.mean(catreemax_mean, catreemax)
    utils.stderr(vtreemax_stderr,  vtreemax)
    utils.stderr(vtreeint_stderr,  vtreemax)
    utils.stderr(catreemax_stderr, catreemax)
    
    // For calcium peak and voltage delay at synapses, we need to
    // compute the mean and stderr of the maximum averaged over all
    // spines at a particular segment. This means we need to convert
    // from spine indices to segment indicies.
    run_quantities_at_syn_to_quantity_at_seg(vsrimax_mean,  vsrimax_stderr,  vsynmax,  syn_sri) 
    run_quantities_at_syn_to_quantity_at_seg(vsriint_mean,  vsriint_stderr,  vsynint,  syn_sri) 
    run_quantities_at_syn_to_quantity_at_seg(vsridel_mean,  vsridel_stderr,  vsyndel,  syn_sri) 
    run_quantities_at_syn_to_quantity_at_seg(casrimax_mean, casrimax_stderr, casynmax, syn_sri) 
    run_quantities_at_syn_to_quantity_at_seg(casriint_mean, casriint_stderr, casynint, syn_sri) 
    run_quantities_at_syn_to_quantity_at_seg(casridel_mean, casridel_stderr, casyndel, syn_sri) 
    run_quantities_at_syn_to_quantity_at_seg(vsriwidth_mean, vsriwidth_stderr, vsynwidth, syn_sri) 

    
    // Output the last time course of various quantities just at the
    // screc synapses
    screc_icasyn      = icasyn
    screc_ica_nmdasyn = ica_nmdasyn
    screc_ica_ampasyn = ica_ampasyn
    
    // We already have vsyn - we just want to get the first few items from it
    screc_vsyn       = new List()
    for i=0, screc_inputs.count()-1 {
        screc_vsyn.append(vsyn.object(i))
    }
    
    // We already have casyn - we just want to get the first few items from it
    screc_casyn       = new List()
    for i=0, screc_inputs.count()-1 {
        screc_casyn.append(casyn.object(i))
    }
    
    screc_catree      = new List()
    for i=0, screc_inputs.count()-1 {
        screc_catree.append(catree.object(syn_sri.x[0][i]).c)
     }

    // Tidy up the recordings
    record_remove(vtree)
    record_remove(catree)
    
    print "Done."
}

//
// 
//
proc run_epsp() {
    // Measure 
    measure_epspsizes()
    measure[1] = epsp_soma_amp 
    epsps_measured = 1
}

// run_quantities_at_syn_to_quantity_at_seg(Vector quantity_mean,
// Vector quantity_stderr, Matrix quantity, Matrix syn_sri)
// 
// Function to compute the mean and stderr of a QUANTITY averaged over
// all spines at a particular segment. This means we need to convert
// from spine indices to segment indicies.  Each row of QUANTITY is a
// run; the columns refer to the quantity at each spine. The matrix
// SYN_SRI gives the index in segfreflist in which each synapse was
// inserted.
proc run_quantities_at_syn_to_quantity_at_seg() { localobj quant_seg_mean, quant_seg_stderr, quant, syn_sri, quant_seg, syninds
    quant_seg_mean =   new Vector(segreflist.srl.count())
    quant_seg_stderr = new Vector(segreflist.srl.count())
    quant = $o3
    syn_sri = $o4
    syninds = new Vector()
    
    // Go through each segref
    for i=0, segreflist.srl.count()-1 {
        quant_seg = new Vector()
        // Go through each run
        for n=0, quant.nrow() - 1 {
            // Find inds of synapses which were inserted in segref i
            // on this run
            syninds.indvwhere(syn_sri.getrow(n), "==", i) 
            quant_seg.append(quant.getrow(n).ind(syninds))
        }
        if (quant_seg.size() > 0) {
            quant_seg_mean.x(i) = quant_seg.mean()
        }
        if (quant_seg.size() > 1) {
            quant_seg_stderr.x(i) = quant_seg.stderr()
        }
    }
    $o1 = quant_seg_mean
    $o2 = quant_seg_stderr
}