package ws.palladian.classification.text;

import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.function.Function;
import org.apache.commons.lang3.Validate;
import ws.palladian.classification.text.DictionaryTrieModel;
import ws.palladian.core.AbstractLearner;
import ws.palladian.core.Category;
import ws.palladian.core.CategoryEntries;
import ws.palladian.core.CategoryEntriesBuilder;
import ws.palladian.core.Classifier;
import ws.palladian.core.FeatureVector;
import ws.palladian.core.Instance;
import ws.palladian.core.InstanceBuilder;
import ws.palladian.core.dataset.Dataset;
import ws.palladian.core.value.TextValue;
import ws.palladian.helper.ProgressMonitor;
import ws.palladian.helper.collection.Bag;

/* loaded from: input_file:ws/palladian/classification/text/PalladianTextClassifier.class */
public class PalladianTextClassifier extends AbstractLearner<DictionaryModel> implements Classifier<DictionaryModel> {
    public static final String VECTOR_TEXT_IDENTIFIER = "text";
    public static final Scorer DEFAULT_SCORER = new DefaultScorer();
    private final DictionaryBuilder dictionaryBuilder;
    private final FeatureSetting featureSetting;
    private final Scorer scorer;
    private final Function<String, Iterator<String>> preprocessor;

    /* loaded from: input_file:ws/palladian/classification/text/PalladianTextClassifier$DefaultScorer.class */
    public static class DefaultScorer implements Scorer {
        @Override // ws.palladian.classification.text.PalladianTextClassifier.Scorer
        public double score(String str, String str2, int i, int i2, int i3, int i4, int i5, int i6, int i7) {
            if (i2 == 0) {
                return 0.0d;
            }
            double d = i / i2;
            return d * d;
        }

        @Override // ws.palladian.classification.text.PalladianTextClassifier.Scorer
        public double scoreCategory(String str, double d, double d2, boolean z) {
            return z ? d : d2;
        }

        @Override // ws.palladian.classification.text.PalladianTextClassifier.Scorer
        public boolean scoreNonMatches() {
            return false;
        }

        public String toString() {
            return getClass().getSimpleName();
        }
    }

    /* loaded from: input_file:ws/palladian/classification/text/PalladianTextClassifier$Scorer.class */
    public interface Scorer {
        double score(String str, String str2, int i, int i2, int i3, int i4, int i5, int i6, int i7);

        double scoreCategory(String str, double d, double d2, boolean z);

        boolean scoreNonMatches();
    }

    public PalladianTextClassifier(FeatureSetting featureSetting) {
        this(featureSetting, new DefaultScorer());
    }

    public PalladianTextClassifier(FeatureSetting featureSetting, DictionaryBuilder dictionaryBuilder) {
        Validate.notNull(dictionaryBuilder, "dictionaryBuilder must not be null", new Object[0]);
        Validate.notNull(featureSetting, "featureSetting must not be null", new Object[0]);
        this.dictionaryBuilder = dictionaryBuilder;
        this.dictionaryBuilder.setFeatureSetting(featureSetting);
        this.featureSetting = featureSetting;
        this.scorer = new DefaultScorer();
        this.preprocessor = new Preprocessor(featureSetting);
    }

    public PalladianTextClassifier(FeatureSetting featureSetting, Scorer scorer) {
        Validate.notNull(featureSetting, "featureSetting must not be null", new Object[0]);
        Validate.notNull(scorer, "scorer must not be null", new Object[0]);
        Validate.notNull(scorer, "scorer must not be null", new Object[0]);
        this.dictionaryBuilder = new DictionaryTrieModel.Builder();
        this.dictionaryBuilder.setFeatureSetting(featureSetting);
        this.featureSetting = featureSetting;
        this.scorer = scorer;
        this.preprocessor = new Preprocessor(featureSetting);
    }

