package net.sf.javaml.classification.evaluation;

import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
import net.sf.javaml.classification.Classifier;
import net.sf.javaml.core.Dataset;
import net.sf.javaml.core.DefaultDataset;
import net.sf.javaml.core.Instance;

/* loaded from: input_file:net/sf/javaml/classification/evaluation/CrossValidation.class */
public class CrossValidation {
    private Classifier classifier;

    public CrossValidation(Classifier classifier) {
        this.classifier = classifier;
    }

    public Map<Object, PerformanceMeasure> crossValidation(Dataset dataset, int i, Random random) {
        Dataset[] folds = dataset.folds(i, random);
        HashMap hashMap = new HashMap();
        Iterator<Object> it = dataset.classes().iterator();
        while (it.hasNext()) {
            hashMap.put(it.next(), new PerformanceMeasure());
        }
        for (int i2 = 0; i2 < i; i2++) {
            Dataset dataset2 = folds[i2];
            DefaultDataset defaultDataset = new DefaultDataset();
            for (int i3 = 0; i3 < i; i3++) {
                if (i3 != i2) {
                    defaultDataset.addAll(folds[i3]);
                }
            }
            this.classifier.buildClassifier(defaultDataset);
            for (Instance instance : dataset2) {
                Object classify = this.classifier.classify(instance);
                if (instance.classValue().equals(classify)) {
                    for (Object obj : hashMap.keySet()) {
                        if (obj.equals(instance.classValue())) {
                            ((PerformanceMeasure) hashMap.get(obj)).tp += 1.0d;
                        } else {
                            ((PerformanceMeasure) hashMap.get(obj)).tn += 1.0d;
                        }
                    }
                } else {
                    for (Object obj2 : hashMap.keySet()) {
                        if (classify.equals(obj2)) {
                            ((PerformanceMeasure) hashMap.get(obj2)).fp += 1.0d;
                        } else if (obj2.equals(instance.classValue())) {
                            ((PerformanceMeasure) hashMap.get(obj2)).fn += 1.0d;
                        } else {
                            ((PerformanceMeasure) hashMap.get(obj2)).tn += 1.0d;
                        }
                    }
                }
            }
        }
        return hashMap;
    }

    public Map<Object, PerformanceMeasure> crossValidation(Dataset dataset, int i) {
        return crossValidation(dataset, i, new Random(System.currentTimeMillis()));
    }

    public Map<Object, PerformanceMeasure> crossValidation(Dataset dataset) {
        return crossValidation(dataset, 10);
    }
}
