package edu.umass.cs.mallet.share.upenn;

import edu.umass.cs.mallet.base.classify.Classification;
import edu.umass.cs.mallet.base.classify.Classifier;
import edu.umass.cs.mallet.base.classify.MaxEntTrainer;
import edu.umass.cs.mallet.base.pipe.CharSequence2TokenSequence;
import edu.umass.cs.mallet.base.pipe.CharSequenceArray2TokenSequence;
import edu.umass.cs.mallet.base.pipe.FeatureSequence2FeatureVector;
import edu.umass.cs.mallet.base.pipe.Pipe;
import edu.umass.cs.mallet.base.pipe.SerialPipes;
import edu.umass.cs.mallet.base.pipe.Target2Label;
import edu.umass.cs.mallet.base.pipe.TokenSequence2FeatureSequence;
import edu.umass.cs.mallet.base.pipe.iterator.ArrayDataAndTargetIterator;
import edu.umass.cs.mallet.base.pipe.iterator.ArrayIterator;
import edu.umass.cs.mallet.base.pipe.iterator.LineIterator;
import edu.umass.cs.mallet.base.pipe.iterator.PipeExtendedIterator;
import edu.umass.cs.mallet.base.pipe.iterator.PipeInputIterator;
import edu.umass.cs.mallet.base.types.Alphabet;
import edu.umass.cs.mallet.base.types.Instance;
import edu.umass.cs.mallet.base.types.InstanceList;
import edu.umass.cs.mallet.base.types.LabelAlphabet;
import edu.umass.cs.mallet.base.types.Labeling;
import edu.umass.cs.mallet.base.types.TokenSequence;
import edu.umass.cs.mallet.base.util.CharSequenceLexer;
import edu.umass.cs.mallet.base.util.CommandOption;
import edu.umass.cs.mallet.base.util.MalletLogger;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.logging.Logger;
import java.util.regex.Pattern;
import org.apache.xerces.impl.xs.SchemaSymbols;
import org.codehaus.groovy.tools.shell.util.ANSI;

/* loaded from: input_file:WEB-INF/lib/mallet-0.1.3.jar:edu/umass/cs/mallet/share/upenn/MaxEntShell.class */
public class MaxEntShell {
    private static Logger logger;
    private static final CommandOption.Double gaussianVarianceOption;
    private static final CommandOption.File trainOption;
    private static final CommandOption.File testOption;
    private static final CommandOption.File classifyOption;
    private static final CommandOption.File modelOption;
    private static final CommandOption.String encodingOption;
    private static final CommandOption.Boolean internalTestOption;
    private static final CommandOption.List commandOptions;
    private static final String[][] internalData;
    private static final String[] internalTargets;
    private static final String[] internalInstance;
    static Class class$edu$umass$cs$mallet$share$upenn$MaxEntShell;

    private MaxEntShell() {
    }

    public static Classifier train(String[][] strArr, String[] strArr2, double d, File file) throws IOException {
        return train(new PipeExtendedIterator(new ArrayDataAndTargetIterator(strArr, strArr2), new CharSequenceArray2TokenSequence()), d, file);
    }

