#!/usr/bin/python -W ignore

__author__ = "Zafeirios Fountas"
__credits__ = ["Murray Shanahan", "Mark Humphries",  "Rob Leech", "Jeanette Hellgren Kotaleski"]
__license__ = "GPLv3"
__version__ = "0.1"
__maintainer__ = "Zafeirios Fountas"
__email__ = "fountas@outlook.com"
__status__ = "Published"

from sys import argv, exit
from numpy import pi
from random import random
from bg_model.model import BasalGanglia
# -----------------------------------------------------------------------------

def print_usage() :
    print "Usage: python experiment1.py -fr <freq> -dop <dop> -ph <phase> "+\
          "-file <optional: data_file_name> -print <optional: True/False>"

if len(argv) < 2:
    print "---------------------------------------------- "
    print "    Warning: Running without arguments.        "
    print "    To see list of basic arguments, run:       "
    print "                        bgrun --help           "
    print "                                               "

FREQ1 = -1.0         # It takes the value -1 if it is the same as freq2
FREQ2 = 20.0
DOP = 0.3            # Could be random()*1.0
PHASE = random()*2.0
DATA_FILE = ""
PRINT = False
PLASTICITY = True
DURATION = 1500      # ms

REC_GPe_types = False
REC_RASTERS = False
REC_BINS = False

# It will run for 500ms in tonic mode so it will be ready for transient recordings..
INITIAL_PERIOD = False
# Random input that simulates behaviour 
RANDOM_WALK = False 
# Tonic mode
TONIC = False
# Oscillation only in one channel
ONE_CHANNEL = False

T1max = 7.5
T2max = 10.0
T1base = 3.0
T2base = 0.0
T3base = 3.0

PD_OFF_STATE = False
WEIGHT_Bahuguna = 1.5 # Ratio of SD2 to SD1 connections compared to SD1 to SD2.
PLOTS = False
PLOT_GPe = False
SEED = 1000
RAMP = -1 # The duration of a fade-in effect of the input. -1 for deactivation..
END = -1 # It's more like duration of the input stimulus
# If > 0, they represent difference between the default G_x_y and the current
GG_stn_gpe = -1
GG_gpe_stn = -1
GPe_density = [-1,-1,-1]
GPe_type = -1
P_STRIATUM_WEIGHT = 1.0 # how many times this prob is increased

# Parse arguments
i = 1
while i < len(argv) :
    if argv[i] == "--help" :
        print_usage()
        exit()
    if argv[i] == "-seed" :
        i+=1
        if argv[i] == "rand" or argv[i] == "random" :
            SEED = int(random()*10000)
        else :
            SEED = int(argv[i])
    elif argv[i] == "-duration" :
        i+=1
        DURATION = float(argv[i])
    elif argv[i] == "-fr" :
        i+=1
        FREQ2 = float(argv[i])
    elif argv[i] == "-fr1" :
        i+=1
        FREQ1 = float(argv[i])
    elif argv[i] == "-dop" :
        i+=1
        DOP = float(argv[i])
    elif argv[i] == "-end" :
        i+=1
        END = float(argv[i])
    elif argv[i] == "-ph" :
        i+=1
        PHASE = float(argv[i])
    elif argv[i] == "-plots" :
        PLOTS = True
    elif argv[i] == "-plot_gpe" :
        PLOT_GPe = True
    elif argv[i] == "-file" :
        i+=1
        DATA_FILE = argv[i]
    elif argv[i] == "-print" :
        PRINT = True
    # -- INPUT TYPES ------------------
    elif argv[i] == "-initial_period" :
        INITIAL_PERIOD = True
    elif argv[i] == "-random_walk" :
        RANDOM_WALK = True
    # -- RECORDING OPTIONS ------------
    elif argv[i] == "-rec_GPe_types" :
        REC_GPe_types = True
    elif argv[i] == "-rec_rasters" :
        REC_RASTERS = True
        #PLOTS = True
    elif argv[i] == "-rec_bins" :
        REC_BINS = True
    # -- INPUT MODES ------------------
    elif argv[i] == "-tonic" :
        TONIC = True
    elif argv[i] == "-one_channel" :
        ONE_CHANNEL = True
    # -----
    elif argv[i] == "-ramp" :
        i+=1
        RAMP = float(argv[i])
    elif argv[i] == "-GG_stn_gpe" :
        i+=1 
        GG_stn_gpe = float(argv[i])
    elif argv[i] == "-GG_gpe_stn" :
        i+=1 
        GG_gpe_stn = float(argv[i])
    # -- AMPLITUTE OF INPUTS ----------
    elif argv[i] == "-T1base" :
        i+=1
        T1base = float(argv[i])
    elif argv[i] == "-T2base" :
        i+=1
        T2base = float(argv[i])
    elif argv[i] == "-T3base" :
        i+=1
        T3base = float(argv[i])
    elif argv[i] == "-T1max" :
        i+=1
        T1max = float(argv[i])
    elif argv[i] == "-T2max" :
        i+=1
        T2max = float(argv[i])
    # -----
    elif argv[i] == "-pd_off_state" :
        i+=1
        PD_OFF_STATE = True
    elif argv[i] == "-weight_Bahuguna" :
        i+=1
        WEIGHT_Bahuguna = float(argv[i])
    elif argv[i] == "-P_striatum_weight" :
        i+=1
        P_STRIATUM_WEIGHT = float(argv[i])
    elif argv[i] == "-GPe_density" :
        i+=1
        GPe_density[0] = float(argv[i])
        i+=1
        GPe_density[1] = float(argv[i])
        GPe_density[2] = 1.0 - GPe_density[0] - GPe_density[1]
    elif argv[i] == "-gpe_type" : # Allows only one GPe neuron type
        i+=1
        if argv[i] == "A" :
            GPe_type = 0
        elif argv[i] == "B" :
            GPe_type = 1
        elif argv[i] == "C" :
            GPe_type = 2
    elif argv[i] == "-plasticity" :
        i+=1
        if argv[i] == "True" :
            PLASTICITY = True
        elif argv[i] == "False" :
            PLASTICITY = False
        else :
            print "Error: Unknown option for plasticity (True/False)", argv[i]
            print_usage()
            exit()
    else :
        print "Error: Unknown command", argv[i]
        print_usage()
        exit()
    i+=1

