import sys
import os
import re
sys.path.append('../packages')
from numpy import *
import random, getopt
from datetime import datetime
from math import *
from pylab import *
from tables import *
from math import exp 
from BeforeAfterExperiment import *
from frame import FrameAxes

from matplotlib import rc
rc('font',**{'family':'sans-serif','sans-serif':['Computer Modern Sans serif']})
rc('text', usetex=True)



dummy_net = SingleThreadNetwork()

 

def display_spike_trains(f, idx, XBeforeMin, XBeforeMax, ax):
    r = constructRecordingsFromH5File(f)
        
    errorbar( (r.before_learning_spikes[idx] - XBeforeMin), 7*ones(len(r.before_learning_spikes[idx])), 0.8 * ones(len(r.before_learning_spikes[idx])), capsize = 0, visible = False, color = 'k')
    errorbar( (r.target_nrn_spikes[idx] - XBeforeMin), 5*ones(len(r.target_nrn_spikes[idx])), 0.8 * ones(len(r.target_nrn_spikes[idx])), capsize = 0, visible = False, color = 'k')
    errorbar( (r.realiz_target_nrn_spikes[idx] - XBeforeMin), 3*ones(len(r.realiz_target_nrn_spikes[idx])), 0.8 * ones(len(r.realiz_target_nrn_spikes[idx])), capsize = 0, visible = False, color = 'k')
    errorbar( (r.after_learning_spikes[idx] - XBeforeMin), ones(len(r.after_learning_spikes[idx])), 0.8 * ones(len(r.after_learning_spikes[idx])), capsize = 0, visible = False, color = 'k')
    
    xlim(0, XBeforeMax - XBeforeMin + 0.04)     
    xlabel("time [sec]")    
    ylim(0,8)
    yticks([])    
    
    text(-0.27, 0.871,'before learning', horizontalalignment='center',
         verticalalignment='center', fontsize = 13, transform = ax.transAxes)
    
    text(-0.27, 0.715,'target $S^*$', horizontalalignment='center',
         verticalalignment='center', fontsize = 13, transform = ax.transAxes)
    text(-0.27, 0.625,'(= rewarded', horizontalalignment='center',
         verticalalignment='center', fontsize = 13, transform = ax.transAxes)
    text(-0.27, 0.535,'spike times)', horizontalalignment='center',
         verticalalignment='center', fontsize = 13, transform = ax.transAxes)
    
    text(-0.27, 0.385,'realizable part', horizontalalignment='center',
         verticalalignment='center', fontsize = 13, transform = ax.transAxes)
    
    text(-0.27, 0.298,'of target $S^*$', horizontalalignment='center',
         verticalalignment='center', fontsize = 13, transform = ax.transAxes)
    
    text(-0.27, 0.125,'after learning', horizontalalignment='center',
         verticalalignment='center', fontsize = 13, transform = ax.transAxes)


def norm_vec(v):
     return sqrt(dot(v,v))
 
def rectangle_kernel(x, width):
    if x < width or x > -width:
        return 1
    return 0
    
def calculate_corr_coeff(spikes, target_spikes, start, end):        
    sigma = 2e-3
    kernel = lambda x:  rectangle_kernel( x, sigma)
         
    kernel_spikes = convolve_spikes(spikes, kernel, 1e-3, start, end, -sigma*10, sigma*10) 
    target_kernel = convolve_spikes(target_spikes, kernel, 1e-3, start, end, -sigma*10, sigma*10)
      
    corr_coeff = dot(kernel_spikes, target_kernel) / ( norm_vec(kernel_spikes) * norm_vec(target_kernel) )
    
    
    return corr_coeff

def generate_spike_corr(h5file, XBeforeMin, XBeforeMax):
    r = constructRecordingsFromH5File(h5file)
    corr = []
    for i in range(len(r.before_learning_spikes)):
        s = clip_window(r.before_learning_spikes[i], XBeforeMin, XBeforeMax, shift = True)
        target_s = clip_window(r.realiz_target_nrn_spikes[i], XBeforeMin, XBeforeMax, shift = True)
        corr.append(calculate_corr_coeff(s, target_s, 0, XBeforeMax - XBeforeMin ))
    after_learn_s = clip_window(r.after_learning_spikes[0], XBeforeMin, XBeforeMax, shift = True)
    target_s =  clip_window(r.realiz_target_nrn_spikes[0], XBeforeMin, XBeforeMax, shift = True)   
    corr.append(calculate_corr_coeff(after_learn_s, target_s, 0, XBeforeMax - XBeforeMin ))
    return  corr
    
