package edu.umass.cs.mallet.grmm.learning.extract;

import edu.umass.cs.mallet.base.extract.Extraction;
import edu.umass.cs.mallet.base.extract.TokenizationFilter;
import edu.umass.cs.mallet.base.fst.Transducer;
import edu.umass.cs.mallet.base.pipe.Pipe;
import edu.umass.cs.mallet.base.pipe.PipeUtils;
import edu.umass.cs.mallet.base.pipe.iterator.PipeInputIterator;
import edu.umass.cs.mallet.base.types.InstanceList;
import edu.umass.cs.mallet.base.util.CollectionUtils;
import edu.umass.cs.mallet.base.util.FileUtils;
import edu.umass.cs.mallet.base.util.MalletLogger;
import edu.umass.cs.mallet.base.util.Timing;
import edu.umass.cs.mallet.grmm.inference.Inferencer;
import edu.umass.cs.mallet.grmm.learning.ACRF;
import edu.umass.cs.mallet.grmm.learning.ACRFEvaluator;
import edu.umass.cs.mallet.grmm.learning.ACRFTrainer;
import edu.umass.cs.mallet.grmm.learning.AcrfSerialEvaluator;
import edu.umass.cs.mallet.grmm.util.PipedIterator;
import edu.umass.cs.mallet.grmm.util.RememberTokenizationPipe;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.logging.Logger;
import org.apache.tools.ant.taskdefs.optional.ejb.GenericDeploymentTool;
import org.codehaus.groovy.tools.shell.util.ANSI;

/* loaded from: input_file:WEB-INF/lib/mallet-0.1.3.jar:edu/umass/cs/mallet/grmm/learning/extract/ACRFExtractorTrainer.class */
public class ACRFExtractorTrainer {
    private static final Logger logger;
    protected ACRF.Template[] tmpls;
    protected InstanceList training;
    protected InstanceList testing;
    private PipeInputIterator testIterator;
    private PipeInputIterator trainIterator;
    protected Pipe featurePipe;
    protected Pipe tokPipe;
    TokenizationFilter filter;
    private Inferencer inferencer;
    private Inferencer viterbiInferencer;
    private boolean cacheUnrolledGraphs;
    private Random r;
    static Class class$edu$umass$cs$mallet$grmm$learning$extract$ACRFExtractorTrainer;
    private int numIter = 99999;
    ACRFTrainer trainer = new ACRFTrainer();
    protected ACRFEvaluator evaluator = new ACRFTrainer.LogEvaluator();
    private int numCheckpointIterations = -1;
    private File checkpointDirectory = null;
    private boolean usePerTemplateTrain = false;
    private int perTemplateIterations = 100;
    private double trainingPct = -1.0d;
    private double testingPct = -1.0d;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:WEB-INF/lib/mallet-0.1.3.jar:edu/umass/cs/mallet/grmm/learning/extract/ACRFExtractorTrainer$CheckpointingEvaluator.class */
    public static class CheckpointingEvaluator extends ACRFEvaluator {
        private File directory;
        private int interval;
        private Pipe tokPipe;
        private Pipe featurePipe;

        public CheckpointingEvaluator(File file, int i, Pipe pipe, Pipe pipe2) {
            this.directory = file;
            this.interval = i;
            this.tokPipe = pipe;
            this.featurePipe = pipe2;
        }

        @Override // edu.umass.cs.mallet.grmm.learning.ACRFEvaluator
        public boolean evaluate(ACRF acrf, int i, InstanceList instanceList, InstanceList instanceList2, InstanceList instanceList3) {
            if (i <= 0 || i % this.interval != 0) {
                return true;
            }
            FileUtils.writeGzippedObject(new File(this.directory, new StringBuffer().append("extor.").append(i).append(".ser.gz").toString()), new ACRFExtractor(acrf, this.tokPipe, this.featurePipe));
            return true;
        }

