####################################################################
# SCRIPT TO RUN ANALYSIS 
####################################################################
include("utilities.jl");
default(show=false)


# HYPERPARAMS
n_runs = 14
patterns = 0:24
spontaneousactivity = true


labels = ["HC_CTRL", "LR_CTRL", "NR_CTRL", "HC_LITM", "LR_LITM", "NR_LITM"]
psc = Dict("HC_CTRL"=>[], "LR_CTRL"=>[], "NR_CTRL"=>[], "HC_LITM"=>[], "LR_LITM"=>[], "NR_LITM"=>[])


if spontaneousactivity
    fig_ext = "-wSA.png"
    csv_ext = "-wSA.csv"
else
    fig_ext = "-nSA.png"
    csv_ext = "-nSA.csv"
end

# CREATE NECESSARY DIRECTORIES 
create_directories_if_not_exist()

# IDENTIFY NEURON POPULATION RANGES
populations = Dict(
    "DG" => [0, 500],
    "BC" => [500,506],
    "MC" => [506, 521],
    "HIPP" => [521, 527]
)

println("finished set up. beginning raster plot.")

for run ∈ 1:n_runs
    for i ∈ 1:length(labels)
        spikes = load_spike_files(patterns, labels[i]*"-$run", populations)
        println(i)
        # CREATE RASTER PLOTS
        for p ∈ unique(spikes.Pattern)
            stimin = spikes[(spikes.Population .== "PP") .& (spikes.Pattern .== p), :]
            plots = []
            append!(plots, [raster_plot(stimin; ylab="PP")])
            println(p)
            for pop ∈ keys(populations)
                lb, ub = populations[pop]
                popspikes = spikes[(spikes.Population .== pop) .& (spikes.Pattern .== p),:]
                #if size(popspikes,1) > 0
                append!(plots, [raster_plot(popspikes; xlab="", ylab=pop)])
                #end
                println(pop)
            end
            fig = plot(reverse(plots)..., layout=grid(5, 1, heights=[0.15, 0.15, 0.15, 0.4, 0.15]), size=(400, 500))
            savefig(fig, "figures/raster-plots/raster-"*string(p)*"-"*labels[i]*"-$run"*"-blSA"*".png")
        end
    end 
end 

println("Finished printing raster plots.")

println("Starting PS analysis.")
# PATTERN SEPARATION CURVES
colors=[:blue, :red, :green, :black, :gray, :purple]
global psfig = plot([0;1], [0;1], ls=:dash, c=:black, 
                        xlabel="Input Correlation "*L"(r_{in})", 
                        ylabel="Output Correlation "*L"(r_{out})", 
                        size=(400, 400),
                        label=nothing, legend=:outerbottom)

#psc = Dict("HC_CTRL"=>[], "LR_CTRL"=>[], "NR_CTRL"=>[]) #, "HC_LITM"=>[], "LR_CTRL"=>[], "LR_LITM"=>[], "NR_CTRL"=>[], "NR_LITM"=>[])

for i ∈ 1:length(labels)
    println(labels[i])
    for run ∈ 1:n_runs
        spikes = load_spike_files(patterns, labels[i]*"-$run", populations)
        
        out = pattern_separation_curve(spikes, 100, 500)
        x, y = out[:,"Input Correlation"], out[:, "Output Correlation"]
        
        # Remove NaNs before fitting
        idx_ = (.!isnan.(x) .& .!isnan.(y))
        x = x[idx_]
        y = y[idx_]

        f = fit_power_law(x, y)
        append!(psc[labels[i]], f(0.6))

        if (run == n_runs) 
            psm = round(mean(psc[labels[i]]), digits=2)
            psse = std(psc[labels[i]])/sqrt(n_runs)
            pslci = round(psm - 1.96*psse, digits=2)
            psuci = round(psm + 1.96*psse, digits=2)
            psc_label = labels[i]*" (PS="*string(psm)*" ["*string(pslci)*", "*string(psuci)*"])"
        else
            psc_label = nothing
        end
        global psfig = scatter!(x, y, c=colors[i], alpha=1/(2*n_runs), label=nothing)
        global psfig = plot!(0:0.01:1, x -> f(x), c=colors[i], label=psc_label)
        println(run)
    end 
end 
psfig
savefig(psfig, "figures/pattern-separation/pattern-separation-curve"*fig_ext)

println("Finished plotting PS curves.")

println("Beginning AUC calculation. This will take .5 hours.")

auc_save = OrderedDict("HC_CTRL"=>[], "LR_CTRL"=>[], "NR_CTRL"=>[], "HC_LITM"=>[], "LR_LITM"=>[], "NR_LITM"=>[])
auc_means = OrderedDict("HC_CTRL"=>[], "LR_CTRL"=>[], "NR_CTRL"=>[], "HC_LITM"=>[], "LR_LITM"=>[], "NR_LITM"=>[])
auc_ses = OrderedDict("HC_CTRL"=>[], "LR_CTRL"=>[], "NR_CTRL"=>[], "HC_LITM"=>[], "LR_LITM"=>[], "NR_LITM"=>[])

for i ∈ 1:length(labels)
    println(i)
    for run ∈ 1:n_runs
        println(run)
        spikes = load_spike_files(patterns, labels[i]*"-$run", populations)
        
        out = pattern_separation_curve(spikes, 100, 500)
        x, y = out[:,"Input Correlation"], out[:, "Output Correlation"]
        
        # Remove NaNs before fitting
        idx_ = (.!isnan.(x) .& .!isnan.(y))
        x = x[idx_]
        y = y[idx_]

        auc = compute_auc(x, y)
        append!(auc_save[labels[i]], auc)
        if (run == n_runs) 
            aucm = round(mean(auc_save[labels[i]]), digits=2)
            append!(auc_means[labels[i]], aucm)
            aucse = std(auc_save[labels[i]])/sqrt(n_runs)
            append!(auc_ses[labels[i]], aucse)
        end
    end 
end 

df_aucsave = DataFrame(auc_save)
df_aucmeans = DataFrame(auc_means)
df_aucses = DataFrame(auc_ses)
CSV.write("figures/pattern-separation/auc_raw"*csv_ext, df_aucsave)
CSV.write("figures/pattern-separation/auc_means"*csv_ext, df_aucmeans)
CSV.write("figures/pattern-separation/auc_ses"*csv_ext, df_aucses)

println("CSVs saved.")

unpack(a) = eltype(a[1])[el[1] for el in a]

data = unpack(collect(values(auc_means)))
data_err = unpack(collect(values(auc_ses)))

pltlabels = ["HC", "LR", "NR"]
auc_fig = groupedbar(pltlabels, 
                [data[1:3] data[4:6]], 
                xlabel = "Group",
                #xtickfont=font(12),
                ylabel = L"AUC_{PS}",
                ylimits = (-0.15, 0.4),
                c = [:gray :white], 
                markerstrokewidth = 1,
                yerror = [data_err[1:3] data_err[4:6]], 
                dpi=300, size=(350,350),
                label=["Baseline" "Lithium"],
                grid = :none
                )
savefig(auc_fig, "figures/pattern-separation/auc-curve"*fig_ext)

println("Finished plotting AUC bars.")