function result = ktsp_cuda(data, labels, probes, max_k, filter)
%TSP_CUDA Implementation of the k-top scoring pair algorithm on the CUDA GPU architecture
%
%   [RESULT] = KTSP_CUDA(DATA) assumes the class identifiers are in the first row
%   of the DATA matrix and removes it.  DATA is then split based on the class
%   labels and the TSP algorithm is performed.  Rows are assumed to be probes and 
%   columns experiments.  The header row must contain only zeroes and ones for the 
%   two class labels.  
%
%   [RESULT] = KTSP_CUDA(DATA, LABELS) splits the data based on the class labels
%   and performs the TSP algorithm. No header row is assumed. Rows are assumed to 
%   be probes and columns experiments. If DATA is an MxN matrix, LABELS must be 
%   a 1xN vector containing only zeros and ones for the two class labels. 
%
%   [RESULT] = KTSP_CUDA(DATA, LABELS, PROBES) defines the probe names for each row
%    of the matrix.  Probes must be a cell_array of strings.  If probes is absent or {}, 
%   then a default set of gene names are created
%
%   [RESULT] = KTSP_CUDA(DATA, LABELS, PROBES, MAXK) is the same as above, but
%   K is the maximum number of TSP scores selected for a classifier.  Defaults to 10.
%   
%   [RESULT] = KTSP_CUDA(DATA, LABELS, PROBES, MAXK, FILTER) is the same as above, but 
%   the genes are sorted for differential expression using the Wilcoxon rank sum test
%   and only the top FILTER genes are used for the TSP calculations.  
%   
%
%RESULT is a struct containing the following fields:
%
%   primary: 	primary TSP score
%   secondary: 	secondary TSP score for breaking ties
%   vote: 	which class this score votes for (0=class1, 1=class2)
%   labels:	labels for each sample (0 or 1) as input by the user 
%   probes:	probe names, either provided by user or generated automatically
%   cvn:	number of samples left out for cross-validation	
%   filter:	filtered size of results based on differential expression
%   k: 		Size of classifier (1 for TSP, determined by kTSP algorithm)

	if (nargin < 5)
		% No filtering for differential expression
		filter = 0;
	end
	if (nargin < 4)
		% Use LOOCV for lower and upper bounds optimization algorithm
		max_k = 10;
	end
	if (nargin < 3)
		probes = {};
	end
	if (nargin < 2)
		% Get the labels from the data matrix first row
		labels = [];
	end
	if (nargin < 1)
		error('Usage: [RESULT] = KTSP_CUDA(DATA, LABELS, PROBES, MAXK, FILTER)');
	end

	% Begin the program timer
	tic;
	
	% If the label set is empty, get the first row of the data matrix
	if (isempty(labels))
		labels = data(1,:);
		data(1,:) = [];	
	% Check to make sure the number of labels is ok
	else (length(labels) ~= size(data, 2))
		error('Number of class labels does not match number of cols of data');
	end

	% Now check to make sure the labels are only zeros and ones
	if (length(unique(labels)) > 2)
		error('Class labels must be only 0 or 1')
	elseif find(unique(labels) ~= [0 1])
		error('Class labels must be only 0 or 1')
	end

	% If the probe set is empty, create a default set of probe names
	if (isempty(probes))
		probes = cell(size(data, 1), 1);
		for j=1:size(data,1)
			probes{j} = ['probe', int2str(j)];
		end
	% Otherwise, check that the probe list is the correct size
	else 
		if (length(probes) ~= size(data, 1))
			error('Number of probe names does not match number of rows of data');
		end
	end

	% Now we have ensured all the data is okay.  Lets impute any missing data
	if ~isempty(find(isnan(data)))
		fprintf('Input matrix contains NaNs, imputing...\n');
		data = knnimpute(data);
	end

	% Check to see if the inputs are singles.
	if ~isa(data, 'single')
		data = single(data);
	end
	if ~isa(labels, 'single');
		labels = single(labels);
	end

	% Calculate the ranks of the data
	ranks = tiedrankmex(data);

	% If asked to filter for differential expression, do so
	if (filter > 0)
		% Yes, this is calculating differential expression based on ranks.  This is how
		% it is done in Lin et al 2009
		[unsorted, wilcox, indices] = ranksummex(ranks, labels, 1);
		ranks = ranks(indices(1:filter), :);
	else
		% Make sure this variable is zero if it is not positive
		filter = 0;
	end
	
	N = 3;
	m = round(size(data,2)/N);	

	% Output to the user
	fprintf('Running kTSP with the following settings:\n');
	fprintf('%d Probes and %d Samples\n', size(data, 1), size(data, 2));
	fprintf('Class1 Size: %d Class2 Size: %d\n', length(find(labels==0)), length(find(labels==1)));
	fprintf('Max size of kTSP classifer: %d\n', max_k);
	fprintf('Perturbing original data set by %d at a time\n', N);
	if (filter > 0) 
		fprintf('Filtering for top %d differentially expressed genes\n', filter);
	else
		fprintf('No filtering for differentially expressed genes\n');
	end
	fprintf('\n');
	
	% Create a matrix to get the average of error rates for each k between 1 and k
	error_means = zeros(max_k, 1);
	error_counts = zeros(max_k, 1);

	% Begin the k-TSP main loop.  We leave out 3 arrays at a time (default, can be changed in
	% a future release).  For each iteration we calculate the TSP score matrices
	for j=1:m

		fprintf('kTSP main loop %d out of %d\n', j, m); 

		% Leave out arrays from the data set
		tdata = ranks;
		tlabels = labels;
		indices = randint(N,1,[1,size(ranks,2)]);
		tdata(:, indices) = [];
		tlabels(:, indices) = [];
		
		% Compute the score using this reduced data set (GPU call))
		[primary, secondary, vote] = nvtspmex(tdata(:, tlabels==0), tdata(:, tlabels==1)); 
		
		% Get top k disjoint pairs from these matrices, sorting first by primary score and second by secondary score (GPU call)
		[sorted_pri, sorted_sec, indexi, indexj] = nvdisjointpairmex(primary, secondary, max_k, 0);
		
		% Get the actual number of k returned
		actual_k = size(sorted_pri, 1);	
			
		% Go through all of these disjoint pairs (only odd numbers)
		for current_k=1:2:actual_k
		
			% Calculate error rate for this set of classifiers
			%error_rate = CVErrorRate(indexi(1:current_k), indexj(1:current_k), ranks, labels);
			[error_rate, predictions] = loocvmex(ranks, labels, [indexi(1:current_k) indexj(1:current_k)]);
			
			% Add this error rate to the matrix of errors and increment the count for this k
			error_means(current_k, 1) = error_means(current_k, 1) + error_rate;
			error_counts(current_k, 1) = error_counts(current_k, 1) + 1;
			
		end
	end

	% Now take the mean of the error rates, we will get NaN for all the even values of k
	error_means = error_means ./ error_counts;

	% Get the optimal value of k
	[max_accuracy, optimal_k] = max(error_means);

	% Finally, using the entire set of data, calculate the TSP scores
	result = struct;
	[result.primary, result.secondary, result.vote] = nvtspmex(ranks(:, labels==0), ranks(:, labels==1)); 

	% Add the labels and probes to the structure
	result.labels = labels;
	result.probes = probes;
	result.cvn = N;
	result.filter = filter;
	result.k = optimal_k;

	% If we have filtered for differential expression, put in indices to original data matrix
	% for each of the filtered genes
	if (filter > 0)
		result.indices = indices(1:filter);
	end
	
	elapsed = toc;
	fprintf('Optimal k: %d\n', result.k);
	fprintf('Finished running kTSP in %d seconds\n', elapsed);
		

