/*
 * Decompiled with CFR 0.152.
 */
package org.neuroph.eval;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.apache.commons.lang3.SerializationUtils;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.learning.error.MeanSquaredError;
import org.neuroph.eval.ClassifierEvaluator;
import org.neuroph.eval.ErrorEvaluator;
import org.neuroph.eval.Evaluation;
import org.neuroph.eval.EvaluationResult;
import org.neuroph.eval.FoldResult;
import org.neuroph.eval.classification.ClassificationMetrics;
import org.neuroph.eval.classification.ConfusionMatrix;

public class KFoldCrossValidation {
    private final NeuralNetwork neuralNetwork;
    private final DataSet dataSet;
    private final int numFolds;
    private CountDownLatch allFoldsCompleted;
    private List<ConfusionMatrix> confusionMatrices;
    private List<ClassificationMetrics.Stats> statlist;
    private List<FoldResult> crossFoldResults;

    public KFoldCrossValidation(NeuralNetwork neuralNetwork, DataSet dataSet, int numFolds) {
        this.neuralNetwork = neuralNetwork;
        this.dataSet = dataSet;
        this.numFolds = numFolds;
    }

    public EvaluationResult run() throws InterruptedException, ExecutionException {
        this.confusionMatrices = new ArrayList<ConfusionMatrix>();
        this.statlist = new ArrayList<ClassificationMetrics.Stats>();
        this.crossFoldResults = new ArrayList<FoldResult>();
        DataSet[] foldDataSets = this.dataSet.split(this.numFolds);
        ArrayList<Object> workersTasks = new ArrayList<Object>();
        this.allFoldsCompleted = new CountDownLatch(this.numFolds);
        for (int i = 0; i < this.numFolds; ++i) {
            DataSet validationSet = foldDataSets[i];
            DataSet trainingSet = this.createTrainingSetFromFolds(foldDataSets, i);
            CrossValidationWorker cvWorker = new CrossValidationWorker(trainingSet, validationSet);
            workersTasks.add(cvWorker);
        }
        ExecutorService executor = Executors.newFixedThreadPool(4);
        List evaluationResults = executor.invokeAll(workersTasks);
        ArrayList<FoldResult> results = new ArrayList<FoldResult>();
        for (Future future : evaluationResults) {
            results.add((FoldResult)future.get());
        }
        this.allFoldsCompleted.await();
        executor.shutdown();
        int foldNumber = 1;
        for (FoldResult crossfolds : results) {
            this.confusionMatrices.add(crossfolds.getConfusionMatrix());
            this.crossFoldResults.add(crossfolds);
            ClassificationMetrics.Stats average = ClassificationMetrics.average(ClassificationMetrics.createFromMatrix(crossfolds.getConfusionMatrix()));
            this.statlist.add(average);
            ++foldNumber;
        }
        ConfusionMatrix confusionMatrix = this.sumConfusionMatrix(this.confusionMatrices, this.dataSet);
        System.out.println();
        System.out.println("All folds results:");
        System.out.println();
        EvaluationResult sumEval = new EvaluationResult();
        sumEval.setConfusionMatrix(confusionMatrix);
        this.printFoldResults(confusionMatrix, 0);
        ClassificationMetrics.Stats allstat = KFoldCrossValidation.averageClassificationMetrics(this.statlist);
        this.printResults(this.dataSet, allstat, this.numFolds);
        return sumEval;
    }

    public ConfusionMatrix sumConfusionMatrix(List<ConfusionMatrix> cmList, DataSet dataSet) {
        ConfusionMatrix cm = new ConfusionMatrix(cmList.get(0).getClassLabels());
        int[][] ar = new int[dataSet.getOutputSize()][dataSet.getOutputSize()];
        for (ConfusionMatrix c : cmList) {
            for (int i = 0; i < dataSet.getOutputSize(); ++i) {
                for (int j = 0; j < dataSet.getOutputSize(); ++j) {
                    int[] nArray = ar[i];
                    int n = j;
                    nArray[n] = nArray[n] + c.get(i, j);
                }
            }
        }
        cm.setValues(ar);
        return cm;
    }

