import pandas as pd
import os.path as op
import os
import seaborn as sns
import matplotlib.pyplot as plt
from datetime import datetime
from definitions import RESULTS_FOLDER, FIGURE_FOLDER, ROOT_FOLDER
col_pal = sns.color_palette()
packard_folder = op.join(ROOT_FOLDER, 'data')
results_folder = op.join(RESULTS_FOLDER, 'plusmaze')
figure_location = op.join(FIGURE_FOLDER, 'packard')
if not op.exists(figure_location):
os.makedirs(figure_location)
def load_original_data():
data = pd.read_csv(op.join(packard_folder, 'plusmaze_ratdata.csv'))
agg = data.pivot_table(index=['Injection site', 'Test day', 'Treatment', 'Behaviour'], aggfunc=len, margins=True)
agg['total'] = agg.groupby(['Injection site', 'Test day', 'Treatment']).transform('sum')
agg['Percentage'] = agg['Run'] / agg['total'] * 100
agg = agg.drop('All')
agg = agg.sort_index(ascending=False, level=2)
agg = agg.reset_index()
return agg[agg['Behaviour'] == 'Place']
def load_model_data():
most_recent_runs = {}
for group in ['control', 'inactivate_HPC', 'inactivate_DLS']:
rf = op.join(results_folder, group)
most_recent_runs[group] = get_most_recent_model_run(rf)
control_results = pd.read_csv(op.join(results_folder, 'control', most_recent_runs['control'], 'summary.csv'))
hpc_inact_results = pd.read_csv(op.join(results_folder, 'inactivate_HPC', most_recent_runs['inactivate_HPC'], 'summary.csv'))
dls_inact_results = pd.read_csv(op.join(results_folder, 'inactivate_DLS', most_recent_runs['inactivate_DLS'], 'summary.csv')) # '2019-11-19 12:04:57.620447'
df = pd.concat([control_results, hpc_inact_results, dls_inact_results])
df = df.pivot_table(index=['trial', 'group', 'score'], aggfunc=len, margins=True)
df['total'] = df['agent'].groupby(['trial', 'group']).sum()
df['Percentage'] = df['agent'] / df['total'] * 100
df = df.drop('All')
df = df.reset_index()
return df[df['score'] == 'place']
def get_most_recent_model_run(directory):
times = [datetime.strptime(i, '%Y-%m-%d %H:%M:%S.%f') for i in os.listdir(directory) if i[:4]=='2020']
times.sort()
return str(times[-1])
def performance_barchart():
ax = sns.catplot(x="Injection site", y="Percentage", hue="Treatment", data=df,
kind="bar", col="Test day")
ax.set_ylabels('% place strategy')
plt.savefig(op.join(figure_location, 'pm_barchart.pdf'), format='pdf')
return ax
def performance_pointplot(ax):
sns.pointplot(ax=ax, data=df[df['Injection site'] == 'Hippocampus'], x='Test day', y='Percentage', hue='Treatment',
palette=[col_pal[0], col_pal[3]], markers=["o", "o"])
cp2 = sns.color_palette('pastel')
sns.pointplot(ax=ax, data=df[df['Injection site'] == 'Striatum'], x='Test day', y='Percentage', hue='Treatment',
palette=[cp2[0], col_pal[2]], markers=["o", "o"])
leg_handles = ax.get_legend_handles_labels()[0]
ax.legend(leg_handles, ['Control - saline HPC',
'Inactivate HPC - lidocaine',
'Control - saline DLS',
'Inactivate DLS - lidocaine'], title='Treatment')
plt.ylabel('% place strategy')
plt.ylim([0, 100])
plt.savefig(op.join(figure_location, 'pm_originaldata_pointplot.pdf'), format='pdf')
def performance_pointplot_model(ax, df):
sns.pointplot(ax=ax, data=df, x='trial', y='Percentage', hue='group',
palette=[col_pal[0], col_pal[2], col_pal[3]])
leg_handles = ax.get_legend_handles_labels()[0]
ax.legend(leg_handles, ['Control - full model',
'Inactivate DLS - only SR',
'Inactivate HPC - only MF'], title='Model')
plt.ylabel('% place strategy')
plt.ylim([0, 100])
plt.xlabel('Trial')
ax.set_xticklabels(['Early', 'Late'])
plt.savefig(op.join(figure_location, 'pm_model_pointplot.pdf'), format='pdf')
if __name__ == '__main__':
# plot model results
df = load_model_data()
fig = plt.figure(figsize=(5, 4.5))
ax = fig.add_subplot()
performance_pointplot_model(ax, df)
plt.show()
# Plot original data
df = load_original_data()
fig = plt.figure()
ax = performance_barchart()
plt.show()
fig2 = plt.figure(figsize=(5, 4.5))
ax2 = fig2.add_subplot()
performance_pointplot(ax2)
plt.show()