#include <stdio.h>
#include <iostream>
#include <math.h>


#define I	7
#define gna	35.
#define gk	9.
#define gl	0.1
#define ena	55.
#define ek	(-90.)
#define el	(-65.)
#define dt	0.01

// Set GPU parallelization 
#define BLOCKS  4
#define THREADS 256

// Set simulation time
#define TIME_ITERATIONS 6000000l

__global__
void run(float *v, float *h, float *n)
{
	int i = blockIdx.x * blockDim.x + threadIdx.x;
	float minf, hinf, htau, ninf, ntau, a, b, vm, dv=(ena-ek)/201., ras;
	float tbl[201][6];
	int  vi;
	for(unsigned long t=0; t<20l; t++){
		tbl[t][0] = vm = ek+dv*t;
		a = 0.1*(vm+35.)/(1.0-exp(-(vm+35.)/10.)) ;
		b = 4.0*exp(-(vm+60.)/18.);
		tbl[t][1] = a/(a+b);
		a = 0.07*exp(-(vm+58.)/20.);
		b = 1.0/(1.0+exp(-(vm+28.)/10.));
		tbl[t][2] =  a/(a+b);
		tbl[t][3] = 1./(a+b);
		a = 0.01*(vm+34.)/(1.0-exp(-(vm+34.)/10.));
		b = 0.125*exp(-(vm+44.)/80.);
		tbl[t][4] =  a/(a+b);
		tbl[t][5] = 1./(a+b);
	}
	for(unsigned long t   = 0; t<TIME_ITERATIONS;    ++t){
		vi = (int)floor((v[i]-ek)/dv);
		ras   = (v[i] - tbl[vi][0])/dv;
		minf = tbl[vi][1] + (tbl[vi+1][1] - tbl[vi][1])*ras;
		hinf = tbl[vi][2] + (tbl[vi+1][2] - tbl[vi][2])*ras;
		htau = tbl[vi][3] + (tbl[vi+1][3] - tbl[vi][3])*ras;
		ninf = tbl[vi][4] + (tbl[vi+1][4] - tbl[vi][4])*ras;
		ntau = tbl[vi][5] + (tbl[vi+1][5] - tbl[vi][5])*ras;
		
		n[i] += dt*(ninf - n[i])/ntau;
		h[i] += dt*(hinf - h[i])/htau;
		v[i] += dt*(-gna*minf*minf*minf*h[i]*(v[i]-ena)-gk*n[i]*n[i]*n[i]*n[i]*(v[i]-ek)-gl*(v[i]-el)+I);
	}
	
}

int main(void)
{
	int N = BLOCKS*THREADS;
	float *v, *h, *n;

	// Allocate Unified Memory – accessible from CPU or GPU
	cudaMallocManaged(&v, N*sizeof(float));
	cudaMallocManaged(&h, N*sizeof(float));
	cudaMallocManaged(&n, N*sizeof(float));

	// initialize arrays on the host
	for (int i = 0; i < N; i++) {
		v[i] = -63.f;
		h[i] = n[i] = 0.f;
	}

	// Run kernel on the GPU
	run<<<BLOCKS, THREADS>>>(v, h, n);

	// Wait for GPU to finish before accessing on host
	cudaDeviceSynchronize();
	
	//check for errors
	cudaError_t e = cudaGetLastError();
	if(e){
		printf("ERROR (%d): %s\n",e,cudaGetErrorString(e));
	}

	// Free memory
	cudaFree(v);
	cudaFree(h);
	cudaFree(n);

	return 0;
}