package ws.palladian.classification.evaluation.roc;

import java.io.File;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.Validate;
import ws.palladian.classification.evaluation.AbstractClassificationEvaluator;
import ws.palladian.classification.evaluation.AbstractGraphPainter;
import ws.palladian.classification.evaluation.Graph;
import ws.palladian.core.Classifier;
import ws.palladian.core.Instance;
import ws.palladian.core.Model;
import ws.palladian.core.dataset.Dataset;
import ws.palladian.helper.collection.AbstractIterator2;
import ws.palladian.helper.functional.Factory;
import ws.palladian.helper.math.ConfusionMatrix;

/* loaded from: input_file:ws/palladian/classification/evaluation/roc/RocCurves.class */
public class RocCurves implements Iterable<EvaluationPoint>, Graph {
    private final List<ResultEntry> results;
    private final int positives;
    private final int negatives;

    /* loaded from: input_file:ws/palladian/classification/evaluation/roc/RocCurves$EvaluationPoint.class */
    public static final class EvaluationPoint {
        private final double sensitivity;
        private final double specificity;
        private final double threshold;

        EvaluationPoint(double d, double d2, double d3) {
            this.sensitivity = d;
            this.specificity = d2;
            this.threshold = d3;
        }

        public double getSensitivity() {
            return this.sensitivity;
        }

        public double getSpecificity() {
            return this.specificity;
        }

        public double getThreshold() {
            return this.threshold;
        }

        public String toString() {
            return AbstractGraphPainter.format(this.threshold) + ": sensitivity=" + AbstractGraphPainter.format(this.sensitivity) + ", specificity=" + AbstractGraphPainter.format(this.specificity);
        }
    }

    /* loaded from: input_file:ws/palladian/classification/evaluation/roc/RocCurves$MccAtThreshold.class */
    public static final class MccAtThreshold {
        public final double mcc;
        public final double threshold;

        MccAtThreshold(double d, double d2) {
            this.mcc = d;
            this.threshold = d2;
        }

