/*=================================================================
 *
 *  loocvmex.c
 *  Author: Andrew Magis
 *  Do LOOCV of a classifier
 *  Inputs: data, labels, classifier (Mx2 matrix)
 *  Outputs: Error rate of classifier
 *
 *
 *=================================================================*/

#include <math.h>
#include "mex.h"
#include <vector>

//#define DEBUG 

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray*prhs[]) { 
		
	//Error check
	if (nrhs != 3) {
		mexErrMsgTxt("Three inputs required (data, labels, classifer (Mx2 matrix of indices)).");
	}
	if (nlhs != 2) {
		mexErrMsgTxt("Two outputs required (error rate of classifier, results of predictions)");
	}
    // The input must be a noncomplex single.
    if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS || mxIsComplex(prhs[0])) {
        mexErrMsgTxt("Class1 Input must be a noncomplex single.");
    }
	if (mxGetClassID(prhs[1]) != mxSINGLE_CLASS || mxIsComplex(prhs[1])) {
        mexErrMsgTxt("Label Input must be a noncomplex single.");
    }		
	if (mxGetClassID(prhs[2]) != mxINT32_CLASS || mxIsComplex(prhs[2])) {
        mexErrMsgTxt("Classifier Input must be a noncomplex INT32.");	
	}
	
	//m is the number of rows (genes)
	//n is the number of chips (samples)
	unsigned int m1 = mxGetM(prhs[0]);
	unsigned int n1 = mxGetN(prhs[0]);
	unsigned int m2 = mxGetM(prhs[1]);
	unsigned int n2 = mxGetN(prhs[1]);
	unsigned int m3 = mxGetM(prhs[2]);
	unsigned int n3 = mxGetN(prhs[2]);
	if (n1 != n2) {
		mexErrMsgTxt("Number of samples for data != number of labels\n");
	}	
	if (m2 != 1) {
		mexErrMsgTxt("Only one row permitted for the labels\n");
	}
	if (n3 != 2) {
		mexErrMsgTxt("This function only works with pairs of classifiers\n");	
	}
		
#ifdef DEBUG	
	printf("Data: [%d, %d] Labels: [%d, %d]\n", m1, n1, m2, n2);
	printf("Performing LOOCV on classifier containing %d pairs\n", m3);
#endif

	// Create an mxArray for the output data - this is automatically zeroed out
	plhs[0] = mxCreateNumericMatrix(1, 1, mxSINGLE_CLASS, mxREAL);
	plhs[1] = mxCreateNumericMatrix(2, n2, mxINT32_CLASS, mxREAL);

	float *data = (float*)mxGetData(prhs[0]);
	float *labels = (float*)mxGetData(prhs[1]);
	int *classifiers = (int*)mxGetData(prhs[2]);
	
	//Count the numbers of each class
	float num_class1 = 0.f, num_class2 = 0.f;
	for (int i = 0; i < n2; i++) {
		if (labels[i] == 0.f)
			num_class1 += 1.f;
		else if (labels[i] == 1.f)
			num_class2 += 1.f;
		else 
			mexErrMsgTxt("Invalid class label: must be all 0 or 1\n");
	}
	
	//Num of correct predictions
	float num_correct = 0.f;
	//Num tried (less than n1 if there are ties)
	float num_tried = 0.f;
	
	float *error = (float*)mxGetData(plhs[0]);
	int *preds = (int*)mxGetData(plhs[1]);
	
	//For each element of the dataset
	for (int sample = 0; sample < n1; sample++) {
	
		//Get the current class counts based on which type we are excluding
		float num_class1_current, num_class2_current;
		if (labels[sample] == 0.f) {
			num_class1_current = num_class1 - 1.f;
			num_class2_current = num_class2;
		} else if (labels[sample] == 1.f) {
			num_class1_current = num_class1;
			num_class2_current = num_class2 - 1.f;
		} else {
			mexErrMsgTxt("Weird event: class labels must be all 0 or 1\n");
		}
		
		//Vector to store the votes for each classifier
		std::vector<float> votes;
	
		//For each classifier
		for (int classifier = 0; classifier < m3; classifier++) {
			
			int indexi = classifiers[classifier];
			int indexj = classifiers[m3+classifier];

			//Get pointers to the correct rows in the data
			float *data1 = &data[indexi-1];
			float *data2 = &data[indexj-1];
			
			float class1_score = 0.f, class2_score = 0.f;
			
			//For all samples (excluding sample above)
			for (int i = 0; i < n1; i++, data1+=m1, data2+=m1) {
						
				//Exclude the LOO sample
				if (i == sample) continue;
										
				//Do the counts
				if (data1[0] < data2[0]) {
					if (labels[i] == 0.f)
						class1_score += 1.f;
					else
						class2_score += 1.f;
				}		
			}
						
			//At the end, get the probabilities using the adjusted class labels
			class1_score /= num_class1_current;
			class2_score /= num_class2_current;
					
			//Choose the vote for this classifier based on these probabilities
			if (class1_score > class2_score) {
				votes.push_back(0.f);
			} else if (class2_score > class1_score) {
				votes.push_back(1.f);
			} else {
				votes.push_back(0.5f);
			}
		}
		
		/*
		//print votes
		printf("C Votes: ");
		for (int v = 0; v < votes.size(); v++) {
			printf("%.1f ", votes[v]);
		}
		printf("\n");
		*/	
			
		//Pointer to left out sample
		float *loo = data + m1*sample;
	
		//Vector of predictions
		std::vector<float> predictions;
	
		//Now we have collected all the votes for each classifier.  We must now
		//predict the left out sample using majority voting
		for (int classifier = 0; classifier < m3; classifier++) {	

			int indexi = classifiers[classifier];
			int indexj = classifiers[m3+classifier];

			if (loo[indexi-1] < loo[indexj-1]) {
			
				//If this vote is 0
				if (votes[classifier] == 1.f) {
					predictions.push_back(1.f);
				} else if (votes[classifier] == 0.f) {
					predictions.push_back(0.f);
				} else {
					predictions.push_back(0.5f);
				}
			
			} else if (loo[indexi-1] > loo[indexj-1]) {
			
				//If this vote is 1
				if (votes[classifier] == 1.f) {
					predictions.push_back(0.f);
				} else if (votes[classifier] == 0.f) {
					predictions.push_back(1.f);
				} else {
					predictions.push_back(0.5f);
				}			
			
			} else {
				//The two are a tie, so no prediction possible
				predictions.push_back(0.5f);
			
			}
		}		
		
		//Finally, get the mean of the predictions.
		float mean = 0.f;
		for (int i = 0; i < predictions.size(); i++) {
			mean += predictions[i];
		}	
		mean /= (float)predictions.size();
		
		float prediction;
		if (mean < 0.5) {
			prediction = 0.f;
			num_tried += 1.f;
		} else if (mean > 0.5) {
			prediction = 1.f;
			num_tried += 1.f;
		} else {
			prediction = 0.5f;
		}
		
		//Finally check the prediction against the label
		if (prediction == labels[sample]) {
			num_correct += 1.f;
		}
		
		//Put the prediction and class label into the output array
		preds[2*sample] = (int)prediction;
		preds[2*sample+1] = (int)labels[sample];
		
	}

	//Assign the error rate
	error[0] = num_correct / num_tried;
	
	
}


