# -*- coding: utf-8 -*-
"""
Created on Sun Mar  6 18:22:04 2011

@author: -
"""
import os
import numpy
from matplotlib import pyplot
from neuronpy.graphics import spikeplot
from bulbspikes import *
from neuronpy.util import spiketrain
from params import sim_var

homedir = os.path.join(os.path.relpath('..'))
analysis_path = homedir

def format_axes(ax, dt=1, ylim=(0.,4.)):
    #ax.set_xticks(numpy.arange(0,num_intervals,(num_intervals-1)/4.))
    #ax.set_xticklabels(['$-\pi$','$-\pi/2$','$0$','$\pi/2$','$\pi$'], fontsize=18)
    xlim = ax.get_xlim()
    timesteps=int((xlim[1]*dt-xlim[0]*dt)/2.)
    ax.set_xticks(numpy.linspace(xlim[0],xlim[1],5))
    ax.set_xticklabels(numpy.asarray(numpy.linspace(-timesteps,timesteps,5), dtype=int))
    ax.set_xlabel('lag (ms)')
    ax.set_ylim(ylim)
    ax.set_ylabel('Synchronization magnitude')

def draw_cell(cellid, ax, color='black'):
    xloc = 10+cellid*20
    # Lateral dends
    y = numpy.abs(numpy.subtract(range(101), xloc))
    yvec = numpy.log(numpy.add(y,1))
    ax.plot(range(101), yvec, color=color)
    # Soma
    ax.fill_between(range(101), numpy.ones(101), yvec, \
        where=numpy.ma.masked_where(yvec < 1., yvec).mask, \
        color=color, linewidth=0.)
    # Glom
    ax.plot([xloc], [9], color=color, marker='o', markersize=10, markerfacecolor='white', markeredgecolor=color)
    ax.plot([xloc], [9], color=color, marker='o', markersize=9, alpha=0.25)
    ax.plot([xloc], [9], color=color, marker='1', markersize=7, markeredgewidth=2)
    # Primary dendrite
    ax.plot([xloc, xloc], [0,8], color=color, linewidth=2)
    format_schematic_axis(ax)
    
def draw_weights(cellids, ax, color='black',scale=1.):
    """Draw granule cells"""
    import synweightsnapshot
    sws = synweightsnapshot.SynWeightSnapshot( \
            nummit=sim_var['num_mitral'], \
            numgran=sim_var['num_granule'])
    
    raw=sws.read_file(sim_var['wt_input_file'], 
            os.path.join(homedir, sim_var['weight_dir']))
    sws.parse_data(raw)
    for cellid in cellids:
        wts = sws.m2g[cellid,:,0]
        wts = wts/numpy.max(wts)
    
        for i in range(len(wts)):
            if wts[i] > 0.0001:
                cellloc = 10+cellid*20
                y = numpy.abs(i - cellloc)
                yloc = numpy.log(numpy.add(y,1))
                gloc = -3.5+((i%2)*1.5)
                ax.plot([i],[yloc], marker='o', markerfacecolor=color, markersize=4.*scale, markeredgecolor=color)
                ax.plot([i,i],[yloc, gloc], color=color)
                ax.plot([i],[gloc], marker='^', markerfacecolor=color, markersize=6.*scale, markeredgecolor=color)
    format_schematic_axis(ax)
    
def format_schematic_axis(ax):
    ax.set_xlim((0,100))
    xticks = [10,30,50,70,90]
    ax.set_xticks(xticks)
    ax.set_xticklabels(numpy.multiply(xticks,10))
    ax.set_xlabel('distance in microns')
    ax.set_ylim((-5,11))
    ax.spines['left'].set_color('none')
    ax.spines['right'].set_color('none')
    ax.set_yticks([])
    ax.spines['top'].set_color('none')
    ax.spines['bottom'].set_color('black')
    ax.xaxis.set_ticks_position('bottom')


def read_weightevents():
    M = numpy.loadtxt(os.path.join(analysis_path, 'stimweightevents.txt'))
    data = []
    for i in range(5):
        data.append([])
    for m in M:
        data[int(m[0])].append(m[1])
    return data