        @Override // edu.umass.cs.mallet.grmm.learning.ACRFEvaluator
        public void test(InstanceList instanceList, List list, String str) {
        }
    }

    public ACRFExtractorTrainer setTemplates(ACRF.Template[] templateArr) {
        this.tmpls = templateArr;
        return this;
    }

    public ACRFExtractorTrainer setDataSource(PipeInputIterator pipeInputIterator, PipeInputIterator pipeInputIterator2) {
        this.trainIterator = pipeInputIterator;
        this.testIterator = pipeInputIterator2;
        return this;
    }

    public ACRFExtractorTrainer setData(InstanceList instanceList, InstanceList instanceList2) {
        this.training = instanceList;
        this.testing = instanceList2;
        return this;
    }

    public ACRFExtractorTrainer setNumIterations(int i) {
        this.numIter = i;
        return this;
    }

    public int getNumIter() {
        return this.numIter;
    }

    public ACRFExtractorTrainer setPipes(Pipe pipe, Pipe pipe2) {
        this.featurePipe = PipeUtils.concatenatePipes(new RememberTokenizationPipe(), pipe2);
        this.tokPipe = pipe;
        return this;
    }

    public ACRFExtractorTrainer setEvaluator(ACRFEvaluator aCRFEvaluator) {
        this.evaluator = aCRFEvaluator;
        return this;
    }

    public ACRFExtractorTrainer setTrainingMethod(ACRFTrainer aCRFTrainer) {
        this.trainer = aCRFTrainer;
        return this;
    }

    public ACRFExtractorTrainer setTokenizatioFilter(TokenizationFilter tokenizationFilter) {
        this.filter = tokenizationFilter;
        return this;
    }

    public ACRFExtractorTrainer setCacheUnrolledGraphs(boolean z) {
        this.cacheUnrolledGraphs = z;
        return this;
    }

    public ACRFExtractorTrainer setNumCheckpointIterations(int i) {
        this.numCheckpointIterations = i;
        return this;
    }

    public ACRFExtractorTrainer setCheckpointDirectory(File file) {
        this.checkpointDirectory = file;
        return this;
    }

    public ACRFExtractorTrainer setUsePerTemplateTrain(boolean z) {
        this.usePerTemplateTrain = z;
        return this;
    }

    public ACRFExtractorTrainer setPerTemplateIterations(int i) {
        this.perTemplateIterations = i;
        return this;
    }

    public ACRFTrainer getTrainer() {
        return this.trainer;
    }

    public TokenizationFilter getFilter() {
        return this.filter;
    }

    public ACRFExtractor trainExtractor() {
        ACRFExtractor aCRFExtractor = new ACRFExtractor(this.usePerTemplateTrain ? perTemplateTrain() : trainAcrf(), this.tokPipe, this.featurePipe);
        if (this.filter != null) {
            aCRFExtractor.setTokenizationFilter(this.filter);
        }
        return aCRFExtractor;
    }

    private ACRF perTemplateTrain() {
        Timing timing = new Timing();
        boolean z = false;
        ACRF acrf = null;
        if (this.training == null) {
            setupData();
        }
        for (int i = 0; i < this.tmpls.length; i++) {
            ACRF.Template[] templateArr = new ACRF.Template[i + 1];
            System.arraycopy(this.tmpls, 0, templateArr, 0, templateArr.length);
            logger.info(new StringBuffer().append("***PerTemplateTrain: Round ").append(i).append("\n  Templates: ").append(CollectionUtils.dumpToString(Arrays.asList(templateArr), ANSI.Renderer.CODE_TEXT_SEPARATOR)).toString());
            acrf = new ACRF(this.featurePipe, templateArr);
            setupAcrf(acrf);
            z = this.trainer.train(acrf, this.training, null, this.testing, setupEvaluator(new StringBuffer().append("tmpl").append(i).toString()), this.perTemplateIterations);
            timing.tick(new StringBuffer().append("PerTemplateTrain round ").append(i).toString());
        }
        ACRFEvaluator aCRFEvaluator = setupEvaluator(GenericDeploymentTool.ANALYZER_FULL);
        if (!z) {
            this.trainer.train(acrf, this.training, null, this.testing, aCRFEvaluator, this.numIter);
        }
        return acrf;
    }

