load_file("bpap-run.hoc")

objref synweightmatrix, savesynweight,savevsynmax,savecasrimax, savedistances, savetransimp_start, savetransimp_end
objref vsridel_mean_perrun, vsridel_stderr_perrun, casrimax_mean_perrun, casrimax_stderr_perrun
objref casriint_mean_perrun, casriint_stderr_perrun, casridel_mean_perrun, casridel_stderr_perrun
objref casrimax_mean_indrun
objref vsomamax

datalist.appendMatrix("synweightmatrix")
datalist.appendMatrix("casrimax_mean_perrun")
datalist.appendVector("vsomamax")

scalingsim=0
// The default value of target Ca, this can be overwritten by a new value in bpap-sims
targetCa = 0.047012 /// median of scaled synapses simulation dendburst-s240-j00t1-a200-n45-bv-r170-sc1-Ra050-nr0100-cv1.R

proc run_bpapscaling(){  local tmax, n, ind, i localobj nsri, sumcasrimax, syninds, camax
    // Record the impedance as a measure of synaptic size
    cell_measure_impedance()
    
    scalingsim=1
    
    // New random oject (for Jitter)
    r = new Random()
    if (jitter_type==1) {
        r.uniform(start,start+jitter)
    } else {
        r.normal(start+jitter,jitter)
    }
    
    // Record voltage and calcium from tree
    record_vtree()
    record_vsoma()
    record_catree()
    record_apcount_soma()
    
    // Matrix to store values at each segment or synapse for each run
    vsynmax   = new Matrix(nruns, nsyn)
    vsynint   = new Matrix(nruns, nsyn)
    vsyndel   = new Matrix(nruns, nsyn)
    vtreemax  = new Matrix(nruns, vtree.count())
    vtreeint  = new Matrix(nruns, vtree.count())
    catreemax = new Matrix(nruns, catree.count())
    casynmax  = new Matrix(nruns, nsyn)
    casynint  = new Matrix(nruns, nsyn)
    casyndel  = new Matrix(nruns, nsyn)
    syn_sri   = new Matrix(nruns, nsyn)
    apc       = new Vector(nruns)
    
    vtreerec = new List()
    tsomarec = new List()
    
    // Matrix to store the synaptic weights defined for each segment,
    // for each run.  Note the last column of the matrix (n+1) is the
    // based on the peak calcium of the last run, but this weight
    // value is not been used in the simulation
    synweightmatrix = new Matrix(nruns + 1, catree.count()) 
    for i=0, nruns {
        // Fill the whole matrix with the defaults syanptic strength
        synweightmatrix.setrow(i, gsbar_ampa) 
    } 
    
    // Do nruns runs
    for n=0,nruns-1 { 
        // Create and distribute spines with synapses on randomly
        // selected locations
        create_and_distribute_inputs()
        
        // Set weights of these new synapses according to weight
        // list. scall_inputs_sri refers to the *location* of synapse
        // i. Thus, although new synapses have been created on each
        // run, the weight is remembered from the last time a synapse
        // was there.
        for i=1, scall_inputs.count()-1 {
            scall_inputs.object(i).netcon.weight = synweightmatrix.x[n][scall_inputs_sri.x(i)]
        }
        
        // Setup the soma iclamp - this is only active is soma_iclamp_state=1
        setup_soma_iclamp()
        
        // Record from the spines
        record_casyn()
        record_vsyn()
        record_icasyn(screc_inputs)
        record_ica_nmdasyn(screc_inputs)
        record_ica_ampasyn(screc_inputs)
        record_tprespike()
        
        print "Run ", n+1, "of " , nruns
        print "Running simulation..."
        
        // Make the random presynaptic spike times
        tprespike = new Vector(scall_inputs.count())
        if (dendburst_state) {
            tprespike.setrand(r)
            for i=0, nsyn-1 {
                scall_inputs.object(i).netstim.start = tprespike.x(i)
            } // should only be one spike
        }
        
        // We initialise before reading in the spike times -- 
        // otherwise they aren't read in since t > the spiketimes we're adding
        stdinit()
        if (spiketime_state) {
            bpap_spiketrain_set_spiketrain(scstim_file)
        }
        
        // Run the simulation
        running_ = 1
        continuerun(tstop)
        
        // Analysis
        //    vsomamax = new Vector(vsoma.count())
        print "Computing maximum voltage..."
        for i=0, vtree.count()-1 {
            vtreemax.x[n][i] = vtree.object(i).max()
        }
        
        print "Computing integral of voltage..."
        for i=0, vtree.count()-1 {
            if (cvode.active()) {
                vtreeint.x[n][i] = integrator.integrate(vtree.object(i),tsyn.object(i))
            } else {
                vtreeint.x[n][i] = vtree.object(i).sum() * dt
            }
        }
        
        print "Computing maximum voltage in each spine..."
        for i=0, vsyn.count()-1 {
            vsynmax.x[n][i] = vsyn.object(i).max()
        }
        
        print "Computing integral of voltage in each spine..."
        for i=0, vsyn.count()-1 {
            if (cvode.active()) {
                vsynint.x[n][i] = integrator.integrate(vsyn.object(i),tsyn.object(0))
            } else {
                vsynint.x[n][i] = vsyn.object(i).sum() * dt
            }
        }
        
        print "Computing delay of voltage..."
        for i=0, vsyndel.ncol()-1 { 
            ind = scall_inputs_sri.x(i)
            tmax = tsyn.object(ind).x(vtree.object(ind).max_ind())
            vsyndel.x[n][i] = tmax - tprespike.x(i)
        }
        
        print "Computing maximum Ca in each spine..."
        for i=0, casyn.count()-1 {
            casynmax.x[n][i] = casyn.object(i).max()
        }
        
        print "Computing integral of Ca in each spine..."
        for i=0, casyn.count()-1 {
            if (cvode.active()) {
                casynint.x[n][i] = integrator.integrate(casyn.object(i),tsyn.object(0))
            } else {
                casynint.x[n][i] = casyn.object(i).sum() * dt
            }
        }
        
        print "Computing delay to peak Ca in each spine..."
        for i=0, casyn.count()-1 {
            tmax = tsyn.object(0).x(casyn.object(i).max_ind())
            casyndel.x[n][i] = tmax - tprespike.x(i)
            syn_sri.x[n][i]  = scall_inputs_sri.x(i)
        }
        
        print "Computing maximum Ca in tree..."
        for i=0, catree.count()-1 {
            catreemax.x[n][i] = catree.object(i).max()
        }
        
        print "Recording voltage traces in tree..."
        if (vtreerecinds.size() > 0) {
            vtreerec.append(new Matrix(vtreerecinds.size(), vtree.o(0).size()))
            tsomarec.append(tsoma.c)
            for i=0, vtreerecinds.size()-1 {
                vtreerec.o(vtreerec.count()-1).setrow(i, vtree.o(vtreerecinds.x[i]))
            }
        }
        
        apc.x(n) = record_apc.n 
        
        // Tidy up the recordings
        record_remove(casyn)
        record_remove(vsyn)
        record_remove(icasyn)
        record_remove(ica_nmdasyn)
        record_remove(ica_ampasyn)
        
        // Adjust the synaptic weights depending on peak calcium value
        // for the next simulation
        run_quantities_at_syn_to_quantity_at_seg_ind_run(casrimax_mean_indrun, casynmax, syn_sri, n) 
        for i=0, casrimax_mean_indrun.size()-1 {
            print casrimax_mean_indrun.x[i]
            if (casrimax_mean_indrun.x[i]==0) {
                // This segment is not stimulated in this run or there
                // is a subthreshold response, so leave the synaptic
                // strength the same.
                synweightmatrix.x[n+1][i] = synweightmatrix.x[n][i]  
                
            } else {
                // This segment has been stimulated, so adjust the
                // synaptic strength according to peak calcium
                synweightmatrix.x[n+1][i] = synweightmatrix.x[n][i]*(1 + 0.1*(targetCa-casrimax_mean_indrun.x[i])/targetCa)
            }
        }
    }
    
    // Compute means and standard deviations of quantities in the tree
    print "Computing means and standard errors..."
    utils.mean(vtreemax_mean, vtreemax)
    utils.mean(vtreeint_mean, vtreeint)
    utils.mean(catreemax_mean, catreemax)
    utils.stderr(vtreemax_stderr,  vtreemax)
    utils.stderr(vtreeint_stderr,  vtreemax)
    utils.stderr(catreemax_stderr, catreemax)
    
    // For calcium peak and voltage delay at synapses, we need to
    // compute the mean and stderr of the maximum averaged over all
    // spines at a particular segment. This means we need to convert
    // from spine indices to segment indicies.
    run_quantities_at_syn_to_quantity_at_seg(vsrimax_mean,  vsrimax_stderr,  vsynmax,  syn_sri) 
    run_quantities_at_syn_to_quantity_at_seg(vsriint_mean,  vsriint_stderr,  vsynint,  syn_sri) 
    run_quantities_at_syn_to_quantity_at_seg(vsridel_mean,  vsridel_stderr,  vsyndel,  syn_sri) 
    run_quantities_at_syn_to_quantity_at_seg(casrimax_mean, casrimax_stderr, casynmax, syn_sri) 
    run_quantities_at_syn_to_quantity_at_seg(casriint_mean, casriint_stderr, casynint, syn_sri) 
    run_quantities_at_syn_to_quantity_at_seg(casridel_mean, casridel_stderr, casyndel, syn_sri) 
    
    // Output the last time course of various quantities just at the
    // screc synapses
    screc_icasyn      = icasyn
    screc_ica_nmdasyn = ica_nmdasyn
    screc_ica_ampasyn = ica_ampasyn
    
    // We already have vsyn - we just want to get the first few items
    // from it
    screc_vsyn       = new List()
    for i=0, screc_inputs.count()-1 {
        screc_vsyn.append(vsyn.object(i))
    }
    
    // We already have casyn - we just want to get the first few items
    // from it
    screc_casyn       = new List()
    for i=0, screc_inputs.count()-1 {
        screc_casyn.append(casyn.object(i))
    }
    
    screc_catree      = new List()
    for i=0, screc_inputs.count()-1 {
        screc_catree.append(catree.object(syn_sri.x[0][i]).c)
    }
    
    // Tidy up the recordings
    record_remove(vtree)
    record_remove(catree)
    
    // For calcium peak and voltage delay at synapses, we need to
    // compute the mean and stderr of the maximum averaged over all
    // spines at a particular segment per run. This means we need to
    // convert from spine indices to segment indicies.
    run_quantities_at_syn_to_quantity_at_seg_per_run(vsridel_mean_perrun,  vsridel_stderr_perrun,  vsyndel,  syn_sri) 
    run_quantities_at_syn_to_quantity_at_seg_per_run(casrimax_mean_perrun, casrimax_stderr_perrun, casynmax, syn_sri) 
    run_quantities_at_syn_to_quantity_at_seg_per_run(casriint_mean_perrun, casriint_stderr_perrun, casynint, syn_sri) 
    run_quantities_at_syn_to_quantity_at_seg_per_run(casridel_mean_perrun, casridel_stderr_perrun, casyndel, syn_sri) 
    
    // Save vsynmax
    savevsynmax = new File()
    savevsynmax.wopen("datastore/vsynmax.dat")
    vsynmax.fprint(0, savevsynmax, "%-20g")
    savevsynmax.close()
    
    // Save casrimax_mean_perrun
    savecasrimax = new File()
    savecasrimax.wopen("datastore/casrimax.dat")
    casrimax_mean_perrun.fprint(0, savecasrimax, "%-20g")
    savecasrimax.close()  
    
    // Save synaptic weights in .dat file called synweight in the
    // folder of bpap-gui.hoc
    savesynweight = new File()
    savesynweight.wopen("datastore/synweight.dat")
    synweightmatrix.fprint(0, savesynweight, "%-20g")
    savesynweight.close()
    
    // Save distances
    savedistances = new File()
    savedistances.wopen("datastore/distances.dat")
    distances.printf(savedistances," %-20g")
    savedistances.close() 
    
    print "Done."
}

