package ws.palladian.classification.featureselection;

import java.io.File;
import java.util.HashMap;
import java.util.Set;
import org.apache.commons.lang3.Validate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import ws.palladian.classification.evaluation.ClassificationEvaluator;
import ws.palladian.classification.evaluation.ConfusionMatrixEvaluator;
import ws.palladian.classification.nb.NaiveBayesClassifier;
import ws.palladian.classification.nb.NaiveBayesLearner;
import ws.palladian.classification.utils.CsvDatasetReaderConfig;
import ws.palladian.core.Classifier;
import ws.palladian.core.Instance;
import ws.palladian.core.Learner;
import ws.palladian.core.Model;
import ws.palladian.core.dataset.Dataset;
import ws.palladian.core.dataset.DefaultDataset;
import ws.palladian.helper.ProgressReporter;
import ws.palladian.helper.collection.CollectionHelper;
import ws.palladian.helper.functional.Filter;
import ws.palladian.helper.functional.Filters;
import ws.palladian.helper.functional.Function;
import ws.palladian.helper.math.ConfusionMatrix;

/* loaded from: input_file:ws/palladian/classification/featureselection/SingleFeatureClassification.class */
public final class SingleFeatureClassification extends AbstractFeatureRanker {
    private static final Logger LOGGER = LoggerFactory.getLogger(SingleFeatureClassification.class);
    private final EvaluatorAndMapper<?, ?> evaluatorAndMapper;

    /* loaded from: input_file:ws/palladian/classification/featureselection/SingleFeatureClassification$EvaluatorAndMapper.class */
    private static final class EvaluatorAndMapper<R, M extends Model> {
        private final Learner<M> learner;
        private final Classifier<M> classifier;
        private final ClassificationEvaluator<R> evaluator;
        private final Function<R, Double> mapper;

        EvaluatorAndMapper(Learner<M> learner, Classifier<M> classifier, ClassificationEvaluator<R> classificationEvaluator, Function<R, Double> function) {
            this.learner = learner;
            this.classifier = classifier;
            this.evaluator = classificationEvaluator;
            this.mapper = function;
        }

        Double evaluate(Dataset dataset, Dataset dataset2) {
            return (Double) this.mapper.compute(this.evaluator.evaluate((Classifier<Classifier<M>>) this.classifier, (Classifier<M>) this.learner.train(dataset), dataset2));
        }
    }

    public <R, M extends Model> SingleFeatureClassification(Learner<M> learner, Classifier<M> classifier, ClassificationEvaluator<R> classificationEvaluator, Function<R, Double> function) {
        Validate.notNull(learner, "learner must not be null", new Object[0]);
        Validate.notNull(classifier, "classifier must not be null", new Object[0]);
        Validate.notNull(classificationEvaluator, "evaluator must not be null", new Object[0]);
        Validate.notNull(function, "mapper must not be null", new Object[0]);
        this.evaluatorAndMapper = new EvaluatorAndMapper<>(learner, classifier, classificationEvaluator, function);
    }

    @Deprecated
    public <M extends Model> SingleFeatureClassification(Learner<M> learner, Classifier<M> classifier, Function<ConfusionMatrix, Double> function) {
        this(learner, classifier, new ConfusionMatrixEvaluator(), function);
    }

    @Deprecated
    public FeatureRanking rankFeatures(Iterable<? extends Instance> iterable, Iterable<? extends Instance> iterable2) {
        return rankFeatures((Dataset) new DefaultDataset(iterable), (Dataset) new DefaultDataset(iterable2));
    }

    @Override // ws.palladian.classification.featureselection.AbstractFeatureRanker, ws.palladian.classification.featureselection.FeatureRanker
    public FeatureRanking rankFeatures(Dataset dataset, Dataset dataset2, ProgressReporter progressReporter) {
        HashMap hashMap = new HashMap();
        Set<String> featureNames = dataset.getFeatureInformation().getFeatureNames();
        progressReporter.startTask("Single feature classification", featureNames.size());
        for (String str : featureNames) {
            Filter<? super String> equal = Filters.equal(str);
            Double evaluate = this.evaluatorAndMapper.evaluate(dataset.filterFeatures(equal), dataset2.filterFeatures(equal));
            LOGGER.info("Finished testing with {}: {}", str, evaluate);
            progressReporter.increment();
            hashMap.put(str, evaluate);
        }
        return new FeatureRanking(hashMap);
    }

    public static void main(String[] strArr) {
        CollectionHelper.print(new SingleFeatureClassification(new NaiveBayesLearner(), new NaiveBayesClassifier(), new Function<ConfusionMatrix, Double>() { // from class: ws.palladian.classification.featureselection.SingleFeatureClassification.1
            public Double compute(ConfusionMatrix confusionMatrix) {
                double f = confusionMatrix.getF(1.0d, "true");
                return Double.valueOf(Double.isNaN(f) ? 0.0d : f);
            }
        }).rankFeatures((Dataset) CsvDatasetReaderConfig.filePath(new File("/Users/pk/Dropbox/LocationExtraction/BFE/fd_merged_train.csv")).m51create(), (Dataset) CsvDatasetReaderConfig.filePath(new File("/Users/pk/Dropbox/LocationExtraction/BFE/fd_merged_validation.csv")).m51create()).getAll());
    }
}
