import os, sys
sys.path.append('../packages')
from pylab import *
from tables import *
from numpy import *
from math import *
from pypcsimplus import *
import pypcsimplus
import pypcsimplus.common
from pypcsimplus.clusterUtils import *
from frame import FrameAxes
import numpy
import time

from matplotlib import rc
rc('font',**{'family':'sans-serif','sans-serif':['Computer Modern Sans serif']})
rc('text', usetex=True)

random.seed(123525)

if len(sys.argv) > 1:
    h5filename = sys.argv[1]

else:
    h5filename = last_file(".*\.h5")
    
print " loading h5 filename : ", h5filename

h5file = openFile(h5filename, mode = "r", title = "Biofeedback DASTDP Experiment results")

p = constructParametersFromH5File(h5file)
r = constructRecordingsFromH5File(h5file)

other_circ_not_ou_weights = r.other_circ_not_ou_weights.copy()
other_circ_ou_weights = r.other_circ_ou_weights.copy()

other_circ_ou_weights *= 1.0/p.WHighOUScale
other_circ_not_ou_weights *= 1.0/p.WLowOUScale

other_weights = vstack((other_circ_not_ou_weights, other_circ_ou_weights))    

reinforced_circ_weights = vstack((r.reinforced_ou_weights, r.reinforced_other_weights))

reinforced_circ_weights *= 1.0/p.WLowOUScale

p.Wmax = 2 * p.Wexc

exc_ou_spikes = []
for i in r.exc_ou_nrn_idxs:
    exc_ou_spikes.append(r.spikes[i])
    
exc_other_spikes = []
for i in r.exc_other_nrn_idxs:
    exc_other_spikes.append(r.spikes[i])


reinforced_circ_avg_weights = average(reinforced_circ_weights, 0)
reinforced_circ_std_weights = std(reinforced_circ_weights, 0)

reinforced_ou_avg_weights = average(r.reinforced_ou_weights, 0)
reinforced_ou_std_weights = std(r.reinforced_ou_weights, 0)

reinforced_other_avg_weights = average(r.reinforced_other_weights, 0)
reinforced_other_std_weights = std(r.reinforced_other_weights, 0)

other_avg_weights = average(other_weights, 0)
other_std_weights = std(other_weights, 0)


other_circ_avg_not_ou_weights = average(other_circ_not_ou_weights, 0)
other_circ_std_not_ou_weights = std(other_circ_not_ou_weights, 0)

other_circ_avg_ou_weights = average(other_circ_ou_weights, 0)
other_circ_std_ou_weights = std(other_circ_ou_weights, 0)


analysis_intervals = [ (300,360), (600,660), (900,960), (1140,1200) ]

gray_colors = ['0.75', '0.6', '0.4', '0.0']
plot_colors = [ 'r', 'g', 'b', 'm' ]

IntervalLength = analysis_intervals[0][1] - analysis_intervals[0][0]


           
    
w_hist_array = []
numBins = 30
for pl_i in range(len(analysis_intervals)):    
    ind_start = analysis_intervals[pl_i][0] / (p.Tsim / len(other_weights[0]))
    ind_end = analysis_intervals[pl_i][1] / (p.Tsim / len(other_weights[0]))
    w_hist_array.append(histogram(mean(other_weights[:,ind_start:ind_end]/p.Wmax,1), bins=numBins, range = (0,1), normed = False)[0]/float(len(other_weights)))
    
    
avg_rates = []
numBins = 20        
for t in range(len(analysis_intervals)):
    avg_rates.append([])
    for i in range(len(r.spikes[:p.nExcNeurons])):
        avg_rates[t].append( calc_rate_2(clip_window(r.spikes[i], analysis_intervals[t][0], analysis_intervals[t][1]), 1, 1, IntervalLength) )
    

rate_hist_array = []
for t in range(len(analysis_intervals)):    
    rate_hist_array.append(histogram(avg_rates[t], bins = numBins, range = (0,18), normed = False)[0]/float(len(avg_rates[t])))
    
    

f = figure(1, figsize = (8,6), facecolor = 'w')        
f.subplots_adjust(top= 0.95, left = 0.085, bottom = 0.11, right = 0.95, hspace = 0.46, wspace = 0.30)

clf()

plot_labels = ['300-360 sec', '600-660 sec', '900-960 sec', '1140-1200 sec' ]

ax = subplot(2,2,1, projection = 'frameaxes')
for pl_i in range(len(analysis_intervals)):
    plot(arange(0,1.0+0.9/len(w_hist_array[pl_i]),(1.0+1.0/len(w_hist_array[pl_i]))/len(w_hist_array[pl_i])), w_hist_array[pl_i], linewidth = 2, color = plot_colors[pl_i], label = plot_labels[pl_i] )

