/*=================================================================
 *
 *  nvtstmemmex.c
 *  Author: Andrew Magis
 *  Calculate TST scores on the GPU, does not return TST matrix
 *  Inputs: Class 1 data, Class 2 data,
 *  Outputs: sorted disjoint pairs, index i, index j, index k
 *
 *=================================================================*/

#include <math.h>
#include "mex.h"
#include <vector>
#include <algorithm>

//#define DEBUG

void DisplayDeviceProperties(int device) {

    cudaDeviceProp deviceProp;
    memset(&deviceProp, 0, sizeof (deviceProp));
	
	printf("-----\n");
	
    if (cudaSuccess == cudaGetDeviceProperties(&deviceProp, device)) {
		printf("Device Name\t\t\t\t%s\n", deviceProp.name);
		printf("Total Global Memory\t\t\t%ld KB\n",deviceProp.totalGlobalMem / 1024);
		printf("Maximum threads per block\t\t%d\n", deviceProp.maxThreadsPerBlock);
		
    } else {
        printf("\n%s", cudaGetErrorString(cudaGetLastError()));
    }
	
	printf("------\n");				
		
}

#define THREADS 8
#define REDUCTION_THREADS 128
#define ABSMACRO(X) (((X)<0)?(-(X)):(X))

//Kernel running on the GPU
__global__ void tstKernel(float *d_class1, float *d_class2, unsigned int n1, unsigned int n2, unsigned int m, unsigned int zcoord, float *d_s1) {
		
	//Declare shared memory variables and zero them out
	__shared__ float sclass1_scores[6*THREADS*THREADS];
	__shared__ float sclass2_scores[6*THREADS*THREADS];
	float *class1_scores = &sclass1_scores[6*(threadIdx.x*THREADS+threadIdx.y)];
	float *class2_scores = &sclass2_scores[6*(threadIdx.x*THREADS+threadIdx.y)];
	
	#pragma unroll
	for (int i = 0; i < 6; i++) {
		class1_scores[i] = 0.f;
		class2_scores[i] = 0.f;
	}
	
	//Pre-calculate the inverse of the two class lengths 
	float n1inverse = 1.f / (float)n1;
	float n2inverse = 1.f / (float)n2;
	
	//Shared memory array for each thread to store its own data
	__shared__ float stemp[6*THREADS*THREADS];
	float *temp = &stemp[6*(threadIdx.x*THREADS+threadIdx.y)];
		
	//We are only building a diagonal matrix here, so return if I am part of the diagonal
	//or below the diagonal
	if (((blockIdx.x*blockDim.x+threadIdx.x) > (blockIdx.y*blockDim.y+threadIdx.y)) &&
  	   ((blockIdx.y*blockDim.y+threadIdx.y) > zcoord)) {
	
		//Pointers to correct memory location for class1
		float *data1 = &d_class1[(blockIdx.x*blockDim.x + threadIdx.x)];
		float *data2 = &d_class1[(blockIdx.y*blockDim.y + threadIdx.y)];
		float *data3 = &d_class1[zcoord];

		//Registers to read from shared memory
		float sdata1, sdata2, sdata3;
		
		for (int i = 0; i < n1*m; i+=m) {
		
			//Set temp array to 0
			#pragma unroll
			for (int j = 0; j < 6; j++) {
				temp[j] = 0.f;
			}		
			float icount = 0.f;
		
			//Copy all the data into registers first
			sdata1 = data1[i]; sdata2 = data2[i]; sdata3 = data3[i];
		
			if ((sdata1 <= sdata2) && (sdata2 <= sdata3)) {
				temp[0] = 1.f;
				icount += 1.f;
			}
			if ((sdata1 <= sdata3) && (sdata3 <= sdata2)) {
				temp[1] = 1.f;
				icount += 1.f;
			}
			if ((sdata2 <= sdata1) && (sdata1 <= sdata3)) {
				temp[2] = 1.f;
				icount += 1.f;
			}
			if ((sdata2 <= sdata3) && (sdata3 <= sdata1)) {
				temp[3] = 1.f;
				icount += 1.f;
			}
			if ((sdata3 <= sdata1) && (sdata1 <= sdata2)) {
				temp[4] = 1.f;
				icount += 1.f;
			}
			if ((sdata3 <= sdata2) && (sdata2 <= sdata1)) {
				temp[5] = 1.f;
				icount += 1.f;
			}			
			
			//After we have computed all cases, if there was a tie, 
			//divide (won't happen very often)
			if (icount > 1.f) {
				#pragma unroll
				for (int j = 0; j < 6; j++) {
					temp[j] = __fdividef(temp[j], icount);
				}		
			}
			
			//Now add the result for each case to the final scores
			#pragma unroll
			for (int j = 0; j < 6; j++) {
				class1_scores[j] += temp[j];
			}
		}
		
		//At the end, scale the class1 scores by the number of elements
		#pragma unroll
		for (int i = 0; i < 6; i++) {
			class1_scores[i] *= n1inverse;
		}

		//Pointers to correct memory location for class2
		data1 = &d_class2[(blockIdx.x*blockDim.x + threadIdx.x)];
		data2 = &d_class2[(blockIdx.y*blockDim.y + threadIdx.y)];
		data3 = &d_class2[zcoord];
		
		for (int i = 0; i < n2*m; i+=m) {
		
			//Set temp array to 0
			#pragma unroll
			for (int j = 0; j < 6; j++) {
				temp[j] = 0.f;
			}		
			float icount = 0.f;
		
			//Copy all the data into registers first
			sdata1 = data1[i]; sdata2 = data2[i]; sdata3 = data3[i];
		
			if ((sdata1 <= sdata2) && (sdata2 <= sdata3)) {
				temp[0] = 1.f;
				icount += 1.f;
			}
			if ((sdata1 <= sdata3) && (sdata3 <= sdata2)) {
				temp[1] = 1.f;
				icount += 1.f;
			}
			if ((sdata2 <= sdata1) && (sdata1 <= sdata3)) {
				temp[2] = 1.f;
				icount += 1.f;
			}
			if ((sdata2 <= sdata3) && (sdata3 <= sdata1)) {
				temp[3] = 1.f;
				icount += 1.f;
			}
			if ((sdata3 <= sdata1) && (sdata1 <= sdata2)) {
				temp[4] = 1.f;
				icount += 1.f;
			}
			if ((sdata3 <= sdata2) && (sdata2 <= sdata1)) {
				temp[5] = 1.f;
				icount += 1.f;
			}			
			
			//After we have computed all cases, if there was a tie, 
			//divide (won't happen very often)
			if (icount > 1.f) {
				#pragma unroll
				for (int j = 0; j < 6; j++) {
					temp[j] = __fdividef(temp[j], icount);
				}		
			}
			
			//Now add the result for each case to the final scores
			#pragma unroll
			for (int j = 0; j < 6; j++) {
				class2_scores[j] += temp[j];
			}
		}
		
		//At the end, scale the class1 scores by the number of elements
		#pragma unroll
		for (int i = 0; i < 6; i++) {
			class2_scores[i] *= n2inverse;
		}
	}

	//Finally, sum the result
	float sum = 0.f;
	#pragma unroll	
	for (int i = 0; i < 6; i++) {
		sum += (float)ABSMACRO(class1_scores[i]-class2_scores[i]);
	}
	
	//Write the result to global memory
	d_s1[(blockIdx.x*blockDim.x + threadIdx.x)*m + (blockIdx.y*blockDim.y + threadIdx.y)] = sum;
}

