using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.IO;
using System.Diagnostics;

namespace Core
{
	//Manages spike propagation through layers
	public class Network
	{
		public List<List<SpikeData>> OrderedSpikes { get; set; }
		public List<List<SpikeData[,]>> Spikes3D { get; set; }
		private Random rnd;
		
		public Network(Random random)
		{
			rnd = random;
		}

		//Train layer S2 with all of the input images
		//This is for the pure STDP (no RL)
		public void TrainKernel(List<List<SpikeData>> orderedSpikes, List<List<SpikeData[,]>> spikeData,
			Kernel kernel, int numberOfEpoch, int numberOfImageRepeat,
			int learningRateIncCount = 400, float learningRateCo = 2.0f)
		{
			OrderedSpikes = orderedSpikes;
			Spikes3D = spikeData;
			
			//training new kernel
			Console.WriteLine($"Training kernel {kernel.Name}...");
			Console.WriteLine($"Shuffling input...");
			int[] shuffled = Enumerable.Range(0, OrderedSpikes.Count).OrderBy(s => rnd.Next()).ToArray();

			int cnt = 1;
			for (int i = 0; i < numberOfEpoch; i++)
			{
				Console.WriteLine("Epoch: " + (i + 1));
				for (int m = 0; m < OrderedSpikes.Count; m++)
				{
					int j = shuffled[m];
					for (int k = 0; k < numberOfImageRepeat; k++, cnt++)
					{
						if (cnt % learningRateIncCount == 0)
						{
							if (kernel.Ap < 0.25)
							{
								Console.WriteLine(cnt);
								kernel.Ap *= learningRateCo;
								kernel.An *= learningRateCo;
							}
						}
						kernel.TrainKernel(OrderedSpikes[j], Spikes3D[j]);
						kernel.ApplySTDP(true);
					}
				}
			}
			Console.WriteLine("Done.");
		}

		//Train layer S2 with all of the input images
		//This is for reinforcement learning
		public void TrainKernelRL(List<List<SpikeData>> orderedSpikes, List<List<SpikeData[,]>> spikeData,
			Kernel kernel, int numberOfEpoch, int numberOfImageRepeat, List<int> labels, int[] neuronAssociations,
			bool[] isActive = null)
		{
			OrderedSpikes = orderedSpikes;
			Spikes3D = spikeData;

			//training new kernel
			Console.WriteLine($"Training kernel {kernel.Name}... (RL)");
			Console.WriteLine($"Shuffling input...");
			int[] shuffled = Enumerable.Range(0, OrderedSpikes.Count).OrderBy(s => rnd.Next()).ToArray();

			int cnt = 1;
			for (int i = 0; i < numberOfEpoch; i++)
			{
				Console.WriteLine("Epoch: " + (i + 1));
				for (int m = 0; m < OrderedSpikes.Count; m++)
				{
					int j = shuffled[m];
					for (int k = 0; k < numberOfImageRepeat; k++, cnt++)
					{
						List<SpikeData> winners = kernel.TrainKernelAndGetSTDPWinners(OrderedSpikes[j], Spikes3D[j], isActive);
						kernel.ApplySTDP(winners.Count != 0 && neuronAssociations[winners[0].Feature] == labels[j]);
					}
				}
			}
			Console.WriteLine("Done.");
		}
		
		//Extracts S2 first spike for each image
		public List<List<int>> GetKernelFirstSpikes(List<List<SpikeData>> orderedSpikes,
			List<List<SpikeData[,]>> spikeData, Kernel kernel)
		{
			List<List<SpikeData>> winners = new List<List<SpikeData>>();
			for (int i = 0; i < orderedSpikes.Count; i++)
			{
				winners.Add(kernel.TestKernelForFirstSpikes(orderedSpikes[i], spikeData[i]));
			}

			Console.Write("Computing first spikes...");
			List<List<int>> firstSpikes = new List<List<int>>(spikeData.Count);
			for (int img = 0; img < orderedSpikes.Count; img++)
			{
				List<int> spikeOrNot = Enumerable.Repeat(0, kernel.NumberOfFeature).ToList();
				if (winners[img].Count != 0)
					spikeOrNot[winners[img][0].Feature] = 1;
				firstSpikes.Add(spikeOrNot);
			}
			Console.WriteLine(" Done.");
			return firstSpikes;
		}