legend(prop = matplotlib.font_manager.FontProperties(size = 'xx-small'))
    
xlabel('synaptic weight $(w/w_{max})$')
ylabel('frac. synapses [\%]')    
yticks(arange(0,0.161,0.04), [ "%d" % (x) for x in arange(0, 0.161, 0.04)*100 ] )
xticks(arange(0,1.01,0.2), [ "%.1f" % (x) for x in arange(0,1.01,0.2) ] )
ylim(0,0.17)

text(-0.2,1.04, 'A',fontsize = 'x-large', transform = ax.transAxes)

plot_labels = ['300-360 sec', '600-660 sec', '900-960 sec', '1140-1200 sec' ]



ax = subplot(2,2,2, projection = 'frameaxes')
for pl_i in range(len(analysis_intervals)):
    plot(arange(0,18.0+17.9/len(rate_hist_array[pl_i]),(18.0+17.9/len(rate_hist_array[pl_i]))/len(rate_hist_array[pl_i])), rate_hist_array[pl_i], linewidth = 2, color = plot_colors[pl_i], label = plot_labels[pl_i] )
    
ylabel('frac. neurons [\%]')    
xlabel('firing rate [Hz]')
xticks(arange(0,18,4), [ "%d" % (x) for x in arange(0,18,4) ])    
yticks(arange(0, 0.251, 0.05), [ "%d" % (x) for x in arange(0, 0.251, 0.05)*100 ])
ylim(0, 0.251)

text(-0.2, 1.04, 'B', fontsize = 'x-large', transform = ax.transAxes)
    
legend(prop = matplotlib.font_manager.FontProperties(size = 'xx-small'))

IPcontroller = {'host' : 'cluster1', 'engine_port' : 32100, 'rc_port' : 32101, 'task_port' : 32102}
IPcluster = Cluster(ClusterConfig(configFile= './clusterconf.py', controller=IPcontroller))
IPcluster.start(waitafter = 3.0)
IPcluster.connect()
rc = IPcluster.getRemoteControllerClient()

print "len(rc) is ", len(rc)
while len(rc) < 30:
    print "len(rc) is ", len(rc)
    time.sleep(1)
    
rc.execute('from pypcsimplus import *')
rc.execute('import pypcsimplus')    
rc.execute('import pypcsimplus.common')    

corrDT = 0.2e-3
leftCorrBound = -100e-3
rightCorrBound = 100e-3
numNrnPairs = 200
pairs = set([])
while len(pairs) < numNrnPairs:
    rnd_pair  = (random.randint(0,p.nExcNeurons-1),random.randint(0,p.nExcNeurons-1))
    if not rnd_pair[0] == rnd_pair[1]:
        pairs.add(rnd_pair)

i = 0
figure_letter = ['C', 'D']
for pl_i in (0,len(analysis_intervals)-1):
    print "processing cross_covariance for interval %d" % (pl_i)
    leftSimT = analysis_intervals[pl_i][0]
    rightSimT = analysis_intervals[pl_i][1]
    
    arg_list = []
    for n1, n2 in pairs:        
        arg_list.append([r.spikes[n1], r.spikes[n2], corrDT, (leftSimT, rightSimT), (leftCorrBound, rightCorrBound)])
                
    corr = rc.map(pypcsimplus.common.cross_covariance_spikes_1arg, arg_list)
    
    avg_corr = mean(vstack(tuple(corr)), 0)
    
    ax = subplot(2,2,i + 3, projection = 'frameaxes')     
    text(-0.2,1.04, figure_letter[i],fontsize = 'x-large', transform = ax.transAxes)
    vlines( arange(int(leftCorrBound/corrDT) * corrDT, int(rightCorrBound/corrDT) * corrDT + corrDT/100, corrDT), zeros(len(avg_corr)), avg_corr, linewidth = 1)
    xlim(leftCorrBound,rightCorrBound)
    xticks(arange(-100e-3,101e-3,50e-3), [  '%d' % (x) for x in arange(-100,101,50)])
    
    xlabel("$\\tau$ [ms]")
    ylabel("$C(\\tau)\quad (\\cdot 10^{-4})$")
    yticks(arange(-0.0005,0.00101,0.0005), [ "%d" % (x) for x in arange(-5,11,5)])
    
    ylim(-0.0007,0.00101)
    i += 1
    
IPcluster.stop()

figure_letter = ['D', 'F']
pl_array = [0,3]
    
savefig('correlations.eps')


