"""
Snakemake rules for preprocessing of the simulation data
"""

rule rate_time_series:
    input:
        expand(os.path.join(DATA_DIR, '{{simulation}}', 'recordings',
                            '-'.join(('{{simulation}}', 'spikes-{{area}}-{pop}.npy'))),
               pop=population_list)
    output:
        expand(os.path.join(DATA_DIR, '{{simulation}}', 'Analysis',
                            'rate_time_series_{{method}}',
                            'rate_time_series_{{method}}_{{area}}_{pop}.npy'),
               pop=population_list),
        os.path.join(DATA_DIR, '{simulation}', 'Analysis', 'rate_time_series_{method}', 'rate_time_series_{method}_{area}.npy')
    shell:
        'python3 compute_rate_time_series.py {} {{wildcards.simulation}} {{wildcards.area}} {{wildcards.method}}'.format(DATA_DIR)

rule rate_histogram:
    input:
        expand(os.path.join(DATA_DIR, '{{simulation}}', 'recordings',
                            '-'.join(('{{simulation}}', 'spikes-{{area}}-{pop}.npy'))),
               pop=population_list)
    output:
        os.path.join(DATA_DIR, '{simulation}', 'Analysis', 'rate_histogram', 'rate_histogram_{area}.npy'),
    shell:
        'python3 compute_rate_histogram.py {} {{wildcards.simulation}} {{wildcards.area}}'.format(DATA_DIR)

rule power_spectrum:
    input:
        os.path.join(DATA_DIR, '{simulation}', 'Analysis', 'rate_time_series_{method}', 'rate_time_series_{method}_{area}.npy')
    output:
        os.path.join(DATA_DIR, '{simulation}', 'Analysis', 'power_spectrum_{method}', 'power_spectrum_{method}_{area}.npy'),
    shell:
        'python3 compute_power_spectrum.py {} {{wildcards.simulation}} {{wildcards.area}} {{wildcards.method}}'.format(DATA_DIR)

rule pop_rates:
    input:
        expand(os.path.join(DATA_DIR, '{{simulation}}', 'recordings', 'spikes_{area}_{pop}.npy'),
               area=area_list, pop=population_list)
    output:
        os.path.join(DATA_DIR, '{simulation}', 'Analysis', 'pop_rates.json')
    shell:
        'python3 compute_pop_rates.py {{wildcards.simulation}}'.format(DATA_DIR)

rule pop_LvR:
    input:
        expand(os.path.join(DATA_DIR, '{{simulation}}', 'recordings', 'spikes_{area}_{pop}.npy'),
               area=area_list, pop=population_list)
    output:
        os.path.join(DATA_DIR, '{simulation}', 'Analysis', 'pop_LvR.json')
    shell:
        'python3 compute_pop_LvR.py {{wildcards.simulation}}'.format(DATA_DIR)

rule corrcoeff:
    input:
        expand(os.path.join(DATA_DIR, '{{simulation}}', 'recordings', 'spikes_{area}_{pop}.npy'),
               area=area_list, pop=population_list)
    output:
        os.path.join(DATA_DIR, '{simulation}', 'Analysis', 'corrcoeff.json')
    shell:
        'python3 compute_corrcoeff.py {{wildcards.simulation}}'.format(DATA_DIR)

rule process_chu2014_data:
    input:
        expand(os.path.join(chu2014_path,'crcns-pvc5/rawSpikeTime/times_090425blk10_ch{channel}.mat'),
               channel=[66, 67, 68, 69, 70, 72, 76, 77, 78, 79, 80, 82, 83, 84, 85,
                        86, 87, 88, 90, 91, 92, 93, 94, 95, 98, 99, 100, 101, 102,
                        103, 106, 107, 108, 109, 110, 111, 112, 114, 115, 116, 117,
                        122, 123, 124, 125, 126])
    output:
        os.path.join(chu2014_path, 'Analysis', 'spike_data_1mm.npy'),
        os.path.join(chu2014_path, 'Analysis', 'neuron_depths.npy'),
        os.path.join(chu2014_path, 'Analysis', 'power_spectrum_freq.npy'),
        os.path.join(chu2014_path, 'Analysis', 'power_spectrum_V1_full.npy'),
        os.path.join(chu2014_path, 'Analysis', 'power_spectrum_V1_low_fluct.npy'),
        os.path.join(chu2014_path, 'Analysis', 'power_spectrum_V1_high_fluct.npy'),
        os.path.join(chu2014_path, 'Analysis', 'rate_time_series_time.npy'),
        os.path.join(chu2014_path, 'Analysis', 'rate_time_series_V1.npy'),
        os.path.join(chu2014_path, 'Analysis', 'spectrogram_freq.npy'),
        os.path.join(chu2014_path, 'Analysis', 'spectrogram_Sxx.npy'),
        os.path.join(chu2014_path, 'Analysis', 'spectrogram_time.npy'),
        os.path.join(chu2014_path, 'Analysis', 'rate_histogram_bins.npy'),
        os.path.join(chu2014_path, 'Analysis', 'rate_histogram_low.npy'),
        os.path.join(chu2014_path, 'Analysis', 'rate_histogram_high.npy'),
        os.path.join(chu2014_path, 'Analysis', 'rate_histogram_full.npy')
    shell:
        'python3 process_chu2014_data.py {} {}'.format(DATA_DIR, SIM_LABELS['Fig6'][-1])

