package edu.umass.cs.mallet.base.pipe;

import edu.umass.cs.mallet.base.classify.BalancedWinnowTrainer;
import edu.umass.cs.mallet.base.classify.Classification;
import edu.umass.cs.mallet.base.classify.Classifier;
import edu.umass.cs.mallet.base.classify.ClassifierTrainer;
import edu.umass.cs.mallet.base.classify.Trial;
import edu.umass.cs.mallet.base.types.Alphabet;
import edu.umass.cs.mallet.base.types.AugmentableFeatureVector;
import edu.umass.cs.mallet.base.types.FeatureVector;
import edu.umass.cs.mallet.base.types.FeatureVectorSequence;
import edu.umass.cs.mallet.base.types.Instance;
import edu.umass.cs.mallet.base.types.InstanceList;
import edu.umass.cs.mallet.base.types.LabelAlphabet;
import edu.umass.cs.mallet.base.types.LabelSequence;
import edu.umass.cs.mallet.base.types.LabelVector;
import edu.umass.cs.mallet.base.util.MalletLogger;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.HashMap;
import java.util.logging.Logger;

/* loaded from: input_file:WEB-INF/lib/mallet-0.1.3.jar:edu/umass/cs/mallet/base/pipe/AddClassifierTokenPredictions.class */
public class AddClassifierTokenPredictions extends Pipe implements Serializable {
    private static Logger logger;
    int[] m_predRanks2add;
    TokenClassifiers m_tokenClassifiers;
    boolean m_binary;
    boolean m_inProduction;
    Alphabet m_dataAlphabet;
    private static final long serialVersionUID = 1;
    static Class class$edu$umass$cs$mallet$base$pipe$AddClassifierTokenPredictions;
    static final boolean $assertionsDisabled;

    /* loaded from: input_file:WEB-INF/lib/mallet-0.1.3.jar:edu/umass/cs/mallet/base/pipe/AddClassifierTokenPredictions$TokenClassifiers.class */
    public static class TokenClassifiers extends Classifier implements Serializable {
        int m_numCV;
        int m_randSeed;
        ClassifierTrainer m_trainer;
        Classifier m_tokenClassifier;
        HashMap m_table;
        private static final long serialVersionUID = 1;
        private static final int CURRENT_SERIAL_VERSION = 1;

        public TokenClassifiers(InstanceList instanceList) {
            this(instanceList, 0, 5);
        }

        public TokenClassifiers(InstanceList instanceList, int i, int i2) {
            this(new BalancedWinnowTrainer(), instanceList, i, i2);
        }

        public TokenClassifiers(ClassifierTrainer classifierTrainer, InstanceList instanceList, int i, int i2) {
            super(instanceList.getPipe());
            this.m_trainer = classifierTrainer;
            this.m_randSeed = i;
            this.m_numCV = i2;
            this.m_table = new HashMap();
            doTraining(instanceList);
        }

        private void doTraining(InstanceList instanceList) {
            AddClassifierTokenPredictions.logger.info(new StringBuffer().append("Training token classifier on entire data set (size=").append(instanceList.size()).append(")...").toString());
            this.m_tokenClassifier = this.m_trainer.train(instanceList);
            AddClassifierTokenPredictions.logger.info(new StringBuffer().append("Training set accuracy = ").append(new Trial(this.m_tokenClassifier, instanceList).accuracy()).toString());
            if (this.m_numCV == 0) {
                return;
            }
            instanceList.getClass();
            InstanceList.CrossValidationIterator crossValidationIterator = new InstanceList.CrossValidationIterator(instanceList, this.m_numCV, this.m_randSeed);
            int i = 1;
            while (crossValidationIterator.hasNext()) {
                i++;
                InstanceList[] nextSplit = crossValidationIterator.nextSplit();
                AddClassifierTokenPredictions.logger.info(new StringBuffer().append("Training token classifier on cv fold ").append(i).append(" / ").append(this.m_numCV).append(" (size=").append(nextSplit[0].size()).append(")...").toString());
                Classifier train = this.m_trainer.train(nextSplit[0]);
                Trial trial = new Trial(train, nextSplit[0]);
                Trial trial2 = new Trial(train, nextSplit[1]);
                AddClassifierTokenPredictions.logger.info(new StringBuffer().append("Within-fold accuracy = ").append(trial.accuracy()).toString());
                AddClassifierTokenPredictions.logger.info(new StringBuffer().append("Out-of-fold accuracy = ").append(trial2.accuracy()).toString());
                for (int i2 = 0; i2 < nextSplit[1].size(); i2++) {
                    this.m_table.put(nextSplit[1].getInstance(i2).getName(), train);
                }
            }
        }