    public ACRF trainAcrf() {
        if (this.training == null) {
            setupData();
        }
        ACRF acrf = new ACRF(this.featurePipe, this.tmpls);
        setupAcrf(acrf);
        this.trainer.train(acrf, this.training, null, this.testing, setupEvaluator(""), this.numIter);
        return acrf;
    }

    private void setupAcrf(ACRF acrf) {
        if (this.cacheUnrolledGraphs) {
            acrf.setCacheUnrolledGraphs(true);
        }
        if (this.inferencer != null) {
            acrf.setInferencer(this.inferencer);
        }
        if (this.viterbiInferencer != null) {
            acrf.setViterbiInferencer(this.viterbiInferencer);
        }
    }

    private ACRFEvaluator setupEvaluator(String str) {
        ACRFEvaluator aCRFEvaluator = this.evaluator;
        if (this.numCheckpointIterations > 0) {
            ArrayList arrayList = new ArrayList();
            arrayList.add(this.evaluator);
            arrayList.add(new CheckpointingEvaluator(this.checkpointDirectory, this.numCheckpointIterations, this.tokPipe, this.featurePipe));
            aCRFEvaluator = new AcrfSerialEvaluator(arrayList);
        }
        return aCRFEvaluator;
    }

    protected void setupData() {
        Timing timing = new Timing();
        this.training = new InstanceList(this.featurePipe);
        this.training.add(new PipedIterator(this.trainIterator, this.tokPipe));
        if (this.trainingPct > Transducer.ZERO_COST) {
            this.training = subsetData(this.training, this.trainingPct);
        }
        if (this.testIterator != null) {
            this.testing = new InstanceList(this.featurePipe);
            this.testing.add(new PipedIterator(this.testIterator, this.tokPipe));
            if (this.testingPct > Transducer.ZERO_COST) {
                this.testing = subsetData(this.testing, this.trainingPct);
            }
        }
        timing.tick("Data loading");
    }

    private InstanceList subsetData(InstanceList instanceList, double d) {
        return instanceList.split(this.r, new double[]{d, 1.0d - d})[0];
    }

    public InstanceList getTrainingData() {
        if (this.training == null) {
            setupData();
        }
        return this.training;
    }

    public InstanceList getTestingData() {
        if (this.testing == null) {
            setupData();
        }
        return this.testing;
    }

    public Extraction extractOnTestData(ACRFExtractor aCRFExtractor) {
        return aCRFExtractor.extract(this.testing);
    }

    public ACRFExtractorTrainer setInferencer(Inferencer inferencer) {
        this.inferencer = inferencer;
        return this;
    }

    public ACRFExtractorTrainer setViterbiInferencer(Inferencer inferencer) {
        this.viterbiInferencer = inferencer;
        return this;
    }

    public ACRFExtractorTrainer setDataSubsets(Random random, double d, double d2) {
        this.r = random;
        this.trainingPct = d;
        this.testingPct = d2;
        return this;
    }

    static Class class$(String str) {
        try {
            return Class.forName(str);
        } catch (ClassNotFoundException e) {
            throw new NoClassDefFoundError().initCause(e);
        }
    }

    static {
        Class cls;
        if (class$edu$umass$cs$mallet$grmm$learning$extract$ACRFExtractorTrainer == null) {
            cls = class$("edu.umass.cs.mallet.grmm.learning.extract.ACRFExtractorTrainer");
            class$edu$umass$cs$mallet$grmm$learning$extract$ACRFExtractorTrainer = cls;
        } else {
            cls = class$edu$umass$cs$mallet$grmm$learning$extract$ACRFExtractorTrainer;
        }
        logger = MalletLogger.getLogger(cls.getName());
    }
}