def multi_run_and_save_beforeAfter(results_file, h5file, numRuns, Tsim):
    new_rec = Recordings(dummy_net)
    new_rec.before_learning_spikes = []
    new_rec.after_learning_spikes = []
    new_rec.target_nrn_spikes = []
    new_rec.realiz_target_nrn_spikes = []
    for i in range(numRuns):
        sampleIdx = i * int(200/numRuns)
        exper = BeforeAfterExperiment('beforeAfter', experParams = {"Tsim" : Tsim}, modelParams = {"biofeed": {"sampleIdx":sampleIdx, "h5filename" : results_file}})         
        exper.run("longrun")
        r = constructRecordingsFromH5File(exper.data_filename).biofeed
        os.remove(exper.data_filename)
        if i == 0:
            first_run_rec = r
        new_rec.before_learning_spikes.append(array(r.before_learning_nrn_spikes))
        new_rec.after_learning_spikes.append(array(r.after_learning_nrn_spikes))
        new_rec.realiz_target_nrn_spikes.append(array(r.realiz_target_nrn_spikes))
        new_rec.target_nrn_spikes.append(array(r.target_nrn_spikes))
    new_rec.src_filename = results_file    
    new_rec.saveInOneH5File(h5file)
    return new_rec, first_run_rec
    
def plot_spike_corr(corr, p):
    ep = p.experiment
    plot(corr, 'k-')
    plot(corr, 'kd', markersize = 5)
    xticks( arange(0,len(corr)) , [ "%d" % i for i in arange(0, int(ep.Tsim/60), int(ep.Tsim/60/6) ) ] )
    xlim(0,6.1)
    xlabel('time [min]') 
    ylabel('spike correlation')
    yticks( arange(0.50,0.95,0.1), [ '%.2f' % (x,) for x in arange(0.50,0.95,0.1) ] )
    ylim(0.50,0.91)
    
def plot_weightvec_angle(p, r):
    ep = p.experiment
    p = p.biofeed
    target_w = hstack((ones(p.numStrongTargetSynapses)*p.Wmax, zeros(p.numWeakTargetSynapses)*p.Wmax))
    norm_target_w = target_w / sqrt(inner(target_w , target_w))
    normed_weights = r.weights.copy()
    for i in range(normed_weights.shape[1]):
        normed_weights[:,i] /= sqrt(inner(normed_weights[:,i], normed_weights[:,i]))    
    angle = arccos(dot(norm_target_w, normed_weights))    
    plot(arange(0,len(angle)*ep.DTsim*p.samplingTime, ep.DTsim * p.samplingTime), angle, 'k-')
    xlim(0,ep.Tsim+1)        
    xticks(arange(0, ep.Tsim + 1, ep.Tsim/4.0), [ "%d" % i for i in arange(0, float(ep.Tsim+10)/60, int(ep.Tsim/60.0)/4.0 ) ] )
    xlabel('time [min]')
    ylabel('angular error [rad]')
    yticks(arange(0.0,1.01,0.2), [ "%.1f" % x for x in arange(0.0,1.01,0.2) ] )
    ylim(0.0,0.9)
    
