package ws.palladian.classification.text;

import java.io.IOException;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import opennlp.tools.doccat.DoccatModel;
import opennlp.tools.doccat.DocumentCategorizerME;
import opennlp.tools.doccat.DocumentSample;
import opennlp.tools.doccat.FeatureGenerator;
import opennlp.tools.util.ObjectStream;
import org.apache.commons.lang3.Validate;
import ws.palladian.core.AbstractLearner;
import ws.palladian.core.CategoryEntries;
import ws.palladian.core.CategoryEntriesBuilder;
import ws.palladian.core.Classifier;
import ws.palladian.core.FeatureVector;
import ws.palladian.core.Instance;
import ws.palladian.core.Model;
import ws.palladian.core.dataset.Dataset;
import ws.palladian.core.value.TextValue;
import ws.palladian.helper.constants.Language;

/* loaded from: input_file:ws/palladian/classification/text/OpenNlpTextClassifier.class */
public final class OpenNlpTextClassifier extends AbstractLearner<OpenNlpTextClassifierModel> implements Classifier<OpenNlpTextClassifierModel> {
    private static final int DEFAULT_CUTOFF = 5;
    private static final int DEFAULT_ITERATIONS = 100;
    private final Language language;
    private final FeatureGenerator featureGenerator;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ws/palladian/classification/text/OpenNlpTextClassifier$InstanceObjectStream.class */
    public static final class InstanceObjectStream implements ObjectStream<DocumentSample> {
        private final Iterator<? extends Instance> instances;

        private InstanceObjectStream(Iterator<? extends Instance> it) {
            this.instances = it;
        }

        /* renamed from: read, reason: merged with bridge method [inline-methods] */
        public DocumentSample m16read() throws IOException {
            if (!this.instances.hasNext()) {
                return null;
            }
            Instance next = this.instances.next();
            return new DocumentSample(next.getCategory(), ((TextValue) next.getVector().get("text")).getText());
        }

        public void reset() throws IOException, UnsupportedOperationException {
        }

        public void close() throws IOException {
        }
    }

    /* loaded from: input_file:ws/palladian/classification/text/OpenNlpTextClassifier$OpenNlpTextClassifierModel.class */
    public static final class OpenNlpTextClassifierModel implements Model {
        private static final long serialVersionUID = 1;
        private final DoccatModel doccatModel;
        private final String featureGeneratorName;

        OpenNlpTextClassifierModel(DoccatModel doccatModel, String str) {
            this.doccatModel = doccatModel;
            this.featureGeneratorName = str;
        }

        public Set<String> getCategories() {
            HashSet hashSet = new HashSet();
            DocumentCategorizerME documentCategorizerME = new DocumentCategorizerME(this.doccatModel);
            for (int i = 0; i < documentCategorizerME.getNumberOfCategories(); i++) {
                hashSet.add(documentCategorizerME.getCategory(i));
            }
            return Collections.unmodifiableSet(hashSet);
        }

        public FeatureGenerator getFeatureGenerator() {
            try {
                return (FeatureGenerator) Class.forName(this.featureGeneratorName).newInstance();
            } catch (ClassNotFoundException e) {
                throw new IllegalStateException("Could not instantiate \"" + this.featureGeneratorName + "\".");
            } catch (IllegalAccessException e2) {
                throw new IllegalStateException("Could not instantiate \"" + this.featureGeneratorName + "\".");
            } catch (InstantiationException e3) {
                throw new IllegalStateException("Could not instantiate \"" + this.featureGeneratorName + "\".");
            }
        }
    }

    public OpenNlpTextClassifier(Language language, FeatureGenerator featureGenerator) {
        Validate.notNull(language, "language must not be null", new Object[0]);
        Validate.notNull(featureGenerator, "featureGenerator must not be null", new Object[0]);
        this.language = language;
        this.featureGenerator = featureGenerator;
    }

    public CategoryEntries classify(FeatureVector featureVector, OpenNlpTextClassifierModel openNlpTextClassifierModel) {
        String text = ((TextValue) featureVector.get("text")).getText();
        DocumentCategorizerME documentCategorizerME = new DocumentCategorizerME(openNlpTextClassifierModel.doccatModel, new FeatureGenerator[]{this.featureGenerator});
        double[] categorize = documentCategorizerME.categorize(text);
        CategoryEntriesBuilder categoryEntriesBuilder = new CategoryEntriesBuilder();
        for (int i = 0; i < categorize.length; i++) {
            categoryEntriesBuilder.set(documentCategorizerME.getCategory(i), categorize[i]);
        }
        return categoryEntriesBuilder.create();
    }

    /* renamed from: train, reason: merged with bridge method [inline-methods] */
    public OpenNlpTextClassifierModel m15train(Dataset dataset) {
        try {
            return new OpenNlpTextClassifierModel(DocumentCategorizerME.train(this.language != null ? this.language.getIso6391() : "", new InstanceObjectStream(dataset.iterator()), DEFAULT_CUTOFF, DEFAULT_ITERATIONS, new FeatureGenerator[]{this.featureGenerator}), this.featureGenerator.getClass().getName());
        } catch (IOException e) {
            throw new IllegalStateException("Encountered IOException during training", e);
        }
    }
}
