/*
error is weighted sum of normalized error in each region. Region i
has weight[i] and domain boundary[i] < x < boundary[i+1].
Normalized region error is an approximation to the integral of
(y(x) - ydat(x))^2 over the integral of x.

The boundary and weight vectors are the only substantive variables owned
by this class.
Usage:
1) if using a gui, map a vbox by calling map(). The Graph is in the public
variable g of this instance.
Note that this graph menu contains some class dependent functionality.
Although you may attach and detach this graph to the standard run system
and graph some variable or expression, it is faster and probably
better to plot the y model vector and flush the graph after a call to efun.
2)specify xdat and ydat by call to set_data(x, y)
3)specify boundaries and weights either by gui or by calling
set_region(boundary, weight) where weight[1]
is the weight for the region between boundary[0] and boundary[1].
Therfore weight[0] is unused.
*/

parmfitness_efun_append("RegionFitness")

begintemplate RegionFitness
public efun, g, boundary, weight, set_data, set_modelx, rfile, wfile, dw_,ydat_
public unmap, map, vbox, use_x, tag, xvec, yvec, xdat, ydat, xdat_, have_data
public clone, scale, build, get_stable_context, before_run, mintstop
public save_context, restore_context
public clipboard_data, set_w
public error

external stdrun_quiet, hoc_obj_

objref boundary, weight

objref xdat, ydat	// authoritative data
objref xvec, yvec, xdat_, ydat_, dw_ // interpolated data consistent with model x values
objref g, vbox, tobj, this
strdef tstr1, mserrlabel, scalelabel, tag

i=0

proc before_run() {}

func efun() { local e // the least squares error function.
	// if xdat_ does not correspond to $o1 x values then should
	// re-interpolate ydat_, xdat_, and re-calculate dw_
	if (have_data && ydat_.size > 0) {
		if (use_x == 1) {
			e = ydat_.meansqerr($o1, dw_)
		}else{
//	printf("ydat_=%s xdat_=%s dw_=%s $o1=%s $o2=%s\n", ydat_, xdat_, dw_,$o1,$o2)
			hoc_obj_ = $o1.c
			hoc_obj_[1] = $o2.c.interpolate(xdat_,$o2)
			e = ydat_.meansqerr($o1.interpolate(xdat_,$o2), dw_)
		}
	}else{
		e = 0
	}
	if (use_gui) {
		sprint(mserrlabel, "%g", e)
		if (stdrun_quiet == 0) {
			redraw(xdat_, $o1)
		}
	}
	
	error = scale*e
	
	return scale*e
}

proc save_context() {
	$o1.pack(tag, scale, ydat, ydat.label, xdat, boundary, weight)
}

proc restore_context() {
	$o1.unpack(tag, &scale, ydat)
	$o1.unpack(tstr1, xdat)
	ydat.label(tstr1)
	set_data(xdat, ydat)
	$o1.unpack(boundary, weight)
	set_w()
}

proc init() {local i
	sprint(tag, "%s", this)
	sscanf(tag, "%[^[]", tag)
	use_x = 0
	use_gui = 0
	have_data = 0
	have_modelx = 0
	xdat = new Vector(0)
	ydat = new Vector(0)
	xvec = new Vector(0)
	yvec = new Vector(0)
	scale=1
	weight = new Vector(10)
	boundary = new Vector(10)
	init_weights(1)
}

proc clone() {
	$o1 = new RegionFitness()
	$o1.have_data = have_data
	$o1.boundary = boundary.c
	$o1.weight = weight.c
	if (have_data) {
		$o1.set_data(xdat, ydat)
	}
	$o1.scale = scale
	sprint(scalelabel, "scale=%g ", scale)
}

proc redraw() {local ymin, ymax
    if (use_gui) {
	g.erase()
	if (have_data) {
		g.label(.8, .95)
		ydat.plot(g, xdat, 2, 1)
//		ymin = ydat.min()
//		ymax = ydat.max()
		for j=0, boundary.size() - 1 {
			g.beginline(3, 1)
			g.line(boundary.x[j], g.size(3))
			g.line(boundary.x[j], g.size(4))
		}
		if (numarg() == 2) {
			$o2.line(g, $o1)
		}
	}
	g.flush()
    }
}

func mintstop() {
	return xdat_.x[xdat_.size-1]
}