        @Override // edu.umass.cs.mallet.base.classify.Classifier
        public Classification classify(Instance instance) {
            return classify(instance, false);
        }

        public Classification classify(Instance instance, boolean z) {
            Object name = instance.getName();
            return (z && this.m_table.containsKey(name)) ? ((Classifier) this.m_table.get(name)).classify(instance) : this.m_tokenClassifier.classify(instance);
        }

        private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
            objectOutputStream.writeInt(1);
            objectOutputStream.writeObject(getInstancePipe());
            objectOutputStream.writeInt(this.m_numCV);
            objectOutputStream.writeInt(this.m_randSeed);
            objectOutputStream.writeObject(this.m_table);
            objectOutputStream.writeObject(this.m_tokenClassifier);
            objectOutputStream.writeObject(this.m_trainer);
        }

        private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
            int readInt = objectInputStream.readInt();
            if (readInt != 1) {
                throw new ClassNotFoundException(new StringBuffer().append("Mismatched TokenClassifiers versions: wanted 1, got ").append(readInt).toString());
            }
            this.instancePipe = (Pipe) objectInputStream.readObject();
            this.m_numCV = objectInputStream.readInt();
            this.m_randSeed = objectInputStream.readInt();
            this.m_table = (HashMap) objectInputStream.readObject();
            this.m_tokenClassifier = (Classifier) objectInputStream.readObject();
            this.m_trainer = (ClassifierTrainer) objectInputStream.readObject();
        }
    }

    public AddClassifierTokenPredictions(InstanceList instanceList) {
        this(instanceList, null);
    }

    public AddClassifierTokenPredictions(InstanceList instanceList, InstanceList instanceList2) {
        this(new TokenClassifiers(convert(instanceList, (Noop) instanceList.getPipe())), new int[]{1}, true, convert(instanceList2, (Noop) instanceList.getPipe()));
    }

    public AddClassifierTokenPredictions(TokenClassifiers tokenClassifiers, int[] iArr, boolean z, InstanceList instanceList) {
        this.m_predRanks2add = iArr;
        this.m_binary = z;
        this.m_tokenClassifiers = tokenClassifiers;
        this.m_inProduction = false;
        this.m_dataAlphabet = (Alphabet) tokenClassifiers.getAlphabet().clone();
        LabelAlphabet labelAlphabet = tokenClassifiers.getLabelAlphabet();
        for (int i = 0; i < this.m_predRanks2add.length; i++) {
            for (int i2 = 0; i2 < labelAlphabet.size(); i2++) {
                this.m_dataAlphabet.lookupIndex(new StringBuffer().append("TOK_PRED=").append(labelAlphabet.lookupObject(i2).toString()).append("_@_RANK_").append(this.m_predRanks2add[i]).toString(), true);
            }
        }
        if (instanceList != null) {
            logger.info(new StringBuffer().append("Token classifier accuracy on test set = ").append(new Trial(this.m_tokenClassifiers, instanceList).accuracy()).toString());
        }
    }

    public void setInProduction(boolean z) {
        this.m_inProduction = z;
    }

    public boolean getInProduction() {
        return this.m_inProduction;
    }

    public static void setInProduction(Pipe pipe, boolean z) {
        if (pipe instanceof AddClassifierTokenPredictions) {
            ((AddClassifierTokenPredictions) pipe).setInProduction(z);
            return;
        }
        if (pipe instanceof SerialPipes) {
            SerialPipes serialPipes = (SerialPipes) pipe;
            for (int i = 0; i < serialPipes.size(); i++) {
                setInProduction(serialPipes.getPipe(i), z);
            }
        }
    }

    @Override // edu.umass.cs.mallet.base.pipe.Pipe
    public Alphabet getDataAlphabet() {
        return this.m_dataAlphabet;
    }

    @Override // edu.umass.cs.mallet.base.pipe.Pipe
    public Instance pipe(Instance instance) {
        FeatureVectorSequence featureVectorSequence = (FeatureVectorSequence) instance.getData();
        InstanceList convert = convert(instance, (Noop) this.m_tokenClassifiers.getInstancePipe());
        if (!$assertionsDisabled && featureVectorSequence.size() != convert.size()) {
            throw new AssertionError();
        }
        FeatureVector[] featureVectorArr = new FeatureVector[featureVectorSequence.size()];
        for (int i = 0; i < convert.size(); i++) {
            Instance instanceList = convert.getInstance(i);
            LabelVector labelVector = this.m_tokenClassifiers.classify(instanceList, !this.m_inProduction).getLabelVector();
            AugmentableFeatureVector augmentableFeatureVector = (AugmentableFeatureVector) instanceList.getData();
            int[] indices = augmentableFeatureVector.getIndices();
            AugmentableFeatureVector augmentableFeatureVector2 = new AugmentableFeatureVector(this.m_dataAlphabet, indices, augmentableFeatureVector.getValues(), indices.length + this.m_predRanks2add.length);
            for (int i2 = 0; i2 < this.m_predRanks2add.length; i2++) {
                int lookupIndex = this.m_dataAlphabet.lookupIndex(new StringBuffer().append("TOK_PRED=").append(labelVector.getLabelAtRank(this.m_predRanks2add[i2]).toString()).append("_@_RANK_").append(this.m_predRanks2add[i2]).toString());
                if (!$assertionsDisabled && lookupIndex < 0) {
                    throw new AssertionError();
                }
                augmentableFeatureVector2.add(lookupIndex, 1.0d);
            }
            featureVectorArr[i] = augmentableFeatureVector2;
        }
        instance.setData(new FeatureVectorSequence(featureVectorArr));
        return instance;
    }

    public static InstanceList convert(InstanceList instanceList, Noop noop) {
        if (instanceList == null) {
            return null;
        }
        InstanceList instanceList2 = new InstanceList(noop);
        for (int i = 0; i < instanceList.size(); i++) {
            instanceList2.add(convert(instanceList.getInstance(i), noop));
        }
        return instanceList2;
    }

    public static InstanceList convert(Instance instance, Noop noop) {
        InstanceList instanceList = new InstanceList(noop);
        Object data = instance.getData();
        if (!$assertionsDisabled && !(data instanceof FeatureVectorSequence)) {
            throw new AssertionError();
        }
        FeatureVectorSequence featureVectorSequence = (FeatureVectorSequence) data;
        LabelSequence labelSequence = (LabelSequence) instance.getTarget();
        if (!$assertionsDisabled && featureVectorSequence.size() != labelSequence.size()) {
            throw new AssertionError();
        }
        Object name = instance.getName() == null ? "NONAME" : instance.getName();
        for (int i = 0; i < featureVectorSequence.size(); i++) {
            FeatureVector featureVector = featureVectorSequence.getFeatureVector(i);
            int[] indices = featureVector.getIndices();
            instanceList.add(new Instance(new AugmentableFeatureVector(noop.getDataAlphabet(), indices, featureVector.getValues(), indices.length), labelSequence.getLabelAtPosition(i), new StringBuffer().append(name.toString()).append("_@_POS_").append(i + 1).toString(), instance.getSource(), noop));
        }
        return instanceList;
    }

    static Class class$(String str) {
        try {
            return Class.forName(str);
        } catch (ClassNotFoundException e) {
            throw new NoClassDefFoundError().initCause(e);
        }
    }

    static {
        Class cls;
        Class cls2;
        if (class$edu$umass$cs$mallet$base$pipe$AddClassifierTokenPredictions == null) {
            cls = class$("edu.umass.cs.mallet.base.pipe.AddClassifierTokenPredictions");
            class$edu$umass$cs$mallet$base$pipe$AddClassifierTokenPredictions = cls;
        } else {
            cls = class$edu$umass$cs$mallet$base$pipe$AddClassifierTokenPredictions;
        }
        $assertionsDisabled = !cls.desiredAssertionStatus();
        if (class$edu$umass$cs$mallet$base$pipe$AddClassifierTokenPredictions == null) {
            cls2 = class$("edu.umass.cs.mallet.base.pipe.AddClassifierTokenPredictions");
            class$edu$umass$cs$mallet$base$pipe$AddClassifierTokenPredictions = cls2;
        } else {
            cls2 = class$edu$umass$cs$mallet$base$pipe$AddClassifierTokenPredictions;
        }
        logger = MalletLogger.getLogger(cls2.getName());
    }
}
