package edu.umass.cs.mallet.base.classify;

import edu.umass.cs.mallet.base.fst.Transducer;
import edu.umass.cs.mallet.base.pipe.Pipe;
import edu.umass.cs.mallet.base.types.Alphabet;
import edu.umass.cs.mallet.base.types.Instance;
import edu.umass.cs.mallet.base.types.InstanceList;
import edu.umass.cs.mallet.base.types.Label;
import edu.umass.cs.mallet.base.types.LabelAlphabet;
import edu.umass.cs.mallet.base.types.Labeling;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.logging.Logger;

/* loaded from: input_file:WEB-INF/lib/mallet-0.1.3.jar:edu/umass/cs/mallet/base/classify/Classifier.class */
public abstract class Classifier {
    private static Logger logger;
    protected Pipe instancePipe;
    static final boolean $assertionsDisabled;
    static Class class$edu$umass$cs$mallet$base$classify$Classifier;
    static Class class$edu$umass$cs$mallet$base$types$LabelAlphabet;

    /* JADX INFO: Access modifiers changed from: protected */
    public Classifier() {
    }

    public Classifier(Pipe pipe) {
        Class<?> cls;
        this.instancePipe = pipe;
        if (!$assertionsDisabled && pipe.getTargetAlphabet() == null) {
            throw new AssertionError();
        }
        if ($assertionsDisabled) {
            return;
        }
        Class<?> cls2 = pipe.getTargetAlphabet().getClass();
        if (class$edu$umass$cs$mallet$base$types$LabelAlphabet == null) {
            cls = class$("edu.umass.cs.mallet.base.types.LabelAlphabet");
            class$edu$umass$cs$mallet$base$types$LabelAlphabet = cls;
        } else {
            cls = class$edu$umass$cs$mallet$base$types$LabelAlphabet;
        }
        if (!cls2.isAssignableFrom(cls)) {
            throw new AssertionError();
        }
    }

    public Pipe getInstancePipe() {
        return this.instancePipe;
    }

    public Alphabet getAlphabet() {
        return this.instancePipe.getDataAlphabet();
    }

    public LabelAlphabet getLabelAlphabet() {
        return (LabelAlphabet) this.instancePipe.getTargetAlphabet();
    }

    public ArrayList classify(InstanceList instanceList) {
        ArrayList arrayList = new ArrayList(instanceList.size());
        InstanceList.Iterator it = instanceList.iterator();
        while (it.hasNext()) {
            arrayList.add(classify(it.nextInstance()));
        }
        return arrayList;
    }

    public Classification[] classify(Instance[] instanceArr) {
        Classification[] classificationArr = new Classification[instanceArr.length];
        for (int i = 0; i < instanceArr.length; i++) {
            classificationArr[i] = classify(instanceArr[i]);
        }
        return classificationArr;
    }

    public abstract Classification classify(Instance instance);

    public Classification classify(Object obj) {
        return obj instanceof Instance ? classify((Instance) obj) : classify(new Instance(obj, null, null, null, this.instancePipe));
    }

    public double getAccuracy(InstanceList instanceList) {
        return getAccuracy(classify(instanceList));
    }

