: $Id: matrix.mod,v 1.25 2000/12/15 20:09:51 billl Exp $
:* COMMENT
COMMENT
NB: only minimal error checking done
NB: no dynamic allocation - eg m1.transpose(m1) will give a wrong result
NB: matrix and vec sizes must be correct before using: use .resize()
================ USAGE ================
objref mat
mat = new Vector(rows*cols)
mat.mprintf(M,N) // print out as M rows and N columns
mat2.transpose(mat) // transpose of matrix
y.mmult(mat,x) // y = mat*x
y.spmult(pre,post,mat,x) // y = mat*x using "sparse matrix"
w.spget(pre,post,row,col) // ie pre,post,post,pre!!
wt.mkspcp/chkspcp(pre,post) // copy the indices into integer arrays
mat.outprod(x,y) // mat = outer product of vectors x and y
mat.mget(i,j,cols) // i=row#; j=col#
mat.mset(i,j,cols,val)
y.mrow(mat,i,cols)
y.mcol(mat,j,cols)
================================================================
ENDCOMMENT
NEURON {
SUFFIX nothing
}
VERBATIM
#ifndef NRN_VERSION_GTEQ_8_2_0
extern double hoc_call_func(Symbol*, int narg);
#endif
ENDVERBATIM
:* mat.outprod(x,y) // mat = outer product of vectors x and y
VERBATIM
static double outprod(void* vv) {
int i, j, nx, ny, nz;
double *x, *y, *z;
/* this will be the outer product */
nx = vector_instance_px(vv, &x);
/* these are the two vectors that make it up */
ny = vector_arg_px(1, &y); // will be number of columns
nz = vector_arg_px(2, &z); // will be number of rows
if (nx != ny*nz) {
hoc_execerror("Vector size mismatch", 0);
}
for (i=0;i<ny;i++) {
for (j=0;j<nz;j++) {
x[i*nz+j] = y[i]*z[j];
}
}
return nx;
}
ENDVERBATIM
:* mmult
VERBATIM
static double mmult(void* vv) {
int i, j, nx, ny, nz;
double *x, *y, *z;
/* x will be the product of matrix y and vec z */
nx = vector_instance_px(vv, &x);
ny = vector_arg_px(1, &y);
nz = vector_arg_px(2, &z);
if (ny != nx*nz) {
hoc_execerror("Vector size mismatch", 0);
}
for (i=0;i<nx;i++) {
x[i] = 0.;
for (j=0;j<nz;j++) {
x[i] += y[i*nz+j]*z[j];
}
}
return nx;
}
ENDVERBATIM
:* ST[PO].spltp(pr,po,wt,ST[PRE])
VERBATIM
static double spltp(void* vv) {
int ii, jj, nstpr, nstpo, nw, npr, npo, flag, cnt;
double *stpr, *stpo, *w, *pr, *po;
char func[4] = "ltp";
Symbol* s = hoc_lookup(func);
if (! s) { hoc_execerror("Can't find ltp() func", 0); }
nstpo = vector_instance_px(vv, &stpo);
npr = vector_arg_px(1, &pr);
npo = vector_arg_px(2, &po);
nw = vector_arg_px(3, &w);
nstpr = vector_arg_px(4, &stpr);
for (ii=0,jj=0,cnt=0;ii<nstpo;ii++) {
if (stpo[ii]==1.0) { /* connections to these will be changed */
for (;po[jj]<ii;jj++) ; /* move forward till find a po */
for (;po[jj]==ii;jj++) { /* move through these po's */
if (stpr[(int)pr[jj]]==1.) { /* did the presyn spike? */
cnt++; hoc_pushx(1.0);
} else {
cnt--; hoc_pushx(-1.0);
}
hoc_pushx(w[jj]);
w[jj]=hoc_call_func(s, 2);
}
}
}
return cnt;
}
ENDVERBATIM
VERBATIM
/* Maintain a parallel vector of ints to avoid the slowness of repeated casts in spmult */
static int *pr_int;
static int *po_int;
static int cpfl=0;
ENDVERBATIM
:* wt.mkspcp(pr,po)
VERBATIM
static double mkspcp(void* vv) {
int j, nw, npr, npo;
double *w, *pr, *po;
if (! ifarg(1)) {
cpfl=0;
if (po_int!=NULL) free(po_int);
if (pr_int!=NULL) free(pr_int);
po_int=(int *)NULL; pr_int=(int *)NULL;
return 0;
}
nw = vector_instance_px(vv, &w);
npr = vector_arg_px(1, &pr);
npo = vector_arg_px(2, &po);
pr_int=(int *)ecalloc(nw, sizeof(int));
po_int=(int *)ecalloc(nw, sizeof(int));
for (j=0;j<nw;j++) {
po_int[j]=(int)po[j];
pr_int[j]=(int)pr[j];
}
cpfl=nw;
return cpfl;
}
ENDVERBATIM
:* wt.chkspcp(pr,po)
VERBATIM
static double chkspcp(void* vv) {
int j, nw, npr, npo, flag;
double *w, *pr, *po;
nw = vector_instance_px(vv, &w);
npr = vector_arg_px(1, &pr);
npo = vector_arg_px(2, &po);
flag=1;
if (po_int==NULL || pr_int==NULL) { cpfl=0; return 0; }
if (cpfl!=nw) { flag=0;
} else for (j=0;j<nw;j++) {
if (po_int[j]!=(int)po[j] || pr_int[j]!=(int)pr[j]) {flag=0; continue;}
}
if (flag==0) {
cpfl=0; free(po_int); free(pr_int);
po_int=(int *)NULL; pr_int=(int *)NULL;
}
return flag;
}
ENDVERBATIM
:* y.spmult(pr,po,wt,x[,flag])
: y=W*x, y will be the product of matrix w with pre/post indices and vec x
: optional flag (5th arg present) - do not clear dest vector initially
VERBATIM
static double spmult(void* vv) {
int i, j, nx, ny, nw, npr, npo, flag;
double *x, *y, *w, *pr, *po, xx;
ny = vector_instance_px(vv, &y);
npr = vector_arg_px(1, &pr);
npo = vector_arg_px(2, &po);
nw = vector_arg_px(3, &w);
nx = vector_arg_px(4, &x);
if (ifarg(5)) {flag=1;} else {flag=0;}
if (nw!=npr || nw!=npo) {
hoc_execerror("Sparse mat must have 3 identical size vecs for pre/post/wt", 0);
}
if (flag==0) for (i=0;i<ny;i++) y[i] = 0.; // clear dest vec
if (cpfl==0) {
for (j=0;j<nw;j++) y[(int)po[j]] += (x[(int)pr[j]]*w[j]);
} else if (cpfl!=nw) { hoc_execerror("cpfl!=nw in spmult", 0); } else {
for (j=0;j<nw;j++) if (x[pr_int[j]]!=0) { y[po_int[j]] += ((x[pr_int[j]])*w[j]); }
}
return nx;
}
ENDVERBATIM
:* wt.spget(pr,po,row,col) returns weight value
VERBATIM
static double spget(void* vv) {
int j, nw, npr, npo;
double *w, *pr, *po, row, col;
nw = vector_instance_px(vv, &w);
npr = vector_arg_px(1, &pr);
npo = vector_arg_px(2, &po);
row = *getarg(3);
col = *getarg(4);
for (j=0;j<nw;j++) if (row==po[j]&&col==pr[j]) break;
if (j==nw) return 0.; else return w[j];
}
ENDVERBATIM
:* transpose
VERBATIM
static double transpose(void* vv) {
int i, j, nx, ny, rows, cols;
double *x, *y;
/* x will be the transpose of matrix y */
nx = vector_instance_px(vv, &x);
ny = vector_arg_px(1, &y);
rows = (int)*getarg(2);
cols = (int)*getarg(3);
if (ny != nx) {
hoc_execerror("Vector size mismatch", 0);
}
for (i=0;i<rows;i++) {
for (j=0;j<cols;j++) {
x[j*rows+i] = y[i*cols+j];
}
}
return nx;
}
ENDVERBATIM
:* mprintf
VERBATIM
static double mprintf(void* vv) {
int i, j, nx, rows, cols;
double *x;
/* x will be printed out */
nx = vector_instance_px(vv, &x);
rows = (int)*getarg(1);
cols = (int)*getarg(2);
if (nx != rows*cols) {
hoc_execerror("Vector size mismatch", 0);
}
for (i=0;i<rows;i++) {
for (j=0;j<cols;j++) {
printf("%g\t",x[i*cols+j]);
}
printf("\n");
}
return nx;
}
ENDVERBATIM
:* mget(i,j,cols)
VERBATIM
static double mget(void* vv) {
int i, j, nx, rows, cols;
double *x;
nx = vector_instance_px(vv, &x);
i = (int)*getarg(1);
j = (int)*getarg(2);
cols = (int)*getarg(3);
if (i*cols+j >= nx) {
hoc_execerror("Indices out of bounds", 0);
}
return x[i*cols+j];
}
ENDVERBATIM
:* mrow(mat,i,cols)
VERBATIM
static double mrow(void* vv) {
int i, j, nx, ny, rows, cols;
double *x, *y;
nx = vector_instance_px(vv, &x);
ny = vector_arg_px(1, &y);
i = (int)*getarg(2);
cols = (int)*getarg(3);
if (cols!=nx || i>=ny/cols) {
hoc_execerror("Indices out of bounds", 0);
}
for (j=0;j<nx;j++) { x[j] = y[i*cols+j]; }
return nx;
}
ENDVERBATIM
:* mcol(mat,j,cols)
VERBATIM
static double mcol(void* vv) {
int i, j, nx, ny, rows, cols;
double *x, *y;
nx = vector_instance_px(vv, &x);
ny = vector_arg_px(1, &y);
j = (int)*getarg(2);
cols = (int)*getarg(3);
if (cols!=ny/nx || j>=cols) {
hoc_execerror("Indices out of bounds", 0);
}
for (i=0;i<nx;i++) { x[i] = y[i*cols+j]; }
return nx;
}
ENDVERBATIM
:* mset(i,j,cols,val)
VERBATIM
static double mset(void* vv) {
int i, j, nx, rows, cols;
double *x, val;
nx = vector_instance_px(vv, &x);
i = (int)*getarg(1);
j = (int)*getarg(2);
cols = (int)*getarg(3);
val = *getarg(4);
if (i*cols+j >= nx) {
hoc_execerror("Indices out of bounds", 0);
}
return (x[i*cols+j]=val);
}
ENDVERBATIM
:* PROCEDURE install_matrix()
PROCEDURE install_matrix() {
VERBATIM
/* the list of additional methods */
install_vector_method("outprod", outprod);
install_vector_method("mmult", mmult);
install_vector_method("spmult", spmult);
install_vector_method("spget", spget);
install_vector_method("mkspcp", mkspcp);
install_vector_method("chkspcp", chkspcp);
install_vector_method("spltp", spltp);
install_vector_method("transpose", transpose);
install_vector_method("mprintf", mprintf);
install_vector_method("mget", mget);
install_vector_method("mset", mset);
install_vector_method("mrow", mrow);
install_vector_method("mcol", mcol);
ENDVERBATIM
}