package edu.umass.cs.mallet.grmm.learning;

import edu.umass.cs.mallet.base.maximize.LimitedMemoryBFGS;
import edu.umass.cs.mallet.base.maximize.Maximizable;
import edu.umass.cs.mallet.base.types.Alphabet;
import edu.umass.cs.mallet.base.types.Instance;
import edu.umass.cs.mallet.base.types.InstanceList;
import edu.umass.cs.mallet.base.types.Label;
import edu.umass.cs.mallet.base.types.LabelAlphabet;
import edu.umass.cs.mallet.base.types.Labels;
import edu.umass.cs.mallet.base.types.LabelsSequence;
import edu.umass.cs.mallet.base.util.MalletLogger;
import edu.umass.cs.mallet.base.util.Timing;
import edu.umass.cs.mallet.grmm.learning.ACRF;
import gnu.trove.TIntArrayList;
import java.io.File;
import java.io.FileWriter;
import java.io.PrintWriter;
import java.util.Date;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.logging.Logger;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.configuration.tree.DefaultExpressionEngine;

/* loaded from: input_file:edu/umass/cs/mallet/grmm/learning/ACRFTrainer.class */
public class ACRFTrainer {
    private static Logger logger;
    private File outputPrefix = new File("");
    private static final double[] SIZE;
    private static final int SUBSET_ITER = 10;
    private static final Random r;
    static Class class$edu$umass$cs$mallet$grmm$learning$ACRFTrainer;

    /* loaded from: input_file:edu/umass/cs/mallet/grmm/learning/ACRFTrainer$FileEvaluator.class */
    public static class FileEvaluator extends ACRFEvaluator {
        private File file;

        public FileEvaluator(File file) {
            this.file = file;
        }

        @Override // edu.umass.cs.mallet.grmm.learning.ACRFEvaluator
        public boolean evaluate(ACRF acrf, int i, InstanceList instanceList, InstanceList instanceList2, InstanceList instanceList3) {
            if (!shouldDoEvaluate(i)) {
                return true;
            }
            test(acrf, instanceList3, "Testing ");
            return true;
        }

