from Controller import *
import matplotlib.pyplot as plt
import pylab
from time import time
###########
##@summary: This is to test the network saved in **.obj, which is generated during training by rumMe.py
##@author: Akihiro Eguchi
##Aug 13, 2013

nTrainings = 2001
ndim = 30
dim_stim = 10

controller = Controller(dim_stim,ndim)
resultFolderName = "results";
controller.resultFolderName = resultFolderName;
transOn = 1
singleColor = 1
alteringInput = 0#not fully implemented yet
itr = 2000

class InputStim:
    r=0
    g=0
    b=0
    def setRGB(self,r,g,b):
        self.r= r
        self.g= g
        self.b= b

inputStim = [[InputStim() for x in xrange(dim_stim)] for x in xrange(dim_stim)]

controller.loadWeightsAndDelays("Network_"+str(itr)+".obj",resultFolderName)
controller.setLearningStates(0)#stop synaptic modifications
controller.variables.tstop = 300;
nThreads = 7;
controller.pc = h.ParallelContext()
controller.pc.nthread(nThreads)

if(alteringInput):
    controller.AlteringInputInit();


weightTemp_LtoL4 = []
weightTemp_C1toL4 = []
weightTemp_C2toL23 = []
weightTemp_L4toL23 = []
weightTemp_L23toL5 = []
for index in range(len(controller.NetCons_STDP_LtoL4)):
    weightTemp_LtoL4.append(controller.NetCons_STDP_LtoL4[index].weight[0])
    weightTemp_C1toL4.append(controller.NetCons_STDP_C1toL4[index].weight[0])
    weightTemp_C2toL23.append(controller.NetCons_STDP_C2toL23[index].weight[0])
for index in range(len(controller.NetCons_STDP_L4toL23)):
    weightTemp_L4toL23.append(controller.NetCons_STDP_L4toL23[index].weight[0])
    weightTemp_L23toL5.append(controller.NetCons_STDP_L23toL5[index].weight[0])


plt.subplot(5, 1, 1)
plt.hist(weightTemp_L23toL5, bins=100, range=[0, 0.005*3]);
plt.xlim(0,0.005*3)
plt.xlabel("synaptic weights between V1_L23 and V1_L5")
plt.ylabel("number of synapses")
frame1 = plt.gca()
frame1.axes.get_xaxis().set_visible(False)
# plt.show();

plt.subplot(5, 1, 2)
plt.hist(weightTemp_C2toL23, bins=100, range=[0, 0.005*1.5]);
plt.xlim(0,0.005*1.5)
plt.xlabel("synaptic weights between C2 and V1_L23")
plt.ylabel("number of synapses")
frame1 = plt.gca()
frame1.axes.get_xaxis().set_visible(False)

plt.subplot(5, 1, 3)
plt.hist(weightTemp_L4toL23, bins=100, range=[0, 0.005*1.5]);
plt.xlim(0,0.005*1.5)
plt.xlabel("synaptic weights between V1_L4 and V1_L23")
plt.ylabel("number of synapses")
frame1 = plt.gca()
frame1.axes.get_xaxis().set_visible(False)


plt.subplot(5, 1, 4)
plt.hist(weightTemp_C1toL4, bins=100, range=[0, 0.005]);
plt.xlim(0,0.005)
plt.xlabel("synaptic weights between C1 and V1_L4")
plt.ylabel("number of synapses")
frame1 = plt.gca()
frame1.axes.get_xaxis().set_visible(False)
# plt.show();

plt.subplot(5, 1, 5)
plt.hist(weightTemp_LtoL4, bins=100, range=[0, 0.005]);
plt.xlim(0,0.005)
plt.xlabel("synaptic weights between L and V1_L4")
plt.ylabel("number of synapses")
frame1 = plt.gca()
frame1.axes.get_xaxis().set_visible(False)
plt.show();



#0.5 0 0.5
#0 0 1
#0 1 1
#0 1 0
#1 1 0
#1 0.5 0
#1 0 0
#1 0 1

# purple    0.5 0 1
# blue      0 0 1 
# light-green 0 1 0
# light-blue 0 1 1
# red         1 0 0
# pink        1 0 1
# yellow      1 1 0
# orange      1 0.5 0
if(singleColor==1):
    fig1 = plt.gcf()
    plt.clf()
    for r in range(2):
        r_bak = r
        for g in range(2):
            g_bak = g
            for b in range(2):
                r2 = r;
                g2 = g;
                b2 = b;
                if(r==0 and g==0 and b==0):
                    r2=0.5
                    g2=0
                    b2=1
  
                if(r==1 and g==1 and b==1):
                    r2=1
                    g2=0.5
                    b2=0
                
                for y in range(dim_stim):
                    for x in range(dim_stim):
                        inputStim[y][x].setRGB(r2,g2,b2)
                if(alteringInput):
                    controller.setAlteringInput(inputStim, 0.25)
                else:
                    controller.setInput(inputStim,0.8)
                controller.recordVols()