		//Extracts S2 potentials for each image
		private List<List<float[,]>> TestNetworkForPotentials(List<List<SpikeData>> orderedSpikes,
			List<List<SpikeData[,]>> spikeData, Kernel kernel)
		{
			OrderedSpikes = orderedSpikes;
			Spikes3D = spikeData;

			List<List<float[,]>> potentials = new List<List<float[,]>>();
			Console.Write($"Testing kernel {kernel.Name}...");

			for (int i = 0; i < OrderedSpikes.Count; i++)
			{
				potentials.Add(kernel.TestKernelForPotentials(OrderedSpikes[i], Spikes3D[i]));
			}

			Console.WriteLine(" Done.");
			return potentials;
		}

		//Extracts S2 maximum potentials for each image
		public List<List<float>> GetKernelPotentials(List<List<SpikeData>> orderedSpikes,
			List<List<SpikeData[,]>> spikeData, Kernel kernel)
		{
			List<List<float[,]>> potentials = TestNetworkForPotentials(orderedSpikes, spikeData, kernel);

			Console.Write("Pooling potentials...");
			List<List<float>> potentials_pooled = new List<List<float>>(potentials.Count);
			for (int img = 0; img < potentials.Count; img++)
			{
				List<float> p = Enumerable.Repeat(0f, kernel.NumberOfFeature).ToList();
				for (int ftr = 0; ftr < potentials[img].Count; ++ftr)
				{
					for (int row = 0; row < potentials[img][ftr].GetLength(0); ++row)
					{
						for (int col = 0; col < potentials[img][ftr].GetLength(1); ++col)
						{
							if (p[ftr] < potentials[img][ftr][row, col])
							{
								p[ftr] = potentials[img][ftr][row, col];
							}
						}
					}
				}
				potentials_pooled.Add(p);
			}
			Console.WriteLine(" Done.");
			return potentials_pooled;
		}

		public void SaveKernel(string taskName, Kernel kernel, List<float[,]> preRealFeatures)
		{
			//saving
			kernel.ComputeRealFeatures(preRealFeatures);
			Console.Write($"Saving kernel {kernel.Name}...");
			kernel.SaveKernel(Path.Combine(taskName,$"{kernel.Name}.kernel"));
			Console.WriteLine(" Done.");

			SaveKernelFeatures(taskName, kernel);
			//SaveKernelMergedWeights(taskName, kernel);
			//SaveKernelWeights(taskName, kernel);
		}

