"""
 * Copyright (C) 2004 Evan Thomas
 * 
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or (at
 * your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
"""

"""
Test the C implementation of the Tiesinga, et al network
try to generate fig 3
"""

import sys
from random import choice
from time import clock
from p3 import *
from py2mat import Matwrap

from tiesinga_C import Pyramidal, epspSynapse
#from tiesinga_pyramid import Pyramidal, epspSynapse
from wb_C import Interneuron, ipspSynapse

# Number of excitatory/inhibitory neurons
Ne = 500
Ni = int(Ne/4)
gAMPA = 0.003
gGABA = 0.001
gKahp = 0.8
gCa   = 10
Iapp  = 0.27
Pee   = 0.1
gKM   = 0.2

# Build neuron
Enet = []
for i in range(Ne):
    cell = Pyramidal()
    cell.Kahp.gKahp = gKahp
    cell.Ca.gCa = gCa
    cell.KM.gKM = gKM
    cell.current.Iinject = Iapp + 0.75
    cell.Kahp.q = 0.1*rand_flat()
    Enet.append(cell)
    
Inet = []
for i in range(Ni):
    cell = Interneuron()
    cell.Em = -75 + 20*rand_flat()
    Inet.append(cell)

# Make synapses
for src in Enet:
    for tgt in Enet+Inet:
        if src==tgt: continue
        if Pee>=rand_flat():
            s = epspSynapse(src, tgt)
            s.dyn.gSyn = gAMPA
for src in Inet:
    for tgt in Enet+Inet:
        if src==tgt: continue
        s = ipspSynapse(src, tgt)
        s.dyn.gSyn = gGABA

network = Enet + Inet

# Randomly trace some neurons
for i in range(10):
    cell = choice(network)
    cell.compartments[0].emtrace = True

###############
# Run options #
###############
gd = GD()
gd.duration  = 500
gd.tolerance = 1e-4
gd.network = network

#####################
# Initialise matlab #
#####################
m = Matwrap()
m.closeOnDel = False
m.write('cd \'' + os.getcwd() + '\'')
print m.read()

#############################################
# HDIAP capture so that only principal #
# neurons are captured                      #
#############################################
apx = None
apfilename = 'ap'
def ap_print(cell):
    global apx, apfilename, apfilecompress
    if apx==None:
        if communicator.size == 1:
            fn = apfilename + '.dat'
        else:
            fn = '%s_%d_%d.dat' % \
                 (apfilename, communicator.size, communicator.rank)
        apx = open(fn, 'w')

    if isinstance(cell, Pyramidal):
        cmpt = cell.compartments[0]
        for tm in cmpt.APtimes:
            apx.write('%d %g\n' % (cell.id, tm))
    apx.flush()
gd.ap_handler = ap_print

##############################################
# Special trace capture so that neurons with #
# differing numbers of compartments can be   #
# handled                                    #
##############################################
tracex = None
trfilename = 'trace.dat'
def trace_print(cell):
    global tracex
    if tracex==None:
        tracex = open('trace.dat', 'w')
        message_print(info, 'Open of trace file trace.dat successful.\n')

    cmpt = cell.compartments[0]
    for i in range(len(cmpt.traceTimes)):
        time = cmpt.traceTimes[i]
        data = cmpt.traceData[i]
        tracex.write('%d %d %g %g\n' % (cell.id, 0, time, data))
                
    tracex.flush()

gd.trace_handler = trace_print

#######
# Run #
#######
start = clock()
parplex(gd)
set_message_option('info')
messJ<;>t(info, 'made it - elapsed time %ds' % int(clock()-start))

#print 'cell.Na.h =', cell.Na.h
#print 'cell.Kdr.n =', cell.Kdr.n
#print 'cell.Ca.s =', cell.Ca.s
#print 'cell.Kahp.q =', cell.Kahp.q
#print 'cell.Kahp.Ca =', cell.Kahp.Ca
#print 'cell.KC.c =', cell.KC.c
#print 'cell.KO.a =', cell.KO.a
#print 'cell.NaO.b =', cell.NaO.b
#print 'cell.KM.r =', cell.KM.r

m.write('figure(1);raster(\'ap.dat\')\n')
m.write('figure(2);traceplot(\'trace.dat\')\n')
print m.read()
