from Controller import *
import matplotlib.pyplot as plt
import pylab
import Image, glob
from time import time
import math
startingTime = time()
nTrainings = 2001
ndim = 30
dim_stim = 10
inputFromImages = 1
controller = Controller(dim_stim,ndim)
Hebb = 0
weights_LtoL4 = []
weights_C1toL4 = []
weights_C2toL23 = []
weights_L4toL23 = []
weights_L23toL5 = []
TimeAndthread = [-1 for x in xrange(16)]
#to specify number of thread employed
nThreads = 7;
controller.pc = h.ParallelContext()
controller.pc.nthread(nThreads)
nThreadTuning = 0
class InputStim:
r=0
g=0
b=0
def setRGB(self,r,g,b):
self.r= r
self.g= g
self.b= b
inputStim = [[InputStim() for x in xrange(dim_stim)] for x in xrange(dim_stim)]
input_files = glob.iglob("./input/*.jpg")
inputImgs = []
for data in input_files:
im = Image.open(data)
tmp = im.load();
inputImgs.append(tmp)
resultFolderName = "results";
controller.resultFolderName = resultFolderName;
controller.variables.nTrainings = nTrainings;
controller.initExtra()
if(Hebb):
controller.setLearningStates(0)
# controller.loadWeightsAndDelays("Network_1600.obj",resultFolderName)
for itr in range(nTrainings):
# if itr<1601:
# continue;
#testing
#saving weights
weightTemp_LtoL4 = []
weightTemp_C1toL4 = []
weightTemp_C2toL23 = []
weightTemp_L4toL23 = []
weightTemp_L23toL5 = []
for index in range(len(controller.NetCons_STDP_LtoL4)):
weightTemp_LtoL4.append(controller.NetCons_STDP_LtoL4[index].weight[0])
weightTemp_C1toL4.append(controller.NetCons_STDP_C1toL4[index].weight[0])
weightTemp_C2toL23.append(controller.NetCons_STDP_C2toL23[index].weight[0])
for index in range(len(controller.NetCons_STDP_L4toL23)):
weightTemp_L4toL23.append(controller.NetCons_STDP_L4toL23[index].weight[0])
weightTemp_L23toL5.append(controller.NetCons_STDP_L23toL5[index].weight[0])
weights_LtoL4.append(weightTemp_LtoL4)
weights_C1toL4.append(weightTemp_C1toL4)
weights_C2toL23.append(weightTemp_C2toL23)
weights_L4toL23.append(weightTemp_L4toL23)
weights_L23toL5.append(weightTemp_L23toL5)
#output the weight dynamics
if(itr%100==0):
fig1 = plt.gcf()
plt.clf()
plt.subplot(511)
plt.plot(weights_L23toL5)
plt.subplot(512)
plt.plot(weights_L4toL23)
plt.subplot(513)
plt.plot(weights_C2toL23)
plt.subplot(514)
plt.plot(weights_C1toL4)
plt.subplot(515)
plt.plot(weights_LtoL4)
fig1.savefig(resultFolderName + "/weightDynamics_L5_L4toL23_C2toL23_C1_L.png",dpi=100)
#save networkstates
if(itr%100==0):
controller.saveWeightsAndDelays(itr)
#test the network with 8 different colour input
if(itr%100==0):
# if(itr==2000):
controller.setLearningStates(0)#stop synaptic modifications
#plotting firing counts
fig1 = plt.gcf()
plt.clf()
for r in range(2):
for g in range(2):
for b in range(2):
r2 = r;
g2 = g;
b2 = b;
if(r==0 and g==0 and b==0):
r2=0.5
g2=0
b2=1
if(r==1 and g==1 and b==1):
r2=1
g2=0.5
b2=0
startTime = time()
for y in range(dim_stim):
for x in range(dim_stim):
inputStim[y][x].setRGB(r2,g2,b2)
controller.setInput(inputStim,0.8)
controller.recordVols()
if(itr==-1):
controller.recordChannelVols()
controller.run()
controller.updateSpikeCount()
if(itr%100==0):
controller.outputFR(itr)
timeSpent = time()-startTime
TimeAndthread[nThreads] = timeSpent
if(nThreadTuning):
if(nThreads>1 and TimeAndthread[nThreads-1]<timeSpent):
nThreads = nThreads-1
nThreadTuning=0
else:
nThreads=nThreads+1
controller.pc.nthread(nThreads)
if(itr%2000==-1):
controller.saveSpikeDetails(r2,g2,b2,itr)
if(itr==-1):
controller.saveChannelSpikeDetails(r2,g2,b2,itr)
plt.subplot(8,4,r*4+g*2+b+1)
plt.imshow(controller.spikeCount_L5,cmap=pylab.gray())
plt.colorbar()
plt.subplot(8,4,r*4+g*2+b+13)
plt.imshow(controller.spikeCount_L23,cmap=pylab.gray())
plt.colorbar()
plt.subplot(8,4,r*4+g*2+b+25)
plt.imshow(controller.spikeCount_L4,cmap=pylab.gray())
plt.colorbar()
#plt.show()
fig1.savefig(resultFolderName+"/"+str(itr),format="png")#dpi=100,
if(~Hebb):
controller.setLearningStates(1)#start synaptic modifications
if(itr==nTrainings-1):
break
print itr
startTime = time()
#weight plot
# controller.setLR(1-(itr/nTrainings))#set learning rate
# for y in range(dim_stim):
# for x in range(dim_stim):
# inputStim[y][x].setRGB(random(),random(),random())
#
if(inputFromImages):
loadedImg = inputImgs[int(len(inputImgs) * itr/nTrainings)]
xBegin = 200*random()
yBegin = 200*random()
tmp_r_tot = 0
tmp_g_tot = 0
tmp_b_tot = 0
for y in range(dim_stim):
for x in range(dim_stim):
tmp = loadedImg[x+xBegin,y+yBegin]
inputStim[y][x].setRGB(tmp[0]/255.0,tmp[1]/255.0,tmp[2]/255.0)
tmp_r_tot+=tmp[0]/255.0;
tmp_g_tot+=tmp[1]/255.0;
tmp_b_tot+=tmp[2]/255.0;
controller.saveColor(tmp_r_tot/(dim_stim*dim_stim),tmp_g_tot/(dim_stim*dim_stim),tmp_b_tot/(dim_stim*dim_stim),itr)
else:
input_r = random();
input_g = random();
input_b = random();
for y in range(dim_stim):
for x in range(dim_stim):
inputStim[y][x].setRGB(input_r,input_g,input_b)
controller.saveColor(input_r,input_g,input_b,itr)
#print (tmp[0]/255.0,tmp[1]/255.0,tmp[2]/255.0)
controller.setInput(inputStim)
controller.recordVols()
if(Hebb):
controller.recordChannelVols()
controller.run()
if(Hebb):
controller.hebbUpdate()
controller.weightNormalization()
# controller.weightNormalization2()
#controller.drawGraph()
timeSpent = time()-startTime
print "iteration time:"+str(timeSpent)+" with nThreads:"+str(nThreads)
print "estimated remaining: at least "+str(timeSpent*(nTrainings-itr))+" + testing Time"
TimeAndthread[nThreads] = timeSpent
# if(nThreadTuning):
# if(nThreads>1 and TimeAndthread[nThreads-1]<timeSpent):
# nThreads = nThreads-1
# nThreadTuning=0
# else:
# nThreads=nThreads+1
# controller.pc.nthread(nThreads)
print (time() - startingTime)
# raw_input("Press Enter to exit...")