from helpers import original_data_path, chu2014_path

area_list = ['V1', 'V2', 'VP', 'V3', 'V3A', 'MT', 'V4t', 'V4', 'VOT', 'MSTd',
             'PIP', 'PO', 'DP', 'MIP', 'MDP', 'VIP', 'LIP', 'PITv', 'PITd',
             'MSTl', 'CITv', 'CITd', 'FEF', 'TF', 'AITv', 'FST', '7a', 'STPp',
             'STPa', '46', 'AITd', 'TH']
population_list = ['23E', '23I', '4E',  '4I', '5E', '5I', '6E', '6I']

LOAD_ORIGINAL_DATA = True

ORIGINAL_SIM_LABELS = {'all': ['533d73357fbe99f6178029e6054b571b485f40f6',
                               '0adda4a542c3d5d43aebf7c30d876b6c5fd1d63e',
                               '33fb5955558ba8bb15a3fdce49dfd914682ef3ea',
                               '3afaec94d650c637ef8419611c3f80b3cb3ff539',
                               '99c0024eacc275d13f719afd59357f7d12f02b77'
                               '783cedb0ff27240133e3daa63f5d0b8d3c2e6b79',
                               '380856f3b32f49c124345c08f5991090860bf9a3',
                               '5a7c6c2d6d48a8b687b8c6853fb4d98048681045',
                               'c1876856b1b2cf1346430cf14e8d6b0509914ca1',
                               'a30f6fba65bad6d9062e8cc51f5483baf84a46b7',
                               '1474e1884422b5b2096d3b7a20fd4bdf388af7e0',
                               'f18158895a5d682db5002489d12d27d7a974146f',
                               '08a3a1a88c19193b0af9d9d8f7a52344d1b17498',
                               '5bdd72887b191ec22a5abcc04ca4a488ea216e32'],
                       'Fig1': None,
                       'Fig2': ['533d73357fbe99f6178029e6054b571b485f40f6',
                                '0adda4a542c3d5d43aebf7c30d876b6c5fd1d63e',
                                '33fb5955558ba8bb15a3fdce49dfd914682ef3ea'],
                       'Fig3': '33fb5955558ba8bb15a3fdce49dfd914682ef3ea',
                       'Fig4': ['33fb5955558ba8bb15a3fdce49dfd914682ef3ea',
                                '1474e1884422b5b2096d3b7a20fd4bdf388af7e0',
                                '99c0024eacc275d13f719afd59357f7d12f02b77',
                                'f18158895a5d682db5002489d12d27d7a974146f',
                                '08a3a1a88c19193b0af9d9d8f7a52344d1b17498',
                                '5bdd72887b191ec22a5abcc04ca4a488ea216e32'],
                       'Fig5': ['3afaec94d650c637ef8419611c3f80b3cb3ff539',
                                '99c0024eacc275d13f719afd59357f7d12f02b77'],
                       'Fig6': ['33fb5955558ba8bb15a3fdce49dfd914682ef3ea',
                                '5bdd72887b191ec22a5abcc04ca4a488ea216e32',
                                '99c0024eacc275d13f719afd59357f7d12f02b77',
                                '3afaec94d650c637ef8419611c3f80b3cb3ff539'],
                       'Fig7': ['99c0024eacc275d13f719afd59357f7d12f02b77'],
                       'Fig8': ['33fb5955558ba8bb15a3fdce49dfd914682ef3ea',
                                '783cedb0ff27240133e3daa63f5d0b8d3c2e6b79',
                                '380856f3b32f49c124345c08f5991090860bf9a3',
                                '5a7c6c2d6d48a8b687b8c6853fb4d98048681045',
                                'c1876856b1b2cf1346430cf14e8d6b0509914ca1',
                                'a30f6fba65bad6d9062e8cc51f5483baf84a46b7',
                                '1474e1884422b5b2096d3b7a20fd4bdf388af7e0',
                                'f18158895a5d682db5002489d12d27d7a974146f',
                                '08a3a1a88c19193b0af9d9d8f7a52344d1b17498',
                                '5bdd72887b191ec22a5abcc04ca4a488ea216e32',
                                '99c0024eacc275d13f719afd59357f7d12f02b77'],
                       'Fig9': ['99c0024eacc275d13f719afd59357f7d12f02b77']}