        public String toString() {
            return this.mcc + " @ " + this.threshold;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:ws/palladian/classification/evaluation/roc/RocCurves$ResultEntry.class */
    public static final class ResultEntry implements Comparable<ResultEntry> {
        final boolean trueCategory;
        final double confidence;

        ResultEntry(boolean z, double d) {
            this.trueCategory = z;
            this.confidence = d;
        }

        @Override // java.lang.Comparable
        public int compareTo(ResultEntry resultEntry) {
            return Double.compare(resultEntry.confidence, this.confidence);
        }
    }

    /* loaded from: input_file:ws/palladian/classification/evaluation/roc/RocCurves$RocCurvesBuilder.class */
    public static class RocCurvesBuilder implements Factory<RocCurves> {
        private final List<ResultEntry> results = new ArrayList();

        public void add(boolean z, double d) {
            this.results.add(new ResultEntry(z, d));
        }

        public void add(RocCurves rocCurves) {
            Objects.requireNonNull(rocCurves, "rocCurves was null");
            this.results.addAll(rocCurves.results);
        }

        /* renamed from: create, reason: merged with bridge method [inline-methods] */
        public RocCurves m17create() {
            return new RocCurves(this.results);
        }
    }

    /* loaded from: input_file:ws/palladian/classification/evaluation/roc/RocCurves$RocCurvesEvaluator.class */
    public static final class RocCurvesEvaluator extends AbstractClassificationEvaluator<RocCurves> {
        private final String trueCategory;

        public RocCurvesEvaluator(String str) {
            this.trueCategory = str;
        }

        @Override // ws.palladian.classification.evaluation.ClassificationEvaluator
        public <M extends Model> RocCurves evaluate(Classifier<M> classifier, M m, Dataset dataset) {
            Validate.isTrue(m.getCategories().size() == 2, "binary model required", new Object[0]);
            if (!m.getCategories().contains(this.trueCategory)) {
                throw new IllegalStateException("Model has no category \"" + this.trueCategory + "\".");
            }
            ArrayList arrayList = new ArrayList();
            Iterator<Instance> iterator2 = dataset.iterator2();
            while (iterator2.hasNext()) {
                Instance next = iterator2.next();
                arrayList.add(new ResultEntry(next.getCategory().equals(this.trueCategory), classifier.classify(next.getVector(), m).getProbability(this.trueCategory)));
            }
            return new RocCurves(arrayList);
        }

        @Override // ws.palladian.classification.evaluation.AbstractClassificationEvaluator, ws.palladian.classification.evaluation.ClassificationEvaluator
        public String getCsvHeader(RocCurves rocCurves) {
            return "AUC";
        }

        @Override // ws.palladian.classification.evaluation.AbstractClassificationEvaluator, ws.palladian.classification.evaluation.ClassificationEvaluator
        public String getCsvLine(RocCurves rocCurves) {
            return String.valueOf(rocCurves.getAreaUnderCurve());
        }

        @Override // ws.palladian.classification.evaluation.ClassificationEvaluator
        public /* bridge */ /* synthetic */ Object evaluate(Classifier classifier, Model model, Dataset dataset) {
            return evaluate((Classifier<Classifier>) classifier, (Classifier) model, dataset);
        }
    }

    RocCurves(List<ResultEntry> list) {
        this.results = new ArrayList(list);
        Collections.sort(this.results);
        int i = 0;
        int i2 = 0;
        Iterator<ResultEntry> it = this.results.iterator();
        while (it.hasNext()) {
            if (it.next().trueCategory) {
                i++;
            } else {
                i2++;
            }
        }
        this.positives = i;
        this.negatives = i2;
    }

    public ConfusionMatrix getConfusionMatrix(double d) {
        ConfusionMatrix confusionMatrix = new ConfusionMatrix();
        for (ResultEntry resultEntry : this.results) {
            confusionMatrix.add(Boolean.toString(resultEntry.trueCategory), Boolean.toString(resultEntry.confidence >= d));
        }
        return confusionMatrix;
    }

    @Override // java.lang.Iterable
    public Iterator<EvaluationPoint> iterator() {
        return new AbstractIterator2<EvaluationPoint>() { // from class: ws.palladian.classification.evaluation.roc.RocCurves.1
            Iterator<ResultEntry> iterator;
            int truePositives = 0;
            int trueNegatives;

            {
                this.iterator = RocCurves.this.results.iterator();
                this.trueNegatives = RocCurves.this.negatives;
            }

            /* JADX INFO: Access modifiers changed from: protected */
            /* renamed from: getNext, reason: merged with bridge method [inline-methods] */
            public EvaluationPoint m16getNext() {
                if (!this.iterator.hasNext()) {
                    return (EvaluationPoint) finished();
                }
                ResultEntry next = this.iterator.next();
                if (next.trueCategory) {
                    this.truePositives++;
                } else {
                    this.trueNegatives--;
                }
                return new EvaluationPoint(this.truePositives / RocCurves.this.positives, this.trueNegatives / RocCurves.this.negatives, next.confidence);
            }
        };
    }

    public double getAreaUnderCurve() {
        double d = 0.0d;
        EvaluationPoint evaluationPoint = null;
        Iterator<EvaluationPoint> it = iterator();
        while (it.hasNext()) {
            EvaluationPoint next = it.next();
            if (evaluationPoint != null) {
                d += ((-next.specificity) + evaluationPoint.specificity) * (next.sensitivity + evaluationPoint.sensitivity);
            }
            evaluationPoint = next;
        }
        return d / 2.0d;
    }

    @Deprecated
    public void showCurves() {
        new RocCurvesPainter().add(this, "ROC").showCurves();
    }

    @Deprecated
    public void saveCurves(File file) throws IOException {
        new RocCurvesPainter().add(this, "ROC").saveCurves(file);
    }

    @Override // ws.palladian.classification.evaluation.Graph
    public void show() {
        showCurves();
    }

    @Override // ws.palladian.classification.evaluation.Graph
    public void save(File file) throws IOException {
        saveCurves(file);
    }

    public void writeEntries(PrintStream printStream, char c) {
        Objects.requireNonNull(printStream, "stream must not be null");
        for (ResultEntry resultEntry : this.results) {
            ArrayList arrayList = new ArrayList();
            arrayList.add(String.valueOf(resultEntry.trueCategory));
            arrayList.add(String.valueOf(resultEntry.confidence));
            printStream.println(StringUtils.join(arrayList, c));
        }
    }

    public MccAtThreshold determineBestMCC(int i) {
        double d = 0.0d;
        double d2 = Double.MIN_VALUE;
        for (int i2 = 0; i2 < i; i2++) {
            double d3 = i2 / i;
            double matthewsCorrelationCoefficient = getConfusionMatrix(d3).getMatthewsCorrelationCoefficient();
            if (matthewsCorrelationCoefficient > d2) {
                d2 = matthewsCorrelationCoefficient;
                d = d3;
            }
        }
        return new MccAtThreshold(d2, d);
    }
}