#                 controller.recordChannelVols()
                controller.run()
                                
                controller.updateSpikeCount()
#                 controller.outputFR(itr)
#                 controller.saveSpikeDetails(r,g,b,itr)
#                 controller.saveChannelSpikeDetails(r,g,b,itr)
                
                
             

                
                plt.subplot(8,4,r*4+g*2+b+1)
                plt.imshow(controller.spikeCount_L5,cmap=pylab.gray())
                plt.colorbar()
                
                plt.subplot(8,4,r*4+g*2+b+13)
                plt.imshow(controller.spikeCount_L23,cmap=pylab.gray())
                plt.colorbar()
                
                plt.subplot(8,4,r*4+g*2+b+25)
                plt.imshow(controller.spikeCount_L4,cmap=pylab.gray())
                plt.colorbar()
                
                
                
                
                
                if(transOn):
                    controller.outputFR_trans(r2,g2,b2,itr)
                    
                    #transformation: varies input with similar colours
                    modVal = 0.01
                    if r2 ==  0:
                        rMod = r2+modVal
                    else:
                        rMod = r2-modVal
                    if g2 ==  0:
                        gMod = g2+modVal
                    else:
                        gMod = g2-modVal
                    if b2 ==  0:
                        bMod = b2+modVal
                    else:
                        bMod = b2-modVal
                    
                    for y in range(dim_stim):
                        for x in range(dim_stim):
                            inputStim[y][x].setRGB(rMod,g2,b2)
                    controller.setInput(inputStim)
                    controller.recordVols()
                    controller.run()
                    controller.updateSpikeCount()
                    controller.outputFR_trans(r2,g2,b2,itr)
                    
                    for y in range(dim_stim):
                        for x in range(dim_stim):
                            inputStim[y][x].setRGB(r2,gMod,b2)
                    controller.setInput(inputStim)
                    controller.recordVols()
                    controller.run()
                    controller.updateSpikeCount()
                    controller.outputFR_trans(r2,g2,b2,itr)
                    
                    for y in range(dim_stim):
                        for x in range(dim_stim):
                            inputStim[y][x].setRGB(r2,g2,bMod)                
                    controller.setInput(inputStim)
                    controller.recordVols()
                    controller.run()
                    controller.updateSpikeCount()
                    controller.outputFR_trans(r2,g2,b2,itr)
                    
                    for y in range(dim_stim):
                        for x in range(dim_stim):
                            inputStim[y][x].setRGB(rMod,gMod,b2)
                    controller.setInput(inputStim)
                    controller.recordVols()
                    controller.run()
                    controller.updateSpikeCount()
                    controller.outputFR_trans(r2,g2,b2,itr)
                    
                    for y in range(dim_stim):
                        for x in range(dim_stim):
                            inputStim[y][x].setRGB(r2,gMod,bMod)
                    controller.setInput(inputStim)
                    controller.recordVols()
                    controller.run()
                    controller.updateSpikeCount()
                    controller.outputFR_trans(r2,g2,b2,itr)
    
                    for y in range(dim_stim):
                        for x in range(dim_stim):
                            inputStim[y][x].setRGB(rMod,g2,bMod)                
                    controller.setInput(inputStim)
                    controller.recordVols()
                    controller.run()
                    controller.updateSpikeCount()
                    controller.outputFR_trans(r2,g2,b2,itr)
    
                    for y in range(dim_stim):
                        for x in range(dim_stim):
                            inputStim[y][x].setRGB(rMod,gMod,bMod)                
                    controller.setInput(inputStim)
                    controller.recordVols()
                    controller.run()
                    controller.updateSpikeCount()
                    controller.outputFR_trans(r2,g2,b2,itr)
    fig1.savefig(resultFolderName+"/"+str(itr),dpi=100)
    plt.show()
else:
#     controller.variables.tstop = 300
    fig1 = plt.gcf()
    plt.clf()
    b = 0
    for r in range(2):
        for g in range(2):
            if (r==g):
                continue
            for y in range(dim_stim):
                for x in range(dim_stim):
                    if y>dim_stim/3 and y<dim_stim*2/3 and x>dim_stim/3 and x>dim_stim*2/3:
                        inputStim[y][x].setRGB(1-r,1-g,b)
                    else:
                        inputStim[y][x].setRGB(r,g,b)
            controller.setInput(inputStim,0.3)
            controller.recordVols()
            controller.run()
            controller.updateSpikeCount()
            controller.outputFR(itr)    
            
            plt.subplot(2,1,r+1)
            plt.imshow(controller.spikeCount_L4,cmap=pylab.gray())
            plt.colorbar()
            
            #controller.drawGraph()
            controller.saveSpikeDetails(r,g,b,111110);
            
    fig1.savefig(resultFolderName+"/multiColTest300_normal"+str(itr),dpi=100)
            
controller.setLearningStates(1)#start synaptic modifications



    
# raw_input("Press Enter to exit...")