    @Override // ws.palladian.core.Learner
    public DictionaryModel train(Dataset dataset) {
        Validate.notNull(dataset, "dataset must not be null", new Object[0]);
        ProgressMonitor progressMonitor = new ProgressMonitor(dataset.size(), 0.1d, "Training text classifier");
        Iterator<Instance> iterator2 = dataset.iterator2();
        while (iterator2.hasNext()) {
            Instance next = iterator2.next();
            String category = next.getCategory();
            Iterator<String> apply = this.preprocessor.apply(((TextValue) next.getVector().get(VECTOR_TEXT_IDENTIFIER)).getText());
            HashSet hashSet = new HashSet();
            while (apply.hasNext() && hashSet.size() < this.featureSetting.getMaxTerms()) {
                hashSet.add(apply.next());
            }
            this.dictionaryBuilder.addDocument(hashSet, category, next.getWeight());
            progressMonitor.incrementAndPrintProgress();
        }
        return (DictionaryModel) this.dictionaryBuilder.create();
    }

    @Override // ws.palladian.core.Classifier
    public CategoryEntries classify(FeatureVector featureVector, DictionaryModel dictionaryModel) {
        Validate.notNull(featureVector, "featureVector must not be null", new Object[0]);
        Validate.notNull(dictionaryModel, "model must not be null", new Object[0]);
        CategoryEntriesBuilder categoryEntriesBuilder = new CategoryEntriesBuilder();
        Iterator<String> apply = this.preprocessor.apply(((TextValue) featureVector.get(VECTOR_TEXT_IDENTIFIER)).getText());
        Bag bag = new Bag();
        while (apply.hasNext() && bag.uniqueItems().size() < this.featureSetting.getMaxTerms()) {
            bag.add(apply.next());
        }
        CategoryEntries termCounts = dictionaryModel.getTermCounts();
        int numUniqTerms = dictionaryModel.getNumUniqTerms();
        int numDocuments = dictionaryModel.getNumDocuments();
        int numTerms = dictionaryModel.getNumTerms();
        boolean scoreNonMatches = this.scorer.scoreNonMatches();
        HashSet hashSet = new HashSet();
        for (Map.Entry entry : bag.unique()) {
            String str = (String) entry.getKey();
            CategoryEntries categoryEntries = dictionaryModel.getCategoryEntries(str);
            int intValue = ((Integer) entry.getValue()).intValue();
            int totalCount = categoryEntries.getTotalCount();
            for (Category category : categoryEntries) {
                String name = category.getName();
                categoryEntriesBuilder.add(name, this.scorer.score(str, name, category.getCount(), totalCount, intValue, termCounts.getCount(name), numUniqTerms, numDocuments, numTerms));
                if (scoreNonMatches) {
                    hashSet.add(name);
                }
            }
            if (scoreNonMatches) {
                for (Category category2 : termCounts) {
                    String name2 = category2.getName();
                    if (!hashSet.contains(name2)) {
                        categoryEntriesBuilder.add(name2, this.scorer.score(str, name2, 0, totalCount, intValue, category2.getCount(), numUniqTerms, numDocuments, numTerms));
                    }
                }
                hashSet.clear();
            }
        }
        boolean z = categoryEntriesBuilder.getTotalScore() != 0.0d;
        for (Category category3 : dictionaryModel.getDocumentCounts()) {
            String name3 = category3.getName();
            categoryEntriesBuilder.set(name3, this.scorer.scoreCategory(name3, categoryEntriesBuilder.getScore(name3), category3.getProbability(), z));
        }
        return categoryEntriesBuilder.m76create();
    }

    public CategoryEntries classify(String str, DictionaryModel dictionaryModel) {
        Validate.notNull(str, "text must not be null", new Object[0]);
        Validate.notNull(dictionaryModel, "model must not be null", new Object[0]);
        return classify(new InstanceBuilder().setText(str).create(), dictionaryModel);
    }

    @Override // ws.palladian.core.AbstractLearner
    public String toString() {
        return getClass().getSimpleName() + "[scorer=" + this.scorer + ", featureSetting=" + this.featureSetting + "]";
    }
}
