package ws.palladian.classification.universal;

import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
import org.apache.commons.lang3.Validate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import ws.palladian.classification.nb.NaiveBayesClassifier;
import ws.palladian.classification.nb.NaiveBayesLearner;
import ws.palladian.classification.nb.NaiveBayesModel;
import ws.palladian.classification.numeric.KnnClassifier;
import ws.palladian.classification.numeric.KnnLearner;
import ws.palladian.classification.numeric.KnnModel;
import ws.palladian.classification.text.DictionaryModel;
import ws.palladian.classification.text.FeatureSetting;
import ws.palladian.classification.text.FeatureSettingBuilder;
import ws.palladian.classification.text.PalladianTextClassifier;
import ws.palladian.classification.utils.NoNormalizer;
import ws.palladian.core.AbstractLearner;
import ws.palladian.core.CategoryEntries;
import ws.palladian.core.CategoryEntriesBuilder;
import ws.palladian.core.Classifier;
import ws.palladian.core.FeatureVector;
import ws.palladian.core.dataset.Dataset;

/* loaded from: input_file:ws/palladian/classification/universal/UniversalClassifier.class */
public class UniversalClassifier extends AbstractLearner<UniversalClassifierModel> implements Classifier<UniversalClassifierModel> {
    private static final Logger LOGGER = LoggerFactory.getLogger(UniversalClassifier.class);
    private final PalladianTextClassifier textClassifier;
    private final KnnClassifier numericClassifier;
    private final NaiveBayesClassifier nominalClassifier;
    private final Set<ClassifierSetting> settings;

    /* loaded from: input_file:ws/palladian/classification/universal/UniversalClassifier$ClassifierSetting.class */
    public enum ClassifierSetting {
        KNN,
        TEXT,
        BAYES
    }

    public UniversalClassifier() {
        this(FeatureSettingBuilder.chars(3, 7).m44create(), ClassifierSetting.values());
    }

    public UniversalClassifier(FeatureSetting featureSetting, ClassifierSetting... classifierSettingArr) {
        Validate.notNull(featureSetting, "featureSetting must not be null", new Object[0]);
        Validate.notNull(classifierSettingArr, "settings must not be null", new Object[0]);
        this.textClassifier = new PalladianTextClassifier(featureSetting);
        this.numericClassifier = new KnnClassifier(3);
        this.nominalClassifier = new NaiveBayesClassifier(1.0E-5d, false);
        this.settings = new HashSet(Arrays.asList(classifierSettingArr));
    }

    @Override // ws.palladian.core.Learner
    public UniversalClassifierModel train(Dataset dataset) {
        NaiveBayesModel naiveBayesModel = null;
        KnnModel knnModel = null;
        DictionaryModel dictionaryModel = null;
        if (this.settings.contains(ClassifierSetting.TEXT)) {
            LOGGER.debug("training text classifier");
            dictionaryModel = this.textClassifier.train(dataset);
        }
        if (this.settings.contains(ClassifierSetting.KNN)) {
            LOGGER.debug("training knn classifier");
            knnModel = new KnnLearner(new NoNormalizer()).train(dataset);
        }
        if (this.settings.contains(ClassifierSetting.BAYES)) {
            LOGGER.debug("training bayes classifier");
            naiveBayesModel = new NaiveBayesLearner().train(dataset);
        }
        return new UniversalClassifierModel(naiveBayesModel, knnModel, dictionaryModel);
    }

    @Override // ws.palladian.core.Classifier
    public CategoryEntries classify(FeatureVector featureVector, UniversalClassifierModel universalClassifierModel) {
        CategoryEntriesBuilder categoryEntriesBuilder = new CategoryEntriesBuilder();
        if (universalClassifierModel.getDictionaryModel() != null) {
            categoryEntriesBuilder.add(this.textClassifier.classify(featureVector, universalClassifierModel.getDictionaryModel()));
        }
        if (universalClassifierModel.getKnnModel() != null) {
            categoryEntriesBuilder.add(this.numericClassifier.classify(featureVector, universalClassifierModel.getKnnModel()));
        }
        if (universalClassifierModel.getBayesModel() != null) {
            categoryEntriesBuilder.add(this.nominalClassifier.classify(featureVector, universalClassifierModel.getBayesModel()));
        }
        return categoryEntriesBuilder.m76create();
    }
}