    public static Classifier train(PipeInputIterator pipeInputIterator, double d, File file) throws IOException {
        Alphabet alphabet = new Alphabet();
        LabelAlphabet labelAlphabet = new LabelAlphabet();
        InstanceList instanceList = new InstanceList(new SerialPipes(new Pipe[]{new Target2Label(labelAlphabet), new TokenSequence2FeatureSequence(alphabet), new FeatureSequence2FeatureVector()}));
        instanceList.add(pipeInputIterator);
        logger.info(new StringBuffer().append("# features = ").append(alphabet.size()).toString());
        logger.info(new StringBuffer().append("# labels = ").append(labelAlphabet.size()).toString());
        logger.info(new StringBuffer().append("# training instances = ").append(instanceList.size()).toString());
        Classifier train = new MaxEntTrainer(d).train(instanceList);
        logger.info(new StringBuffer().append("The training accuracy is ").append(train.getAccuracy(instanceList)).toString());
        alphabet.stopGrowth();
        if (file != null) {
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(file));
            objectOutputStream.writeObject(train);
            objectOutputStream.close();
        }
        return train;
    }

    public static double test(Classifier classifier, String[][] strArr, String[] strArr2) {
        return test(classifier, new PipeExtendedIterator(new ArrayDataAndTargetIterator(strArr, strArr2), new CharSequenceArray2TokenSequence()));
    }

    public static double test(Classifier classifier, PipeInputIterator pipeInputIterator) {
        InstanceList instanceList = new InstanceList(classifier.getInstancePipe());
        instanceList.add(pipeInputIterator);
        logger.info(new StringBuffer().append("# test instances = ").append(instanceList.size()).toString());
        return classifier.getAccuracy(instanceList);
    }

    public static Classification classify(Classifier classifier, String[] strArr) {
        return classifier.classify(new Instance(new TokenSequence(strArr), null, null, null, classifier.getInstancePipe()));
    }

    public static Classification[] classify(Classifier classifier, String[][] strArr) {
        return classify(classifier, new PipeExtendedIterator(new ArrayIterator(strArr), new CharSequenceArray2TokenSequence()));
    }

    public static Classification[] classify(Classifier classifier, PipeInputIterator pipeInputIterator) {
        InstanceList instanceList = new InstanceList(classifier.getInstancePipe());
        instanceList.add(pipeInputIterator);
        logger.info(new StringBuffer().append("# unlabeled instances = ").append(instanceList.size()).toString());
        return (Classification[]) classifier.classify(instanceList).toArray(new Classification[0]);
    }

    public static Classifier load(File file) throws IOException, ClassNotFoundException {
        ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(file));
        Classifier classifier = (Classifier) objectInputStream.readObject();
        objectInputStream.close();
        return classifier;
    }

    private static void internalTest() throws IOException {
        Classifier train = train(internalData, internalTargets, 1.0d, null);
        System.out.println(new StringBuffer().append("Training accuracy = ").append(test(train, internalData, internalTargets)).toString());
        Labeling labeling = classify(train, internalInstance).getLabeling();
        LabelAlphabet labelAlphabet = labeling.getLabelAlphabet();
        for (int i = 0; i < labelAlphabet.size(); i++) {
            System.out.print(new StringBuffer().append(labelAlphabet.lookupObject(i)).append(ANSI.Renderer.CODE_TEXT_SEPARATOR).append(labeling.value(i)).append(ANSI.Renderer.CODE_TEXT_SEPARATOR).toString());
        }
        System.out.println();
    }

    private static InputStreamReader getReader(File file, String str) throws IOException {
        return str != null ? new InputStreamReader(new FileInputStream(file), str) : new FileReader(file);
    }

    public static void main(String[] strArr) throws Exception {
        Classifier classifier = null;
        CharSequence2TokenSequence charSequence2TokenSequence = new CharSequence2TokenSequence(new CharSequenceLexer(CharSequenceLexer.LEX_NONWHITESPACE_TOGETHER));
        Pattern compile = Pattern.compile("^\\s*(\\S+)\\s*(.*)\\s*$");
        Pattern compile2 = Pattern.compile("^\\s*(.*)\\s*$");
        commandOptions.process(strArr);
        if (internalTestOption.value) {
            internalTest();
        }
        if (trainOption.value != null) {
            classifier = train(new PipeExtendedIterator(new LineIterator(getReader(trainOption.value, encodingOption.value), compile, 2, 1, -1), charSequence2TokenSequence), gaussianVarianceOption.value, modelOption.value);
        } else if (modelOption.value != null) {
            classifier = load(modelOption.value);
        }
        if (classifier != null) {
            if (testOption.value != null) {
                System.out.println(new StringBuffer().append("The testing accuracy is ").append(test(classifier, new PipeExtendedIterator(new LineIterator(getReader(testOption.value, encodingOption.value), compile, 2, 1, -1), charSequence2TokenSequence))).toString());
            }
            if (classifyOption.value != null) {
                classifier.getInstancePipe().setTargetProcessing(false);
                for (Classification classification : classify(classifier, new PipeExtendedIterator(new LineIterator(getReader(classifyOption.value, encodingOption.value), compile2, 1, -1, -1), charSequence2TokenSequence))) {
                    Labeling labeling = classification.getLabeling();
                    LabelAlphabet labelAlphabet = labeling.getLabelAlphabet();
                    for (int i = 0; i < labelAlphabet.size(); i++) {
                        System.out.print(new StringBuffer().append(labelAlphabet.lookupObject(i)).append(ANSI.Renderer.CODE_TEXT_SEPARATOR).append(labeling.value(i)).append(ANSI.Renderer.CODE_TEXT_SEPARATOR).toString());
                    }
                    System.out.println();
                }
            }
        }
    }

    static Class class$(String str) {
        try {
            return Class.forName(str);
        } catch (ClassNotFoundException e) {
            throw new NoClassDefFoundError().initCause(e);
        }
    }

    /* JADX WARN: Type inference failed for: r0v14, types: [java.lang.String[], java.lang.String[][]] */
    static {
        Class cls;
        Class cls2;
        Class cls3;
        Class cls4;
        Class cls5;
        Class cls6;
        Class cls7;
        Class cls8;
        if (class$edu$umass$cs$mallet$share$upenn$MaxEntShell == null) {
            cls = class$("edu.umass.cs.mallet.share.upenn.MaxEntShell");
            class$edu$umass$cs$mallet$share$upenn$MaxEntShell = cls;
        } else {
            cls = class$edu$umass$cs$mallet$share$upenn$MaxEntShell;
        }
        logger = MalletLogger.getLogger(cls.getName());
        if (class$edu$umass$cs$mallet$share$upenn$MaxEntShell == null) {
            cls2 = class$("edu.umass.cs.mallet.share.upenn.MaxEntShell");
            class$edu$umass$cs$mallet$share$upenn$MaxEntShell = cls2;
        } else {
            cls2 = class$edu$umass$cs$mallet$share$upenn$MaxEntShell;
        }
        gaussianVarianceOption = new CommandOption.Double(cls2, "gaussian-variance", SchemaSymbols.ATTVAL_DECIMAL, true, 1.0d, "The gaussian prior variance used for training.", null);
        if (class$edu$umass$cs$mallet$share$upenn$MaxEntShell == null) {
            cls3 = class$("edu.umass.cs.mallet.share.upenn.MaxEntShell");
            class$edu$umass$cs$mallet$share$upenn$MaxEntShell = cls3;
        } else {
            cls3 = class$edu$umass$cs$mallet$share$upenn$MaxEntShell;
        }
        trainOption = new CommandOption.File(cls3, "train", "FILENAME", true, null, "Training datafile", null);
        if (class$edu$umass$cs$mallet$share$upenn$MaxEntShell == null) {
            cls4 = class$("edu.umass.cs.mallet.share.upenn.MaxEntShell");
            class$edu$umass$cs$mallet$share$upenn$MaxEntShell = cls4;
        } else {
            cls4 = class$edu$umass$cs$mallet$share$upenn$MaxEntShell;
        }
        testOption = new CommandOption.File(cls4, "test", "filename", true, null, "Test datafile", null);
        if (class$edu$umass$cs$mallet$share$upenn$MaxEntShell == null) {
            cls5 = class$("edu.umass.cs.mallet.share.upenn.MaxEntShell");
            class$edu$umass$cs$mallet$share$upenn$MaxEntShell = cls5;
        } else {
            cls5 = class$edu$umass$cs$mallet$share$upenn$MaxEntShell;
        }
        classifyOption = new CommandOption.File(cls5, "classify", "filename", true, null, "Datafile to classify", null);
        if (class$edu$umass$cs$mallet$share$upenn$MaxEntShell == null) {
            cls6 = class$("edu.umass.cs.mallet.share.upenn.MaxEntShell");
            class$edu$umass$cs$mallet$share$upenn$MaxEntShell = cls6;
        } else {
            cls6 = class$edu$umass$cs$mallet$share$upenn$MaxEntShell;
        }
        modelOption = new CommandOption.File(cls6, "model", "filename", true, null, "Model file", null);
        if (class$edu$umass$cs$mallet$share$upenn$MaxEntShell == null) {
            cls7 = class$("edu.umass.cs.mallet.share.upenn.MaxEntShell");
            class$edu$umass$cs$mallet$share$upenn$MaxEntShell = cls7;
        } else {
            cls7 = class$edu$umass$cs$mallet$share$upenn$MaxEntShell;
        }
        encodingOption = new CommandOption.String(cls7, "encoding", "character-encoding-name", true, null, "Input character encoding", null);
        if (class$edu$umass$cs$mallet$share$upenn$MaxEntShell == null) {
            cls8 = class$("edu.umass.cs.mallet.share.upenn.MaxEntShell");
            class$edu$umass$cs$mallet$share$upenn$MaxEntShell = cls8;
        } else {
            cls8 = class$edu$umass$cs$mallet$share$upenn$MaxEntShell;
        }
        internalTestOption = new CommandOption.Boolean(cls8, "internal-test", "true|false", true, false, "Run internal tests", null);
        commandOptions = new CommandOption.List("Training, testing and running a generic tagger.", new CommandOption[]{gaussianVarianceOption, trainOption, testOption, modelOption, classifyOption, encodingOption, internalTestOption});
        internalData = new String[]{new String[]{"a", "b"}, new String[]{"b", "c"}, new String[]{"a", "c"}};
        internalTargets = new String[]{"yes", "no", "no"};
        internalInstance = new String[]{"a", "b", "c"};
    }
}