// Same procedure as run_quantities_at_syn_to_quantity_at_seg(), but
// averaged per run (timestep in scaling simulation), instead of
// averaging over all runs (as for other simulations).
proc run_quantities_at_syn_to_quantity_at_seg_per_run() { localobj quant_seg_mean, quant_seg_stderr, quant, syn_sri, quant_seg, syninds
    quant_seg_mean =   new Matrix(nruns,segreflist.srl.count())
    quant_seg_stderr = new Matrix(nruns,segreflist.srl.count())
    quant = $o3
    syn_sri = $o4
    syninds = new Vector()
    
    // Go through each segref
    for i=0, segreflist.srl.count()-1 {
        // Go through each run
        for n=0, nruns - 1 {
            quant_seg = new Vector()
            // Find inds of synapses which were inserted in segref i on this run
            syninds.indvwhere(syn_sri.getrow(n), "==", i) 
            quant_seg.append(quant.getrow(n).ind(syninds))

            if (quant_seg.size() > 0) {
                quant_seg_mean.x[n][i] = quant_seg.mean()
            }
            if (quant_seg.size() > 1) {
                quant_seg_stderr.x[n][i] = quant_seg.stderr()
            }
        }
    }
    $o1 = quant_seg_mean
    $o2 = quant_seg_stderr
}

proc run_quantities_at_syn_to_quantity_at_seg_ind_run() { localobj quant_seg_mean_indrun, quant, syn_sri, quant_seg, syninds
    quant_seg_mean_indrun =   new Vector(segreflist.srl.count())
    quant = $o2
    syn_sri = $o3
    n = $4
    syninds = new Vector()
    
    // Go through each segref
    for i=0, segreflist.srl.count()-1 {
        quant_seg = new Vector()
        // Find inds of synapses which were inserted in segref i on this run
        syninds.indvwhere(syn_sri.getrow(n), "==", i) 
        quant_seg.append(quant.getrow(n).ind(syninds))
        
        if (quant_seg.size() > 0) {
            quant_seg_mean_indrun.x[i] = quant_seg.mean()
        }
    }
    $o1 = quant_seg_mean_indrun
}