rule cross_correlation:
    input:
        os.path.join(DATA_DIR, '{simulation}', 'Analysis',
                     'rate_time_series_full', 'rate_time_series_full_{area1}.npy'),
        os.path.join(DATA_DIR, '{simulation}', 'Analysis',
                     'rate_time_series_full', 'rate_time_series_full_{area2}.npy')
    output:
        os.path.join(DATA_DIR, '{simulation}', 'Analysis', 'cross_correlation',
                     'cross_correlation_{area1}_{area2}.npy')
    shell:
        'python3 compute_cross_correlation.py {} {{wildcards.simulation}} {{wildcards.area1}} {{wildcards.area2}}'.format(DATA_DIR)

rule synaptic_input:
    input:
        expand(os.path.join(DATA_DIR, '{{simulation}}', 'Analysis',
                            'rate_time_series_full',
                            'rate_time_series_full_{{area}}_{pop}.npy'),
               pop=population_list),
    output:
        os.path.join(DATA_DIR, '{simulation}', 'Analysis', 'synaptic_input', 'synaptic_input_{area}.npy')
    shell:
        'python3 compute_synaptic_input.py {} {{wildcards.simulation}} {{wildcards.area}}'.format(DATA_DIR)

rule functional_connectivity:
    input:
        expand(os.path.join(DATA_DIR, '{{simulation}}', 'Analysis', '{{method}}', '{{method}}_{area}.npy'),
               area=area_list)
    output:
        os.path.join(DATA_DIR, '{simulation}', 'Analysis', 'functional_connectivity_{method}.npy')
    shell:
        'python3 compute_functional_connectivity.py {} {{wildcards.simulation}} {{wildcards.method}}'.format(DATA_DIR)

rule communities_FC:
    input:
        os.path.join(DATA_DIR, '{simulation}', 'Analysis', 'functional_connectivity_{method}.npy')
    output:
        os.path.join(DATA_DIR, '{simulation}', 'Analysis', 'FC_{method}_communities.json')
    shell:
        'python3 compute_louvain_communities.py {} {{wildcards.simulation}} {{wildcards.method}}'.format(DATA_DIR)

rule communities_FC_exp:
    input:
        'Fig8_exp_func_conn.csv'
    output:
        'FC_exp_communities.json'
    shell:
        'python3 compute_louvain_communities.py None exp None'

rule bold_signal:
    input:
        os.path.join(DATA_DIR, '{simulation}', 'Analysis', 'synaptic_input', 'synaptic_input_{area}.npy')
    output:
        os.path.join(DATA_DIR, '{simulation}', 'Analysis', 'bold_signal', 'bold_signal_{area}.npy')
    shell:
        'python3 compute_bold_signal.py {} {{wildcards.simulation}} {{wildcards.area}}'.format(DATA_DIR)

rule granger_causality:
    input:
        os.path.join(DATA_DIR, '{simulation}', 'Analysis',
                     'rate_time_series_full', 'rate_time_series_full_{area}_{pop}.npy')
    output:
        os.path.join(DATA_DIR, '{simulation}', 'Analysis',
                     'granger_causality', 'granger_causality_{area}_{pop}.json')
    shell:
        'python3 compute_granger_causality.py {} {{wildcards.simulation}} {{wildcards.area}} {{wildcards.pop}}'.format(DATA_DIR)