package ws.palladian.helper.math;

import java.nio.CharBuffer;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Set;
import ws.palladian.helper.collection.CountMatrix;
import ws.palladian.helper.io.FileHelper;

/* loaded from: input_file:ws/palladian/helper/math/ConfusionMatrix.class */
public class ConfusionMatrix {
    private final CountMatrix<String> confusionMatrix = CountMatrix.create();

    public void add(String str, String str2) {
        add(str, str2, 1);
    }

    public void add(String str, String str2, int i) {
        this.confusionMatrix.add(str2, str, i);
    }

    public double getAccuracy() {
        return getTotalCorrect() / getTotalDocuments();
    }

    public int getTotalCorrect() {
        int i = 0;
        for (String str : getCategories()) {
            i += this.confusionMatrix.getCount(str, str);
        }
        return i;
    }

    public int getCorrectlyClassifiedDocuments(String str) {
        return this.confusionMatrix.getCount(str, str);
    }

    public int getClassifiedDocuments(String str) {
        return this.confusionMatrix.getColumn((CountMatrix<String>) str).getSum();
    }

    public int getRealDocuments(String str) {
        return this.confusionMatrix.getRow((CountMatrix<String>) str).getSum();
    }

    public int getConfusions(String str, String str2) {
        return this.confusionMatrix.getCount(str2, str);
    }

    public Set<String> getCategories() {
        return this.confusionMatrix.getRowKeys();
    }

    public int getTotalDocuments() {
        int i = 0;
        Iterator<String> it = this.confusionMatrix.getRowKeys().iterator();
        while (it.hasNext()) {
            i += this.confusionMatrix.getRow((CountMatrix<String>) it.next()).getSum();
        }
        return i;
    }

    public double getHighestPrior() {
        int i = 0;
        Iterator<String> it = this.confusionMatrix.getColumnKeys().iterator();
        while (it.hasNext()) {
            i = Math.max(i, this.confusionMatrix.getRow((CountMatrix<String>) it.next()).getSum());
        }
        int totalDocuments = getTotalDocuments();
        if (totalDocuments == 0) {
            return 0.0d;
        }
        return i / totalDocuments;
    }

    public double getSuperiority() {
        return getAccuracy() / getHighestPrior();
    }

    public double getPrecision(String str) {
        int correctlyClassifiedDocuments = getCorrectlyClassifiedDocuments(str);
        int classifiedDocuments = getClassifiedDocuments(str);
        if (classifiedDocuments == 0) {
            return Double.NaN;
        }
        return correctlyClassifiedDocuments / classifiedDocuments;
    }

    public double getRecall(String str) {
        int correctlyClassifiedDocuments = getCorrectlyClassifiedDocuments(str);
        int realDocuments = getRealDocuments(str);
        if (realDocuments == 0) {
            return 1.0d;
        }
        return correctlyClassifiedDocuments / realDocuments;
    }

    public double getF(double d, String str) {
        double precision = getPrecision(str);
        double recall = getRecall(str);
        if (Double.isNaN(precision)) {
            return Double.NaN;
        }
        double d2 = d * d;
        return (1.0d + d2) * ((precision * recall) / ((d2 * precision) + recall));
    }

    public double getSensitivity(String str) {
        int correctlyClassifiedDocuments = getCorrectlyClassifiedDocuments(str);
        if (correctlyClassifiedDocuments + (getRealDocuments(str) - correctlyClassifiedDocuments) == 0) {
            return Double.NaN;
        }
        return correctlyClassifiedDocuments / (correctlyClassifiedDocuments + r0);
    }

    public double getSpecificity(String str) {
        int correctlyClassifiedDocuments = getCorrectlyClassifiedDocuments(str);
        int realDocuments = getRealDocuments(str);
        int classifiedDocuments = getClassifiedDocuments(str);
        int i = classifiedDocuments - correctlyClassifiedDocuments;
        int totalDocuments = (getTotalDocuments() - classifiedDocuments) - (realDocuments - correctlyClassifiedDocuments);
        if (totalDocuments + i == 0) {
            return Double.NaN;
        }
        return totalDocuments / (totalDocuments + i);
    }

