package ws.palladian.classification.utils;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import org.apache.commons.lang3.Validate;
import ws.palladian.core.Classifier;
import ws.palladian.core.Instance;
import ws.palladian.core.Learner;
import ws.palladian.core.Model;
import ws.palladian.helper.ProgressMonitor;
import ws.palladian.helper.io.FileHelper;
import ws.palladian.helper.math.ConfusionMatrix;
import ws.palladian.helper.math.MathHelper;
import ws.palladian.helper.math.ThresholdAnalyzer;

/* loaded from: input_file:ws/palladian/classification/utils/ClassifierEvaluation.class */
public final class ClassifierEvaluation {
    private ClassifierEvaluation() {
    }

    @Deprecated
    public static <M extends Model> ConfusionMatrix evaluate(Classifier<M> classifier, Iterable<? extends Instance> iterable, M... mArr) {
        ConfusionMatrix confusionMatrix = new ConfusionMatrix();
        for (Instance instance : iterable) {
            confusionMatrix.add(instance.getCategory(), ClassificationUtils.classifyWithMultipleModels(classifier, instance.getVector(), mArr).getMostLikelyCategory());
        }
        return confusionMatrix;
    }

    public static <M extends Model> ConfusionMatrix evaluate(Classifier<M> classifier, Iterable<? extends Instance> iterable, List<String> list, M m) {
        ConfusionMatrix confusionMatrix = new ConfusionMatrix();
        list.add("Vector###Correct Category###Classified Category");
        for (Instance instance : iterable) {
            String mostLikelyCategory = ClassificationUtils.classifyWithMultipleModels(classifier, instance.getVector(), m).getMostLikelyCategory();
            String category = instance.getCategory();
            list.add(instance.getVector().toString() + "###" + category + "###" + mostLikelyCategory);
            confusionMatrix.add(category, mostLikelyCategory);
        }
        return confusionMatrix;
    }

    @Deprecated
    public static <M extends Model> ConfusionMatrix evaluate(Classifier<M> classifier, M m, Iterable<? extends Instance> iterable) {
        ConfusionMatrix confusionMatrix = new ConfusionMatrix();
        for (Instance instance : iterable) {
            confusionMatrix.add(instance.getCategory(), classifier.classify(instance.getVector(), m).getMostLikelyCategory());
        }
        return confusionMatrix;
    }

    public static <M extends Model> ConfusionMatrix evaluate(Learner<M> learner, Classifier<M> classifier, List<? extends Instance> list) {
        Validate.notNull(learner, "learner must not be null", new Object[0]);
        Validate.notNull(classifier, "classifier must not be null", new Object[0]);
        Validate.notNull(list, "instances must not be null", new Object[0]);
        Validate.isTrue(list.size() > 2, "instances must contain at least two elements", new Object[0]);
        return evaluate(classifier, list.subList(list.size() / 2, list.size() - 1), learner.train(list.subList(0, list.size() / 2)));
    }

    @Deprecated
    public static <M extends Model> ThresholdAnalyzer thresholdAnalysis(Classifier<M> classifier, M m, Iterable<? extends Instance> iterable, String str) {
        Validate.isTrue(m.getCategories().size() == 2, "binary model required", new Object[0]);
        ThresholdAnalyzer thresholdAnalyzer = new ThresholdAnalyzer(100);
        for (Instance instance : iterable) {
            thresholdAnalyzer.add(instance.getCategory().equals(str), classifier.classify(instance.getVector(), m).getProbability(str));
        }
        return thresholdAnalyzer;
    }

    public static <M extends Model> void createLearningCurves(Learner<M> learner, Classifier<M> classifier, Collection<? extends Instance> collection, Collection<? extends Instance> collection2, String str, int i) {
        Validate.notNull(learner, "learner must not be null", new Object[0]);
        Validate.notNull(classifier, "classifier must not be null", new Object[0]);
        Validate.notNull(collection, "trainSet must not be null", new Object[0]);
        Validate.notNull(collection2, "testSet must not be null", new Object[0]);
        Validate.isTrue(i >= 1, "stepSize must be greater/equal one", new Object[0]);
        Validate.notNull(str, "correctClass must not be null", new Object[0]);
        String format = String.format("learningCurves_%s.csv", Long.valueOf(System.currentTimeMillis()));
        ProgressMonitor progressMonitor = new ProgressMonitor();
        progressMonitor.startTask("Creating learning curves", (int) Math.ceil(collection.size() / i));
        ArrayList arrayList = new ArrayList(collection);
        Collections.shuffle(arrayList);
        FileHelper.appendFile(format, "trainItems;trainPercent;trainPrecision;trainRecall;trainF1;testPrecision;testRecall;testF1;\n");
        int i2 = i;
        while (true) {
            int i3 = i2;
            if (i3 >= collection.size()) {
                return;
            }
            List subList = arrayList.subList(0, i3);
            M train = learner.train(subList);
            ConfusionMatrix evaluate = evaluate(classifier, subList, train);
            ConfusionMatrix evaluate2 = evaluate(classifier, collection2, train);
            StringBuilder sb = new StringBuilder();
            sb.append(i3).append(';');
            sb.append(MathHelper.round((i3 * 100.0d) / collection.size(), 2)).append(';');
            sb.append(evaluate.getPrecision(str)).append(';');
            sb.append(evaluate.getRecall(str)).append(';');
            sb.append(evaluate.getF(1.0d, str)).append(';');
            sb.append(evaluate2.getPrecision(str)).append(';');
            sb.append(evaluate2.getRecall(str)).append(';');
            sb.append(evaluate2.getF(1.0d, str)).append(';');
            sb.append('\n');
            FileHelper.appendFile(format, sb);
            progressMonitor.increment();
            i2 = i3 + i;
        }
    }
}