		//Save synaptic weights of S2 in a text file
		//Note that only the maximum weight among synapses corresponding to orientations will be saved
		public static void SaveKernelMergedWeights(string taskName, Kernel ker)
		{
			Console.Write("Saving merged weights for kernel " + ker.Name + " to file...");
			if(!Directory.Exists(Path.Combine(taskName, $"{ker.Name}_MergedWeights")))
			{
				Directory.CreateDirectory(Path.Combine(taskName, $"{ker.Name}_MergedWeights"));
			}

			for (int i = 0; i < ker.Weights.Count; i++)
			{
				string pw = Path.Combine($"{ker.Name}_MergedWeights", ker.Name + "_SumWeights_n" + i + ".txt");
				pw = Path.Combine(taskName, pw);
				StreamWriter sumWriter = new StreamWriter(pw);

				for (int j = 0; j < ker.Weights[i].GetLength(0); j++)
				{
					for (int k = 0; k < ker.Weights[i].GetLength(1); k++)
					{
						float sum = 0;
						for (int m = 0; m < ker.Weights[i].GetLength(2); m++)
						{
							sum += ker.Weights[i][j, k, m];
						}
						if (sum > 1)
							sum = 1;
						sumWriter.Write(sum + " ");
					}
					sumWriter.WriteLine();
				}

				sumWriter.Close();
			}
			Console.WriteLine(" Done.");

			Console.Write("Generating gnuplot file...");
			string p = Path.Combine($"{ker.Name}_MergedWeights", ker.Name + "_SumWeightsPlot.plt");
			p = Path.Combine(taskName, p);
			StreamWriter gnuWriter = new StreamWriter(p);
			gnuWriter.WriteLine($"set xrange [-0.5:{ker.Weights[0].GetLength(1) - 0.5}]; set yrange [{ker.Weights[0].GetLength(0) - 0.5}:-0.5]");
			//gnuWriter.WriteLine($"set xrange [-0.5:{ker.RealFeatures[0].GetLength(1) - 0.5}]; set yrange [{ker.RealFeatures[0].GetLength(0) - 0.5}:-0.5]");
			gnuWriter.WriteLine("set size ratio 1");
			gnuWriter.WriteLine("set cbrange [0:1]");
			gnuWriter.WriteLine("set pm3d map");
			gnuWriter.WriteLine("unset colorbox");
			gnuWriter.WriteLine("set palette gray");
			gnuWriter.WriteLine("set terminal png");

			gnuWriter.WriteLine("do for [i = 0:" + (ker.NumberOfFeature - 1) + "]{");
			gnuWriter.WriteLine($"\tt = sprintf('Sum Weights | Kernel: {ker.Name}, Neuron: %d', i)");
			//gnuWriter.WriteLine($"\tt = sprintf('Feature | Kernel: {ker.Name}, Neuron: %d', i)");
			gnuWriter.WriteLine("\tset title t");
			gnuWriter.WriteLine($"\toutfile = sprintf('{ker.Name}_SumWeights_n%03.0f.png', i)");
			gnuWriter.WriteLine("\tset output outfile");
			gnuWriter.WriteLine($"\tinfile = sprintf('{ker.Name}_SumWeights_n%d.txt', i)");
			gnuWriter.WriteLine("\tsplot infile matrix with image");
			gnuWriter.WriteLine("}");
			gnuWriter.Close();
			Console.WriteLine(" Done.");
		}

		//Save deconvolution of weights of S2 in a text file (for visualization)
		public static void SaveKernelFeatures(string taskName, Kernel ker)
		{
			Console.Write("Saving features for kernel " + ker.Name + " to file...");
			if (!Directory.Exists(Path.Combine(taskName, $"{ker.Name}_Features")))
			{
				Directory.CreateDirectory(Path.Combine(taskName, $"{ker.Name}_Features"));
			}

			for (int i = 0; i < ker.Weights.Count; i++)
			{
				string pw = Path.Combine($"{ker.Name}_Features", ker.Name + "_Features_n" + i + ".txt");
				pw = Path.Combine(taskName, pw);
				StreamWriter sumWriter = new StreamWriter(pw);

				for (int j = 0; j < ker.RealFeatures[i].GetLength(0); j++)
				{
					for (int k = 0; k < ker.RealFeatures[i].GetLength(1); k++)
					{
						sumWriter.Write(ker.RealFeatures[i][j, k] + " ");
					}
					sumWriter.WriteLine();
				}
				sumWriter.Close();
			}
			Console.WriteLine(" Done.");

			Console.Write("Generating gnuplot file...");
			string p = Path.Combine($"{ker.Name}_Features", ker.Name + "_FeaturesPlot.plt");
			p = Path.Combine(taskName, p);

			StreamWriter gnuWriter = new StreamWriter(p);

			gnuWriter.WriteLine($"set xrange [-0.5:{ker.RealFeatures[0].GetLength(1) - 0.5}]; set yrange [{ker.RealFeatures[0].GetLength(0) - 0.5}:-0.5]");
			gnuWriter.WriteLine("set size ratio 1");
			gnuWriter.WriteLine("set cbrange [0:1]");
			gnuWriter.WriteLine("set pm3d map");
			gnuWriter.WriteLine("unset colorbox");
			gnuWriter.WriteLine("set palette gray");
			gnuWriter.WriteLine("set terminal pngcairo enhanced crop");
			gnuWriter.WriteLine("unset xtics");
			gnuWriter.WriteLine("unset ytics");

			gnuWriter.WriteLine("do for [i = 0:" + (ker.NumberOfFeature - 1) + "]{");
			gnuWriter.WriteLine($"\toutfile = sprintf('{ker.Name}_Features_n%03.0f.png', i)");
			gnuWriter.WriteLine("\tset output outfile");
			gnuWriter.WriteLine($"\tinfile = sprintf('{ker.Name}_Features_n%d.txt', i)");
			gnuWriter.WriteLine("\tsplot infile matrix with image");
			gnuWriter.WriteLine("}");
			gnuWriter.Close();
			Console.WriteLine(" Done.");
		}