# -----------------------------------------------------------------------------

bg = BasalGanglia(PRINT=PRINT, PLOTS=PLOTS, SEED=SEED)

if REC_RASTERS   : bg.RECORD_RASTERS = True
if REC_BINS      : bg.RECORD_BINS = True 
if REC_GPe_types : bg.RECORD_GPe_types = True

bg.GPe_type = GPe_type

if PD_OFF_STATE == False :
    bg.pars.DOPAMINE = DOP
else :
    bg.pars.DOPAMINE = 0.0

if FREQ1 == -1.0 :
    bg.pars.iFreq_LOW_T1 = FREQ2 # The same as FreqT2
else :
    bg.pars.iFreq_LOW_T1 = FREQ1
bg.pars.iFreq_LOW_T2 = FREQ2
bg.pars.iPhase_LOW = PHASE * pi

bg.Plasticity = PLASTICITY

# How many times I increase/decrease the probability of recurrent connections
# in the striatum
if P_STRIATUM_WEIGHT != 1.0 :
    bg.pars.P_MSN_MSN_in = min(bg.pars.P_MSN_MSN_in*P_STRIATUM_WEIGHT, 1.0)
    bg.pars.P_MSN_MSN_ex = min(bg.pars.P_MSN_MSN_ex*P_STRIATUM_WEIGHT, 1.0)
    
    # NOTE: Also change of FSI connections
    bg.pars.P_FSI_MSN_in = min(bg.pars.P_FSI_MSN_in*P_STRIATUM_WEIGHT, 1.0)
    bg.pars.P_FSI_MSN_ex = min(bg.pars.P_FSI_MSN_ex*P_STRIATUM_WEIGHT, 1.0)
    
    bg.pars.P_FSI_FSI_in = min(bg.pars.P_FSI_FSI_in*P_STRIATUM_WEIGHT, 1.0)
    bg.pars.P_FSI_FSI_ex = min(bg.pars.P_FSI_FSI_ex*P_STRIATUM_WEIGHT, 1.0)



# NOTE: Change that to show every argument!
string = "Starting experiment1 for freq "+str(FREQ2)+", dop "+str(DOP)+\
         ", phase "+str(PHASE)
if PLASTICITY :
    string += ", plasticity"

if PLOTS :
    string += ", plots"

if SEED != 1000 :
    string += ", seed: " + str(SEED)

if DATA_FILE != "" :
    string += ", filename " + DATA_FILE
else :
    string += ", without saving"


# -- EXPERIMENT PARAMETERS -----------------------------------------------------
from brian import Hz, ms

if GG_stn_gpe >= 0 :
    g0ampa = bg.pars.syn["GPe"]["G"]["STN-GPe"]["AMPA"]["value"]
    g0nmda = bg.pars.syn["GPe"]["G"]["STN-GPe"]["NMDA"]["value"]
    bg.pars.syn["GPe"]["G"]["STN-GPe"]["AMPA"]["value"] = g0ampa / float(GG_stn_gpe)
    bg.pars.syn["GPe"]["G"]["STN-GPe"]["NMDA"]["value"] = g0nmda / float(GG_stn_gpe)
if GG_gpe_stn >= 0 :
    g0 = bg.pars.syn["STN"]["G"]["GABA"]["value"]
    bg.pars.syn["STN"]["G"]["GABA"]["value"] = g0 / float(GG_gpe_stn)


    bg.RAMP = RAMP
    string += ", ramp" + str(RAMP)