    public static ClassificationMetrics.Stats averageClassificationMetrics(List<ClassificationMetrics.Stats> metricsList) {
        ClassificationMetrics.Stats average = metricsList.get(0);
        double count = 1.0;
        for (ClassificationMetrics.Stats st : metricsList) {
            if (st.equals(average)) continue;
            average.accuracy += st.accuracy;
            average.precision += st.precision;
            average.recall += st.recall;
            average.fScore += st.fScore;
        }
        count += 1.0;
        count = metricsList.size();
        average.accuracy /= count;
        average.precision /= count;
        average.recall /= count;
        average.fScore /= count;
        return average;
    }

    public void printResults(DataSet dataset, ClassificationMetrics.Stats nst, int numfolds) {
        System.out.println();
        System.out.println("=== Cross validation result ===");
        System.out.println("Instances: " + dataset.size());
        System.out.println("Number of folds: " + numfolds);
        System.out.println("\n");
        System.out.println("=== Summary ===");
        System.out.println("Accuracy: " + nst.accuracy);
        System.out.println("Precision: " + nst.precision);
        System.out.println("Recall: " + nst.recall);
        System.out.println("FScore: " + nst.fScore);
        System.out.println("Correlation coefficient: " + nst.correlationCoefficient);
    }

    public void printFoldResults(ConfusionMatrix confusionMatrix, int foldIdx) {
        System.out.println();
        System.out.println("Fold: " + foldIdx);
        System.out.println();
        System.out.println("Confusion matrrix:\r\n");
        System.out.println(confusionMatrix.toString() + "\r\n\r\n");
        System.out.println("Classification metrics\r\n");
        ClassificationMetrics[] metrics = ClassificationMetrics.createFromMatrix(confusionMatrix);
        ClassificationMetrics.Stats stat = ClassificationMetrics.average(metrics);
        for (ClassificationMetrics cm : metrics) {
            System.out.println(cm.toString() + "\r\n");
        }
        System.out.println(stat.toString());
    }

    public void printStats(ConfusionMatrix confusionMatrix) {
        ClassificationMetrics[] metrics = ClassificationMetrics.createFromMatrix(confusionMatrix);
        ClassificationMetrics.Stats stat = ClassificationMetrics.average(metrics);
        System.out.println(stat.toString());
    }

    private DataSet createTrainingSetFromFolds(DataSet[] folds, int excludeIdx) {
        DataSet trainingSet = new DataSet(this.dataSet.getInputSize(), this.dataSet.getOutputSize());
        for (int i = 0; i < folds.length; ++i) {
            if (i == excludeIdx) continue;
            trainingSet.addAll(folds[i]);
        }
        return trainingSet;
    }

    public List<FoldResult> getResultsByFolds() {
        return this.crossFoldResults;
    }

    private class CrossValidationWorker
    implements Callable<FoldResult> {
        private final DataSet trainingSet;
        private final DataSet validationSet;

        public CrossValidationWorker(DataSet trainingSet, DataSet validationSet) {
            this.trainingSet = trainingSet;
            this.validationSet = validationSet;
        }

        @Override
        public FoldResult call() throws Exception {
            NeuralNetwork neuralNet = (NeuralNetwork)SerializationUtils.clone((Serializable)KFoldCrossValidation.this.neuralNetwork);
            Evaluation evaluation = new Evaluation();
            evaluation.addEvaluator(new ErrorEvaluator(new MeanSquaredError()));
            if (KFoldCrossValidation.this.neuralNetwork.getOutputsCount() == 1) {
                evaluation.addEvaluator(new ClassifierEvaluator.Binary(0.5));
            } else {
                evaluation.addEvaluator(new ClassifierEvaluator.MultiClass(KFoldCrossValidation.this.dataSet.getColumnNames()));
            }
            KFoldCrossValidation.this.neuralNetwork.learn(this.trainingSet);
            EvaluationResult evaluationResult = evaluation.evaluate(neuralNet, this.validationSet);
            FoldResult foldResult = new FoldResult(neuralNet, this.trainingSet, this.validationSet);
            foldResult.setConfusionMatrix(evaluationResult.getConfusionMatrix());
            KFoldCrossValidation.this.allFoldsCompleted.countDown();
            return foldResult;
        }
    }
}