f = figure(2, figsize = (8,6), facecolor = 'w')        
f.subplots_adjust(top= 0.95, left = 0.085, bottom = 0.11, right = 0.95, hspace = 0.46, wspace = 0.30)

clf()

plot_labels = ['300-360 sec', '600-660 sec', '900-960 sec', '1140-1200 sec' ]

ax = subplot(2,2,2, projection = 'frameaxes')

avg_rates = []
numBins = 20        
for t in range(len(analysis_intervals)):
    avg_rates.append([])
    for i in range(len(exc_ou_spikes)):
        avg_rates[t].append( calc_rate_2(clip_window(exc_ou_spikes[i], analysis_intervals[t][0], analysis_intervals[t][1]), 1, 1, IntervalLength) )

rate_hist_array = []
for t in range(len(analysis_intervals)):
    rate_hist_array.append(numpy.histogram(avg_rates[t], bins = numBins, range = (0,14), normed = False)[0]/float(len(avg_rates[t])))


for pl_i in range(len(analysis_intervals)):
    plot(arange(0,14.0+13.9/len(rate_hist_array[pl_i]),(14.0+13.9/len(rate_hist_array[pl_i]))/len(rate_hist_array[pl_i])), rate_hist_array[pl_i], linewidth = 2, color = plot_colors[pl_i], label = plot_labels[pl_i] )
        
ylabel('frac. neurons [\%]')    
xlabel('firing rate [Hz]')
xticks(arange(0,15,4), [ "%d" % (x) for x in arange(0,15,4) ])    
yticks(arange(0, 0.251, 0.05), [ "%d" % (x) for x in arange(0, 0.251, 0.05)*100 ])
ylim(0, 0.251)

text(-0.2, 1.04, 'B', fontsize = 'x-large', transform = ax.transAxes)
    
legend(prop = matplotlib.font_manager.FontProperties(size = 'xx-small'))


IPcontroller = {'host' : 'cluster1', 'engine_port' : 32100, 'rc_port' : 32101, 'task_port' : 32102}
IPcluster = Cluster(ClusterConfig(configFile = './clusterconf.py', controller=IPcontroller))
IPcluster.start(waitafter = 3.0)
IPcluster.connect()
rc = IPcluster.getRemoteControllerClient()

print "len(rc) is ", len(rc)
while len(rc) < 30:
    print "len(rc) is ", len(rc)
    time.sleep(1)



rc.execute('import pypcsimplus')
rc.execute('from pypcsimplus import *')    
rc.execute('import pypcsimplus.common')    

corrDT = 0.2e-3
leftCorrBound = -100e-3
rightCorrBound = 100e-3
numNrnPairs = 200
pairs = set([])
while len(pairs) < numNrnPairs:
    rnd_pair  = (random.randint(0,len(exc_ou_spikes)-1),random.randint(0,len(exc_ou_spikes)-1))
    if not rnd_pair[0] == rnd_pair[1]:
        pairs.add(rnd_pair)

i = 0
figure_letter = ['C', 'D']
for pl_i in (0,len(analysis_intervals)-1):
    print "processing cross_covariance for interval %d" % (pl_i)
    leftSimT = analysis_intervals[pl_i][0]
    rightSimT = analysis_intervals[pl_i][1]
    
    arg_list = []
    for n1, n2 in pairs:        
        arg_list.append([exc_ou_spikes[n1], exc_ou_spikes[n2], corrDT, (leftSimT, rightSimT), (leftCorrBound, rightCorrBound)])
                
    corr = rc.map(pypcsimplus.common.cross_covariance_spikes_1arg, arg_list)
    
    avg_corr = mean(vstack(tuple(corr)), 0)
    
    ax = subplot(2,2,i + 3, projection = 'frameaxes')     
    text(-0.2,1.04, figure_letter[i],fontsize = 'x-large', transform = ax.transAxes)
    vlines( arange(int(leftCorrBound/corrDT) * corrDT, int(rightCorrBound/corrDT) * corrDT + corrDT/100, corrDT), zeros(len(avg_corr)), avg_corr, linewidth = 1)
    xlim(leftCorrBound,rightCorrBound)
    xticks(arange(-100e-3,101e-3,50e-3), [  '%d' % (x) for x in arange(-100,101,50)])
    
    xlabel("$\\tau$ [ms]")
    ylabel("$C(\\tau)\quad (\\cdot 10^{-4})$")
    yticks(arange(-0.0005,0.00101,0.0005), [ "%d" % (x) for x in arange(-5,11,5)])
    
    ylim(-0.0005,0.00071)
    i += 1
    
IPcluster.stop()

figure_letter = ['D', 'F']
pl_array = [0,3]

savefig('dynamicsHighOUNeurons.eps')