if END > 0 :
    bg.END = END * ms
    string += ", end" + str(END)

# -- STIMULATION ---------------------------------------------------------------
if TONIC :                                   # ZERO CHANNELS ACTIVE - Tonic mode
    bg.pars.T1_amp = 0.0 * Hz
    bg.pars.T2_amp = 0.0 * Hz
    bg.pars.base_input_T1 = 3.0*Hz
    bg.pars.base_input_T2 = 3.0*Hz     
    bg.pars.base_input_T3 = 3.0*Hz
    string += ", with tonic input"
elif ONE_CHANNEL :                           # ONE CHANNEL ACTIVE (rest tonic)
    bg.pars.T1_amp = 0.0 * Hz
    bg.pars.T2_amp = (T2max - T2base) * Hz
    bg.pars.base_input_T1 = T1base*Hz
    bg.pars.base_input_T2 = T2base*Hz
    bg.pars.base_input_T3 = T3base*Hz
    string += ", with one active channel at "+str(T2max)+"Hz and the "
    string += "neighbouring channels at "+str(T1base)+"Hz and "+str(T3base)+"Hz"
else :                                       # TWO CHANNELS ACTIVE
    bg.pars.T1_amp = (T1max - T1base) * Hz
    bg.pars.T2_amp = (T2max - T2base) * Hz
    bg.pars.base_input_T1 = T1base*Hz
    bg.pars.base_input_T2 = T2base*Hz     
    bg.pars.base_input_T3 = T3base*Hz
    string += ", with competing input at "+str(T1max)+"Hz and "+ str(T2max)+"Hz"
# ------------------------------------------------------------------------------

if GPe_density != [-1,-1,-1] :
    bg.pars.neur["GPe-typeA"]["density"]["value"] = GPe_density[0]
    bg.pars.neur["GPe-typeB"]["density"]["value"] = GPe_density[1]
    bg.pars.neur["GPe-typeC"]["density"]["value"] = GPe_density[2]
    string += ", with GPe_density" + str(GPe_density)

if WEIGHT_Bahuguna != 1.5 :
    print "weight_Bahuguna changed."
    bg.weight_Bahuguna = WEIGHT_Bahuguna





# ---------------- INITIALIZATION AND RUN ------------------------------------ #
str_end = " for "+str(DURATION/1000.0)+" seconds!"

if PD_OFF_STATE : 
    str_end = " in PD 'off' state" + str_end
    bg.PD_OFF_WEIGHT = 1.1

if INITIAL_PERIOD :
    print string + " with 'initial tonic period' " + str_end
    bg.initialize(EXPERIMENT = "initial_period")
elif RANDOM_WALK :
    print string + " randomly " + str_end
    bg.initialize(EXPERIMENT = "random_walk")
else :
    print string + str_end
    bg.initialize(EXPERIMENT = "none")

data1 = bg.run_experiment(DURATION=DURATION)


# ---------------------------------------------------------------------------- #
#                                                                              #
#                               RECORDING DATA                                 #
#                                                                              #
# ---------------------------------------------------------------------------- #

if GG_stn_gpe >= 0 or GG_gpe_stn >= 0 :
    data1["STATS"]["GG_stn_gpe"] = GG_stn_gpe
    data1["STATS"]["GG_gpe_stn"] = GG_gpe_stn

if GPe_density != [-1,-1,-1] :
    data1["STATS"]["GPe_density"] = GPe_density

#data1["STATS"]["T1_base"] = bg.pars.base_input_T1
#data1["STATS"]["T2_base"] = bg.pars.base_input_T2
#data1["STATS"]["T3_base"] = bg.pars.base_input_T3
#data1["STATS"]["T1_amp"] = bg.pars.T1_amp
#data1["STATS"]["T2_amp"] = bg.pars.T2_amp
#data1["STATS"]["base"] = [bg.pars.base_input_T1, bg.pars.base_input_T2, bg.pars.base_input_T3]
data1["STATS"]["plasticity"] = PLASTICITY
data1["STATS"]["seed"] = SEED

if END > 0 :
    data1["STATS"]["stim_dur"] = END

data1["STATS"]["P_STRIATUM_WEIGHT"] = P_STRIATUM_WEIGHT
if INITIAL_PERIOD : data1["STATS"]["initial_period"] = True
else :              data1["STATS"]["initial_period"] = False

if DATA_FILE != "" :
    bg.save_data(DATA_FILE)
else :
    print "Data not saved:", data1
    bg.print_results()
    
if PLOTS :
    #print data1
    bg.print_results()
    #bg.plot_SNr()
    bg.SAVEPLOTS = True
    bg.FOLDER = 'figures'
    bg.plot_results()

if PLOT_GPe : bg.plot_GPe()
if PLOT_GPe : bg.plot_GPe_results()
if PLOT_GPe : bg.print_GPe_results()