package ws.palladian.classification.featureselection;

import java.io.File;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.function.Function;
import java.util.function.Predicate;
import org.apache.commons.lang3.Validate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import ws.palladian.classification.dt.QuickDtClassifier;
import ws.palladian.classification.dt.QuickDtLearner;
import ws.palladian.classification.utils.CsvDatasetReaderConfig;
import ws.palladian.core.Classifier;
import ws.palladian.core.Instance;
import ws.palladian.core.Learner;
import ws.palladian.core.Model;
import ws.palladian.core.dataset.Dataset;
import ws.palladian.core.dataset.DefaultDataset;
import ws.palladian.helper.ProgressMonitor;
import ws.palladian.helper.ProgressReporter;
import ws.palladian.helper.collection.CollectionHelper;
import ws.palladian.helper.functional.Factories;
import ws.palladian.helper.functional.Factory;
import ws.palladian.helper.functional.Predicates;
import ws.palladian.helper.math.ConfusionMatrix;

/* loaded from: input_file:ws/palladian/classification/featureselection/FeatureSelector.class */
public final class FeatureSelector extends AbstractFeatureRanker {
    private final FeatureSelectorConfig config;
    private static final Logger LOGGER = LoggerFactory.getLogger(FeatureSelector.class);

    @Deprecated
    public static final Function<ConfusionMatrix, Double> ACCURACY_SCORER = new Function<ConfusionMatrix, Double>() { // from class: ws.palladian.classification.featureselection.FeatureSelector.1
        @Override // java.util.function.Function
        public Double apply(ConfusionMatrix confusionMatrix) {
            return Double.valueOf(confusionMatrix.getAccuracy());
        }
    };

    @Deprecated
    /* loaded from: input_file:ws/palladian/classification/featureselection/FeatureSelector$FMeasureScorer.class */
    public static final class FMeasureScorer implements Function<ConfusionMatrix, Double> {
        private final String className;

        public FMeasureScorer(String str) {
            Validate.notEmpty(str, "className must not be empty", new Object[0]);
            this.className = str;
        }

        @Override // java.util.function.Function
        public Double apply(ConfusionMatrix confusionMatrix) {
            double f = confusionMatrix.getF(1.0d, this.className);
            return Double.valueOf(Double.isNaN(f) ? 0.0d : f);
        }

