package ws.palladian.classification.xgboost;

import gnu.trove.list.array.TFloatArrayList;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.list.array.TLongArrayList;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.IEvaluation;
import ml.dmlc.xgboost4j.java.IObjective;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.apache.commons.lang3.Validate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import ws.palladian.core.AbstractLearner;
import ws.palladian.core.FeatureVector;
import ws.palladian.core.Instance;
import ws.palladian.core.dataset.Dataset;
import ws.palladian.core.dataset.statistics.DatasetStatistics;
import ws.palladian.core.value.NullValue;
import ws.palladian.core.value.NumericValue;
import ws.palladian.core.value.Value;
import ws.palladian.helper.collection.CollectionHelper;
import ws.palladian.helper.collection.Vector;
import ws.palladian.helper.io.CloseableIterator;

/* loaded from: input_file:ws/palladian/classification/xgboost/XGBoostLearner.class */
public class XGBoostLearner extends AbstractLearner<XGBoostModel> {
    private static final Logger LOGGER = LoggerFactory.getLogger(XGBoostLearner.class);
    private final Map<String, Object> params;
    private final int rounds;
    private final IEvaluation evaluation;

    public XGBoostLearner(Map<String, Object> map, int i, IEvaluation iEvaluation) {
        Validate.notNull(map, "params must not be null", new Object[0]);
        Validate.isTrue(i > 0, "round must be at least 1", new Object[0]);
        this.params = new HashMap(map);
        this.rounds = i;
        this.evaluation = iEvaluation;
    }

    public XGBoostLearner(Map<String, Object> map, int i) {
        this(map, i, null);
    }

    public XGBoostLearner() {
        HashMap hashMap = new HashMap();
        hashMap.put("objective", "binary:logistic");
        hashMap.put("early_stopping_rounds", "50");
        hashMap.put("eval_metric", "auc");
        hashMap.put("booster", "gbtree");
        hashMap.put("eta", Double.valueOf(0.02d));
        hashMap.put("subsample", Double.valueOf(0.7d));
        hashMap.put("colsample_bytree", Double.valueOf(0.7d));
        hashMap.put("min_child_weight", 0);
        hashMap.put("min_child_weight", 0);
        hashMap.put("max_depth", 10);
        hashMap.put("silent", 0);
        this.params = hashMap;
        this.rounds = 100;
        this.evaluation = null;
    }

    /* renamed from: train, reason: merged with bridge method [inline-methods] */
    public XGBoostModel m22train(Dataset dataset, Dataset dataset2) {
        ArrayList arrayList = new ArrayList(new DatasetStatistics(dataset).getCategoryStatistics().getValues());
        ArrayList arrayList2 = new ArrayList(dataset.getFeatureInformation().getFeatureNamesOfType(NumericValue.class));
        if (arrayList2.isEmpty()) {
            throw new IllegalArgumentException("The training data contains no numeric features.");
        }
        HashMap hashMap = new HashMap(this.params);
        if (arrayList.size() > 2) {
            LOGGER.debug("num_class = {}", Integer.valueOf(arrayList.size()));
            hashMap.put("num_class", Integer.valueOf(arrayList.size()));
        }
        Map createIndexMap = CollectionHelper.createIndexMap(arrayList);
        Map createIndexMap2 = CollectionHelper.createIndexMap(arrayList2);
        try {
            DMatrix makeMatrix = makeMatrix(dataset, createIndexMap, createIndexMap2);
            HashMap hashMap2 = new HashMap();
            hashMap2.put("training", makeMatrix);
            if (dataset2 != null) {
                LOGGER.debug("Using dedicated validation set");
                hashMap2.put("validation", makeMatrix(dataset2, createIndexMap, createIndexMap2));
            }
            return new XGBoostModel(XGBoost.train(makeMatrix, hashMap, this.rounds, hashMap2, (IObjective) null, this.evaluation), arrayList, createIndexMap2);
        } catch (XGBoostError e) {
            throw new IllegalStateException((Throwable) e);
        }
    }

    /* renamed from: train, reason: merged with bridge method [inline-methods] */
    public XGBoostModel m23train(Dataset dataset) {
        return m22train(dataset, (Dataset) null);
    }

    private static DMatrix makeMatrix(Dataset dataset, Map<String, Integer> map, Map<String, Integer> map2) throws XGBoostError {
        TFloatArrayList tFloatArrayList = new TFloatArrayList();
        TFloatArrayList tFloatArrayList2 = new TFloatArrayList();
        TLongArrayList tLongArrayList = new TLongArrayList();
        TIntArrayList tIntArrayList = new TIntArrayList();
        long j = 0;
        tLongArrayList.add(0L);
        CloseableIterator it = dataset.iterator();
        while (it.hasNext()) {
            Instance instance = (Instance) it.next();
            if (map.get(instance.getCategory()) != null) {
                tFloatArrayList.add(r0.intValue());
                j += makeRow(map2, tFloatArrayList2, tIntArrayList, instance.getVector());
                tLongArrayList.add(j);
            }
        }
        DMatrix dMatrix = new DMatrix(tLongArrayList.toArray(), tIntArrayList.toArray(), tFloatArrayList2.toArray(), DMatrix.SparseType.CSR);
        dMatrix.setLabel(tFloatArrayList.toArray());
        return dMatrix;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static long makeRow(Map<String, Integer> map, TFloatArrayList tFloatArrayList, TIntArrayList tIntArrayList, FeatureVector featureVector) {
        long j = 0;
        Iterator it = featureVector.iterator();
        while (it.hasNext()) {
            Vector.VectorEntry vectorEntry = (Vector.VectorEntry) it.next();
            NumericValue numericValue = (Value) vectorEntry.value();
            Integer num = map.get(vectorEntry.key());
            if (num != null && numericValue != NullValue.NULL && (numericValue instanceof NumericValue)) {
                float f = numericValue.getFloat();
                if (Math.abs(f) >= 2.8E-45f) {
                    tFloatArrayList.add(f);
                    tIntArrayList.add(num.intValue());
                    j++;
                }
            }
        }
        return j;
    }
}
