package ws.palladian.classification.dt;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import org.apache.commons.lang3.Validate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import quickdt.HashMapAttributes;
import quickdt.PredictiveModel;
import quickdt.PredictiveModelBuilder;
import quickdt.TreeBuilder;
import quickdt.randomForest.RandomForestBuilder;
import ws.palladian.core.AbstractLearner;
import ws.palladian.core.FeatureVector;
import ws.palladian.core.Instance;
import ws.palladian.core.dataset.Dataset;
import ws.palladian.core.value.NominalValue;
import ws.palladian.core.value.NumericValue;
import ws.palladian.core.value.Value;
import ws.palladian.helper.collection.Vector;

/* loaded from: input_file:ws/palladian/classification/dt/QuickDtLearner.class */
public final class QuickDtLearner extends AbstractLearner<QuickDtModel> {
    private static final Logger LOGGER = LoggerFactory.getLogger(QuickDtLearner.class);
    private final PredictiveModelBuilder<? extends PredictiveModel> builder;

    public static QuickDtLearner randomForest() {
        return randomForest(10);
    }

    public static QuickDtLearner randomForest(int i) {
        Validate.isTrue(i > 0, "numTrees must be greater zero", new Object[0]);
        return new QuickDtLearner(new RandomForestBuilder(new TreeBuilder().ignoreAttributeAtNodeProbability(0.7d)).numTrees(i));
    }

    public static QuickDtLearner tree() {
        return new QuickDtLearner(new TreeBuilder());
    }

    @Deprecated
    public QuickDtLearner(PredictiveModelBuilder<? extends PredictiveModel> predictiveModelBuilder) {
        Validate.notNull(predictiveModelBuilder, "builder must not be null", new Object[0]);
        this.builder = predictiveModelBuilder;
    }

    @Override // ws.palladian.core.Learner
    public QuickDtModel train(Dataset dataset) {
        HashSet hashSet = new HashSet();
        HashSet hashSet2 = new HashSet();
        for (Instance instance : dataset) {
            hashSet.add(HashMapAttributes.create(getInput(instance.getVector())).classification(instance.getCategory()));
            hashSet2.add(instance.getCategory());
        }
        return new QuickDtModel(this.builder.buildPredictiveModel(hashSet), hashSet2);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Serializable[] getInput(FeatureVector featureVector) {
        ArrayList arrayList = new ArrayList();
        Iterator it = featureVector.iterator();
        while (it.hasNext()) {
            Vector.VectorEntry vectorEntry = (Vector.VectorEntry) it.next();
            Value value = (Value) vectorEntry.value();
            if (value instanceof NominalValue) {
                arrayList.add(vectorEntry.key());
                arrayList.add(((NominalValue) value).getString());
            } else if (value instanceof NumericValue) {
                arrayList.add(vectorEntry.key());
                arrayList.add(Double.valueOf(((NumericValue) value).getDouble()));
            } else {
                LOGGER.trace("Unsupported type for {}: {}", vectorEntry.key(), value.getClass().getName());
            }
        }
        return (Serializable[]) arrayList.toArray(new Serializable[arrayList.size()]);
    }

    public String toString() {
        return "QuickDtLearner (" + this.builder + ")";
    }
}
