function classifier = tst_cuda(data, labels, probes, filter)
%TSP_CUDA Implementation of the top-scoring triple algorithm on the CUDA GPU architecture
%
%   [RESULT] = TST_CUDA(DATA) assumes the class identifiers are in the first row
%   of the DATA matrix and removes it.  DATA is then split based on the class
%   labels and the TST algorithm is performed.  Rows are assumed to be probes and 
%   columns experiments.  The header row must contain only zeroes and ones for the 
%   two class labels.
%
%   [RESULT] = TST_CUDA(DATA, LABELS) splits the data based on the class labels
%   and performs the TST algorithm. No header row is assumed. Rows are assumed to 
%   be probes and columns experiments. If DATA is an MxN matrix, LABELS must be 
%   a 1xN vector containing only zeros and ones for the two class labels. 
%
%   [RESULT] = TST_CUDA(DATA, LABELS, PROBES) defines the probe names for each row
%   of the matrix.  Probes must be a cell_array of strings.  If probes is absent or {}, 
%   then a default set of gene names are created
%  
%   [RESULT] = TST_CUDA(DATA, LABELS, PROBES, FILTER) is the same as above, but 
%   the genes are sorted for differential expression using the Wilcoxon rank sum test
%   and only the top FILTER genes are used for the TST calculations.  
%   
%
%RESULT is a struct containing the following fields:
%
%   primary: 	primary TSP score
%   secondary: 	secondary TSP score for breaking ties
%   lower: 	lower bounds for cross-validation optimization algorithm
%   upper: 	upper bounds for cross-validation optimization algorithm
%   vote: 	which class this score votes for (0=class1, 1=class2)
%   labels:	labels for each sample (0 or 1) as input by the user 
%   probes:	probe names, either provided by user or generated automatically
%   cvn:	number of samples left out for cross-validation	
%   filter:	filtered size of results based on differential expression
%   k: 		Size of classifier (1 for TSP, determined by kTSP algorithm)

	if (nargin < 4)
		% No filtering for differential expression
		filter = 0;
	end
	if (nargin < 3)
		probes = {};
	end
	if (nargin < 2)
		% Get the labels from the data matrix first row
		labels = [];
	end
	if (nargin < 1)
		error('Usage: [RESULT] = TST_CUDA(DATA, LABELS, PROBES, FILTER)');
	end

	% Begin the program timer
	tic;	

	% If the label set is empty, get the first row of the data matrix
	if (isempty(labels))
		labels = data(1,:);
		data(1,:) = [];	
	% Check to make sure the number of labels is ok
	else (length(labels) ~= size(data, 2))
		error('Number of class labels does not match number of cols of data');
	end

	% Now check to make sure the labels are only zeros and ones
	if (length(unique(labels)) > 2)
		error('Class labels must be only 0 or 1')
	elseif find(unique(labels) ~= [0 1])
		error('Class labels must be only 0 or 1')
	end

	% If the probe set is empty, create a default set of probe names
	if (isempty(probes))
		probes = cell(size(data, 1), 1);
		for j=1:size(data,1)
			probes{j} = ['probe', int2str(j)];
		end
	% Otherwise, check that the probe list is the correct size
	else 
		if (length(probes) ~= size(data, 1))
			error('Number of probe names does not match number of rows of data');
		end
	end

	% Now we have ensured all the data is okay.  Lets impute any missing data
	if ~isempty(find(isnan(data)))
		fprintf('Input matrix contains NaNs, imputing...\n');
		data = knnimpute(data);
	end
	
	% Check to see if the inputs are singles.
	if ~isa(data, 'single')
		data = single(data);
	end
	if ~isa(labels, 'single');
		labels = single(labels);
	end	

	% Calculate the ranks of the data
	ranks = tiedrankmex(data);

	% If asked to filter for differential expression, do so
	if (filter > 0)
		% Yes, this is calculating differential expression based on ranks.  This is how
		% it is done in Lin et al 2009
		[unsorted, wilcox, indices] = ranksummex(ranks, labels, 1);
		ranks = ranks(indices(1:filter), :);
	else
		% Make sure this variable is zero if it is not positive
		filter = 0;
		
		fprintf('Warning! You are about to run the top-scoring triple algorithm on your entire dataset\n');
		fprintf('Even on the GPU this can take a significant amount of time\n');
		fprintf('Press any key to continue, CTRL-C to cancel\n');
		k = waitforbuttonpress 
		
	end

	% Output to the user
	fprintf('Running TST with the following settings:\n');
	fprintf('%d Probes and %d Samples\n', size(data, 1), size(data, 2));
	fprintf('Class1 Size: %d Class2 Size: %d\n', length(find(labels==0)), length(find(labels==1)));
	if (filter > 0) 
		fprintf('Filtering for top %d differentially expressed genes\n', filter);
	else
		fprintf('No filtering for differentially expressed genes\n');
	end
	fprintf('\n');

	% Finally we can run the TST algorithm on the GPU
	classifier = struct;
	[primary, row, col, z] = nvtstmemmex(ranks(:, labels==0), ranks(:, labels==1)); 

	% Get actual indices into the original matrix
	classifier.indexi = indices(row(1));
	classifier.indexj = indices(col(1));
	classifier.indexk = indices(z(1));	
	
	% Now we have the top scoring triples. Because the GPU does not save the permutation
	% probabilities, we will recreate them for the TST here
	[score, prob_class1, prob_class2] = CalculateProbabilities(ranks, labels, row(1), col(1), z(1), primary(1));
	classifier.primary = score;
	classifier.class1_probs = prob_class1;
	classifier.class2_probs = prob_class2;
	classifier.name1 = probes(indices(row(1)));
	classifier.name2 = probes(indices(col(1)));
	classifier.name3 = probes(indices(z(1)));
	
	elapsed = toc;
	fprintf('Finished running TST in %d seconds\n', elapsed);
	
	