def plot_weight_evolution(p, r, ax):
    ep=p.experiment
    
    box()
    xticks([])
    yticks([])
    
    ax_length = 0.8
    ax_gap = 0.08
    
    leg_width = 0.07
    
    ax_pos = ax.get_position().get_points().flatten()
    ax_pos[2] -= ax_pos[0]
    ax_pos[3] -= ax_pos[1]    
    
    leg_ax_pos = list(ax_pos)
    leg_ax_pos[0]  = leg_ax_pos[0] + leg_ax_pos[2]*(ax_length + ax_gap) 
    leg_ax_pos[2] = leg_width * leg_ax_pos[2]
    leg_ax = axes(leg_ax_pos)
    
    
    arr = arange(1,0,-0.01)
    arr.resize(100,1)    
    imshow(arr, aspect = 0.098)
    xticks([])
     
    yticks( arange(0,101,50), ['0', '0.5', '1'] )
    text(1.17, 0.5, '$w/w_{max}$', horizontalalignment = 'center', verticalalignment = 'center', rotation = 90, transform = ax.transAxes)
    
    leg_ax.yaxis.tick_right()
    
    
    new_ax_pos = list(ax_pos)
    new_ax_pos[2] = new_ax_pos[2] * ax_length   
    im_ax = axes(new_ax_pos)
    
    jet()
    imshow(r.weights, aspect = 1.7, interpolation = 'nearest')
    
    yticks( arange(0,101,50) , [ '%d' % (x,) for x in arange(100,-1,-50) ])
    xticks(arange(0, 201, 50), [ "%d" % i for i in arange(0, float(ep.Tsim+10)/60, int(ep.Tsim/60.0)/4.0 ) ] )
    xlabel('time [min]')        
    ylabel('synapse \#')
    jet()
    
    jet()    
    pass
    
    

def plot_multi_run_wstar(directory):    
    p = re.compile('biofeed.*\.h5')
    entries = os.listdir(directory)
    files = [ x for x in entries if p.match(x) ]    
    files.sort()
    print files
    plot_colors = [ 'b', 'r', 'g', 'm', 'k']
    col_n = 0
    for fname in files:
        h5file = openFile(os.path.join(directory,fname), mode = "r")

        all_p = constructParametersFromH5File(h5file)
        all_r = constructRecordingsFromH5File(h5file)
        
        h5file.close();
        
        p = all_p.biofeed
        ep = all_p.experiment
        
        r = all_r.biofeed
        
        strong_syn_avg = average(r.weights[:p.numStrongTargetSynapses], 0)
        strong_syn_std = std(r.weights[:p.numStrongTargetSynapses], 0)
        weak_syn_avg = average(r.weights[p.numStrongTargetSynapses:], 0)
        weak_syn_std = std(r.weights[p.numStrongTargetSynapses:], 0)

        plot( arange(0,(len(strong_syn_avg)-.5)*ep.DTsim*p.samplingTime, ep.DTsim * p.samplingTime), strong_syn_avg, plot_colors[col_n] + '-' )        
        plot( arange(0,(len(weak_syn_avg) -.5)*ep.DTsim*p.samplingTime, ep.DTsim * p.samplingTime), weak_syn_avg, plot_colors[col_n] + '--' )
        col_n += 1
        
    xlim(0,ep.Tsim+1)
    print "range is ", arange(0, ep.Tsim + 1, ep.Tsim/4.0)    
    xticks(arange(0, ep.Tsim + 1, ep.Tsim/4.0), [ "%d" % i for i in arange(0, float(ep.Tsim+10)/60, int(ep.Tsim/60.0)/4.0 ) ] )
    xlabel('time [min]')
    ylim(0,p.Wmax)
    yticks(arange(0,p.Wmax*1.001, p.Wmax/5.0), [ "%.1f" % i for i in arange(0,1.01,0.2) ])
    
    ylabel('avg. weights $(w/w_{max})$')
    
    
def plot_weight_change_fig(r, p):    
    last_weights = []
    initial_weights = []
    for w in r.weights:
        last_weights.append(mean(w[-10:-1]))
        initial_weights.append(w[0])
    
        failed_strong = len(find(last_weights[0:50] < p.Wmax/2))
        failed_weak = len(find(last_weights[50:100] > p.Wmax/2))
    
    plot(arange(100), hstack((p.Wmax * ones(50), 0 * zeros(50))), 'k:')
    plot(arange(100), p.Wmax/2 * ones(100), 'k--')
    plot(arange(100), initial_weights, 'k x', markersize = 3.4)
    plot(arange(100), last_weights, 'k o',markersize = 3.4)
    vlines(arange(100), initial_weights, last_weights)
    xlabel('synapse \#')
    ylim(0, p.Wmax)
    yticks(arange(0,p.Wmax*1.001, p.Wmax/5.0), [ "%.1f" % i for i in arange(0,1.01,0.2) ])
    ylabel('syn. weight $(w/w_{max})$')
    xticks( arange(0,101,50), [ '%d' % (x,) for x in arange(0,101,50)] )    
    
   
