package edu.umass.cs.mallet.base.fst;

import edu.umass.cs.mallet.base.fst.CRF4;
import edu.umass.cs.mallet.base.fst.Transducer;
import edu.umass.cs.mallet.base.maximize.LimitedMemoryBFGS;
import edu.umass.cs.mallet.base.maximize.Maximizable;
import edu.umass.cs.mallet.base.pipe.Pipe;
import edu.umass.cs.mallet.base.types.Alphabet;
import edu.umass.cs.mallet.base.types.FeatureVector;
import edu.umass.cs.mallet.base.types.FeatureVectorSequence;
import edu.umass.cs.mallet.base.types.Instance;
import edu.umass.cs.mallet.base.types.InstanceList;
import edu.umass.cs.mallet.base.types.Sequence;
import edu.umass.cs.mallet.base.util.MalletLogger;
import java.io.Serializable;
import java.text.DecimalFormat;
import java.util.BitSet;
import java.util.logging.Logger;

/* loaded from: input_file:edu/umass/cs/mallet/base/fst/MEMM.class */
public class MEMM extends CRF4 implements Serializable {
    private static Logger logger;
    private boolean gatheringTrainingData;
    private InstanceList trainingGatheredFor;
    static Class class$edu$umass$cs$mallet$base$fst$MEMM;
    static final boolean $assertionsDisabled;

    /* loaded from: input_file:edu/umass/cs/mallet/base/fst/MEMM$MaximizableMEMM.class */
    public class MaximizableMEMM extends CRF4.MaximizableCRF implements Maximizable.ByGradient {
        private final MEMM this$0;

        /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
        protected MaximizableMEMM(MEMM memm, InstanceList instanceList, MEMM memm2) {
            super(memm, instanceList, memm2);
            this.this$0 = memm;
        }

        protected double gatherExpectationsOrConstraints(boolean z) {
            boolean z2 = false;
            if (this.infiniteValues == null) {
                this.infiniteValues = new BitSet();
                z2 = true;
            }
            double d = 0.0d;
            for (int i = 0; i < this.crf.numStates(); i++) {
                State state = (State) this.crf.getState(i);
                if (state.trainingSet == null) {
                    System.out.println(new StringBuffer().append("Empty training set for state ").append(state.name).toString());
                } else {
                    for (int i2 = 0; i2 < state.trainingSet.size(); i2++) {
                        Instance instanceList = state.trainingSet.getInstance(i2);
                        double instanceWeight = state.trainingSet.getInstanceWeight(i2);
                        FeatureVector featureVector = (FeatureVector) instanceList.getData();
                        String str = (String) instanceList.getTarget();
                        TransitionIterator transitionIterator = new TransitionIterator(state, featureVector, z ? str : null, this.crf);
                        while (transitionIterator.hasNext()) {
                            double cost = transitionIterator.getCost();
                            transitionIterator.incrementCount(Math.exp(-cost) * instanceWeight);
                            if (!z && transitionIterator.getOutput() == str) {
                                if (Double.isInfinite(cost)) {
                                    MEMM.logger.warning(new StringBuffer().append("State ").append(i).append(" transition ").append(i2).append(" has infinite cost; skipping.").toString());
                                    if (z2) {
                                        throw new IllegalStateException("Infinite-cost transitions not yet supported");
                                    }
                                    if (!this.infiniteValues.get(i2)) {
                                        throw new IllegalStateException("Instance i used to have non-infinite value, but now it has infinite value.");
                                    }
                                } else {
                                    d += (-instanceWeight) * cost;
                                }
                            }
                        }
                    }
                }
            }
            for (int i3 = 0; i3 < this.crf.numStates(); i3++) {
                State state2 = (State) this.crf.getState(i3);
                state2.initialExpectation = state2.initialConstraint;
                state2.finalExpectation = state2.finalConstraint;
            }
            return d;
        }

        @Override // edu.umass.cs.mallet.base.fst.CRF4.MaximizableCRF
        protected double getExpectationValue() {
            return gatherExpectationsOrConstraints(false);
        }
    }

    /* loaded from: input_file:edu/umass/cs/mallet/base/fst/MEMM$State.class */
    public static class State extends CRF4.State implements Serializable {
        InstanceList trainingSet;

        protected State(String str, int i, double d, double d2, String[] strArr, String[] strArr2, String[][] strArr3, CRF4 crf4) {
            super(str, i, d, d2, strArr, strArr2, strArr3, crf4);
        }

        @Override // edu.umass.cs.mallet.base.fst.CRF4.State, edu.umass.cs.mallet.base.fst.Transducer.State
        public Transducer.TransitionIterator transitionIterator(Sequence sequence, int i, Sequence sequence2, int i2) {
            if (i < 0 || i2 < 0) {
                throw new UnsupportedOperationException("Epsilon transitions not implemented.");
            }
            if (sequence == null) {
                throw new UnsupportedOperationException("CRFs are not generative models; must have an input sequence.");
            }
            return new TransitionIterator(this, (FeatureVectorSequence) sequence, i, sequence2 == null ? null : (String) sequence2.get(i2), this.crf);
        }

