package ws.palladian.helper.math;

import java.text.NumberFormat;
import java.util.Collections;
import java.util.Iterator;
import java.util.Locale;
import org.apache.commons.lang3.StringUtils;
import ws.palladian.helper.collection.AbstractIterator2;
import ws.palladian.helper.collection.Bag;

/* loaded from: input_file:ws/palladian/helper/math/ThresholdAnalyzer.class */
public class ThresholdAnalyzer implements Iterable<ThresholdEntry> {
    private final int numBins;
    private final Bag<Integer> relevantItems;
    private final Bag<Integer> retrievedItems;

    /* loaded from: input_file:ws/palladian/helper/math/ThresholdAnalyzer$ThresholdEntry.class */
    public static final class ThresholdEntry {
        private static final int F1_BAR_LENGTH = 50;
        private final double t;
        private final double pr;
        private final double rc;
        private final double accuracy;

        ThresholdEntry(double d, double d2, double d3, double d4) {
            this.t = d;
            this.pr = d2;
            this.rc = d3;
            this.accuracy = d4;
        }

        public double getThreshold() {
            return this.t;
        }

        public double getPrecision() {
            return this.pr;
        }

        public double getRecall() {
            return this.rc;
        }

        public double getF1() {
            return ((2.0d * this.pr) * this.rc) / (this.pr + this.rc);
        }

        public double getAccuracy() {
            return this.accuracy;
        }

        public String toString() {
            return internalToString("threshold=%s: pr=%s, rc=%s, f1=%s, acc=%s");
        }

        /* JADX INFO: Access modifiers changed from: private */
        public String internalToString(String str) {
            NumberFormat numberInstance = NumberFormat.getNumberInstance(Locale.US);
            numberInstance.setMaximumFractionDigits(5);
            return String.format(str, numberInstance.format(this.t), numberInstance.format(this.pr), numberInstance.format(this.rc), numberInstance.format(getF1()), numberInstance.format(getAccuracy()), makeBar(getF1())).replace("�", "NaN");
        }

        private String makeBar(double d) {
            return StringUtils.repeat('*', (int) Math.round(50.0d * d));
        }
    }

    public ThresholdAnalyzer() {
        this(5);
    }

    public ThresholdAnalyzer(int i) {
        if (i < 2) {
            throw new IllegalArgumentException("numBins must be least two, was " + i);
        }
        this.numBins = i;
        this.retrievedItems = new Bag<>();
        this.relevantItems = new Bag<>();
    }

    @Deprecated
    public double getPrecision(double d) {
        return getEntry(d).pr;
    }

    @Deprecated
    public double getRecall(double d) {
        return getEntry(d).rc;
    }

    @Deprecated
    public double getF1(double d) {
        double precision = getPrecision(d);
        double recall = getRecall(d);
        return ((2.0d * precision) * recall) / (precision + recall);
    }

    @Override // java.lang.Iterable
    public Iterator<ThresholdEntry> iterator() {
        return new AbstractIterator2<ThresholdEntry>() { // from class: ws.palladian.helper.math.ThresholdAnalyzer.1
            int bin;
            final int end;

            {
                this.bin = ((Integer) Collections.min(ThresholdAnalyzer.this.retrievedItems.uniqueItems())).intValue();
                this.end = ((Integer) Collections.max(ThresholdAnalyzer.this.retrievedItems.uniqueItems())).intValue();
            }

            /* JADX INFO: Access modifiers changed from: protected */
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // ws.palladian.helper.collection.AbstractIterator2
            public ThresholdEntry getNext() {
                if (this.bin > this.end) {
                    return finished();
                }
                int i = this.bin;
                this.bin = i + 1;
                double d = i / ThresholdAnalyzer.this.numBins;
                return d > 1.0d ? finished() : ThresholdAnalyzer.this.getEntry(d);
            }
        };
    }

    public ThresholdEntry getEntry(double d) {
        int numRelevantAt = getNumRelevantAt(d);
        int numIrrelevantBelow = getNumIrrelevantBelow(d);
        return new ThresholdEntry(d, numRelevantAt / getRetrievedAt(d), numRelevantAt / this.relevantItems.size(), (numRelevantAt + numIrrelevantBelow) / this.retrievedItems.size());
    }

    @Deprecated
    public double getMaxF1() {
        ThresholdEntry maxF1Entry = getMaxF1Entry();
        if (maxF1Entry != null) {
            return maxF1Entry.getF1();
        }
        return 0.0d;
    }

    public ThresholdEntry getMaxF1Entry() {
        ThresholdEntry thresholdEntry = null;
        Iterator<ThresholdEntry> it = iterator();
        while (it.hasNext()) {
            ThresholdEntry next = it.next();
            if (thresholdEntry == null || thresholdEntry.getF1() < next.getF1()) {
                thresholdEntry = next;
            }
        }
        return thresholdEntry;
    }

    public void add(boolean z, double d) {
        int bin = getBin(d);
        if (z) {
            this.relevantItems.add(Integer.valueOf(bin));
        }
        this.retrievedItems.add(Integer.valueOf(bin));
    }

    int getBin(double d) {
        if (d < 0.0d || d > 1.0d) {
            throw new IllegalArgumentException("Threshold must be in range [0,1], but was " + d);
        }
        return (int) Math.round(d * this.numBins);
    }

    int getRetrievedAt(double d) {
        int i = 0;
        for (int bin = getBin(d); bin <= this.numBins; bin++) {
            i += this.retrievedItems.count(Integer.valueOf(bin));
        }
        return i;
    }

    int getNumRelevantAt(double d) {
        int i = 0;
        for (int bin = getBin(d); bin <= this.numBins; bin++) {
            i += this.relevantItems.count(Integer.valueOf(bin));
        }
        return i;
    }

    int getNumIrrelevantBelow(double d) {
        int i = 0;
        for (int i2 = 0; i2 < getBin(d); i2++) {
            i += this.retrievedItems.count(Integer.valueOf(i2)) - this.relevantItems.count(Integer.valueOf(i2));
        }
        return i;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("t\tPr\tRc\tF1\tAcc\tF1\n");
        Iterator<ThresholdEntry> it = iterator();
        while (it.hasNext()) {
            sb.append(it.next().internalToString("%s\t%s\t%s\t%s\t%s\t%s\n"));
        }
        NumberFormat numberInstance = NumberFormat.getNumberInstance(Locale.US);
        ThresholdEntry maxF1Entry = getMaxF1Entry();
        sb.append('\n').append("Max. F1=").append(numberInstance.format(maxF1Entry.getF1())).append("@t=").append(numberInstance.format(maxF1Entry.getThreshold()));
        return sb.toString();
    }
}