    public double getAccuracy(String str) {
        int correctlyClassifiedDocuments = getCorrectlyClassifiedDocuments(str);
        int realDocuments = getRealDocuments(str);
        int classifiedDocuments = getClassifiedDocuments(str);
        int i = classifiedDocuments - correctlyClassifiedDocuments;
        int i2 = realDocuments - correctlyClassifiedDocuments;
        if (correctlyClassifiedDocuments + ((getTotalDocuments() - classifiedDocuments) - i2) + i + i2 == 0) {
            return Double.NaN;
        }
        return (correctlyClassifiedDocuments + r0) / (((correctlyClassifiedDocuments + r0) + i) + i2);
    }

    public double getPrior(String str) {
        int realDocuments = getRealDocuments(str);
        int totalDocuments = getTotalDocuments();
        if (totalDocuments == 0) {
            return 0.0d;
        }
        return realDocuments / totalDocuments;
    }

    public double getAveragePrecision(boolean z) {
        double d = 0.0d;
        for (String str : getCategories()) {
            double precision = getPrecision(str);
            if (!Double.isNaN(precision)) {
                d += precision * (z ? getPrior(str) : 1.0d);
            }
        }
        if (z) {
            return d;
        }
        int size = getCategories().size();
        if (size == 0) {
            return Double.NaN;
        }
        return d / size;
    }

    public double getAverageRecall(boolean z) {
        double d = 0.0d;
        for (String str : getCategories()) {
            double recall = getRecall(str);
            if (!Double.isNaN(recall)) {
                d += recall * (z ? getPrior(str) : 1.0d);
            }
        }
        if (z) {
            return d;
        }
        int size = getCategories().size();
        if (size == 0) {
            return Double.NaN;
        }
        return d / size;
    }

    public double getAverageF(double d, boolean z) {
        double d2 = 0.0d;
        for (String str : getCategories()) {
            double f = getF(d, str);
            if (!Double.isNaN(f)) {
                d2 += f * (z ? getPrior(str) : 1.0d);
            }
        }
        if (z) {
            return d2;
        }
        int size = getCategories().size();
        if (size == 0) {
            return Double.NaN;
        }
        return d2 / size;
    }

    public double getAverageSensitivity(boolean z) {
        double d = 0.0d;
        for (String str : getCategories()) {
            double sensitivity = getSensitivity(str);
            if (!Double.isNaN(sensitivity)) {
                d += sensitivity * (z ? getPrior(str) : 1.0d);
            }
        }
        if (z) {
            return d;
        }
        int size = getCategories().size();
        if (size == 0) {
            return Double.NaN;
        }
        return d / size;
    }

    public double getAverageSpecificity(boolean z) {
        double d = 0.0d;
        for (String str : getCategories()) {
            double specificity = getSpecificity(str);
            if (Double.isNaN(specificity)) {
                return Double.NaN;
            }
            d += specificity * (z ? getPrior(str) : 1.0d);
        }
        if (z) {
            return d;
        }
        int size = getCategories().size();
        if (size == 0) {
            return Double.NaN;
        }
        return d / size;
    }

    @Deprecated
    public double getAverageAccuracy(boolean z) {
        double d = 0.0d;
        for (String str : getCategories()) {
            double accuracy = getAccuracy(str);
            if (Double.isNaN(accuracy)) {
                return Double.NaN;
            }
            d += accuracy * (z ? getPrior(str) : 1.0d);
        }
        if (z) {
            return d;
        }
        int size = getCategories().size();
        if (size == 0) {
            return Double.NaN;
        }
        return d / size;
    }

    public double getMatthewsCorrelationCoefficient() {
        if (getCategories().size() != 2) {
            throw new IllegalStateException("Matthews correlation coefficient only works for binary classifications");
        }
        Iterator<String> it = getCategories().iterator();
        String next = it.next();
        String next2 = it.next();
        return calculateMatthewsCorrelationCoefficient(getConfusions(next, next), getConfusions(next2, next2), getConfusions(next2, next), getConfusions(next, next2));
    }

