{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Jupyter notebook for testing purposes. Will create formal julia analysis script later. \n",
    "\n",
    "include(\"utilities.jl\");\n",
    "default(show=false)\n",
    "default(fontfamily=\"Computer Modern\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dict{String, Vector{Int64}} with 1 entry:\n",
       "  \"DG\" => [0, 1]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# HYPERPARAMS\n",
    "n_runs = 10\n",
    "patterns = 0:1\n",
    "duration = 750\n",
    "labels = [\"theta\" , \"ftheta\", \"alpha\", \"beta\", \"gamma\"]\n",
    "freqlabels = [L\"\\theta\" L\"\\theta_{fast}\" L\"\\alpha\" L\"\\beta\" L\"\\gamma\"]\n",
    "\n",
    "fig_ext = \".png\"\n",
    "\n",
    "# CREATE NECESSARY DIRECTORIES \n",
    "create_directories_if_not_exist()\n",
    "\n",
    "# IDENTIFY NEURON POPULATION RANGES\n",
    "populations = Dict(\n",
    "    \"DG\" => [0,1]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "load_spike_files (generic function with 1 method)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "function load_spike_files(\n",
    "    patterns::Union{UnitRange{Int64}, Vector{Int64}}, \n",
    "    model_label::String,\n",
    "    neuron_ids::Dict;\n",
    "    neurons_per_pattern::Int64=1, #\n",
    "    data_dir::String=\"data/gcpp/\")\n",
    "\n",
    "    df = DataFrame()\n",
    "    for p ∈ patterns\n",
    "        spike_fname = data_dir*\"DGsp-\"*string(p)*\"-\"*string(neurons_per_pattern)*\"-\"*model_label*\".txt\"\n",
    "        stims_fname = data_dir*\"StimIn-\"*string(p)*\"-\"*string(neurons_per_pattern)*\"-\"*model_label*\".txt\"\n",
    "        spikes = CSV.read(spike_fname, delim=\"\\t\", header=0, DataFrame)\n",
    "        stimin = CSV.read(stims_fname, delim=\"\\t\", header=0, DataFrame)\n",
    "\n",
    "        if size(spikes, 1) > 0\n",
    "            rename!(spikes, [\"Time\", \"Neuron\"])\n",
    "            spikes[:,\"Population\"] .= \"\" \n",
    "            spikes[:,\"Pattern\"] .= p\n",
    "            spikes[:,\"NeuronsPerPattern\"] .= neurons_per_pattern \n",
    "            spikes[:,\"Model\"] .= model_label\n",
    "            for k ∈ keys(neuron_ids) \n",
    "                lb, ub = neuron_ids[k]\n",
    "                spikes[lb .<= spikes[:, \"Neuron\"] .< ub, \"Population\"] .= k\n",
    "            end\n",
    "\n",
    "            df = [df; spikes]\n",
    "        end \n",
    "\n",
    "        if size(stimin, 1) > 0\n",
    "            rename!(stimin, [\"Time\", \"Neuron\"])\n",
    "            stimin[:,\"Population\"] .= \"PP\"\n",
    "            stimin[:,\"Pattern\"] .= p\n",
    "            stimin[:,\"NeuronsPerPattern\"] .= neurons_per_pattern \n",
    "            stimin[:,\"Model\"] .= model_label\n",
    "            df = [df; stimin]\n",
    "        end\n",
    "    end\n",
    "\n",
    "    return df\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ISI figure\n",
    "\n",
    "default(fontfamily=\"Computer Modern\")\n",
    "n_runs = 10\n",
    "save_cov = OrderedDict(\"theta\"=>[], \"ftheta\"=>[], \"alpha\"=>[], \"beta\"=>[], \"gamma\"=>[])\n",
    "\n",
    "for run_ ∈ 1:n_runs\n",
    "    for l ∈ 1:length(labels)\n",
    "        spikes = load_spike_files(patterns, labels[l]*\"-$run_\", populations)\n",
    "        gc = spikes[(spikes.Population .== \"DG\") .& (spikes.Pattern .== 1), :]\n",
    "        spiketimes = gc[!,\"Time\"]\n",
    "        ISIs = []\n",
    "\n",
    "        for i ∈ 1:(length(spiketimes)-1)\n",
    "            time1 = spiketimes[i]\n",
    "            time2 = spiketimes[i+1] \n",
    "            ISI = time2 - time1 \n",
    "            append!(ISIs, ISI)\n",
    "        end\n",
    "\n",
    "        cov = mean(ISIs)\n",
    "        append!(save_cov[labels[l]], cov)\n",
    "\n",
    "    end\n",
    "end\n",
    "\n",
    "meanisi = OrderedDict(\"theta\"=>[], \"ftheta\"=>[], \"alpha\"=>[], \"beta\"=>[], \"gamma\"=>[])\n",
    "stdisi = OrderedDict(\"theta\"=>[], \"ftheta\"=>[], \"alpha\"=>[], \"beta\"=>[], \"gamma\"=>[])\n",
    "\n",
    "for l ∈ 1:length(labels)\n",
    "    meanv = []\n",
    "    stdv = []\n",
    "    meanv = mean(save_cov[labels[l]])\n",
    "    append!(meanisi[labels[l]], meanv)\n",
    "    stdv =  std(save_cov[labels[l]])\n",
    "    append!(stdisi[labels[l]], stdv)\n",
    "end\n",
    "\n",
    "\n",
    "unpack(a) = eltype(a[1])[el[1] for el in a]\n",
    "\n",
    "avgisi = unpack(collect(values(meanisi)))\n",
    "varisi = unpack(collect(values(stdisi)))\n",
    "freqlabels = [L\"\\theta\", L\"\\theta_{f}\", L\"\\alpha\", L\"\\beta\", L\"\\gamma\"]\n",
    "\n",
    "isi_fig = plot(freqlabels, avgisi, yerror = varisi, xlabel = \"Input Frequency Band\", yticks = [50, 75, 100],\n",
    "            xtickfont=font(16), xguidefontsize=11, ytickfontsize = 16, yguidefontsize = 14, \n",
    "            markerstrokewidth =2, grid=:none, ylabel = \"Mean ISI (ms)\", \n",
    "            c = :black, linewidth = 4, dpi=300, size=(250,220), label=nothing)\n",
    "\n",
    "savefig(isi_fig, \"figures/pattern-separation/ISI_10runs-CCN\"*fig_ext)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 222,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "┌ Warning: thread = 1 warning: parsed expected 2 columns, but didn't reach end of line around data row: 1. Parsing extra columns and widening final columnset\n",
      "└ @ CSV /Users/selenasingh/.julia/packages/CSV/jFiCn/src/file.jl:579\n"
     ]
    }
   ],
   "source": [
    "# VOLTAGE TRACE + INPUT SPIKE TRAIN FIGURE\n",
    "volt_fname = \"data/gcpp/DGVt-1-1-gamma-1.txt\"\n",
    "volttrace = CSV.read(volt_fname, delim=\"\\t\", header=1, DataFrame)\n",
    "volttrace = volttrace[:, 1:2]\n",
    "\n",
    "spikes = load_spike_files(patterns, labels[5]*\"-1\", populations)\n",
    "stimin = spikes[(spikes.Population .== \"PP\") .& (spikes.Pattern .== 1), :]\n",
    "plots = []\n",
    "\n",
    "rename!(volttrace, [\"Time\", \"Voltage\"])\n",
    "\n",
    "p = scatter(stimin[:,1], stimin[:,2], \n",
    "            marker = (:vline, 30, 0.8, :red, stroke(4, 0.8, :black, :vline)),\n",
    "                label=nothing, c=:red, msc=:black, #msa=1, ma=1, \n",
    "                xlims = [0,285], ylims = [0,5], axis=([], false), grid = false )\n",
    "\n",
    "vtrace = plot(volttrace[1:2850, \"Time\"], volttrace[1:2850, \"Voltage\"],\n",
    "        color = :black, linewidth = 2, legend = :none, axis=([], false), grid = false, \n",
    "        size = (600,300))\n",
    "\n",
    "append!(plots, [vtrace])\n",
    "append!(plots, [p])\n",
    "\n",
    "fig = plot(plots..., layout=grid(2, 1, heights=[0.95, 0.05]), size=(500, 400), dpi = 300)\n",
    "\n",
    "savefig(fig, \"figures/voltage-tracings/voltage-tracings-gamma.png\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 1.7.3",
   "language": "julia",
   "name": "julia-1.7"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.7.3"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "40d3a090f54c6569ab1632332b64b2c03c39dcf918b08424e98f38b5ae0af88f"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}