if __name__ == "__main__":
    mode = "just_corr"
    mode = 'complete'
    
    XBeforeMin, XBeforeMax = (5,35)   
    if mode == 'complete':
        if len(sys.argv) > 1:
            sim_dir = sys.argv[1]
        else:
            sim_dir = last_created_dir('biofeed.*')
        sim_file = os.path.join(sim_dir, last_file('biofeed.*er18.*', sim_dir))
        
        print " loading simulation filename : ", sim_file
        
        output_name = 'noname'
        if len(sys.argv) > 2:
            output_name = sys.argv[2]
        spikes_h5file = open_experiment_h5file("spikes_corr", output_name)
        
        new_rec, first_run_rec = multi_run_and_save_beforeAfter(sim_file, spikes_h5file, 6, XBeforeMax)        
    else:
        if len(sys.argv) > 1:
            sim_dir = sys.argv[1]
        else:
            sim_dir = last_created_dir('biofeed.*')
        sim_file = os.path.join(sim_dir, last_file('biofeed.*', sim_dir))            
        if len(sys.argv) > 1:
            spikes_h5file = sys.argv[1]
        else:
            spikes_h5file = last_file('spikes_corr.*\.h5$')        
        print " loading h5 filename : ", spikes_h5file
        sim_file = constructRecordingsFromH5File(spikes_h5file).src_filename

        print "loading sim h5 filename : " , sim_file
    
    
        
    sim_r = constructRecordingsFromH5File(sim_file).biofeed
    sim_p = constructParametersFromH5File(sim_file)
    
    pp = sim_p.biofeed
    
    A_plus_kappa_theory = pp.DAStdpRate * pp.stdpApos * pp.KappaApos / (0.01 * pp.Wmax)
    A_minus_kappa_theory = pp.DAStdpRate * pp.stdpApos * pp.KappaAneg / (0.01 * pp.Wmax)
    
    print "A_plus_kappa_theory = ", A_plus_kappa_theory
    print "A_minus_kappa_theory = ", A_minus_kappa_theory
    print " ratio of A_plus and A_minus kappas = ", A_plus_kappa_theory / A_minus_kappa_theory
    
        
    f = figure(1,figsize=(8,9), facecolor = 'w')
    
    f.subplots_adjust(top= 0.93, left = 0.11, bottom = 0.06, right = 0.93, hspace = 0.55, wspace = 0.55)
    clf()
    
    print sim_p
    
    ax = subplot(3, 2, 1, projection = 'frameaxes')
    text(-0.25, 1.13, 'A', fontsize = 'x-large', transform = ax.transAxes )
    plot_multi_run_wstar(sim_dir)
    
    
    
    ax = subplot(3, 2, 2, projection = 'frameaxes')
    text(-0.25, 1.13,  'B', fontsize = 'x-large', transform = ax.transAxes )
    start = 19.1 
    end = 21.6
    spike_idx = 0
    display_spike_trains(spikes_h5file, spike_idx, start, end, ax)

    
    ax = subplot(3, 2, 3, projection = 'frameaxes')
    text(-0.25, 1.13,  'C', fontsize = 'x-large', transform = ax.transAxes )
    corr = generate_spike_corr(spikes_h5file, XBeforeMin, XBeforeMax)
    plot_spike_corr(corr, sim_p)
    
    
    
    ax = subplot(3, 2, 4, projection = 'frameaxes')
    text(-0.25, 1.13,  'D', fontsize = 'x-large', transform = ax.transAxes )
    plot_weightvec_angle(sim_p, sim_r)
    
    
    
    ax = subplot(3, 2, 5)
    text(-0.25, 1.12,  'E', fontsize = 'x-large', transform = ax.transAxes )
    plot_weight_change_fig(sim_r, sim_p.biofeed)
    
    
    ax = subplot(3, 2, 6)
    text(-0.25, 1.12,  'F', fontsize = 'x-large', transform = ax.transAxes )    
    plot_weight_evolution(sim_p, sim_r, ax)    


    savefig("wstar_static_current.eps")