__global__ void maxKernel(float *d_tsp, unsigned int m, unsigned int m1, float *maxValue, unsigned int *maxIndex, float *d_baddata) {

    __shared__ float sdata[REDUCTION_THREADS];
	__shared__ float sIndex[REDUCTION_THREADS];
	float s_maxValue = -1e-6;
	unsigned int s_index = 0;
	
	if (d_baddata[blockIdx.x] != 0) {
        maxValue[blockIdx.x] = 0.f;
        maxIndex[blockIdx.x] = 0.f;
		return;
 	}

	float *g_idata;
	for (unsigned int i = 0; i < m; i+=REDUCTION_THREADS) {
	
		//Set shared memory to be zero
		sdata[threadIdx.x] = 0.f;
		sIndex[threadIdx.x] = 0.f;
	
		// Go to correct loation in memory 
		g_idata = d_tsp + m*blockIdx.x + i;
		
		//Check to see if we will overshoot the actual data
		int WA = m-i > REDUCTION_THREADS ? REDUCTION_THREADS : m-i;
		
		if (threadIdx.x < WA) {
			sdata[threadIdx.x] = g_idata[threadIdx.x];
			sIndex[threadIdx.x] = m1*blockIdx.x + i + threadIdx.x;
		}
		__syncthreads();
			
		// do reduction in shared mem
		for(unsigned int s=blockDim.x/2; s>0; s>>=1) {
			if (threadIdx.x < s) {
				if (sdata[threadIdx.x + s] > sdata[threadIdx.x]) {
					sdata[threadIdx.x] = sdata[threadIdx.x + s];
					sIndex[threadIdx.x] = sIndex[threadIdx.x + s];
				}
			}
			__syncthreads();
		}
		
		// Keep track of largest element of this round
		if (threadIdx.x == 0) {
			if (sdata[0] > s_maxValue) {
				s_maxValue = sdata[0];
				s_index = sIndex[0];
			}
		}
	}
	
	if (threadIdx.x == 0) {
		maxValue[blockIdx.x] = s_maxValue;
		maxIndex[blockIdx.x] = s_index;
	}
}