% Matlab CV function - now obsolete due to C++ version MEX file
function [error_rate] = CVErrorRate(indexi, indexj, data, labels)

	% This function takes a classifier of size k (indices i and j).  For each sample, we
	% leave-it-out, use the resulting data to train the classifier, and then predict the
	% outcome of the left-out sample.  The error_rate is returned.
	num_correct = 0;
	
	%Go through each sample and remove it
	for j=1:size(data, 2)
	
		%fprintf('CV round %d out of %d\n', j, size(data, 2));
	
		% Copy over the original data
		cv_labels = labels;
		
		% Eliminate this particular sample from the data
		%loo = cv_data(:, j);
		%cv_data(:, j) = [];
		cv_labels(:, j) = [];
		
		% This will store the votes for each pair of this classifier
		votes = zeros(size(indexi, 1), 1);
		% For each pair...

		for k=1:size(indexi, 1)
			
			% Get the reduced samples for this pair
			data1 = data(indexi(k), :);
			data2 = data(indexj(k), :);
			data1(:, j) = []; data2(:, j) = [];
			
			% We are not calculating the TSP score here, we are calculating
			% the vote for this particular pair given this data.  This is done 
			% calculating Pr(data1 < data2 | C=C1), and Pr(data1 < data2 | C=C2)
			class1_score = sum(data1(cv_labels==0) < data2(cv_labels==0)) / sum(cv_labels==0);
			class2_score = sum(data1(cv_labels==1) < data2(cv_labels==1)) / sum(cv_labels==1);
			
			% If Pr(C1) > Pr(C2)
			if (class1_score > class2_score)
				votes(k) = 0;
			elseif (class2_score > class1_score)
				votes(k) = 1;
			else 
				votes(k) = 0.5;
			end;
		end
		
		% Now we have all the votes for this classifier based on the traning set
		% We must now predict the left out sample based on the classifier genes and
		% majority voting
		%votes'
		
		% This could be tied into the above FOR loop but I will keep it separate
		% for clarity; it should not slow things down very much
		predictions = [];
		for k=1:size(indexi, 1)
	
			% Using the indexi and indexj to check the prediction
			if data(indexi(k), j) < data(indexj(k), j)
				% If this vote is 0, we predict class 0
				if (votes(k) == 1) 
					prediction = 1;
				else
					prediction = 0;
				end
			else
				% If class_vote is 0, we predict class 0
				if (votes(k) == 1) 
					prediction = 0;
				else
					prediction = 1;
				end	
			end			
	
			% Add this to the set of predictions
			predictions = [predictions prediction];
	
		end
			
		% Finally, get the majority of the predictions and use this to predict the class label
		label_prediction = mode(predictions);
		
		% If somehow there were many, many, many ties and this happens, throw an error
		if (label_prediction == 0.5) 
			disp('Too many ties to make a prediction');
		else
		
			% Check to see if it is correct
			if (label_prediction == labels(j)) 
				num_correct = num_correct + 1;
			end
		end
	end
	
	% Calculate the error rate
	error_rate = num_correct / length(labels);