f = figure(3, figsize = (8,6), facecolor = 'w')        
f.subplots_adjust(top= 0.95, left = 0.085, bottom = 0.11, right = 0.95, hspace = 0.46, wspace = 0.30)

clf()

plot_labels = ['300-360 sec', '600-660 sec', '900-960 sec', '1140-1200 sec' ]

ax = subplot(2,2,2, projection = 'frameaxes')

avg_rates = []
numBins = 20        
for t in range(len(analysis_intervals)):
    avg_rates.append([])
    for i in range(len(exc_other_spikes)):
        avg_rates[t].append( calc_rate_2(clip_window(exc_other_spikes[i], analysis_intervals[t][0], analysis_intervals[t][1]), 1, 1, IntervalLength) )

rate_hist_array = []
for t in range(len(analysis_intervals)):
    rate_hist_array.append(numpy.histogram(avg_rates[t], bins = numBins, range = (0,18), normed = False)[0]/float(len(avg_rates[t])))


for pl_i in range(len(analysis_intervals)):
    plot(arange(0,18.0+17.9/len(rate_hist_array[pl_i]),(18.0+17.9/len(rate_hist_array[pl_i]))/len(rate_hist_array[pl_i])), rate_hist_array[pl_i], linewidth = 2, color = plot_colors[pl_i], label = plot_labels[pl_i] )
    
ylabel('frac. neurons [\%]')    
xlabel('firing rate [Hz]')
xticks(arange(0,18,4), [ "%d" % (x) for x in arange(0,18,4) ])    
yticks(arange(0, 0.251, 0.05), [ "%d" % (x) for x in arange(0, 0.251, 0.05)*100 ])
ylim(0, 0.151)

text(-0.2, 1.04, 'B', fontsize = 'x-large', transform = ax.transAxes)
    
legend(prop = matplotlib.font_manager.FontProperties(size = 'xx-small'))

IPcontroller = {'host' : 'cluster1', 'engine_port' : 32100, 'rc_port' : 32101, 'task_port' : 32102}
IPcluster = Cluster(ClusterConfig(configFile = './clusterconf.py', controller=IPcontroller))
IPcluster.start(waitafter = 3.0)
IPcluster.connect()
rc = IPcluster.getRemoteControllerClient()

print "len(rc) is ", len(rc)
while len(rc) < 30:
    print "len(rc) is ", len(rc)
    time.sleep(1)


rc.execute('import pypcsimplus')
rc.execute('from pypcsimplus import *')    
rc.execute('import pypcsimplus.common')    


corrDT = 0.2e-3
leftCorrBound = -100e-3
rightCorrBound = 100e-3
numNrnPairs = 200
pairs = set([])
while len(pairs) < numNrnPairs:
    rnd_pair  = (random.randint(0,len(exc_other_spikes)-1),random.randint(0,len(exc_other_spikes)-1))
    if not rnd_pair[0] == rnd_pair[1]:
        pairs.add(rnd_pair)

i = 0
figure_letter = ['C', 'D']
for pl_i in (0,len(analysis_intervals)-1):        
    leftSimT = analysis_intervals[pl_i][0]
    rightSimT = analysis_intervals[pl_i][1]
    
    arg_list = []
    for n1, n2 in pairs:        
        arg_list.append([exc_other_spikes[n1], exc_other_spikes[n2], corrDT, (leftSimT, rightSimT), (leftCorrBound, rightCorrBound)])
                
    corr = rc.map(pypcsimplus.common.cross_covariance_spikes_1arg, arg_list)
    
    avg_corr = mean(vstack(tuple(corr)), 0)
    
    ax = subplot(2,2,i + 3, projection = 'frameaxes')     
    text(-0.2,1.04, figure_letter[i],fontsize = 'x-large', transform = ax.transAxes)
    vlines( arange(int(leftCorrBound/corrDT) * corrDT, int(rightCorrBound/corrDT) * corrDT + corrDT/100, corrDT), zeros(len(avg_corr)), avg_corr, linewidth = 1)
    xlim(leftCorrBound,rightCorrBound)
    xticks(arange(-100e-3,101e-3,50e-3), [  '%d' % (x) for x in arange(-100,101,50)])
    
    xlabel("$\\tau$ [ms]")
    ylabel("$C(\\tau)\quad (\\cdot 10^{-4})$")
    yticks(arange(-0.0005,0.00161,0.0005), [ "%d" % (x) for x in arange(-5,17,5)])
    
    ylim(-0.00101,0.00161)
    i += 1
    
    
IPcluster.stop()



figure_letter = ['D', 'F']
pl_array = [0,3]


savefig('dynamicsLowOUNeurons.eps')