        @Override // edu.umass.cs.mallet.base.fst.CRF4.State, edu.umass.cs.mallet.base.fst.Transducer.State
        public void incrementFinalCount(double d) {
            if (((MEMM) this.crf).gatheringTrainingData) {
                return;
            }
            super.incrementFinalCount(d);
        }

        @Override // edu.umass.cs.mallet.base.fst.CRF4.State, edu.umass.cs.mallet.base.fst.Transducer.State
        public void incrementInitialCount(double d) {
            if (((MEMM) this.crf).gatheringTrainingData) {
                return;
            }
            super.incrementInitialCount(d);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:edu/umass/cs/mallet/base/fst/MEMM$TransitionIterator.class */
    public static class TransitionIterator extends CRF4.TransitionIterator implements Serializable {
        private double sum;
        static final boolean $assertionsDisabled;

        public TransitionIterator(State state, FeatureVectorSequence featureVectorSequence, int i, String str, CRF4 crf4) {
            super(state, featureVectorSequence, i, str, crf4);
            normalizeCosts();
        }

        public TransitionIterator(State state, FeatureVector featureVector, String str, CRF4 crf4) {
            super(state, featureVector, str, crf4);
            normalizeCosts();
        }

        private void normalizeCosts() {
            this.sum = Double.POSITIVE_INFINITY;
            for (int i = 0; i < this.costs.length; i++) {
                this.sum = Transducer.sumNegLogProb(this.sum, this.costs[i]);
            }
            if (!$assertionsDisabled && Double.isNaN(this.sum)) {
                throw new AssertionError();
            }
            if (Double.isInfinite(this.sum)) {
                return;
            }
            for (int i2 = 0; i2 < this.costs.length; i2++) {
                double[] dArr = this.costs;
                int i3 = i2;
                dArr[i3] = dArr[i3] - this.sum;
            }
        }

        @Override // edu.umass.cs.mallet.base.fst.CRF4.TransitionIterator, edu.umass.cs.mallet.base.fst.Transducer.TransitionIterator
        public void incrementCount(double d) {
            if (!((MEMM) this.crf).gatheringTrainingData) {
                super.incrementCount(d);
                return;
            }
            if (this.crf.someTrainingDone || d == Transducer.ZERO_COST) {
                return;
            }
            if (((State) this.source).trainingSet == null) {
                ((State) this.source).trainingSet = new InstanceList(null);
            }
            ((State) this.source).trainingSet.add(getInput(), getOutput(), null, null, d);
        }

        @Override // edu.umass.cs.mallet.base.fst.CRF4.TransitionIterator, edu.umass.cs.mallet.base.fst.Transducer.TransitionIterator
        public String describeTransition(double d) {
            return new StringBuffer().append(super.describeTransition(d)).append("Log Z = ").append(new DecimalFormat("0.###").format(this.sum)).append("\n").toString();
        }

        static {
            Class cls;
            if (MEMM.class$edu$umass$cs$mallet$base$fst$MEMM == null) {
                cls = MEMM.class$("edu.umass.cs.mallet.base.fst.MEMM");
                MEMM.class$edu$umass$cs$mallet$base$fst$MEMM = cls;
            } else {
                cls = MEMM.class$edu$umass$cs$mallet$base$fst$MEMM;
            }
            $assertionsDisabled = !cls.desiredAssertionStatus();
        }
    }

    public MEMM(Pipe pipe, Pipe pipe2) {
        super(pipe, pipe2);
        this.gatheringTrainingData = false;
    }

    public MEMM(Alphabet alphabet, Alphabet alphabet2) {
        super(alphabet, alphabet2);
        this.gatheringTrainingData = false;
    }

    public MEMM(CRF4 crf4) {
        super(crf4);
        this.gatheringTrainingData = false;
    }

    @Override // edu.umass.cs.mallet.base.fst.CRF4
    protected CRF4.State newState(String str, int i, double d, double d2, String[] strArr, String[] strArr2, String[][] strArr3, CRF4 crf4) {
        return new State(str, i, d, d2, strArr, strArr2, strArr3, crf4);
    }

    @Override // edu.umass.cs.mallet.base.fst.CRF4
    public boolean train(InstanceList instanceList, InstanceList instanceList2, InstanceList instanceList3, TransducerEvaluator transducerEvaluator, int i) {
        if (i <= 0) {
            return false;
        }
        if (!$assertionsDisabled && instanceList.size() <= 0) {
            throw new AssertionError();
        }
        if (this.trainingGatheredFor != instanceList) {
            gatherTrainingSets(instanceList);
        }
        if (this.useSparseWeights) {
            setWeightsDimensionAsIn(instanceList);
        } else {
            setWeightsDimensionDensely();
        }
        MaximizableMEMM maximizableMEMM = new MaximizableMEMM(this, instanceList, this);
        maximizableMEMM.gatherExpectationsOrConstraints(true);
        LimitedMemoryBFGS limitedMemoryBFGS = new LimitedMemoryBFGS();
        boolean z = false;
        logger.info(new StringBuffer().append("CRF about to train with ").append(i).append(" iterations").toString());
        int i2 = 0;
        while (true) {
            if (i2 >= i) {
                break;
            }
            try {
                z = limitedMemoryBFGS.maximize(maximizableMEMM, 1);
                logger.info(new StringBuffer().append("CRF finished one iteration of maximizer, i=").append(i2).toString());
            } catch (IllegalArgumentException e) {
                e.printStackTrace();
                logger.info("Catching exception; saying converged.");
                z = true;
            }
            if (transducerEvaluator != null) {
                if (!transducerEvaluator.evaluate(this, z || i2 == i - 1, i2, z, maximizableMEMM.getValue(), instanceList, instanceList2, instanceList3)) {
                    break;
                }
            }
            if (z) {
                logger.info(new StringBuffer().append("CRF training has converged, i=").append(i2).toString());
                break;
            }
            i2++;
        }
        logger.info("About to setTrainable(false)");
        setTrainable(false);
        logger.info("Done setTrainable(false)");
        return z;
    }

    void gatherTrainingSets(InstanceList instanceList) {
        if (this.trainingGatheredFor != null) {
            throw new UnsupportedOperationException("Training with multiple sets not supported.");
        }
        this.trainingGatheredFor = instanceList;
        this.gatheringTrainingData = true;
        for (int i = 0; i < instanceList.size(); i++) {
            Instance instanceList2 = instanceList.getInstance(i);
            forwardBackward((Sequence) instanceList2.getData(), (Sequence) instanceList2.getTarget(), true);
        }
        this.gatheringTrainingData = false;
    }

    @Override // edu.umass.cs.mallet.base.fst.CRF4
    public boolean train(InstanceList instanceList, InstanceList instanceList2, InstanceList instanceList3, TransducerEvaluator transducerEvaluator, int i, int i2, double[] dArr) {
        throw new UnsupportedOperationException();
    }

    @Override // edu.umass.cs.mallet.base.fst.CRF4
    public boolean trainWithFeatureInduction(InstanceList instanceList, InstanceList instanceList2, InstanceList instanceList3, TransducerEvaluator transducerEvaluator, int i, int i2, int i3, int i4, double d, boolean z, double[] dArr, String str) {
        throw new UnsupportedOperationException();
    }

    @Override // edu.umass.cs.mallet.base.fst.CRF4
    public CRF4.MaximizableCRF getMaximizableCRF(InstanceList instanceList) {
        return new MaximizableMEMM(this, instanceList, this);
    }

    public void printInstanceLists() {
        for (int i = 0; i < numStates(); i++) {
            State state = (State) getState(i);
            InstanceList instanceList = state.trainingSet;
            System.out.println(new StringBuffer().append("State ").append(i).append(" : ").append(state.getName()).toString());
            if (instanceList == null) {
                System.out.println("No data");
            } else {
                for (int i2 = 0; i2 < instanceList.size(); i2++) {
                    Instance instanceList2 = instanceList.getInstance(i2);
                    System.out.println(new StringBuffer().append("From : ").append(state.getName()).append(" To : ").append(instanceList2.getTarget()).toString());
                    System.out.println(new StringBuffer().append("Instance ").append(i2).toString());
                    System.out.println(instanceList2.getTarget());
                    System.out.println(instanceList2.getData());
                }
            }
        }
    }

    static Class class$(String str) {
        try {
            return Class.forName(str);
        } catch (ClassNotFoundException e) {
            throw new NoClassDefFoundError().initCause(e);
        }
    }

    static {
        Class cls;
        Class cls2;
        if (class$edu$umass$cs$mallet$base$fst$MEMM == null) {
            cls = class$("edu.umass.cs.mallet.base.fst.MEMM");
            class$edu$umass$cs$mallet$base$fst$MEMM = cls;
        } else {
            cls = class$edu$umass$cs$mallet$base$fst$MEMM;
        }
        $assertionsDisabled = !cls.desiredAssertionStatus();
        if (class$edu$umass$cs$mallet$base$fst$MEMM == null) {
            cls2 = class$("edu.umass.cs.mallet.base.fst.MEMM");
            class$edu$umass$cs$mallet$base$fst$MEMM = cls2;
        } else {
            cls2 = class$edu$umass$cs$mallet$base$fst$MEMM;
        }
        logger = MalletLogger.getLogger(cls2.getName());
    }
}
