package edu.umass.cs.mallet.base.classify;

import edu.umass.cs.mallet.base.fst.Transducer;
import edu.umass.cs.mallet.base.types.Instance;
import edu.umass.cs.mallet.base.types.InstanceList;
import edu.umass.cs.mallet.base.types.MatrixOps;
import edu.umass.cs.mallet.base.util.MalletLogger;
import edu.umass.cs.mallet.base.util.Maths;
import java.util.Arrays;
import java.util.Random;
import java.util.logging.Logger;

/* loaded from: input_file:WEB-INF/lib/mallet-0.1.3.jar:edu/umass/cs/mallet/base/classify/AdaBoostM2Trainer.class */
public class AdaBoostM2Trainer extends ClassifierTrainer {
    private static Logger logger;
    private static int MAX_NUM_RESAMPLING_ITERATIONS;
    ClassifierTrainer weakLearner;
    int numRounds;
    static Class class$edu$umass$cs$mallet$base$classify$AdaBoostM2Trainer;

    public AdaBoostM2Trainer(ClassifierTrainer classifierTrainer, int i) {
        if (!(classifierTrainer instanceof Boostable)) {
            throw new IllegalArgumentException("weak learner not boostable");
        }
        if (i <= 0) {
            throw new IllegalArgumentException("number of rounds must be positive");
        }
        this.weakLearner = classifierTrainer;
        this.numRounds = i;
    }

    public AdaBoostM2Trainer(ClassifierTrainer classifierTrainer) {
        this(classifierTrainer, 100);
    }

    @Override // edu.umass.cs.mallet.base.classify.ClassifierTrainer
    public Classifier train(InstanceList instanceList, InstanceList instanceList2, InstanceList instanceList3, ClassifierEvaluating classifierEvaluating, Classifier classifier) {
        double d;
        if (instanceList.getFeatureSelection() != null) {
            throw new UnsupportedOperationException("FeatureSelection not yet implemented.");
        }
        int size = instanceList.getTargetAlphabet().size();
        int size2 = instanceList.size();
        InstanceList instanceList4 = new InstanceList();
        double[] dArr = new double[size2 * (size - 1)];
        Arrays.fill(dArr, 1.0d / dArr.length);
        int[] iArr = new int[dArr.length];
        int i = 0;
        for (int i2 = 0; i2 < size2; i2++) {
            Instance instanceList5 = instanceList.getInstance(i2);
            int bestIndex = instanceList5.getLabeling().getBestIndex();
            for (int i3 = 0; i3 < size; i3++) {
                if (i3 != bestIndex) {
                    instanceList4.add(instanceList5, 1.0d);
                    iArr[i] = i3;
                    i++;
                }
            }
        }
        Random random = new Random();
        Classifier[] classifierArr = new Classifier[this.numRounds];
        double[] dArr2 = new double[this.numRounds];
        double[] dArr3 = new double[dArr.length];
        int[] iArr2 = new int[dArr.length];
        for (int i4 = 0; i4 < iArr2.length; i4++) {
            iArr2[i4] = i4;
        }
        int i5 = 0;
        while (i5 < this.numRounds) {
            logger.info(new StringBuffer().append("===========  AdaBoostM2Trainer round ").append(i5 + 1).append(" begin").toString());
            new InstanceList();
            int i6 = 0;
            do {
                double d2 = 0.0d;
                InstanceList instanceList6 = new InstanceList();
                for (int i7 : sampleWithWeights(iArr2, dArr, random)) {
                    instanceList6.add(instanceList4.getInstance(i7), 1.0d);
                }
                classifierArr[i5] = this.weakLearner.train(instanceList6, instanceList2);
                for (int i8 = 0; i8 < instanceList4.size(); i8++) {
                    Classification classify = classifierArr[i5].classify(instanceList4.getInstance(i8));
                    double valueOfCorrectLabel = classify.valueOfCorrectLabel();
                    double value = classify.getLabeling().value(iArr[i8]);
                    d2 += dArr[i8] * ((1.0d - valueOfCorrectLabel) + value);
                    dArr3[i8] = (1.0d + valueOfCorrectLabel) - value;
                }
                d = d2 * 0.5d;
                i6++;
                if (!Maths.almostEquals(d, Transducer.ZERO_COST)) {
                    break;
                }
            } while (i6 < MAX_NUM_RESAMPLING_ITERATIONS);
            if (Maths.almostEquals(d, Transducer.ZERO_COST)) {
                logger.info(new StringBuffer().append("AdaBoostM2Trainer stopped at ").append(i5 + 1).append(" / ").append(this.numRounds).append(" pseudo-loss=").append(d).toString());
                int i9 = i5 == 0 ? 1 : i5;
                if (i5 == 0) {
                    dArr2[0] = 1.0d;
                }
                double[] dArr4 = new double[i9];
                Classifier[] classifierArr2 = new Classifier[i9];
                System.arraycopy(dArr2, 0, dArr4, 0, i9);
                System.arraycopy(classifierArr, 0, classifierArr2, 0, i9);
                for (int i10 = 0; i10 < dArr4.length; i10++) {
                    logger.info(new StringBuffer().append("AdaBoostM2Trainer weight[weakLearner[").append(i10).append("]]=").append(dArr4[i10]).toString());
                }
                return new AdaBoostM2(instanceList4.getPipe(), classifierArr2, dArr4);
            }
            double d3 = d / (1.0d - d);
            dArr2[i5] = Math.log(1.0d / d3);
            double d4 = 0.0d;
            for (int i11 = 0; i11 < dArr.length; i11++) {
                int i12 = i11;
                dArr[i12] = dArr[i12] * Math.pow(d3, 0.5d * dArr3[i11]);
                d4 += dArr[i11];
            }
            MatrixOps.timesEquals(dArr, 1.0d / d4);
            logger.info(new StringBuffer().append("===========  AdaBoostM2Trainer round ").append(i5 + 1).append(" finished, pseudo-loss = ").append(d).toString());
            i5++;
        }
        for (int i13 = 0; i13 < dArr2.length; i13++) {
            logger.info(new StringBuffer().append("AdaBoostM2Trainer weight[weakLearner[").append(i13).append("]]=").append(dArr2[i13]).toString());
        }
        return new AdaBoostM2(instanceList4.getPipe(), classifierArr, dArr2);
    }