def read_delayevents():
    M = numpy.loadtxt(os.path.join(analysis_path, 'stimdelayevents.txt'))
    data = []
    for i in range(5):
        data.append([])
    for m in M:
        data[int(m[0])].append(m[1])
    return data


def raster(pair=[0,4], cluster_width=5, fi=.005, xlim=(1000,2000)):
#    pos1 = (10+pair[0]*20, cluster_width, 1, pair)
#    pos2 = (10+pair[1]*20, cluster_width, 1, pair)
#    stim_odor_mags = numpy.ones(5)*.55

    fig = pyplot.figure(figsize=(9.5,5.7))
    raster_ax = fig.add_axes([.1,.1,.8,.27])
    schematic_ax = fig.add_axes([.1,.85,.8,.1])
    syn_ax = fig.add_axes([.1,.45,.8,.225])

    draw_cell(pair[0], schematic_ax, color='red')
    draw_cell(pair[1], schematic_ax, color='blue')
    draw_weights(pair, schematic_ax, color='black')

    # Analyze an output file in some_dir
    bulb_spikes = BulbSpikes(sim_time=sim_var['tstop'])
    bulb_spikes.read_file(os.path.join(homedir,'spikeout.spk'))
    breath_events = numpy.loadtxt(os.path.join(homedir, 'breathevents.txt'))

    wts = read_weightevents()
    delays = read_delayevents()
    
    dt = 1
    tstop = xlim[1]
    x = numpy.arange(0,tstop,dt)
    y0 = numpy.zeros(tstop/dt)
    y1 = numpy.zeros(tstop/dt)
    EXP = numpy.exp(numpy.multiply(x,-1./200.))-numpy.exp( \
            numpy.multiply(x,-1./20.))
    
    idx = 0
    for b in breath_events:
        if b >= tstop:
            break
        else:
            dtidx = int((b+delays[pair[0]][idx])/dt)
            y0[dtidx:] += EXP[:-dtidx]*wts[pair[0]][idx]
            dtidx = int((b+delays[pair[1]][idx])/dt)
            y1[dtidx:] += EXP[:-dtidx]*wts[pair[1]][idx]
        idx += 1
    redplt = syn_ax.plot(x,y0, color='red')
    blueplt = syn_ax.plot(x,y1, color='blue')
    for breath in breath_events:
        breathplt = syn_ax.plot([breath, breath], [0,2], linestyle='--', \
                color='gray', linewidth=2)
    syn_ax.set_xlim(xlim)
    syn_ax.set_ylim(0,1.6)
    syn_ax.set_yticks([])
    syn_ax.set_xticks([])
    syn_ax.set_ylabel('EPSC onto tuft')
    leg = syn_ax.legend([breathplt, redplt, blueplt], \
            ['sniff event', 'input onto red', 'input onto blue'], \
            bbox_to_anchor=(0, 1.15, 1., .102), loc=1, ncol=3, mode="expand", \
            borderaxespad=0., handletextpad=.2)
    # Mark sniff interval
    for i in range(len(breath_events)):
        if breath_events[i] > xlim[0]:
            span = syn_ax.annotate('', xy=(breath_events[i], .28),  xycoords='data',
                               xytext=(breath_events[i+1], .28), \
                               textcoords='data', \
                               arrowprops=dict(arrowstyle="|-|", linewidth=2)
                               )
    
            syn_ax.text((breath_events[i]+breath_events[i+1])/2., .53, \
                    'sniff every\n150 - 250 ms', \
                    horizontalalignment='center', verticalalignment='top', \
                    backgroundcolor='white')
            break

    # Mark amplitude interval
    span = syn_ax.annotate('', xy=(1190, 1.28),  xycoords='data',
                       xytext=(1190, 1.12), \
                       textcoords='data', \
                       arrowprops=dict(arrowstyle="|-|", linewidth=2)
                       )

    syn_ax.text(1215, 1.21, \
            '+/- 5%', \
            horizontalalignment='left', verticalalignment='center')
    
    # Mark delay interval
    for i in range(len(breath_events)):
        if breath_events[i] > 1400:
            span = syn_ax.annotate('', xy=(breath_events[i]-2, .5),  xycoords='data',
                               xytext=(breath_events[i]+17, .5), \
                               textcoords='data', \
                               arrowprops=dict(arrowstyle="|-|", linewidth=2)
                               )
    
            syn_ax.text(breath_events[i]+7.5, .28, \
                    'delay 0-15 ms', \
                    horizontalalignment='center', verticalalignment='top', \
                    backgroundcolor='white')
            break
    

    spikes = bulb_spikes.get_mitral_spikes()
    ref=spikes[pair[0]]
    comp=spikes[pair[1]]
    gcspikes = bulb_spikes.get_granule_spikes()
    mididx = 10+pair[0]*20
    gcleft = gcspikes[mididx-int(cluster_width/2.):mididx+int(cluster_width/2.)+1]
    mididx = 10+pair[1]*20
    gcright = gcspikes[mididx-int(cluster_width/2.):mididx+int(cluster_width/2.)+1]

    sp = spikeplot.SpikePlot(fig=fig, savefig=False)
    sp.set_markercolor('blue')
    sp.set_markeredgewidth(2.)
    sp.set_markerscale(4)
    sp.plot_spikes([comp], label='comp', cell_offset=cluster_width*2+5, \
            draw=False )
    sp.set_markercolor('red')
    sp.plot_spikes([ref], label='ref', cell_offset=cluster_width*2+2, \
            draw=False)
    sp.set_markerscale(1.3)

    sp.set_markeredgewidth(1.5)
    sp.set_markercolor('blue')
    sp.plot_spikes(gcright, label='gcright', cell_offset=cluster_width, \
            draw=False)
    sp.set_markercolor('red')
    sp.plot_spikes(gcleft, label='gcleft', cell_offset=0, \
            draw=False)

    coincidences, mask_a, mask_b, ratio = \
            spiketrain.get_sync_traits(ref, comp, window=5)