    public static double calculateMatthewsCorrelationCoefficient(int i, int i2, int i3, int i4) {
        double sqrt = Math.sqrt(i + i3) * Math.sqrt(i + i4) * Math.sqrt(i2 + i3) * Math.sqrt(i2 + i4);
        long j = (i * i2) - (i3 * i4);
        if (sqrt != 0.0d) {
            return j / sqrt;
        }
        return 0.0d;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder("Confusion Matrix:\n\n");
        ArrayList<String> arrayList = new ArrayList(getCategories());
        StringBuilder sb2 = new StringBuilder();
        int i = 0;
        for (String str : arrayList) {
            sb2.append(str).append(" ");
            i = str.length() > i ? str.length() : i;
        }
        String replace = CharBuffer.allocate(i).toString().replace((char) 0, ' ');
        sb.append(replace).append("\t").append("classified as:\n");
        sb.append(replace).append("\t").append((CharSequence) sb2);
        sb.append(FileHelper.NEWLINE_CHARACTER);
        for (String str2 : arrayList) {
            sb.append(str2);
            sb.append(CharBuffer.allocate(i - str2.length()).toString().replace((char) 0, ' '));
            sb.append("\t");
            for (String str3 : arrayList) {
                Integer num = this.confusionMatrix.get(str3, str2);
                Integer valueOf = Integer.valueOf(num == null ? 0 : num.intValue());
                int length = valueOf.toString().length();
                int max = Math.max((int) Math.ceil(((str3.length() - length) + 1) / 2.0d), 0);
                sb.append(CharBuffer.allocate(max).toString().replace((char) 0, ' '));
                sb.append(valueOf);
                sb.append(CharBuffer.allocate(Math.max((str3.length() - length) - max, 1)).toString().replace((char) 0, ' '));
            }
            sb.append(FileHelper.NEWLINE_CHARACTER);
        }
        sb.append(FileHelper.NEWLINE_CHARACTER);
        sb.append(replace).append("  ").append("prior  precision recall f1-measure accuracy\n");
        for (String str4 : arrayList) {
            sb.append(str4).append(": ");
            int length2 = i - str4.length();
            if (length2 > 0) {
                sb.append(CharBuffer.allocate(length2).toString().replace((char) 0, ' '));
            }
            double round = MathHelper.round(getPrior(str4), 4);
            double round2 = MathHelper.round(getPrecision(str4), 4);
            double round3 = MathHelper.round(getRecall(str4), 4);
            double round4 = MathHelper.round(getAccuracy(str4), 4);
            double round5 = MathHelper.round(getF(1.0d, str4), 4);
            sb.append(round);
            sb.append(CharBuffer.allocate(Math.max("prior  ".length() - String.valueOf(round).length(), 0)).toString().replace((char) 0, ' ')).append(round2);
            sb.append(CharBuffer.allocate(Math.max("precision ".length() - String.valueOf(round2).length(), 0)).toString().replace((char) 0, ' ')).append(round3);
            sb.append(CharBuffer.allocate(Math.max("recall ".length() - String.valueOf(round3).length(), 0)).toString().replace((char) 0, ' ')).append(round5);
            sb.append(CharBuffer.allocate(Math.max("f1-measure ".length() - String.valueOf(round5).length(), 0)).toString().replace((char) 0, ' ')).append(round4);
            sb.append(FileHelper.NEWLINE_CHARACTER);
        }
        sb.append(FileHelper.NEWLINE_CHARACTER);
        sb.append("Accuracy:\t").append(MathHelper.round(getAccuracy(), 4)).append('\n');
        sb.append("Highest Prior:\t").append(MathHelper.round(getHighestPrior(), 4)).append('\n');
        sb.append("Superiority:\t").append(MathHelper.round(getSuperiority(), 4)).append('\n');
        if (getCategories().size() == 2) {
            sb.append("Matthews Correlation Coefficient:\t").append(MathHelper.round(getMatthewsCorrelationCoefficient(), 4)).append('\n');
        }
        sb.append("# Documents:\t").append(getTotalDocuments()).append('\n');
        sb.append("# Correctly Classified:\t").append(getTotalCorrect()).append('\n');
        sb.append(FileHelper.NEWLINE_CHARACTER);
        sb.append("Average Precision:\t").append(MathHelper.round(getAveragePrecision(true), 4)).append('\n');
        sb.append("Average Recall:\t").append(MathHelper.round(getAverageRecall(true), 4)).append('\n');
        sb.append("Average F1:\t").append(MathHelper.round(getAverageF(0.5d, true), 4)).append('\n');
        sb.append("Average Sensitivity:\t").append(MathHelper.round(getAverageSensitivity(true), 4)).append('\n');
        sb.append("Average Specificity:\t").append(MathHelper.round(getAverageSpecificity(true), 4)).append('\n');
        return sb.toString();
    }
}