if LOAD_ORIGINAL_DATA:
    DATA_DIR = original_data_path
    SIM_LABELS = ORIGINAL_SIM_LABELS
else:
    from network_simulations import create_label_dict
    from config import data_path
    DATA_DIR = data_path
    SIM_LABELS = create_label_dict()

if DATA_DIR is None:
    raise TypeError("The path to the data files is None. "
                    "Please define the data path.")

rule all:
    input:
        'Fig1_model_overview.eps',
        'Fig2_bistability.eps',
        'Fig3_ground_state_chi1.eps',
        'Fig4_metastability.eps',
        'Fig5_ground_state.eps',
        'Fig6_comparison_exp_spiking_data.eps',
        'Fig7_temporal_hierarchy.eps',
        'Fig8_interactions.eps',
        'Fig9_laminar_interactions.eps'

include: './Snakefile_preprocessing'

rule Fig2_bistability:
    input:
        expand(os.path.join(DATA_DIR, '{simulation}', 'Analysis/pop_rates.json'), simulation=SIM_LABELS['Fig2'])
    output:
        'Fig2_bistability.eps'
    shell:
        'python3 Fig2_bistability.py'
    
rule Fig3_ground_state_chi1:
    input:
        os.path.join(DATA_DIR, SIM_LABELS['Fig3'], 'Analysis/pop_rates.json'),
        os.path.join(DATA_DIR, SIM_LABELS['Fig3'], 'Analysis/pop_LvR.json'),
        os.path.join(DATA_DIR, SIM_LABELS['Fig3'], 'Analysis/corrcoeff.json'),
        expand(os.path.join(DATA_DIR, SIM_LABELS['Fig3'], 'recordings', '-'.join((SIM_LABELS['Fig3'], 'spikes-{area}-{pop}.npy'))),
               area=['V1', 'V2', 'FEF'], pop=population_list),
        expand(os.path.join(DATA_DIR, SIM_LABELS['Fig3'], 'Analysis', 'rate_time_series_full', 'rate_time_series_full_{area}.npy'),
               area=['V1', 'V2', 'FEF']),
        expand(os.path.join(DATA_DIR, SIM_LABELS['Fig3'], 'Analysis', 'rate_time_series_auto_kernel', 'rate_time_series_auto_kernel_{area}.npy'), area=['V1', 'V2', 'FEF']),
    output:
        'Fig3_ground_state_chi1.eps'
    shell:
        'python3 Fig3_ground_state_chi1.py'

rule Fig4_theory_calculations:
    output:
        'Fig4_theory_data/results_{cc_weights_factor}.npy'
    shell:
        'python3 Fig4_theory.py {wildcards.cc_weights_factor}'        

rule Fig4_metastability:
    input:
        expand(os.path.join(DATA_DIR, '{simulation}', 'Analysis', 'rate_time_series_full', 'rate_time_series_full_V1.npy'),
               simulation=SIM_LABELS['Fig4']),
        expand('Fig4_theory_data/results_{cc_weights_factor}.npy',
               cc_weights_factor=[1.0, 1.8, 1.9, 2., 2.1, 2.5])
    output:
        'Fig4_metastability.eps'
    shell:
        'python3 Fig4_metastability.py'
    
rule Fig5_ground_state:
    input:
        os.path.join(DATA_DIR, SIM_LABELS['Fig5'][1], 'Analysis', 'pop_rates.json'),
        os.path.join(DATA_DIR, SIM_LABELS['Fig5'][1], 'Analysis', 'pop_LvR.json'),
        os.path.join(DATA_DIR, SIM_LABELS['Fig5'][1], 'Analysis', 'corrcoeff.json'),
        expand(os.path.join(DATA_DIR, SIM_LABELS['Fig5'][0], 'recordings', '-'.join((SIM_LABELS['Fig5'][0], 'spikes-{area}-{pop}.npy'))),
               area=['V1', 'V2', 'FEF'], pop=population_list),
        expand(os.path.join(DATA_DIR, SIM_LABELS['Fig5'][1], 'Analysis', 'rate_time_series_full', 'rate_time_series_full_{area}.npy'),
               area=['V1', 'V2', 'FEF']),
        expand(os.path.join(DATA_DIR, SIM_LABELS['Fig5'][1], 'Analysis', 'rate_time_series_auto_kernel', 'rate_time_series_auto_kernel_{area}.npy'), area=['V1', 'V2', 'FEF']),
    output:
        'Fig5_ground_state.eps'
    shell:
        'python3 Fig5_ground_state.py'