        public String toString() {
            return "FMeasureScorer [class=" + this.className + "]";
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ws/palladian/classification/featureselection/FeatureSelector$TestRun.class */
    public final class TestRun implements Callable<TestRunResult> {
        private final Dataset trainData;
        private final Dataset testData;
        private final Predicate<? super String> features;
        private final Predicate<? super String> evaluatedFeature;
        private final ProgressReporter progress;

        public TestRun(Dataset dataset, Dataset dataset2, Predicate<? super String> predicate, Predicate<? super String> predicate2, ProgressReporter progressReporter) {
            this.trainData = dataset;
            this.testData = dataset2;
            this.features = predicate;
            this.evaluatedFeature = predicate2;
            this.progress = progressReporter;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public TestRunResult call() throws Exception {
            FeatureSelector.LOGGER.trace("Starting evaluation for {}", this.features);
            Double valueOf = Double.valueOf(FeatureSelector.this.config.evaluator().score(this.trainData.filterFeatures(this.features), this.testData.filterFeatures(this.features)));
            FeatureSelector.LOGGER.debug("Finished evaluation for {}, score {}", this.features, valueOf);
            this.progress.increment();
            return new TestRunResult(valueOf, this.evaluatedFeature);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ws/palladian/classification/featureselection/FeatureSelector$TestRunResult.class */
    public static final class TestRunResult {
        private final Double score;
        private final Predicate<? super String> evaluatedFeature;

        public TestRunResult(Double d, Predicate<? super String> predicate) {
            this.score = d;
            this.evaluatedFeature = predicate;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public FeatureSelector(FeatureSelectorConfig featureSelectorConfig) {
        this.config = featureSelectorConfig;
    }

    @Deprecated
    public <M extends Model> FeatureSelector(Learner<M> learner, Classifier<M> classifier, Function<ConfusionMatrix, Double> function) {
        this(FeatureSelectorConfig.with(learner, classifier).scorer(function).createConfig());
    }

    @Deprecated
    public <M extends Model> FeatureSelector(Learner<M> learner, Classifier<M> classifier) {
        this(FeatureSelectorConfig.with(learner, classifier).scoreAccuracy().createConfig());
    }

    @Deprecated
    public <M extends Model> FeatureSelector(Factory<? extends Learner<M>> factory, Factory<? extends Classifier<M>> factory2, Function<ConfusionMatrix, Double> function, int i) {
        this(FeatureSelectorConfig.with(factory, factory2).scorer(function).numThreads(i).createConfig());
    }

    @Deprecated
    public FeatureRanking rankFeatures(Iterable<? extends Instance> iterable, Iterable<? extends Instance> iterable2, ProgressReporter progressReporter) {
        return rankFeatures((Dataset) new DefaultDataset(iterable), (Dataset) new DefaultDataset(iterable2), progressReporter);
    }

    @Override // ws.palladian.classification.featureselection.AbstractFeatureRanker, ws.palladian.classification.featureselection.FeatureRanker
    public FeatureRanking rankFeatures(Dataset dataset, Dataset dataset2, ProgressReporter progressReporter) {
        HashMap hashMap = new HashMap();
        Set<Predicate<? super String>> constructFeatureFilters = constructFeatureFilters(dataset);
        ArrayList arrayList = new ArrayList();
        int size = (constructFeatureFilters.size() * (constructFeatureFilters.size() + 1)) / 2;
        progressReporter.startTask("Feature selection", size);
        int size2 = this.config.isBackward() ? 0 : constructFeatureFilters.size();
        LOGGER.info("# of features or feature sets: {}", Integer.valueOf(constructFeatureFilters.size()));
        LOGGER.info("# of iterations: {}", Integer.valueOf(size));
        try {
            if (this.config.isBackward()) {
                LOGGER.info("Score with all features {}", new TestRun(dataset, dataset2, Predicates.ALL, Predicates.NONE, progressReporter).call().score);
            }
            ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(this.config.numThreads());
            while (true) {
                HashSet<Predicate> hashSet = new HashSet(constructFeatureFilters);
                hashSet.removeAll(arrayList);
                if (hashSet.isEmpty()) {
                    newFixedThreadPool.shutdown();
                    return new FeatureRanking(hashMap);
                }
                ArrayList arrayList2 = new ArrayList();
                for (Predicate predicate : hashSet) {
                    ArrayList arrayList3 = new ArrayList(arrayList);
                    arrayList3.add(predicate);
                    Predicate or = Predicates.or(arrayList3);
                    if (this.config.isBackward()) {
                        or = Predicates.not(or);
                    }
                    arrayList2.add(new TestRun(dataset, dataset2, or, predicate, progressReporter));
                }
                Predicate predicate2 = null;
                double d = 0.0d;
                Iterator it = newFixedThreadPool.invokeAll(arrayList2).iterator();
                while (it.hasNext()) {
                    TestRunResult testRunResult = (TestRunResult) ((Future) it.next()).get();
                    if (testRunResult.score.doubleValue() >= d || predicate2 == null) {
                        d = testRunResult.score.doubleValue();
                        predicate2 = testRunResult.evaluatedFeature;
                    }
                }
                LOGGER.info("Selected {}, score {}", predicate2, Double.valueOf(d));
                arrayList.add(predicate2);
                String obj = predicate2.toString();
                int i = size2 + (this.config.isBackward() ? 1 : -1);
                size2 = i;
                hashMap.put(obj, Integer.valueOf(i));
            }
        } catch (Exception e) {
            throw new IllegalStateException(e);
        }
    }

    private Set<Predicate<? super String>> constructFeatureFilters(Dataset dataset) {
        HashSet hashSet = new HashSet(this.config.featureGroups());
        Iterator it = CollectionHelper.filter(dataset.getFeatureInformation().getFeatureNames(), Predicates.not(Predicates.or(this.config.featureGroups()))).iterator();
        while (it.hasNext()) {
            hashSet.add(Predicates.equal((String) it.next()));
        }
        return hashSet;
    }

    public static void main(String[] strArr) {
        CollectionHelper.print(new FeatureSelector(new Factory<QuickDtLearner>() { // from class: ws.palladian.classification.featureselection.FeatureSelector.2
            /* renamed from: create, reason: merged with bridge method [inline-methods] */
            public QuickDtLearner m20create() {
                return QuickDtLearner.randomForest(10);
            }
        }, Factories.constant(new QuickDtClassifier()), new FMeasureScorer("true"), 1).rankFeatures((Dataset) CsvDatasetReaderConfig.filePath(new File("/path/to/training.csv")).m62create(), (Dataset) CsvDatasetReaderConfig.filePath(new File("/path/to/validation.csv")).m62create(), (ProgressReporter) new ProgressMonitor()).getAll());
    }
}