function [score, class1_scores, class2_scores] = CalculateProbabilities(ranks, labels, row, col, z, primary)

	% Calculate the permutation probabilities for a particular TST and verify
	
	% Get the data we are interested in
	sdata1 = ranks(row, labels==0);
	sdata2 = ranks(col, labels==0);
	sdata3 = ranks(z, labels==0);
	class1_scores = zeros(1, 6);

	% Just use a FOR loop for this.  Could be vectorized 
	for j=1:size(sdata1, 2)
	
		temp = zeros(1, 6);
		icount = 0;	
		
		if ((sdata1(j) <= sdata2(j)) && (sdata2(j) <= sdata3(j))) 
			temp(1) = 1;
			icount = icount + 1;
		end
		if ((sdata1(j) <= sdata3(j)) && (sdata3(j) <= sdata2(j))) 
			temp(2) = 1;
			icount = icount + 1;
		end
		if ((sdata2(j) <= sdata1(j)) && (sdata1(j) <= sdata3(j))) 
			temp(3) = 1;
			icount = icount + 1;
		end
		if ((sdata2(j) <= sdata3(j)) && (sdata3(j) <= sdata1(j))) 
			temp(4) = 1;
			icount = icount + 1;
		end
		if ((sdata3(j) <= sdata1(j)) && (sdata1(j) <= sdata2(j))) 
			temp(5) = 1;
			icount = icount + 1;
		end
		if ((sdata3(j) <= sdata2(j)) && (sdata2(j) <= sdata1(j))) 
			temp(6) = 1;
			icount = icount + 1;
		end
		
		% Divide if there is a tie
		if (icount > 1)
			temp = temp ./ icount;
		end
	
		% Add the results to the class1 scores
		class1_scores = class1_scores + temp;
	
	end
	
	% Scale the permutation scores by the number of elements
	class1_scores = class1_scores ./ size(sdata1, 2);
	
	% Get the other class data
	sdata1 = ranks(row, labels==1);
	sdata2 = ranks(col, labels==1);
	sdata3 = ranks(z, labels==1);
	class2_scores = zeros(1, 6);	

	% Just use a FOR loop for this.  Could be vectorized 
	for j=1:size(sdata1, 2)
	
		temp = zeros(1, 6);
		icount = 0;	
	
		if ((sdata1(j) <= sdata2(j)) && (sdata2(j) <= sdata3(j))) 
			temp(1) = 1;
			icount = icount + 1;
		end
		if ((sdata1(j) <= sdata3(j)) && (sdata3(j) <= sdata2(j))) 
			temp(2) = 1;
			icount = icount + 1;
		end
		if ((sdata2(j) <= sdata1(j)) && (sdata1(j) <= sdata3(j))) 
			temp(3) = 1;
			icount = icount + 1;
		end
		if ((sdata2(j) <= sdata3(j)) && (sdata3(j) <= sdata1(j))) 
			temp(4) = 1;
			icount = icount + 1;
		end
		if ((sdata3(j) <= sdata1(j)) && (sdata1(j) <= sdata2(j))) 
			temp(5) = 1;
			icount = icount + 1;
		end
		if ((sdata3(j) <= sdata2(j)) && (sdata2(j) <= sdata1(j))) 
			temp(6) = 1;
			icount = icount + 1;
		end
		
		% Divide if there is a tie
		if (icount > 1)
			temp = temp ./ icount;
		end
	
		% Add the results to the class1 scores
		class2_scores = class2_scores + temp;
	
	end

	% Scale the permutation scores by the number of elements
	class2_scores = class2_scores ./ size(sdata1, 2);
	
	% Now calculate the score.  
	score = (sum(abs(class1_scores - class2_scores))+2)/4;
	
	% Compare scores
	%fprintf('GPU Score: %.6f MATLAB Score: %.6f\n', primary, score);