proc set_data() {local i
	have_data = 0
//	i = $o1.indwhere(">=", 0)
//	if (i < 0 || $o1.size < 1) return
	xdat = $o1.c()
	ydat = $o2.cl()
	if (have_data == 0) {
		have_data = 1
		init_weights(1)
	}
	xdat_ = xdat.c
	ydat_ = ydat.c
	xvec = xdat
	yvec = ydat
	dw_ = new Vector(xdat.size)
	have_modelx = 1
	if (use_gui) {
		g.size(xdat.min(), xdat.max(), ydat.min(), ydat.max())
		redraw()
	}
	set_w()
}
proc set_modelx() {local n, i, j, theta
	xdat_ = $o1
	have_modelx = 1
	n = xdat_.size
	ydat_ = new Vector(n)
	j = 0
	//xdat's before and after xdat will have weight 0 anyway.
	while(xdat_.x[j] < xdat.x[0]) {
		j += 1
		if (j >= n) break
	}
	for i=1, xdat.size-1 {
		if (j >= n) break
		while (xdat_.x[j] >= xdat.x[i-1] && xdat_.x[j] <= xdat.x[i]){
			// linear interpolation
			theta = (xdat.x[i] - xdat.x[i-1])
			theta = (xdat_.x[j] - xdat_.x[i-1])/theta
			ydat_.x[j] = (1-theta)*ydat.x[i-1] + theta*ydat.x[i]
			j += 1
			if (j >= n) break
		}
	}
	set_w()
}
	
proc clipboard_data() {
	sprint(tstr1, "%s.set_data(hoc_obj_[1], hoc_obj_[0])", this)
	if(execute1(tstr1) == 0) {
continue_dialog("No data in the Vector clipboard. Select a Graph line first")
	}
}

proc build() {
	if (use_gui) return
	use_gui = 1
	vbox = new VBox()
	vbox.ref(this)
	sprint(tstr1, "execute(\"%s.unmap()\")", this)
	vbox.dismiss_action(tstr1)
	vbox.save("save()")
	vbox.intercept(1)
	g = new Graph(0)
	xpanel("", 1)
	g.menu_tool("Adjust", "adjust_weights")
	xmenu("Regions")
		xbutton("Data from Clipboard", "clipboard_data()")
		xbutton("Weight panel", "weight_panel()")
		xbutton("Copy weights", "copy_weights()")
		xbutton("Paste weights", "paste_weights()")
	xmenu()
	mserrlabel="MeanSqErr xxxxxxxxxxx"
	sprint(scalelabel, "scale=   %g  ", scale)
	xvarlabel(scalelabel)
	xvarlabel(mserrlabel)
	xpanel()
	g.view(0, -80, 5, 40, 0,300,0,200)
	if (have_data) {
		g.size(xdat.min(), xdat.max(), ydat.min(), ydat.max())
	}
	vbox.intercept(0)
}

proc map() {
	if (!use_gui) build()
	if (numarg() > 1) {
		vbox.map($s1, $2, $3, $4, $5)
	}else if (numarg() == 1){
		vbox.map($s1)
	}else{
		vbox.map()
	}
	redraw()
}

proc unmap() {
}

proc copy_weights() {
	hoc_obj_[0] = weight.c
	hoc_obj_[1] = boundary.c
}

proc paste_weights() {
	weight = hoc_obj_[0].c
	boundary = hoc_obj_[1].c
	set_w()
}

proc wfile() {local i
	$o1.printf("RegionFitness xdat ydat boundary weight (lines=%d) %g\n",\
		4 + 2*xdat.size + 2*boundary.size, scale)
	$o1.printf("|%s|\n", ydat.label)
	$o1.printf("%d\n", xdat.size)
	xdat.printf($o1)
	ydat.printf($o1)
	$o1.printf("%d\n", boundary.size)
	boundary.printf($o1)
	weight.printf($o1)
}

proc rfile() {local i, n
	scale = $o1.scanvar
	sprint(scalelabel, "scale=%g", scale)
	$o1.gets(tstr1)
	if (sscanf(tstr1, "|%[^|]", tstr1) == 1) {
		ydat.label(tstr1)
	}
	n = $o1.scanvar
	if (n > 0) {
		xdat.resize(n) ydat.resize(n)
		xdat.scanf($o1, n)
		ydat.scanf($o1, n)
		set_data(xdat, ydat)
	}
	n = $o1.scanvar
	boundary.resize(n)
	boundary.scanf($o1, n)
	weight.resize(n)
	weight.scanf($o1, n)
	set_w()
}

