# from net import *
from mysetup import *
import numpy
# Declarations
g = h.Graph()
# objref g,rdm[2],ind,XO,YO,animv[2],tmpobj,tvec,vec[10],scr[2],tmpvec,veclist,ind
# objref netcon
# objref sl,gx[2],drv
ncell=100
cells = [] # h.List()
nclist = h.List()
vec = []
rdm = []
scr=[]
animv=[]
for ii in xrange(0,2): rdm.append(h.Random())
for ii in xrange(0,10): vec.append(h.Vector())
for ii in xrange(0,2): scr.append(h.Vector())
for ii in xrange(0,2): animv.append(h.Vector())
tvec=h.Vector()
ind=h.Vector()
veclist=h.List()
sl=h.SectionList() # empty list
drv=h.Vector() # drawing vector
Gshp=0
tveclen=0
# Cell template
class Cell:
def __init__ (self,_id,_x,_y,_z):
self.pp = h.IntervalFire()
self.pp.tau = 10
self.pp.invl = 20
self.id = _id
self.x = _x
self.y = _y
self.z = _z
def is_art ():
return 1
def connect2target (targ,nc):
nc=h.NetCon(pp, targ)
return nc
def M ():
return pp.M()
# Network specification
# createnet()
def createnet ():
global cells,nclist,netcon
netcon = h.nil
cells = []
for i in xrange(0,ncell): cells.append(Cell(i,i,0,0))
wire()
# wire():: full non-self connectivity
# artificial cell templates have obj.pp
# params: ncell
# creates nclist: list of NetCons
def wire ():
nclist.remove_all()
for pre in cells:
for post in cells:
if pre!=post: nclist.append(h.NetCon(pre.pp,post.pp))
# ranwire()
def ranwire ():
nclist.remove_all()
rdm[0].discunif(0,ncell-1)
for i in xrange(0,ncell):
proj=rdm[0].repick()
if proj<ncell/2:
beg=proj
end=proj+int(ncell/2.0)-1
else:
beg=proj-int(ncell/2.0)+1
end=proj
for j in xrange(beg,end+1):
if i!=j:
netcon = h.NetCon(cells[i].pp,cells[j].pp)
nclist.append(netcon)
# parameters
h.tstop = 500
ncell = 10
ta=10
w=-0.01
delay=4
low=10
high=11
# routines: weight(),delay(),tau(),interval()
def weight (w=-0.3):
vv=h.Vector()
if w == -1: # a bad choice of code value
for i in xrange(0,int(nclist.count())):
vv.append(nclist.o(i).weight[0])
print vv.min,vv.max,vv.mean,vv.stdev
else:
for n in nclist: n.weight[0] = w
# weight2(WT,EXCLUDE_VEC) :: set weight to WT
# unless in EXCLUDE_VEC then set wt. to 0
def weight2 (w,ecv):
for i,nc in zip(range(len(nclist)),nclist):
if ecv.contains(i):
nc.weight[0]=0
else:
nc.weight[0]=w
def setdelay (delay):
for n in nclist: n.delay = delay
def settau (tau):
for c in cells: c.pp.tau = tau
# //** interval(low,high) randomly sets cells to have periods between low and high
def interval (low,high):
rdm[0].uniform(low,high)
vec[0].resize(ncell)
vec[0].setrand(rdm[0])
for i in xrange(0,len(cells)):
cells[i].pp.invl = vec[0][i]
# //** setparams() sets weight, delay, tau and intervals
def setparams (w=-0.2,delay=2,ta=10,low=10,high=12):
weight(w)
setdelay(delay)
settau(ta)
interval(low, high)
# //* Run code
# //** savspks() -- save to a vector
def savspks ():
for ii in xrange(0,2): animv[ii].resize(0)
for c in cells:
if h.cvode.netconlist(c, '', '').count()==0:
tmpobj=h.NetCon(c.pp, None)
else:
tmpobj=cvode.netconlist(c, '', '').object(0)
tmpobj.record(tvec,ind)
animv[0].append(h.objnum(tmpobj))
animv[1].append(h.objnum(tmpobj.pre)) # in this case are all in a row anyway
# //** showspks() -- show spikes on graph g
def markspks ():
g.erase()
ind.mark(g,tvec,"O",4,2,1)
g.flush()
g.exec_menu("View = plot")
def showspks ():
# ind.mark(g,tvec,"O",8,2,1)
for x in tvec:
g.beginline(3,1)
g.line(x,ind.min())
g.line(x,ind.max())
g.flush()
g.exec_menu("View = plot")
# //** syncer() :: returns sync measure 0 to <1
# // measures how well spikes "fill up" the time
# // assumes spike times in tvec, tstop
# // param: width
# // syncer doesn't take account of prob of overlaps
# // due to too many spikes stuffed into too little time
def syncer (vec=tvec,width=1):
t0=-1; cnt=0
for tt in vec:
if tt>=t0+width: t0=tt; cnt+=1
return 1.0-cnt/(h.tstop/width)
# ** autorun() changes weights
def autorun ():
veclist.remove_all()
rvc() # clear the storage vectors
for w in [-0.5,-0.3,-0.1,-0.01,-0.001]:
weight(w)
run()
savevec([ind,tvec])
vec[1].append(w)
s = syncer()
vec[2].append(s)
print w,s
g.erase()
vec[2].plot(g,vec[1])
vec[2].mark(g,vec[1],"O",8,2,1)
g.exec_menu("View = plot")
# //** autorun1() changes connections density
def autorun1 ():
veclist.remove_all()
rvc()
maxx=-0.1
pij_inc=0.1
S=ncell*ncell
vec[3].resize(0)
for ii in xrange(0,10):
C = (1-ii*pij_inc) # // percent convergence
w=maxx/C # // scale weight up as convergence goes down
setparams()
weight2(w,vec[3])
run()
savevec([ind,tvec])
print S-vec[3].size(),syncer()
vec[1].append((S-vec[3].size())/S)
vec[2].append(syncer())
rdm[0].discunif(0,S-1)
rdmunq(vec[3],0.1*S,rdm[0]) # increase those set to 0
g.erase()
vec[2].plot(g,vec[1])
vec[2].mark(g,vec[1],"O",8,2,1)
g.exec_menu("View = plot")
# //* Utility functions
# //** savevec(list of vectors) add vectors onto veclist
def savevec (vecs):
for vec in vecs:
tmpvec = h.Vector(vec.size())
tmpvec.copy(vec)
veclist.append(tmpvec)
tmpvec = h.nil
# //** rvc() clears vec[0..9]
def rvc ():
for ii in xrange(0,10):
vec[ii].resize(0)
# //** setdensity(pij) sets connection density to 0<pij<1
def setdensity (pij):
maxx=-0.1
S=nclist.count()
rdm[0].discunif(0,S-1)
vec[3].resize(0)
rdmunq(vec[3],(1-pij)*S,rdm[0]) # // number to set to 0
C = (1-pij) # // percent convergence
w=maxx/C # // scale weight up as convergence goes down
setparams()
weight2(w,vec[3])
# //** rdmunq(vec,n,rdm) -- augment vec1 by n unique vals from rdm
def rdmunq (vec,n,rm):
num=0
flag=1
loop=0
scr[0].resize(n*4) # // hopefully will get what we want
while flag:
scr[0].setrand(rm)
for ii in xrange(0,int(scr[0].size())):
xx=scr[0].x[ii]
if not vec.contains(xx):
vec.append(xx)
num+=1
if num==n:
flag=0
break
loop+=1
if loop==10:
print "rdmunq ERR; inf loop"
flag=0
break
# //** rdmord (vec,n) randomly ordered numbers 0->n-1 in vec
def rdmord (vec,n):
rdm[0].uniform(0,100)
scr[0].resize(n)
scr[0].setrand(rdm[0])
scr[0].sortindex(vec)
# // vcount (num,vec)
def vcount (num,vec):
scr[0].where(vec,"==",num)
return scr[0].size()
# //* Mapping functions
# //** getcnum(CELL_OBJ) return index given cell object
def getcnum (cell):
return cell.id
# h.sprint(h.tstr,"%s",cell)
# h("sscanf("
# if h.sscanf(h.tstr,"IntervalFire[%d]",&x) != 1:
# x=-1
# return x
# //** fconn(PREVEC,POSTVEC) places values of
# // pre- and post-syn cells in parallel vectors
# // only lists pairs with non-zero connections
# // getcnum() returns index of cell obj
def fconn (prev,post):
prev.resize(0)
postv.resize(0)
for ii in xrange(0,nclist.count()):
XO=nclist.o(ii)
if XO.weight[0]!=0:
prev.append(getcnum(XO.pre))
postv.append(getcnum(XO.syn))
# //** showconns() -- show all the connections as line segments
def showconns ():
g.erase()
fconn(scr[0],scr[1])
for ii in xrange(0,scr[0].size()):
pr=scr[0].x[ii]
po=scr[1].x[ii]
drawline(pr,po,10)
g.flush()
# //** showconv1(ID,color) -- show convergence to one cell as line seg
def showconv1 (ID,colr=2):
fconn(scr[0],scr[1])
for ii in xrange(0,scr[0].size()):
pr=scr[0].x[ii]
po=scr[1].x[ii]
if po==ID:
drawline(pr,po,10,colr,3)
print pr
print ''
g.flush()
# //** showdiv1() -- show divergence from one cell as line seg
def showdiv1 (id,colr=2):
fconn(scr[0],scr[1])
for ii in xrange(0,scr[0].size()):
pr=scr[0].x[ii]
po=scr[1].x[ii]
if pr==id:
drawline(pr,po,10,colr,8)
print po
print ''
g.flush()
def xpos (x,cols):
return x%cols
def ypos (co1s,y):
return int(co1s/y)
# //*** func distn () calc distance
def distn (c1,c2,cols):
from math import sqrt
xd=xpos(c1,cols)-xpos(c2,cols)
yd=ypos(c1,cols)-ypos(c2,cols)
return sqrt(xd*xd+yd*yd)
# //*** distwire(pij)
def distwire (pij):
allsyns = nclist.count()
total=pij*allsyns # how many syns to set
rdm[1].uniform(0,1) # for flipping coin
# // maxdist==12.728 for 10x10; mindist=1 (neighbors)
maxdist=0.33*distn(0,ncell-1,sqrt(ncell)) # the full dist from lower left to upper right
maxwt=-0.9/pij # norm wt by convergence
loop=cnt=0 # counters
for i in xrange(0,allsyns): nclist.o(i).weight[0] = 0 # # clear weights
while cnt<total and loop<4:
rdmord(vec[3],allsyns) # # test each synapse in random order
for ii in xrange(0,vec[3].size()):
XO=nclist.object(vec[3].x[ii]) # # pick a synapse
# max prob of connection is 0.8*(1-mindist/maxdist)~74% for 10x10
# zero prob of diag connection from corner to corner
prob = 1.0*(1-(distn(getcnum(XO.pre),getcnum(XO.syn),sqrt(ncell))/maxdist))
if rdm[1].repick<prob:
XO.weight=maxwt
cnt+=1
if cnt>=total:
break # finished
print cnt,total
if cnt<total:
print "distwire ERR: target ", total, "set ", cnt
# //*** drawline(beg,end,columns[,color,line_width])
def drawline (beg,end,cols,clr=4,lwid=1):
# local beg,end,cols,clr,lwid
g.beginline(clr,lwid)
g.line(xpos(beg,cols),ypos(beg,cols))
g.line(xpos(end,cols),ypos(end,cols))
# //* Animation
# //** animplot() put up the shape plot for hinton diagram
def animplot ():
gx[Gshp] = h.PlotShape(sl,0)
h.flush_list.append(gx[Gshp])
ctern(gx[Gshp],0,2)
gx[Gshp].view(1,0,1.1,1.1,500,200,100,100)
drawcells()
# //* Ternary color map
def ctern (cm,minc,maxc):
cm.colormap(3)
cm.colormap(0, 255, 0, 0)
cm.colormap(1, 255, 255, 0)
cm.colormap(2, 0, 0, 255)
cm.scale(minc, maxc)
# //** drawcells() draw squares of hinton diagram
def drawcells ():
if ncell!=100:
print "ERROR: drawcells() currently written for ncell=100"
return
gx[Gshp].erase_all()
drv.resize(ncell)
xoff=1.1
yoff=0.1
wdt=0.1
nx=ny=10
for i in xrange(0,nx):
for j in xrange(0,ny):
#gx[Gshp].hinton(&drv.x[j*nx+i],(i+.5)*wdt+xoff,(j+.5)*wdt+yoff,wdt)
gx[Gshp].hinton(drv.x[j*nx+i],(i+.5)*wdt+xoff,(j+.5)*wdt+yoff,wdt)
gx[Gshp].size(1, 2.2, 0, 1.2)
gx[Gshp].exec_menu("Shape Plot")
# //** chkhint(cell#) light up a location for a single cell
def chkhint (cellid):
drv.fill(0)
drv.x[cellid]=2
gx.flush()
# //** anim() animates sim stored in tvec,ind
def anim ():
tstep=0.1
sz = ind.size()-1
gx[Gshp].exec_menu("Shape Plot")
scr[0].copy(tvec)
scr[0].add(500*tstep) # how many steps to keep it illuminated
drv.fill(0)
ii=jj=0
for tt in numpy.linspace(0,h.tstop,h.tstop/tstep+1):
while ii < sz and tt > tvec.x[ii]:
drv.x[animv.indwhere("==",ind.x[ii])]=2
ii = ii + 1
while jj < sz and tt > scr[0].x[jj]:
drv.x[animv.indwhere("==",ind.x[jj])]=0
jj = jj + 1
gx[Gshp].flush()
doEvents()
# //* Run sequences
h.cvode_active(1)
def run ():h.run()
def runseq ():
createnet()
setparams()
savspks()
run()
# from net import *
# ncell=10
# w=-1e-6
# runseq()
# g=new Graph()
# g.erase_all
# markspks()
# showspks()
# w=-0.3 # and repeat
# setparams()
# run() # then erase graph and show
createnet()
savspks()
setparams(w=-0.1)
# don't repeat runseq() unless want to change # of cells
run();g.erase_all();markspks();showspks()
# animplot()
# anim()