rule Fig6_comparison_exp_spiking_data:
    input:
        expand(os.path.join(DATA_DIR, '{simulation}', 'Analysis', 'power_spectrum_subsample', 'power_spectrum_subsample_V1.npy'),
               simulation=SIM_LABELS['Fig6']),
        expand(os.path.join(DATA_DIR, '{simulation}', 'Analysis', 'rate_histogram', 'rate_histogram_V1.npy'),
               simulation=SIM_LABELS['Fig6']),
        expand(os.path.join(DATA_DIR,'{simulation}', 'Analysis', 'rate_histogram', 'rate_histogram_V1.npy'),
               simulation=SIM_LABELS['Fig6'][:2] + SIM_LABELS['Fig6'][-1:]),
        os.path.join(chu2014_path, 'Analysis', 'spike_data_1mm.npy'),
        os.path.join(chu2014_path, 'Analysis', 'neuron_depths.npy'),
        os.path.join(chu2014_path, 'Analysis', 'power_spectrum_freq.npy'),
        os.path.join(chu2014_path, 'Analysis', 'power_spectrum_V1_full.npy'),
        os.path.join(chu2014_path, 'Analysis', 'power_spectrum_V1_low_fluct.npy'),
        os.path.join(chu2014_path, 'Analysis', 'power_spectrum_V1_high_fluct.npy'),
        os.path.join(chu2014_path, 'Analysis', 'rate_time_series_time.npy'),
        os.path.join(chu2014_path, 'Analysis', 'rate_time_series_V1.npy'),
        os.path.join(chu2014_path, 'Analysis', 'spectrogram_freq.npy'),
        os.path.join(chu2014_path, 'Analysis', 'spectrogram_Sxx.npy'),
        os.path.join(chu2014_path, 'Analysis', 'spectrogram_time.npy'),
        os.path.join(chu2014_path, 'Analysis', 'rate_histogram_bins.npy'),
        os.path.join(chu2014_path, 'Analysis', 'rate_histogram_low.npy'),
        os.path.join(chu2014_path, 'Analysis', 'rate_histogram_high.npy'),
        os.path.join(chu2014_path, 'Analysis', 'rate_histogram_full.npy')
    output:
        'Fig6_comparison_exp_spiking_data.eps'
    shell:
        'python3 Fig6_comparison_exp_spiking_data.py'

rule Fig7_temporal_hierarchy:
    input:
        expand(os.path.join(DATA_DIR, '{simulation}', 'Analysis', 'rate_time_series_full', 'rate_time_series_full_{area}.npy'),
               simulation=SIM_LABELS['Fig7'], area=area_list),
        expand(os.path.join(DATA_DIR, '{simulation}', 'Analysis', 'cross_correlation', 'cross_correlation_{area1}_{area2}.npy'),
               simulation=SIM_LABELS['Fig7'], area1=area_list, area2=area_list),
    output:
        'Fig7_temporal_hierarchy.eps'
    shell:
        'python3 Fig7_temporal_hierarchy.py'

rule Fig8_interactions:
    input:
        expand(os.path.join(DATA_DIR, '{simulation}', 'Analysis', 'functional_connectivity_synaptic_input.npy'),
               simulation=SIM_LABELS['Fig8']),
        expand(os.path.join(DATA_DIR, '{simulation}', 'Analysis', 'FC_synaptic_input_communities.json'),
               simulation=SIM_LABELS['Fig8'][-1]),
        'Fig8_exp_func_conn.csv',
        'FC_exp_communities.json',
        expand(os.path.join(DATA_DIR, '{simulation}', 'Analysis', 'functional_connectivity_bold_signal.npy'),
               simulation=SIM_LABELS['Fig8'][-1])
    output:
        'Fig8_interactions.eps'
    shell:
        'python3 Fig8_interactions.py'
        
