/*=================================================================
 *
 *  Main.cu
 *  Author: Andrew Magis
 *  Main tsp_cuda application entry point
 *
 *=================================================================*/

#include <stdlib.h>
#include <stdio.h>
#include <unistd.h>
#include <string>
#include <sys/time.h>
#include <vector>

#include "FileParser.h"
#include "Kernels.cu"
#include "RankSum.h"
#include "DisjointPair.cu"
#include "LOOCV.h"
#include "TSP.cu"
#include "TST.cu"

void Usage() {
    printf("Usage: tsp_cuda input_filename algorithm [-N num_genes] [-p] [-o output_filename] [-d devicenum]\n\n");
	printf("\tinput_filename (required)\t\tcomma or TAB-delimited file of expression data\n");
	printf("\talgorithm (required)\t\t\tmust be one of the following: tsp, tst\n");
	printf("\t-N num_genes (optional)\t\t\tFilter for the top differentially expressed genes\n");
	printf("\t-p (optional)\t\t\t\tLook for probe names in input file (row headers)\n");
	printf("\t-o output_filename (optional)\t\tWrite output to this filename\n");
	printf("\t-d device_num (optional)\t\tUse GPU device device_num (defaults to 0)\n");
	printf("\n");
	exit(1);
}

int main(int argc, char** argv) {

	FILE* output = stdout;
	int filter = 0;
	bool probes_defined = false;

    //Read in the CUDA_DEVICE environment variable
    int device_num = 0;
    if (getenv("CUDA_DEVICE") != NULL) {
        device_num = atoi(getenv("CUDA_DEVICE"));
    }
	
    if (argc < 3) {
        Usage();
    }
	
    //Read in mandatory command line arguments first
    std::string input_file = argv[1];
	//Read in mandatory algorithm argument
	int algorithm = -1;
	
	if ((strcmp(argv[2], "tsp") == 0) || (strcmp(argv[2], "TSP") == 0)) {
		algorithm = 0;
		printf("Performing TSP algorithm\n");
	} else if ((strcmp(argv[2], "ktsp") == 0) || (strcmp(argv[2], "KTSP") == 0)) {
		algorithm = 1;
		printf("Performing kTSP algorithm\n");
	} else if ((strcmp(argv[2], "tst") == 0) || (strcmp(argv[2], "TST") == 0)) {
		algorithm = 2;
		printf("Performing TST algorithm\n");
	} else {
		printf("Invalid algorithm selection '%s'\n", argv[2]);
		Usage();
		exit(1);
	}
	
	char c;
    // Parse the optional command line arguments
    while ((c = getopt(argc-2, &argv[2], "N:o:d:p")) != -1) {
	
        switch (c) {
			//If filter for differential expression
			case 'N':
				filter = atoi(optarg);
				printf("Filtering for top %d differentially expressed genes\n", filter);
				break;
            //If output filename is provided
            case 'o':
                output =  fopen(optarg, "w");
				printf("Set output file to be '%s'\n", optarg);
                break;
            //If the device number is specified on command line
            case 'd':
                device_num = atoi(optarg);
				printf("Set device num to be %d\n", device_num);
                break;
			case 'p':
				probes_defined = true;
				printf("Looking for probe names in input file\n");
				break;
            case '?':
                if ((optopt == 'N') || (optopt == 'o') || (optopt == 'd'))
                    fprintf(stderr, "Option -%c requires an argument.\n", optopt);
                else if (isprint(optopt))
                    fprintf(stderr, "Unknown option `-%c'.\n", optopt);
                else
                    fprintf(stderr, "Unknown option character `\\x%x'.\n", optopt);
            default:
                Usage();
        }
    }
	
	//Set the device
	cudaSetDevice(device_num);
	
	//Open the input file
	printf("Loading in data from file '%s'\n", input_file.c_str());
	FileParser<float> parse(input_file, probes_defined, false);
	printf("Finished loading data\n");
	
	//Get the numbers of columns in each class
	unsigned int num_class1 = parse.GetClassNum(0);
	unsigned int num_class2 = parse.GetClassNum(1);
	unsigned int num_probes = parse.Rows()-1;
	unsigned int num_samples = parse.Cols();
	
	//Convert data to ranks
	parse.ConvertToRanks();
	float *ranks1 = NULL, *ranks2 = NULL;
	
	//Filter for differential expression
	std::vector<float*> ranksum;
	if (filter > 0) {
				
		//Get all the data and compute the ranks
		float *ranks = parse.GetData();
				
		//Calculate differentially expressed genes
		printf("Filtering for differentially expressed genes\n");
		ranksum = RankSum(ranks, num_probes, num_samples, parse.GetRow(0), 1);
		
		//Extract out these rows from the original data
		num_probes = filter; 
		ranks1 = parse.GetClassColOrder(0, filter, (int*)ranksum[2]);
		ranks2 = parse.GetClassColOrder(1, filter, (int*)ranksum[2]);	
		delete[] ranks;								
										
	} else {
		num_probes = parse.Rows()-1;
		ranks1 = parse.GetClassColOrder(0);
		ranks2 = parse.GetClassColOrder(1);			
	}
	
	/*
	for (int i = 0; i < 500; i++) {
		printf("r1: %.3f\n", ranks1[i]);
	}
	printf("numprobes: %d\n", num_probes);
	*/
	
	
	//TSP algorithm
	std::vector<ScoreElement> pairs;	
	std::vector<float> class1_probs, class2_probs;
	if (algorithm == 0) {
	
		//Call TSP on the ranks, pasning in the dimensions of the data
		printf("Running TSP algorithm on GPU\n");
		std::vector<float*> pointers = TSP(ranks1, ranks2, num_probes, num_class1, num_probes, num_class2);
		
		//Find all the disjoint pairs	
		printf("Finding disjoint pairs\n");
		pairs = DisjointPairKernel(pointers[0], pointers[1], num_probes, num_probes, 10);
	
		
	} else if (algorithm == 1) {
	
		printf("kTSP is not implemented in this version.  Try the MATLAB version\n");
	
	} else if (algorithm == 2) {
	
		printf("Running TST algorithm on GPU\n");
		pairs = TST(ranks1, ranks2, num_probes, num_class1, num_probes, num_class2);
		
		//Calculate probabilities and scores for this classifier
		pairs[0].primary = CalculateTSTProbabilities(ranks1, ranks2, num_probes, num_class1, num_probes, num_class2, 
			pairs[0].row, pairs[0].col, pairs[0].z, class1_probs, class2_probs);
	
	} else {
	
		printf("Unknown algorithm: %d\n", algorithm);
	}
	
	if (pairs.size() == 0) {
		fprintf(output, "No pairs found.  Try different input settings\n");
		exit(1);
	}
		
	//Get the probe names
	std::vector<string> probe_names = parse.GetRowHeader();
		
	//Output the TSP to the output stream
	printf("Finished running tsp_cuda\n");
	fprintf(output, "\n*** Run information\n");
	if (algorithm == 0)
		fprintf(output, "Algorithm: top-scoring pair\n");
	else if (algorithm == 1)
		fprintf(output, "Algorithm: k-top-scoring pair\n");
	else if (algorithm == 2)
		fprintf(output, "Algorithm: top-scoring triplet\n");
	else fprintf(output, "Unknown algorithm\n");
	fprintf(output, "Input file: '%s'\n", input_file.c_str());
	fprintf(output, "Num probes: %d\n", parse.Rows()-1);
	fprintf(output, "Num samples: %d\n", parse.Cols());
	fprintf(output, "Class1 samples: %d Class2 samples: %d\n", num_class1, num_class2);
	if (filter > 0) {
		int *indices = (int*)ranksum[2];
		fprintf(output, "Filtered for the top %d differentially expressed genes\n", filter);
		fprintf(output, "\n*** Classifier:\n");
		fprintf(output, "Indexi: %d\n", indices[pairs[0].row]+1);
		fprintf(output, "Indexj: %d\n", indices[pairs[0].col]+1);
		if (algorithm == 2) {
			fprintf(output, "Indexk: %d\n", indices[pairs[0].z]+1);		
		}
		fprintf(output, "Primary: %.4f\n", pairs[0].primary);
		if (algorithm == 2) {
			fprintf(output, "Class1_probs: [ ");
			for (int i = 0; i < 6; i++) {
				fprintf(output, "%.4f ", class1_probs[i]);
			}
			fprintf(output, "\n");
			fprintf(output, "Class2_probs: [ ");
			for (int i = 0; i < 6; i++) {
				fprintf(output, "%.4f ", class2_probs[i]);
			}
			fprintf(output, "\n");		
		} else {
			fprintf(output, "Secondary: %.4f\n", pairs[0].secondary);
		}
		fprintf(output, "Name1: '%s'\n", probe_names[indices[pairs[0].row]].c_str());
		fprintf(output, "Name2: '%s'\n", probe_names[indices[pairs[0].col]].c_str());		
		if (algorithm == 2) {
			fprintf(output, "Name3: '%s'\n", probe_names[indices[pairs[0].z]].c_str());		
		}
	} else {
		fprintf(output, "No filtering for differentially expressed genes\n");
		fprintf(output, "\n*** Classifier:\n");
		fprintf(output, "Indexi: %d\n", pairs[0].row+1);
		fprintf(output, "Indexj: %d\n", pairs[0].col+1);
		if (algorithm == 2) {
			fprintf(output, "Indexk: %d\n", pairs[0].z+1);
		}		
		fprintf(output, "Primary: %.4f\n", pairs[0].primary);
		if (algorithm == 2) {
			fprintf(output, "Class1_probs: [ ");
			for (int i = 0; i < 6; i++) {
				fprintf(output, "%.4f ", class1_probs[i]);
			}
			fprintf(output, "\n");
			fprintf(output, "Class2_probs: [ ");
			for (int i = 0; i < 6; i++) {
				fprintf(output, "%.4f ", class2_probs[i]);
			}
			fprintf(output, "\n");		
		} else {
			fprintf(output, "Secondary: %.4f\n", pairs[0].secondary);
		}
		fprintf(output, "Name1: '%s'\n", probe_names[pairs[0].row].c_str());
		fprintf(output, "Name2: '%s'\n", probe_names[pairs[0].col].c_str());		
		if (algorithm == 2) {
			fprintf(output, "Name3: '%s'\n", probe_names[pairs[0].z].c_str());		
		}
	}
	
	//If we ran TSP we can do a cross validation on the classifier
	if (algorithm == 0) {
	
		//Get all the data and compute the ranks
		float *ranks = parse.GetData();
		
		//Create output vector
		std::vector<int> preds;
		float error;
		if (filter > 0) 
			error = LOOCV(ranks, parse.Rows()-1, num_samples, parse.GetRow(0), pairs, 1, (int*)ranksum[2], preds);
		else 
			error = LOOCV(ranks, parse.Rows()-1, num_samples, parse.GetRow(0), pairs, 1, NULL, preds);
	
		fprintf(output, "\n*** Cross-validation:\n");
		fprintf(output, "Error Rate: %.4f\n", error);
		
	
	}
	
	delete[] ranks1;
	delete[] ranks2;
	if (filter > 0) {
		delete[] ranksum[0];
		delete[] ranksum[1];
		delete[] ranksum[2];
	}
	
}




