package ws.palladian.classification.quickml;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import org.apache.commons.lang3.Validate;
import quickml.data.instances.ClassifierInstance;
import quickml.supervised.PredictiveModelBuilder;
import quickml.supervised.classifier.Classifier;
import quickml.supervised.ensembles.randomForest.randomDecisionForest.RandomDecisionForestBuilder;
import quickml.supervised.tree.decisionTree.DecisionTreeBuilder;
import ws.palladian.classification.quickml.FlyweightAttributesMap;
import ws.palladian.core.AbstractLearner;
import ws.palladian.core.Instance;
import ws.palladian.core.dataset.Dataset;

/* loaded from: input_file:ws/palladian/classification/quickml/QuickMlLearner.class */
public final class QuickMlLearner extends AbstractLearner<QuickMlModel> {
    private final PredictiveModelBuilder<? extends Classifier, ClassifierInstance> builder;

    public static QuickMlLearner randomForest() {
        return randomForest(10);
    }

    public static QuickMlLearner randomForest(int i) {
        Validate.isTrue(i > 0, "numTrees must be greater zero", new Object[0]);
        return new QuickMlLearner(new RandomDecisionForestBuilder(new DecisionTreeBuilder().ignoreAttributeProbability(0.7d)).numTrees(i));
    }

    public static QuickMlLearner tree() {
        return new QuickMlLearner(new DecisionTreeBuilder());
    }

    public QuickMlLearner(PredictiveModelBuilder<? extends Classifier, ClassifierInstance> predictiveModelBuilder) {
        Validate.notNull(predictiveModelBuilder, "builder must not be null", new Object[0]);
        this.builder = predictiveModelBuilder;
    }

    /* renamed from: train, reason: merged with bridge method [inline-methods] */
    public QuickMlModel m9train(Dataset dataset) {
        Validate.notNull(dataset, "instances must not be null", new Object[0]);
        FlyweightAttributesMap.Builder builder = new FlyweightAttributesMap.Builder(dataset.getFeatureInformation().getFeatureNames());
        ArrayList arrayList = new ArrayList();
        HashSet hashSet = new HashSet();
        Iterator it = dataset.iterator();
        while (it.hasNext()) {
            Instance instance = (Instance) it.next();
            arrayList.add(new ClassifierInstance(builder.create(instance.getVector()), instance.getCategory()));
            hashSet.add(instance.getCategory());
        }
        return new QuickMlModel(this.builder.buildPredictiveModel(arrayList), hashSet);
    }

    public String toString() {
        return getClass().getSimpleName() + " (" + this.builder.getClass().getSimpleName() + ")";
    }
}