		//Save synaptic weights of S2 in a text file
		public static void SaveKernelWeights(string taskName, Kernel ker)
		{
			Console.WriteLine("Saving kernel weights to file...");
			if (!Directory.Exists(Path.Combine(taskName, $"{ker.Name}_Weights")))
			{
				Directory.CreateDirectory(Path.Combine(taskName, $"{ker.Name}_Weights"));
			}
			for (int i = 0; i < ker.Weights.Count; i++)
			{
				string pw = Path.Combine($"{ker.Name}_Weights", ker.Name + "_Weights_n" + i + ".txt");
				pw = Path.Combine(taskName, pw);
				StreamWriter writer = new StreamWriter(pw);
				for (int j = 0; j < ker.Weights[i].GetLength(0); j++)
				{
					for (int m = 0; m < ker.Weights[i].GetLength(2); m++)
					{
						for (int k = 0; k < ker.Weights[i].GetLength(1); k++)
						{
							writer.Write(ker.Weights[i][j, k, m] + " ");
						}
					}
					writer.WriteLine();
				}
				writer.Close();
			}

			Console.Write("Generating gnuplot file...");
			string p = Path.Combine($"{ker.Name}_Weights", ker.Name + "_WeightsPlot.plt");
			p = Path.Combine(taskName, p);
			StreamWriter gnuWriter = new StreamWriter(p);
			gnuWriter.WriteLine($"set xrange [-0.5:{ker.Weights[0].GetLength(2) * ker.Weights[0].GetLength(1) - 0.5}]; set yrange [{ker.Weights[0].GetLength(0) - 0.5}:-0.5]");
			gnuWriter.WriteLine($"set size ratio {1f / ker.Weights[0].GetLength(2)}");
			gnuWriter.WriteLine("set cbrange [0:1]");
			gnuWriter.WriteLine("set pm3d map");
			gnuWriter.WriteLine("unset colorbox");
			gnuWriter.WriteLine("set palette gray");
			gnuWriter.WriteLine("set terminal png");

			gnuWriter.WriteLine("do for [i = 0:" + (ker.NumberOfFeature - 1) + "]{");
			gnuWriter.WriteLine($"\tt = sprintf('Weights | Kernel: {ker.Name}, Neuron: %d', i)");
			gnuWriter.WriteLine("\tset title t");
			gnuWriter.WriteLine($"\toutfile = sprintf('{ker.Name}_Weights_n%03.0f.png', i)");
			gnuWriter.WriteLine("\tset output outfile");
			gnuWriter.WriteLine($"\tinfile = sprintf('{ker.Name}_Weights_n%d.txt', i)");
			gnuWriter.WriteLine("\tsplot infile matrix with image");
			gnuWriter.WriteLine("}");
			gnuWriter.Close();
			Console.WriteLine(" Done.");
			Console.WriteLine(" Done.");
		}
	}
}