#        idx = 0
#        for i in mask_a:
#            if i == 1:
#                raster_ax.plot([ref[idx]],[cluster_width*2+1.9], marker='o', color='red')
#            idx += 1
    idx = 0
    for i in mask_b:
        if i == 1:
            if comp[idx] >= xlim[0] and comp[idx] < xlim[1]:
                raster_ax.text(comp[idx],cluster_width*2+8.5, '*', \
                    color='purple', fontweight='bold', \
                    horizontalalignment='center', verticalalignment='center')
            #raster_ax.plot([comp[idx]],[cluster_width*2+7], marker='o', color='blue')
        idx += 1

    raster_ax.text(2000,cluster_width*2+8.5, '(synchronized)', color='purple', \
            horizontalalignment='center', verticalalignment='center',
            fontsize=11)

    raster_ax.set_yticks([])
    ylim = (0.5, cluster_width*2+7.5)
    for breath in breath_events:
        raster_ax.plot([breath, breath], [ylim[0], ylim[1]], linestyle='--', color='gray', linewidth=2)

    sp.update_xlim(xlim)
    raster_ax.set_ylim(ylim)
    raster_ax.set_xlabel('time (ms)')
    raster_ax.set_ylabel('spike output\n  granule      mitral\n\n', horizontalalignment='center')

    pos = schematic_ax.get_position()
    schematic_ax.text(.025, pos.ymax+.02, 'A)', transform=fig.transFigure, 
          verticalalignment='baseline')
    pos = syn_ax.get_position()
    syn_ax.text(.025, pos.ymax+.07, 'B)', transform=fig.transFigure, 
          verticalalignment='baseline')            
    pos = raster_ax.get_position()
    raster_ax.text(.025, pos.ymax+.02, 'C)', transform=fig.transFigure, 
          verticalalignment='baseline')            

#    fig.savefig(os.path.join(analysis_path, 'raster_w%d_(%d-%d)_%.3f.pdf') %(cluster_width, pair[0], pair[1], fi))
    fig.savefig(os.path.join(analysis_path, 'fig1.pdf'))
raster()