package ws.palladian.kaggle.restaurants.utils;

import ws.palladian.classification.evaluation.roc.RocCurves;
import ws.palladian.core.CategoryEntries;
import ws.palladian.core.Classifier;
import ws.palladian.core.Instance;
import ws.palladian.core.Learner;
import ws.palladian.core.Model;
import ws.palladian.core.dataset.Dataset;
import ws.palladian.helper.io.CloseableIterator;
import ws.palladian.helper.math.ConfusionMatrix;
import ws.palladian.helper.math.ThresholdAnalyzer;

/* loaded from: input_file:ws/palladian/kaggle/restaurants/utils/ClassifierCombination.class */
public final class ClassifierCombination<M extends Model> {
    private final Learner<M> learner;
    private final Classifier<M> classifier;

    /* loaded from: input_file:ws/palladian/kaggle/restaurants/utils/ClassifierCombination$EvaluationResult.class */
    public static final class EvaluationResult<M extends Model> {
        private final ConfusionMatrix confusionMatrix;
        private final M model;
        private final long trainingTime;
        private final long testingTime;
        private final ThresholdAnalyzer thresholdAnalyzer;
        private RocCurves rocCurves;

        EvaluationResult(ConfusionMatrix confusionMatrix, M m, long j, long j2, ThresholdAnalyzer thresholdAnalyzer, RocCurves rocCurves) {
            this.confusionMatrix = confusionMatrix;
            this.model = m;
            this.trainingTime = j;
            this.testingTime = j2;
            this.thresholdAnalyzer = thresholdAnalyzer;
            this.rocCurves = rocCurves;
        }

        public ConfusionMatrix getConfusionMatrix() {
            return this.confusionMatrix;
        }

        public M getModel() {
            return this.model;
        }

        public long getTrainingTime() {
            return this.trainingTime;
        }

        public long getTestingTime() {
            return this.testingTime;
        }

        public ThresholdAnalyzer getThresholdAnalyzer() {
            return this.thresholdAnalyzer;
        }

        public RocCurves getRocCurves() {
            return this.rocCurves;
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public <LC extends Learner<M> & Classifier<M>> ClassifierCombination(LC lc) {
        this.learner = lc;
        this.classifier = (Classifier) lc;
    }

    public ClassifierCombination(Learner<M> learner, Classifier<M> classifier) {
        this.learner = learner;
        this.classifier = classifier;
    }

    @Deprecated
    public ConfusionMatrix evaluate(Dataset dataset, Dataset dataset2) {
        return runEvaluation(dataset, dataset2, "true").getConfusionMatrix();
    }

    public EvaluationResult<M> runEvaluation(Dataset dataset, Dataset dataset2) {
        return runEvaluation(dataset, dataset2, "true");
    }

    public EvaluationResult<M> runEvaluation(Dataset dataset, Dataset dataset2, String str) {
        long currentTimeMillis = System.currentTimeMillis();
        Model train = this.learner.train(dataset);
        long currentTimeMillis2 = System.currentTimeMillis() - currentTimeMillis;
        ConfusionMatrix confusionMatrix = new ConfusionMatrix();
        ThresholdAnalyzer thresholdAnalyzer = new ThresholdAnalyzer(100);
        RocCurves.RocCurvesBuilder rocCurvesBuilder = new RocCurves.RocCurvesBuilder();
        long currentTimeMillis3 = System.currentTimeMillis();
        CloseableIterator it = dataset2.iterator();
        while (it.hasNext()) {
            Instance instance = (Instance) it.next();
            CategoryEntries classify = this.classifier.classify(instance.getVector(), train);
            String mostLikelyCategory = classify.getMostLikelyCategory();
            String category = instance.getCategory();
            double probability = classify.getProbability(str);
            boolean equals = category.equals(str);
            confusionMatrix.add(category, mostLikelyCategory);
            thresholdAnalyzer.add(equals, probability);
            rocCurvesBuilder.add(equals, probability);
        }
        return new EvaluationResult<>(confusionMatrix, train, currentTimeMillis2, System.currentTimeMillis() - currentTimeMillis3, thresholdAnalyzer, rocCurvesBuilder.create());
    }

    public Learner<M> getLearner() {
        return this.learner;
    }

    public Classifier<M> getClassifier() {
        return this.classifier;
    }

    public String toString() {
        return this.classifier.toString();
    }
}