__global__ void clearKernel(float *d_tsp, unsigned int m, unsigned int row, unsigned int col, float *d_baddata) {

	for (unsigned int i = 0; i < m; i+=REDUCTION_THREADS) {
		
		// Go to correct loation in memory 
		float *col_loc = d_tsp + m*col + i + threadIdx.x;
		float *row_loc = d_tsp + m*(threadIdx.x+i) + row;
		
		//Check to see if we will overshoot the actual data
		int WA = m-i > REDUCTION_THREADS ? REDUCTION_THREADS : m-i;
		
		if (threadIdx.x < WA) {
			*col_loc = 0.f;
			*row_loc = 0.f;
		}
		__syncthreads();
	}
	d_baddata[col] = 1.f;
}
typedef struct {
	unsigned int row;
	unsigned int col;
	unsigned int z;
	float primary;
}	ScoreElement;

//Sort predicate function for sorting the vector of tsp scores by 
//first the primary then the secondary score
bool sort_pred(const ScoreElement& left, const ScoreElement& right) {
	if (left.primary > right.primary) return true;
	return false;
}

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray*prhs[]) { 
	
#ifdef DEBUG	
	DisplayDeviceProperties(0);
	
	//Time the execution of this function
	cudaEvent_t start_event, stop_event;
	cudaEventCreate(&start_event);
    cudaEventCreate(&stop_event);
    cudaEventRecord(start_event, 0);
    cudaEventSynchronize(start_event);
	float time_run;
#endif		
		
	//Error check
	if (nrhs != 2) {
		mexErrMsgTxt("Two inputs required (class 1 ranks, class 2 ranks)");
	}
	if (nlhs != 4) {
		mexErrMsgTxt("Four outputs required");
	}
    // The input must be a noncomplex single.
    if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS || mxIsComplex(prhs[0])) {
        mexErrMsgTxt("Class1 Input must be a noncomplex single.");
    }
	if (mxGetClassID(prhs[1]) != mxSINGLE_CLASS || mxIsComplex(prhs[1])) {
        mexErrMsgTxt("Class2 Input must be a noncomplex single.");
    }		
	
	//m is the number of rows (genes)
	//n is the number of chips (samples)
	unsigned int m1 = mxGetM(prhs[0]);
	unsigned int n1 = mxGetN(prhs[0]);
	unsigned int m2 = mxGetM(prhs[1]);
	unsigned int n2 = mxGetN(prhs[1]);
	if (m1 != m2) {
		mexErrMsgTxt("Number of genes for class 1 != class 2\n");
	}	
		
	//Create a padded m which is multiple of THREADS
	unsigned int m;
	if (m1 % THREADS == 0) {
		m = m1;
	} else {
		m = ((int)(m1 / THREADS) + 1) * THREADS;
	}
	
