package ws.palladian.classification.evaluation.reliability;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import ws.palladian.classification.evaluation.AbstractClassificationEvaluator;
import ws.palladian.classification.evaluation.Graph;
import ws.palladian.classification.evaluation.LogLossEvaluator;
import ws.palladian.core.Classifier;
import ws.palladian.core.Instance;
import ws.palladian.core.Model;
import ws.palladian.core.dataset.Dataset;
import ws.palladian.helper.collection.Bag;
import ws.palladian.helper.collection.LazyMap;
import ws.palladian.helper.math.SlimStats;
import ws.palladian.helper.math.Stats;

/* loaded from: input_file:ws/palladian/classification/evaluation/reliability/ReliabilityDiagramEvaluator.class */
public class ReliabilityDiagramEvaluator extends AbstractClassificationEvaluator<ReliabilityDiagram> {
    private final String trueClass;
    private final int numBins;

    /* loaded from: input_file:ws/palladian/classification/evaluation/reliability/ReliabilityDiagramEvaluator$DataPoint.class */
    public static final class DataPoint {
        final double mean;
        final int numItems;
        final int numPositiveItems;

        DataPoint(double d, int i, int i2) {
            this.mean = d;
            this.numItems = i;
            this.numPositiveItems = i2;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public double positiveFraction() {
            return this.numPositiveItems / this.numItems;
        }

        public String toString() {
            return String.format("%f:%f", Double.valueOf(this.mean), Double.valueOf(positiveFraction()));
        }
    }

    /* loaded from: input_file:ws/palladian/classification/evaluation/reliability/ReliabilityDiagramEvaluator$ReliabilityDiagram.class */
    public static class ReliabilityDiagram implements Graph, Iterable<DataPoint> {
        private final List<DataPoint> dataPoints;
        public final double logLoss;

        ReliabilityDiagram(List<DataPoint> list, double d) {
            this.dataPoints = list;
            this.logLoss = d;
        }

        @Override // ws.palladian.classification.evaluation.Graph
        public void show() {
            new ReliabilityDiagramPainter().add(this, "Reliability").showCurves();
        }

        @Override // ws.palladian.classification.evaluation.Graph
        public void save(File file) throws IOException {
            new ReliabilityDiagramPainter().add(this, "Reliability").saveCurves(file);
        }

        @Override // java.lang.Iterable
        public Iterator<DataPoint> iterator() {
            return this.dataPoints.iterator();
        }
    }

    public ReliabilityDiagramEvaluator(String str, int i) {
        this.trueClass = str;
        this.numBins = i;
    }

    @Override // ws.palladian.classification.evaluation.ClassificationEvaluator
    public <M extends Model> ReliabilityDiagram evaluate(Classifier<M> classifier, M m, Dataset dataset) {
        Bag bag = new Bag();
        Bag bag2 = new Bag();
        LazyMap lazyMap = new LazyMap(SlimStats.FACTORY);
        int i = 0;
        double d = 0.0d;
        Iterator<Instance> iterator2 = dataset.iterator2();
        while (iterator2.hasNext()) {
            Instance next = iterator2.next();
            double probability = classifier.classify(next.getVector(), m).getProbability(this.trueClass);
            boolean equals = next.getCategory().equals(this.trueClass);
            int round = (int) Math.round(this.numBins * probability);
            bag.add(Integer.valueOf(round));
            if (equals) {
                bag2.add(Integer.valueOf(round));
            }
            ((Stats) lazyMap.get(Integer.valueOf(round))).add(Double.valueOf(probability));
            d += LogLossEvaluator.logLoss(equals, probability);
            i++;
        }
        double d2 = d / i;
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < this.numBins; i2++) {
            int count = bag.count(Integer.valueOf(i2));
            if (count != 0) {
                arrayList.add(new DataPoint(((Stats) lazyMap.get(Integer.valueOf(i2))).getMean(), count, bag2.count(Integer.valueOf(i2))));
            }
        }
        return new ReliabilityDiagram(Collections.unmodifiableList(arrayList), d2);
    }

    @Override // ws.palladian.classification.evaluation.ClassificationEvaluator
    public /* bridge */ /* synthetic */ Object evaluate(Classifier classifier, Model model, Dataset dataset) {
        return evaluate((Classifier<Classifier>) classifier, (Classifier) model, dataset);
    }
}