proc init_weights() {local i, n, min, max
	if (numarg() == 1) {
		if ($1 == 0) { // one more
			n = boundary.size
		}else if ($1 == -1) { // one fewer
			n = boundary.size-2
			if (n < 1) n = 1
		}else{
			n = $1
		}
	}else{
		n = 1
	}
//	boundary = new Vector(n+1)
//	weight = new Vector(n+1, 1)
	boundary.resize(n+1)
	weight.resize(n+1) weight.x[n] = 1
	if (have_data) {
		min = xdat.x[0]
		max = xdat.x[xdat.size-1]
		boundary.indgen(min, (max-min)/n)
		set_w()
	}
}
	

proc weight_panel() {local i
	xpanel("data weights")
	xpvalue("Total weight (scale)", &scale, 1, "sprint(scalelabel, \"scale=%g \", scale)")
	xpvalue("interval 1 startpoint", &boundary.x[0], 1, "set_w()")
	for i=1, boundary.size() - 1 {
		sprint(tstr1, "interval %d endpoint", i)
		xpvalue(tstr1, &boundary.x[i], 1, "set_w()")
		sprint(tstr1, "interval %d weight", i)
		xpvalue(tstr1, &weight.x[i], 1, "set_w()")
	}
	xpanel()
}

proc set_w() {local i, j, t, w, d, tmin, tmax, n, res
    if (have_data){
	// make sure weight regions are within boundaries
	tmin = xdat.x[0]
	tmax = xdat.x[xdat.size - 1]
	n = boundary.size()
	for i=0, n-1 {
		t = boundary.x[i]
		if (t < tmin) {
			boundary.x[i] = tmin
		}
		if (t > tmax) {
			boundary.x[i] = tmax
		}
	}
    }
    if (have_modelx) {
	dw_.resize(xdat.size)
	j = 0
	tmax = boundary.x[0]
	w = 0
	for i=0, dw_.size - 1 {
		t = xdat.x[i]
		while ((t > tmax && j < n) || (t == tmax && i == 0)) {
			j += 1
			if (i > 0) {
				boundary.x[j-1] = (t + xdat.x[i-1])/2
			}else{
				boundary.x[j-1] = t
			}
			if (j >= n) {
				tmax = 1e9
				w = 0
				break
			}
			tmax = boundary.x[j]
			d = boundary.x[j] - boundary.x[j-1]
			if (d <= 0) {
				continue
			}
			w = weight.x[j]/d
		}
		dw_.x[i] = w
		//printf("dw_.x[%d]=%f\n",i,dw_.x[i])
	}
	// throw out the 0 weight points.
	tobj = dw_.c.indvwhere(">", 0)
	dw_.index(dw_.c, tobj)
	xdat_.index(xdat, tobj)
	ydat_.index(ydat, tobj)
	w = dw_.sum()
	if (w > 0) {
		dw_.mul(dw_.size/w)
	}
    }
	redraw()
}

proc adjust_weights() {local x
//print $1, $2, $3
	if ($1 == 2) { // press
		adjust = pick_weight($2)
		set_w()
	}
	if (adjust == -1) {
		return
	}
	if ($1 == 1) { // drag
		boundary.x[adjust] = $2
		if (adjust < boundary.size-1) if ($2 > boundary.x[adjust+1]){
			boundary.x[adjust] = boundary.x[adjust+1]
		}
		set_w()
	}
	if ($1 == 3) { // release
		x = g.view_info(g.view_info, 11, $2)
		if (boundary.size > 2) {
			if (x < 0) {
				boundary.remove(0)
				weight.remove(1)
			}else if (x > 1) {
				boundary.remove(boundary.size-1)
				weight.remove(weight.size-1)
			}else{
				x = boundary.x[adjust]
				tobj = boundary.c.indvwhere("==", x)
				if (tobj.size > 1) {
					boundary.remove[adjust]
					if (tobj.x[0] == adjust) {
						weight.remove(adjust+1)
					}else{
						weight.remove(adjust)
					}
				}
			}
		}
		boundary.sort()
		set_w()
		adjust = -1
//print boundary.size, " boundaries"
//print "boundary" boundary.printf
//print "weight" weight.printf
	}
}

func pick_weight() {local vi, d, i, j, x, m
	vi = g.view_info
	x = g.view_info(vi, 13, $1)
	for i=0, boundary.size() - 1 {
		d = x -  g.view_info(vi, 13, boundary.x[i])
		if (d < 12) { // before or on i
			break
		}
	}
	if (i == boundary.size()) {
		boundary.append($1)
		weight.append(weight.x[i-1])
		return i
	}
	if (d > -12) { // actual selection of line
		return i
	}
	boundary.insrt(i, $1)
	if (i == 0) {
		weight.insrt(1, weight.x[1])
	}else{
		weight.insrt(i, weight.x[i])
	}
	return i
}

endtemplate RegionFitness