#ifdef DEBUG	
	printf("Class1 Ranks: [%d, %d] Class2 Ranks: [%d, %d]\n", m1, n1, m2, n2);
	printf("Thread Dimension: %d Padded length: %d\n", THREADS, m);
#endif
	
	unsigned long int class1_size = m*n1 * sizeof(float);
	unsigned long int class2_size = m*n2 * sizeof(float);
	unsigned long int result_size_gpu = m*m * sizeof(float);
	
	//Allocate space on the GPU to store the input data
	float *d_class1, *d_class2, *d_s1;
	if ( cudaMalloc( (void**)&d_class1, class1_size ) != cudaSuccess )
       	mexErrMsgTxt("Memory allocating failure on the GPU.");
	if ( cudaMalloc( (void**)&d_class2, class2_size )  != cudaSuccess )
        mexErrMsgTxt("Memory allocating failure on the GPU.");
    if ( cudaMalloc( (void**)&d_s1, result_size_gpu )  != cudaSuccess )
        mexErrMsgTxt("Memory allocating failure on the GPU.");
	
	//Reallocate space for the data with zeroed out padding
	float *h_class1, *h_class2, *h_s1;
	if (cudaMallocHost((void**)&h_class1, class1_size) != cudaSuccess) 
		mexErrMsgTxt("Memory allocating failure on the host.");
	if (cudaMallocHost((void**)&h_class2, class2_size) != cudaSuccess)
		mexErrMsgTxt("Memory allocating failure on the host.");
	if (cudaMallocHost((void**)&h_s1, result_size_gpu) != cudaSuccess) 
		mexErrMsgTxt("Memory allocating failure on the host.");
						
	//Zero out this memory
	memset(h_class1, 0, class1_size);
	memset(h_class2, 0, class2_size);
	memset(h_s1, 0, result_size_gpu);
	
	//Copy over data to new padded array location
	float *temp = h_class1;
	float *mtemp = (float*)mxGetData(prhs[0]);
	for (int i = 0; i < n1; i++) {
		memcpy(temp, mtemp, m1*sizeof(float));
		mtemp += m1;
		temp += m;
	}	
	temp = h_class2;
	mtemp = (float*)mxGetData(prhs[1]);
	for (int i = 0; i < n2; i++) {
		memcpy(temp, mtemp, m1*sizeof(float));
		mtemp += m1;
		temp += m;
	}		
									
	//Copy data to the GPU
	if (cudaMemcpy(d_class1, h_class1, class1_size, cudaMemcpyHostToDevice) != cudaSuccess)
		mexErrMsgTxt("Error copying memory to the GPU.");
	if (cudaMemcpy(d_class2, h_class2, class2_size, cudaMemcpyHostToDevice) != cudaSuccess)
		mexErrMsgTxt("Error copying memory to the GPU.");
	
	//Allocate space for the maximum value calculations
	float *d_maxValues, *h_maxValues, *d_maxValue, *h_maxValue;
	unsigned int *d_maxIndices, *d_maxIndex, *h_maxIndices, *h_maxIndex;
	if ( cudaMalloc( (void**)&d_maxValues, m*sizeof(float))  != cudaSuccess )
       	mexErrMsgTxt("Memory allocating failure on the GPU.");
	if ( cudaMalloc( (void**)&d_maxValue, sizeof(float))  != cudaSuccess )
       	mexErrMsgTxt("Memory allocating failure on the GPU.");	
	if ( cudaMalloc( (void**)&d_maxIndices, m*sizeof(float))  != cudaSuccess )
       	mexErrMsgTxt("Memory allocating failure on the GPU.");
	if ( cudaMalloc( (void**)&d_maxIndex, sizeof(float))  != cudaSuccess )
       	mexErrMsgTxt("Memory allocating failure on the GPU.");	
	if (cudaMallocHost((void**)&h_maxValues, m*sizeof(float)) != cudaSuccess) 
		mexErrMsgTxt("Memory allocating failure on the host.");	
	if (cudaMallocHost((void**)&h_maxValue, sizeof(float)) != cudaSuccess) 
		mexErrMsgTxt("Memory allocating failure on the host.");	
	if (cudaMallocHost((void**)&h_maxIndices, m*sizeof(float)) != cudaSuccess) 
		mexErrMsgTxt("Memory allocating failure on the host.");	
	if (cudaMallocHost((void**)&h_maxIndex, sizeof(float)) != cudaSuccess) 
		mexErrMsgTxt("Memory allocating failure on the host.");	
	
	//Allocate space for the maximum value calculation speedup arrays
	float *h_baddata, *d_baddata, *h_baddata_single, *d_baddata_single;
 	if ( cudaMalloc( (void**)&d_baddata, m1*sizeof(float))  != cudaSuccess )
    	mexErrMsgTxt("Memory allocating failure on the GPU.");
    if ( cudaMalloc( (void**)&d_baddata_single, sizeof(float))  != cudaSuccess )
        mexErrMsgTxt("Memory allocating failure on the GPU.");	
    if (cudaMallocHost((void**)&h_baddata, m1*sizeof(float)) != cudaSuccess)
        mexErrMsgTxt("Memory allocating failure on the host.");
    if (cudaMallocHost((void**)&h_baddata_single, sizeof(float)) != cudaSuccess)
        mexErrMsgTxt("Memory allocating failure on the host.");
		
	//Set the dimension of the blocks and grids
	dim3 dimBlock(THREADS, THREADS);
	dim3 dimGrid(m/THREADS, m/THREADS);	
	
	//Set the dimension of the parallel reduction blocks and grids
	dim3 dimBlockMax(REDUCTION_THREADS, 1, 1);
	dim3 dimGridMax(m, 1, 1);	
	
