package ws.palladian.classification.featureselection;

import java.util.Collection;
import java.util.HashSet;
import org.apache.commons.lang.Validate;
import ws.palladian.classification.evaluation.ClassificationEvaluator;
import ws.palladian.classification.evaluation.ConfusionMatrixEvaluator;
import ws.palladian.classification.evaluation.roc.RocCurves;
import ws.palladian.classification.featureselection.FeatureSelector;
import ws.palladian.core.Classifier;
import ws.palladian.core.Learner;
import ws.palladian.core.Model;
import ws.palladian.core.dataset.Dataset;
import ws.palladian.helper.functional.Factories;
import ws.palladian.helper.functional.Factory;
import ws.palladian.helper.functional.Filter;
import ws.palladian.helper.functional.Function;
import ws.palladian.helper.math.ConfusionMatrix;

/* loaded from: input_file:ws/palladian/classification/featureselection/FeatureSelectorConfig.class */
public class FeatureSelectorConfig {
    private final EvaluationConfig<?, ?> evaluator;
    private final int numThreads;
    private final Collection<? extends Filter<? super String>> featureGroups;
    private final boolean backward;

    /* loaded from: input_file:ws/palladian/classification/featureselection/FeatureSelectorConfig$Builder.class */
    public static final class Builder<M extends Model> implements Factory<FeatureSelector> {
        private final Factory<? extends Learner<M>> learnerFactory;
        private final Factory<? extends Classifier<M>> classifierFactory;
        private EvaluationConfig<M, ?> evaluator;
        private int numThreads;
        private Collection<Filter<? super String>> featureGroups;
        private boolean backward;

        private Builder(Learner<M> learner, Classifier<M> classifier) {
            this(Factories.constant(learner), Factories.constant(classifier));
        }

        private Builder(Factory<? extends Learner<M>> factory, Factory<? extends Classifier<M>> factory2) {
            this.numThreads = 1;
            this.featureGroups = new HashSet();
            this.backward = true;
            this.learnerFactory = factory;
            this.classifierFactory = factory2;
            this.evaluator = new EvaluationConfig<>(factory, factory2, new ConfusionMatrixEvaluator(), FeatureSelector.ACCURACY_SCORER);
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Deprecated
        public Builder<M> scorer(Function<ConfusionMatrix, Double> function) {
            Validate.notNull(function, "scorer must not be null");
            evaluator(new ConfusionMatrixEvaluator(), function);
            return this;
        }

        public Builder<M> scoreAccuracy() {
            scorer(FeatureSelector.ACCURACY_SCORER);
            return this;
        }

        public Builder<M> scoreF1(String str) {
            scorer(new FeatureSelector.FMeasureScorer(str));
            return this;
        }

        public Builder<M> scoreAuc(String str) {
            Validate.notNull(str, "className must not be null");
            evaluator(new RocCurves.RocCurvesEvaluator(str), new Function<RocCurves, Double>() { // from class: ws.palladian.classification.featureselection.FeatureSelectorConfig.Builder.1
                public Double compute(RocCurves rocCurves) {
                    return Double.valueOf(rocCurves.getAreaUnderCurve());
                }
            });
            return this;
        }

        public <R> Builder<M> evaluator(ClassificationEvaluator<R> classificationEvaluator, Function<R, Double> function) {
            this.evaluator = new EvaluationConfig<>(this.learnerFactory, this.classifierFactory, classificationEvaluator, function);
            return this;
        }

        public Builder<M> numThreads(int i) {
            Validate.isTrue(i > 0, "numThreads must be greater zero");
            this.numThreads = i;
            return this;
        }

        public Builder<M> featureGroups(Collection<? extends Filter<? super String>> collection) {
            Validate.notNull(collection, "featureGroups must not be null");
            this.featureGroups = new HashSet(collection);
            return this;
        }

        public Builder<M> addFeatureGroup(Filter<? super String> filter) {
            Validate.notNull(filter, "featureGroup must not be null");
            this.featureGroups.add(filter);
            return this;
        }

        public Builder<M> forward() {
            this.backward = false;
            return this;
        }

        public Builder<M> backward() {
            this.backward = true;
            return this;
        }

        /* renamed from: create, reason: merged with bridge method [inline-methods] */
        public FeatureSelector m15create() {
            return new FeatureSelector(createConfig());
        }

        public FeatureSelectorConfig createConfig() {
            if (this.learnerFactory == null) {
                throw new IllegalArgumentException("no learner specified");
            }
            if (this.classifierFactory == null) {
                throw new IllegalArgumentException("no classifier specified");
            }
            return new FeatureSelectorConfig(this);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:ws/palladian/classification/featureselection/FeatureSelectorConfig$EvaluationConfig.class */
    public static final class EvaluationConfig<M extends Model, R> {
        private final Factory<? extends Learner<M>> learnerFactory;
        private final Factory<? extends Classifier<M>> classifierFactory;
        private final ClassificationEvaluator<R> evaluator;
        private final Function<R, Double> mapper;

        private EvaluationConfig(Factory<? extends Learner<M>> factory, Factory<? extends Classifier<M>> factory2, ClassificationEvaluator<R> classificationEvaluator, Function<R, Double> function) {
            this.learnerFactory = factory;
            this.classifierFactory = factory2;
            this.evaluator = classificationEvaluator;
            this.mapper = function;
        }

        public double score(Dataset dataset, Dataset dataset2) {
            return ((Double) this.mapper.compute(this.evaluator.evaluate((Learner) this.learnerFactory.create(), (Classifier) this.classifierFactory.create(), dataset, dataset2))).doubleValue();
        }
    }

    public static <M extends Model> Builder<M> with(Learner<M> learner, Classifier<M> classifier) {
        return new Builder<>(learner, classifier);
    }

    public static <M extends Model> Builder<M> with(Factory<? extends Learner<M>> factory, Factory<? extends Classifier<M>> factory2) {
        return new Builder<>(factory, factory2);
    }

    protected FeatureSelectorConfig(Builder<?> builder) {
        this.evaluator = ((Builder) builder).evaluator;
        this.numThreads = ((Builder) builder).numThreads;
        this.featureGroups = ((Builder) builder).featureGroups;
        this.backward = ((Builder) builder).backward;
    }

    public EvaluationConfig<?, ?> evaluator() {
        return this.evaluator;
    }

    public int numThreads() {
        return this.numThreads;
    }

    public Collection<? extends Filter<? super String>> featureGroups() {
        return this.featureGroups;
    }

    public boolean isBackward() {
        return this.backward;
    }
}