    public double getAccuracy(ArrayList arrayList) {
        int i = 0;
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            if (((Classification) arrayList.get(i2)).bestLabelIsCorrect()) {
                i++;
            }
        }
        return i / arrayList.size();
    }

    public double getPrecision(InstanceList instanceList, Object obj) {
        return getPrecision(classify(instanceList), getLabelAlphabet().lookupIndex(obj, false));
    }

    public double getPrecision(ArrayList arrayList, Object obj) {
        return getPrecision(arrayList, getLabelAlphabet().lookupIndex(obj, false));
    }

    public double getPrecision(InstanceList instanceList, int i) {
        return getPrecision(classify(instanceList), i);
    }

    public double getPrecision(ArrayList arrayList, int i) {
        int i2 = 0;
        int i3 = 0;
        for (int i4 = 0; i4 < arrayList.size(); i4++) {
            int bestIndex = ((Classification) arrayList.get(i4)).getInstance().getLabeling().getBestIndex();
            if (((Classification) arrayList.get(i4)).getLabeling().getBestIndex() == i) {
                i3++;
                if (bestIndex == i) {
                    i2++;
                }
            }
        }
        if (i3 == 0) {
            logger.warning("No class instances: dividing by 0");
        }
        return i2 / i3;
    }

    public double getRecall(InstanceList instanceList, Object obj) {
        return getRecall(classify(instanceList), getLabelAlphabet().lookupIndex(obj, false));
    }

    public double getRecall(InstanceList instanceList, int i) {
        return getRecall(classify(instanceList), i);
    }

    public double getRecall(ArrayList arrayList, Object obj) {
        return getRecall(arrayList, getLabelAlphabet().lookupIndex(obj, false));
    }

    public double getRecall(ArrayList arrayList, int i) {
        int i2 = 0;
        int i3 = 0;
        for (int i4 = 0; i4 < arrayList.size(); i4++) {
            int bestIndex = ((Classification) arrayList.get(i4)).getInstance().getLabeling().getBestIndex();
            int bestIndex2 = ((Classification) arrayList.get(i4)).getLabeling().getBestIndex();
            if (bestIndex == i) {
                i3++;
                if (bestIndex2 == i) {
                    i2++;
                }
            }
        }
        if (i3 == 0) {
            logger.warning("No class instances: dividing by 0");
        }
        return i2 / i3;
    }

    public double getF1(InstanceList instanceList, Object obj) {
        return getF1(classify(instanceList), getLabelAlphabet().lookupIndex(obj, false));
    }

    public double getF1(InstanceList instanceList, int i) {
        return getF1(classify(instanceList), i);
    }

    public double getF1(ArrayList arrayList, Object obj) {
        return getF1(arrayList, getLabelAlphabet().lookupIndex(obj, false));
    }

    public double getF1(ArrayList arrayList, int i) {
        double precision = getPrecision(arrayList, i);
        double recall = getRecall(arrayList, i);
        if (precision == Transducer.ZERO_COST && recall == Transducer.ZERO_COST) {
            logger.warning("Precision and recall are 0: dividing by 0");
        }
        return ((2.0d * precision) * recall) / (precision + recall);
    }

    public double getAvgPosAc(InstanceList instanceList) {
        return getAvgPosAc(classify(instanceList));
    }

    public double getAvgPosAc(ArrayList arrayList) {
        double d = 0.0d;
        for (int i = 0; i < arrayList.size(); i++) {
            Classification classification = (Classification) arrayList.get(i);
            Instance classification2 = classification.getInstance();
            Labeling labeling = classification.getLabeling();
            Label label = (Label) classification2.getTarget();
            int rank = labeling.getRank(label);
            System.out.println(new StringBuffer().append(rank).append("   ").append(label.toString()).append(" Best was: ").append(labeling.getLabelAtRank(0)).toString());
            d += rank;
        }
        return d / arrayList.size();
    }

    public void print() {
        System.out.println(new StringBuffer().append("Classifier ").append(getClass().getName()).append("\n  Detailed printout not yet implemented.").toString());
    }

    public void print(PrintWriter printWriter) {
        printWriter.println(new StringBuffer().append("Classifier ").append(getClass().getName()).append("\n  Detailed printout not yet implemented.").toString());
    }

    static Class class$(String str) {
        try {
            return Class.forName(str);
        } catch (ClassNotFoundException e) {
            throw new NoClassDefFoundError().initCause(e);
        }
    }

    static {
        Class cls;
        if (class$edu$umass$cs$mallet$base$classify$Classifier == null) {
            cls = class$("edu.umass.cs.mallet.base.classify.Classifier");
            class$edu$umass$cs$mallet$base$classify$Classifier = cls;
        } else {
            cls = class$edu$umass$cs$mallet$base$classify$Classifier;
        }
        $assertionsDisabled = !cls.desiredAssertionStatus();
        logger = Logger.getLogger("edu.umass.cs.mallet.base.classify.Classifier");
    }
}