rule Fig9_laminar_interactions:
    input:
        expand(os.path.join(DATA_DIR, '{simulation}', 'Analysis', 'granger_causality', 'granger_causality_{area}_{pop}.json'),
               simulation=SIM_LABELS['Fig9'], area=area_list, pop=population_list),
        'Fig9_{}_significant_channels.json'.format(SIM_LABELS['Fig9'][0]),
        'Fig9_{}_HL_interactions.eps'.format(SIM_LABELS['Fig9'][0]),
        'Fig9_{}_HZ_interactions.eps'.format(SIM_LABELS['Fig9'][0]),
        'Fig9_{}_LH_interactions.eps'.format(SIM_LABELS['Fig9'][0]),
        'Fig9_{}_HL_paths.eps'.format(SIM_LABELS['Fig9'][0]),
        'Fig9_{}_HZ_paths.eps'.format(SIM_LABELS['Fig9'][0]),
        'Fig9_{}_LH_paths.eps'.format(SIM_LABELS['Fig9'][0]),
    output:
        'Fig9_laminar_interactions.eps'
    shell:
        'python3 Fig9_laminar_interactions.py'

rule Fig9_significant_channels:
    input:
        expand(os.path.join(DATA_DIR, '{{simulation}}', 'Analysis', 'granger_causality', 'granger_causality_{area}_{pop}.json'),
               area=area_list, pop=population_list)
    output:
        'Fig9_{simulation}_significant_channels.json',
        'Fig9_tex_files/{simulation}_lw_HL_interactions.tex',
        'Fig9_tex_files/{simulation}_lw_LH_interactions.tex',
        'Fig9_tex_files/{simulation}_lw_HZ_interactions.tex'
    shell:
        'python3 Fig9_significant_channels.py {} {{wildcards.simulation}}'.format(DATA_DIR)


rule Fig9_interactions:
    input:
        'Fig9_tex_files/{simulation}_lw_{type}_interactions.tex'
    output:
        'Fig9_{simulation}_{type}_interactions.eps'
    shell:
        'cd Fig9_tex_files &&'
        'mv {wildcards.simulation}_lw_{wildcards.type}_interactions.tex lw_{wildcards.type}_interactions.tex &&'
        'latex {wildcards.type}_interactions.tex &&'
        'dvipdf {wildcards.type}_interactions.dvi &&'
        'pdftops -eps {wildcards.type}_interactions.pdf &&'
        'mv {wildcards.type}_interactions.eps ../Fig9_{wildcards.simulation}_{wildcards.type}_interactions.eps'

rule Fig9_path_analysis:
    input:
        expand(os.path.join(DATA_DIR, '{{simulation}}', 'Analysis', 'pop_rates.json'),
               area=area_list, pop=population_list)
    output:
        'Fig9_tex_files/{simulation}_lw_HL_paths.tex',
        'Fig9_tex_files/{simulation}_lw_LH_paths.tex',
        'Fig9_tex_files/{simulation}_lw_HZ_paths.tex'
    shell:
        'python3 Fig9_path_analysis.py {} {{wildcards.simulation}}'.format(DATA_DIR)
    
rule Fig9_paths:
    input:
        'Fig9_tex_files/{simulation}_lw_{type}_paths.tex'
    output:
        'Fig9_{simulation}_{type}_paths.eps'
    shell:
        'cd Fig9_tex_files &&'
        'mv {wildcards.simulation}_lw_{wildcards.type}_paths.tex lw_{wildcards.type}_paths.tex &&'
        'latex {wildcards.type}_paths.tex &&'
        'dvipdf {wildcards.type}_paths.dvi &&'
        'pdftops -eps {wildcards.type}_paths.pdf &&'
        'mv {wildcards.type}_paths.eps ../Fig9_{wildcards.simulation}_{wildcards.type}_paths.eps'