    private int[] sampleWithWeights(int[] iArr, double[] dArr, Random random) {
        if (dArr.length != iArr.length) {
            throw new IllegalArgumentException("length of weight vector must equal number of data points");
        }
        double d = 0.0d;
        for (int i = 0; i < iArr.length; i++) {
            if (dArr[i] < Transducer.ZERO_COST) {
                throw new IllegalArgumentException("weight vector must be non-negative");
            }
            d += dArr[i];
        }
        if (d <= Transducer.ZERO_COST) {
            throw new IllegalArgumentException("weights must sum to positive value");
        }
        int[] iArr2 = new int[iArr.length];
        double[] dArr2 = new double[iArr.length];
        double d2 = 0.0d;
        for (int i2 = 0; i2 < iArr.length; i2++) {
            d2 += random.nextDouble();
            dArr2[i2] = d2;
        }
        MatrixOps.timesEquals(dArr2, d / d2);
        dArr2[iArr.length - 1] = d;
        int i3 = 0;
        double d3 = 0.0d;
        for (int i4 = 0; i3 < iArr.length && i4 < iArr.length; i4++) {
            d3 += dArr[i4];
            while (i3 < iArr.length && dArr2[i3] <= d3) {
                iArr2[i3] = iArr[i4];
                i3++;
            }
        }
        return iArr2;
    }

    static Class class$(String str) {
        try {
            return Class.forName(str);
        } catch (ClassNotFoundException e) {
            throw new NoClassDefFoundError().initCause(e);
        }
    }

    static {
        Class cls;
        if (class$edu$umass$cs$mallet$base$classify$AdaBoostM2Trainer == null) {
            cls = class$("edu.umass.cs.mallet.base.classify.AdaBoostM2Trainer");
            class$edu$umass$cs$mallet$base$classify$AdaBoostM2Trainer = cls;
        } else {
            cls = class$edu$umass$cs$mallet$base$classify$AdaBoostM2Trainer;
        }
        logger = MalletLogger.getLogger(cls.getName());
        MAX_NUM_RESAMPLING_ITERATIONS = 10;
    }
}