#ifdef DEBUG	
	printf("Scheduling [%d %d] threads in [%d %d] blocks for %d executions\n", THREADS, THREADS, m/THREADS, m/THREADS, m1);
#endif

	std::vector<float> v_tsp;
	std::vector<int> v_row;
	std::vector<int> v_col;
	std::vector<ScoreElement> scores;
	
	//Get top three for each 2D matrix.  For the current purposes this is enough, 
	//since we are only interested in the maximum value at this time.  If a kTST is
	//developed, we would need a more rigorous method to choose disjoint triples.
	int stop = 3;

	//No streams here
	for (unsigned int zcoord = 0; zcoord < m1; zcoord++) {

		//Call the TST kernel
		tstKernel<<<dimGrid, dimBlock>>>(d_class1, d_class2, n1, n2, m, zcoord, d_s1);
		cudaThreadSynchronize();
	
		//Reset the bad data arrays and copy to the GPU	
		memset(h_baddata, 0, m1*sizeof(float));
		memset(h_baddata_single, 0, sizeof(float));
		if (cudaMemcpy(d_baddata, h_baddata, m1*sizeof(float), cudaMemcpyHostToDevice) != cudaSuccess)
			mexErrMsgTxt("Error copying memory to the GPU.");
		if (cudaMemcpy(d_baddata_single, h_baddata_single, sizeof(float), cudaMemcpyHostToDevice) != cudaSuccess)
			mexErrMsgTxt("Error copying memory to the GPU.");	
		
		//Find the set of maximum scores
		h_maxValue[0] = 1.f;
		for (int i = 0; i < (int)stop; i++) {
	
			maxKernel<<<dimGridMax, dimBlockMax>>>(d_s1, m, m1, d_maxValues, d_maxIndices, d_baddata);
			cudaThreadSynchronize();
			
			if (cudaMemcpy(h_maxValues, d_maxValues, m*sizeof(float), cudaMemcpyDeviceToHost) != cudaSuccess) 
				mexErrMsgTxt("Error copying memory from the GPU.");
			if (cudaMemcpy(h_maxIndices, d_maxIndices, m*sizeof(float), cudaMemcpyDeviceToHost) != cudaSuccess) 
				mexErrMsgTxt("Error copying memory from the GPU.");		
			
			maxKernel<<<1, dimBlockMax>>>(d_maxValues, m, m1, d_maxValue, d_maxIndex, d_baddata_single);
		
			if (cudaMemcpy(h_maxValue, d_maxValue, sizeof(float), cudaMemcpyDeviceToHost) != cudaSuccess) 
				mexErrMsgTxt("Error copying memory from the GPU.");
			if (cudaMemcpy(h_maxIndex, d_maxIndex, sizeof(float), cudaMemcpyDeviceToHost) != cudaSuccess) 
				mexErrMsgTxt("Error copying memory from the GPU.");	
			
			//If we can no longer find scores, exit the loop
			if (h_maxValues[h_maxIndex[0]] <= 0) {
				break;
			}
			
			//Convert index into row/column indices
			int index = h_maxIndices[h_maxIndex[0]];
			int col = (int)floor(index/m1);
			int row = index % m1;
		
			//Add these scores to the vector of scores
			ScoreElement score;
			score.row = row;
			score.col = col;
			score.z = zcoord;
			score.primary = h_maxValues[h_maxIndex[0]];
			scores.push_back(score);

			//Clear this row and column
			clearKernel<<<1, dimBlockMax>>>(d_s1, m, row, col, d_baddata);	
			clearKernel<<<1, dimBlockMax>>>(d_s1, m, col, row, d_baddata);			
		
		}		

		//Make sure all copies are complete before continuing
		cudaThreadSynchronize();
	
	}

	//Resort the scores,in case there are ties of the primary score
	sort(scores.begin(), scores.end(), sort_pred);
						
	//Create the output for the top scoring pairs
 	plhs[0] = mxCreateNumericMatrix(scores.size(), 1, mxSINGLE_CLASS, mxREAL);
 	plhs[1] = mxCreateNumericMatrix(scores.size(), 1, mxINT32_CLASS, mxREAL);
 	plhs[2] = mxCreateNumericMatrix(scores.size(), 1, mxINT32_CLASS, mxREAL);	
 	plhs[3] = mxCreateNumericMatrix(scores.size(), 1, mxINT32_CLASS, mxREAL);
	
	float *maxscores = (float*) mxGetData(plhs[0]);
	int *indexi = (int*) mxGetData(plhs[1]);
	int *indexj = (int*) mxGetData(plhs[2]);
	int *indexk = (int*) mxGetData(plhs[3]);
	
	for (int i = 0; i < scores.size(); i++) {
		maxscores[i] = scores[i].primary;
		indexi[i] = scores[i].row+1;
		indexj[i] = scores[i].col+1;
		indexk[i] = scores[i].z+1;
	}	
	
#ifdef DEBUG	
	cudaEventRecord(stop_event, 0);
	cudaEventSynchronize(stop_event); // block until the event is actually recorded
	cudaEventElapsedTime(&time_run, start_event, stop_event);
	printf("Finished running nvTST in %.6f seconds\n", time_run / 1000.0);
#endif		
		
	//Clear up memory on the device
	cudaFree(d_class1);
	cudaFree(d_class2);
	cudaFree(d_s1); 
	cudaFree(d_maxValues);
	cudaFree(d_maxValue);
	cudaFree(d_maxIndex);
	cudaFree(d_maxIndices);
	cudaFree(d_baddata);
	cudaFree(d_baddata_single);
	
	//Clear up memory on the host
	cudaFreeHost(h_class1);
	cudaFreeHost(h_class2);
	cudaFreeHost(h_s1); 
	cudaFreeHost(h_maxValues);
	cudaFreeHost(h_maxValue);	
	cudaFreeHost(h_maxIndices);
	cudaFreeHost(h_maxIndex);
	cudaFreeHost(h_baddata);
	cudaFreeHost(h_baddata_single);	
		
}