        @Override // edu.umass.cs.mallet.grmm.learning.ACRFEvaluator
        public void test(InstanceList instanceList, List list, String str) {
            ACRFTrainer.logger.info(new StringBuffer().append("Number of testing instances = ").append(instanceList.size()).toString());
            TestResults computeTestResults = LogEvaluator.computeTestResults(instanceList, list);
            try {
                PrintWriter printWriter = new PrintWriter(new FileWriter(this.file, true));
                computeTestResults.print(str, printWriter);
                printWriter.close();
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }

    /* loaded from: input_file:edu/umass/cs/mallet/grmm/learning/ACRFTrainer$LogEvaluator.class */
    public static class LogEvaluator extends ACRFEvaluator {
        private TestResults lastResults;
        static final boolean $assertionsDisabled;

        @Override // edu.umass.cs.mallet.grmm.learning.ACRFEvaluator
        public boolean evaluate(ACRF acrf, int i, InstanceList instanceList, InstanceList instanceList2, InstanceList instanceList3) {
            if (!shouldDoEvaluate(i)) {
                return true;
            }
            if (instanceList != null) {
                test(acrf, instanceList, "Training");
            }
            if (instanceList3 == null) {
                return true;
            }
            test(acrf, instanceList3, "Testing");
            return true;
        }

        @Override // edu.umass.cs.mallet.grmm.learning.ACRFEvaluator
        public void test(InstanceList instanceList, List list, String str) {
            ACRFTrainer.logger.info(new StringBuffer().append(str).append(": Number of instances = ").append(instanceList.size()).toString());
            TestResults computeTestResults = computeTestResults(instanceList, list);
            computeTestResults.log(str);
            this.lastResults = computeTestResults;
        }

        public static TestResults computeTestResults(InstanceList instanceList, List list) {
            TestResults testResults = new TestResults(instanceList);
            InstanceList.Iterator it = instanceList.iterator();
            Iterator it2 = list.iterator();
            while (it.hasNext()) {
                compareLabelings(testResults, (LabelsSequence) it2.next(), (LabelsSequence) ((Instance) it.next()).getTarget());
            }
            testResults.computeStatistics();
            return testResults;
        }

        static void compareLabelings(TestResults testResults, LabelsSequence labelsSequence, LabelsSequence labelsSequence2) {
            if (!$assertionsDisabled && labelsSequence.size() != labelsSequence2.size()) {
                throw new AssertionError();
            }
            for (int i = 0; i < labelsSequence.size(); i++) {
                testResults.incrementCount(labelsSequence.getLabels(i), labelsSequence2.getLabels(i));
            }
        }

        public double getJointAccuracy() {
            return this.lastResults.getJointAccuracy();
        }

        static {
            Class cls;
            if (ACRFTrainer.class$edu$umass$cs$mallet$grmm$learning$ACRFTrainer == null) {
                cls = ACRFTrainer.class$("edu.umass.cs.mallet.grmm.learning.ACRFTrainer");
                ACRFTrainer.class$edu$umass$cs$mallet$grmm$learning$ACRFTrainer = cls;
            } else {
                cls = ACRFTrainer.class$edu$umass$cs$mallet$grmm$learning$ACRFTrainer;
            }
            $assertionsDisabled = !cls.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:edu/umass/cs/mallet/grmm/learning/ACRFTrainer$TestResults.class */
    public static class TestResults {
        public int[][] confusion;
        public int numClasses;
        public int[] trueCounts;
        public int[] returnedCounts;
        public double[] precision;
        public double[] recall;
        public double[] f1;
        public TIntArrayList[] factors;
        public int maxT;
        public int correctT;
        public Alphabet alphabet;

        TestResults(InstanceList instanceList) {
            this(instanceList.getInstance(0));
        }

        TestResults(Instance instance) {
            this.maxT = 0;
            this.correctT = 0;
            this.alphabet = new Alphabet();
            setupAlphabet(instance);
            this.numClasses = this.alphabet.size();
            this.confusion = new int[this.numClasses][this.numClasses];
            this.precision = new double[this.numClasses];
            this.recall = new double[this.numClasses];
            this.f1 = new double[this.numClasses];
        }

        private void setupAlphabet(Instance instance) {
            Labels labels = ((LabelsSequence) instance.getTarget()).getLabels(0);
            this.factors = new TIntArrayList[labels.size()];
            for (int i = 0; i < labels.size(); i++) {
                LabelAlphabet labelAlphabet = labels.get(i).getLabelAlphabet();
                this.factors[i] = new TIntArrayList(labelAlphabet.size());
                for (int i2 = 0; i2 < labelAlphabet.size(); i2++) {
                    this.factors[i].add(this.alphabet.lookupIndex(labelAlphabet.lookupObject(i2)));
                }
            }
        }

        void incrementCount(Labels labels, Labels labels2) {
            boolean z = true;
            for (int i = 0; i < labels.size(); i++) {
                Label label = labels.get(i);
                int lookupIndex = this.alphabet.lookupIndex(labels2.get(i).getEntry());
                int lookupIndex2 = this.alphabet.lookupIndex(label.getEntry());
                if (lookupIndex != lookupIndex2) {
                    z = false;
                }
                int[] iArr = this.confusion[lookupIndex];
                iArr[lookupIndex2] = iArr[lookupIndex2] + 1;
            }
            this.maxT++;
            if (z) {
                this.correctT++;
            }
        }

        void computeStatistics() {
            this.trueCounts = new int[this.numClasses];
            this.returnedCounts = new int[this.numClasses];
            for (int i = 0; i < this.numClasses; i++) {
                for (int i2 = 0; i2 < this.numClasses; i2++) {
                    int[] iArr = this.trueCounts;
                    int i3 = i;
                    iArr[i3] = iArr[i3] + this.confusion[i][i2];
                    int[] iArr2 = this.returnedCounts;
                    int i4 = i2;
                    iArr2[i4] = iArr2[i4] + this.confusion[i][i2];
                }
            }
            for (int i5 = 0; i5 < this.numClasses; i5++) {
                double d = this.confusion[i5][i5];
                if (this.returnedCounts[i5] == 0) {
                    this.precision[i5] = d == 0.0d ? 1.0d : 0.0d;
                } else {
                    this.precision[i5] = d / this.returnedCounts[i5];
                }
                if (this.trueCounts[i5] == 0) {
                    this.recall[i5] = 1.0d;
                } else {
                    this.recall[i5] = d / this.trueCounts[i5];
                }
                this.f1[i5] = ((2.0d * this.precision[i5]) * this.recall[i5]) / (this.precision[i5] + this.recall[i5]);
            }
        }

        public void log() {
            log("");
        }

        public void log(String str) {
            ACRFTrainer.logger.info(new StringBuffer().append(str).append(":  i\tLabel\tN\tCorrect\tReturned\tP\tR\tF1").toString());
            for (int i = 0; i < this.numClasses; i++) {
                ACRFTrainer.logger.info(new StringBuffer().append(str).append(":  ").append(i).append("\t").append(this.alphabet.lookupObject(i)).append("\t").append(this.trueCounts[i]).append("\t").append(this.confusion[i][i]).append("\t").append(this.returnedCounts[i]).append("\t").append(this.precision[i]).append("\t").append(this.recall[i]).append("\t").append(this.f1[i]).append("\t").toString());
            }
            for (int i2 = 0; i2 < this.factors.length; i2++) {
                int i3 = 0;
                int i4 = 0;
                for (int i5 = 0; i5 < this.factors[i2].size(); i5++) {
                    int i6 = this.factors[i2].get(i5);
                    i3 += this.confusion[i6][i6];
                    i4 += this.returnedCounts[i6];
                }
                ACRFTrainer.logger.info(new StringBuffer().append(str).append(":  Factor ").append(i2).append(" accuracy: (").append(i3).append(HelpFormatter.DEFAULT_LONG_OPT_SEPARATOR).append(i4).append(") ").append(i3 / i4).toString());
            }
            ACRFTrainer.logger.info(new StringBuffer().append(str).append(" CorrectT ").append(this.correctT).append("  maxt ").append(this.maxT).toString());
            ACRFTrainer.logger.info(new StringBuffer().append(str).append(" Joint accuracy: ").append(this.correctT / this.maxT).toString());
        }

        public void print(String str, PrintWriter printWriter) {
            printWriter.println("i\tLabel\tN\tCorrect\tReturned\tP\tR\tF1");
            for (int i = 0; i < this.numClasses; i++) {
                printWriter.println(new StringBuffer().append(i).append("\t").append(this.alphabet.lookupObject(i)).append("\t").append(this.trueCounts[i]).append("\t").append(this.confusion[i][i]).append("\t").append(this.returnedCounts[i]).append("\t").append(this.precision[i]).append("\t").append(this.recall[i]).append("\t").append(this.f1[i]).append("\t").toString());
            }
            for (int i2 = 0; i2 < this.factors.length; i2++) {
                int i3 = 0;
                int i4 = 0;
                for (int i5 = 0; i5 < this.factors[i2].size(); i5++) {
                    int i6 = this.factors[i2].get(i5);
                    i3 += this.confusion[i6][i6];
                    i4 += this.returnedCounts[i6];
                }
                printWriter.println(new StringBuffer().append(str).append(" Factor ").append(i2).append(" accuracy: (").append(i3).append(HelpFormatter.DEFAULT_LONG_OPT_SEPARATOR).append(i4).append(") ").append(i3 / i4).toString());
            }
            printWriter.println(new StringBuffer().append(str).append(" CorrectT ").append(this.correctT).append("  maxt ").append(this.maxT).toString());
            printWriter.println(new StringBuffer().append(str).append(" Joint accuracy: ").append(this.correctT / this.maxT).toString());
        }

        void printConfusion() {
            System.out.println("True\t\tReturned\tCount");
            for (int i = 0; i < this.numClasses; i++) {
                for (int i2 = 0; i2 < this.numClasses; i2++) {
                    System.out.println(new StringBuffer().append(i).append("\t\t").append(i2).append("\t").append(this.confusion[i][i2]).toString());
                }
            }
        }

        public double getJointAccuracy() {
            return this.correctT / this.maxT;
        }
    }

    public void setOutputPrefix(File file) {
        this.outputPrefix = file;
    }

    public boolean train(ACRF acrf, InstanceList instanceList) {
        return train(acrf, instanceList, null, null, new LogEvaluator(), 1);
    }

    public boolean train(ACRF acrf, InstanceList instanceList, int i) {
        return train(acrf, instanceList, null, null, new LogEvaluator(), i);
    }

    public boolean train(ACRF acrf, InstanceList instanceList, ACRFEvaluator aCRFEvaluator, int i) {
        return train(acrf, instanceList, null, null, aCRFEvaluator, i);
    }

    public boolean train(ACRF acrf, InstanceList instanceList, InstanceList instanceList2, InstanceList instanceList3, int i) {
        return train(acrf, instanceList, instanceList2, instanceList3, new LogEvaluator(), i);
    }

    public boolean train(ACRF acrf, InstanceList instanceList, InstanceList instanceList2, InstanceList instanceList3, ACRFEvaluator aCRFEvaluator, int i) {
        return train(acrf, instanceList, instanceList2, instanceList3, aCRFEvaluator, i, createMaximizable(acrf, instanceList));
    }

    protected Maximizable.ByGradient createMaximizable(ACRF acrf, InstanceList instanceList) {
        return acrf.getMaximizable(instanceList);
    }

    public boolean incrementalTrain(ACRF acrf, InstanceList instanceList, InstanceList instanceList2, InstanceList instanceList3, int i) {
        return incrementalTrain(acrf, instanceList, instanceList2, instanceList3, new LogEvaluator(), i);
    }

    public boolean incrementalTrain(ACRF acrf, InstanceList instanceList, InstanceList instanceList2, InstanceList instanceList3, ACRFEvaluator aCRFEvaluator, int i) {
        long time = new Date().getTime();
        for (int i2 = 0; i2 < SIZE.length; i2++) {
            InstanceList instanceList4 = instanceList.split(new double[]{SIZE[i2], 1.0d - SIZE[i2]})[0];
            logger.info(new StringBuffer().append("Training on subset of size ").append(instanceList4.size()).toString());
            train(acrf, instanceList, instanceList2, (InstanceList) null, aCRFEvaluator, 10, createMaximizable(acrf, instanceList4));
            logger.info(new StringBuffer().append("Subset training ").append(i2).append(" finished...").toString());
        }
        logger.info(new StringBuffer().append("All subset training finished.  Time = ").append(new Date().getTime() - time).append(" ms.").toString());
        return train(acrf, instanceList, instanceList2, instanceList3, aCRFEvaluator, i);
    }

    public boolean train(ACRF acrf, InstanceList instanceList, InstanceList instanceList2, InstanceList instanceList3, ACRFEvaluator aCRFEvaluator, int i, Maximizable.ByGradient byGradient) {
        LimitedMemoryBFGS limitedMemoryBFGS = new LimitedMemoryBFGS();
        boolean z = false;
        boolean z2 = true;
        long currentTimeMillis = System.currentTimeMillis();
        int totalNodes = byGradient instanceof ACRF.MaximizableACRF ? ((ACRF.MaximizableACRF) byGradient).getTotalNodes() : 0;
        double d = 1.0E-5d * totalNodes;
        if (instanceList3 == null) {
            logger.warning("ACRF trainer: No test set provided.");
        }
        double d2 = Double.NEGATIVE_INFINITY;
        int i2 = 0;
        while (true) {
            if (i2 >= i) {
                break;
            }
            logger.info(new StringBuffer().append("ACRF trainer iteration ").append(i2).append(" at time ").append(new Date().getTime() - currentTimeMillis).toString());
            try {
                z = limitedMemoryBFGS.maximize(byGradient, 1) | callEvaluator(acrf, instanceList, instanceList2, instanceList3, i2, aCRFEvaluator);
            } catch (RuntimeException e) {
                e.printStackTrace();
                if (z2 && (limitedMemoryBFGS instanceof LimitedMemoryBFGS)) {
                    logger.warning(new StringBuffer().append("Exception in iteration ").append(i2).append(":").append(e).append("\n  Resetting LBFGs and trying again...").toString());
                    limitedMemoryBFGS.reset();
                    z2 = false;
                } else {
                    logger.warning(new StringBuffer().append("Exception in iteration ").append(i2).append(":").append(e).append("\n   Quitting and saying converged...").toString());
                    z = true;
                }
            }
            if (z) {
                break;
            }
            z2 = true;
            if (z) {
                break;
            }
            double value = byGradient.getValue();
            if (Math.abs(value - d2) >= d) {
                d2 = value;
            } else if (z2) {
                logger.info(new StringBuffer().append("ACRFTrainer saying converged:  Current value ").append(value).append(", previous ").append(d2).append("\n...threshold was ").append(d).append(" = 1e-5 * ").append(totalNodes).toString());
                z = true;
                break;
            }
            i2++;
        }
        logger.warning(new StringBuffer().append("Exception in iteration ").append(i2).append(":").append(e).append("\n   Quitting and saying converged...").toString());
        z = true;
        logger.info(new StringBuffer().append("ACRF training time (ms) = ").append(System.currentTimeMillis() - currentTimeMillis).toString());
        if (byGradient instanceof ACRF.MaximizableACRF) {
            ((ACRF.MaximizableACRF) byGradient).report();
        }
        if (instanceList3 != null) {
            aCRFEvaluator.test(acrf, instanceList3, "Testing");
        }
        return z;
    }

    private boolean callEvaluator(ACRF acrf, InstanceList instanceList, InstanceList instanceList2, InstanceList instanceList3, int i, ACRFEvaluator aCRFEvaluator) {
        aCRFEvaluator.setOutputPrefix(this.outputPrefix);
        boolean isCacheUnrolledGraphs = acrf.isCacheUnrolledGraphs();
        acrf.setCacheUnrolledGraphs(false);
        Timing timing = new Timing();
        if (aCRFEvaluator.evaluate(acrf, i, instanceList, instanceList2, instanceList3)) {
            timing.tick(new StringBuffer().append("Evaluation time (iteration ").append(i).append(DefaultExpressionEngine.DEFAULT_INDEX_END).toString());
            acrf.setCacheUnrolledGraphs(isCacheUnrolledGraphs);
            return false;
        }
        logger.info("ACRF trainer: evaluator returned false. Quitting.");
        timing.tick(new StringBuffer().append("Evaluation time (iteration ").append(i).append(DefaultExpressionEngine.DEFAULT_INDEX_END).toString());
        return true;
    }

    public void test(ACRF acrf, InstanceList instanceList, ACRFEvaluator aCRFEvaluator) {
        test(acrf, instanceList, new ACRFEvaluator[]{aCRFEvaluator});
    }

    public void test(ACRF acrf, InstanceList instanceList, ACRFEvaluator[] aCRFEvaluatorArr) {
        List bestLabels = acrf.getBestLabels(instanceList);
        for (int i = 0; i < aCRFEvaluatorArr.length; i++) {
            aCRFEvaluatorArr[i].setOutputPrefix(this.outputPrefix);
            aCRFEvaluatorArr[i].test(instanceList, bestLabels, "Testing");
        }
    }

    public void train(ACRF acrf, InstanceList instanceList, InstanceList instanceList2, InstanceList instanceList3, ACRFEvaluator aCRFEvaluator, double[] dArr, int i) {
        for (int i2 = 0; i2 < dArr.length; i2++) {
            double d = dArr[i2];
            InstanceList[] split = instanceList.split(r, new double[]{d, 1.0d});
            logger.info(new StringBuffer().append("ACRF trainer: Round ").append(i2).append(", training proportion = ").append(d).toString());
            train(acrf, split[0], instanceList2, instanceList3, aCRFEvaluator, i);
        }
        logger.info("ACRF trainer: Training on full data");
        train(acrf, instanceList, instanceList2, instanceList3, aCRFEvaluator, 99999);
    }

    static Class class$(String str) {
        try {
            return Class.forName(str);
        } catch (ClassNotFoundException e) {
            throw new NoClassDefFoundError().initCause(e);
        }
    }

    static {
        Class cls;
        if (class$edu$umass$cs$mallet$grmm$learning$ACRFTrainer == null) {
            cls = class$("edu.umass.cs.mallet.grmm.learning.ACRFTrainer");
            class$edu$umass$cs$mallet$grmm$learning$ACRFTrainer = cls;
        } else {
            cls = class$edu$umass$cs$mallet$grmm$learning$ACRFTrainer;
        }
        logger = MalletLogger.getLogger(cls.getName());
        SIZE = new double[]{0.1d, 0.5d};
        r = new Random(1729L);
    }
}
