package ws.palladian.classification.xgboost;

import java.io.File;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.XGBoostError;
import ws.palladian.classification.featureselection.FeatureRanking;
import ws.palladian.classification.featureselection.RankingSource;
import ws.palladian.core.Model;
import ws.palladian.helper.collection.CollectionHelper;
import ws.palladian.helper.io.FileHelper;

/* loaded from: input_file:ws/palladian/classification/xgboost/XGBoostModel.class */
public class XGBoostModel implements Model, RankingSource {
    private static final long serialVersionUID = 1;
    private final Booster booster;
    private final List<String> labelIndices;
    private final Map<String, Integer> featureIndices;
    private File featureMapFile;

    /* JADX INFO: Access modifiers changed from: package-private */
    public XGBoostModel(Booster booster, List<String> list, Map<String, Integer> map) {
        this.booster = booster;
        this.labelIndices = new ArrayList(list);
        this.featureIndices = new HashMap(map);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Booster getBooster() {
        return this.booster;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Map<String, Integer> getFeatureIndices() {
        return Collections.unmodifiableMap(this.featureIndices);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public String getLabel(int i) {
        return this.labelIndices.get(i);
    }

    public Set<String> getCategories() {
        return new HashSet(this.labelIndices);
    }

    private String buildFeatureMap() {
        Map sortByValue = CollectionHelper.sortByValue(this.featureIndices);
        StringBuilder sb = new StringBuilder();
        for (Map.Entry entry : sortByValue.entrySet()) {
            sb.append(entry.getValue()).append('\t');
            sb.append(((String) entry.getKey()).replaceAll("\\s", "_")).append("\tq\n");
        }
        return sb.toString();
    }

    private synchronized void conditionallyWriteFeatureMap() {
        if (this.featureMapFile == null) {
            this.featureMapFile = FileHelper.getTempFile();
            try {
                Files.write(this.featureMapFile.toPath(), buildFeatureMap().getBytes(StandardCharsets.UTF_8), new OpenOption[0]);
            } catch (IOException e) {
                throw new IllegalStateException(e);
            }
        }
    }

    public FeatureRanking getFeatureRanking() {
        return new FeatureRanking(getFeatureScore());
    }

    @Deprecated
    public Map<String, Integer> getFeatureScore() {
        conditionallyWriteFeatureMap();
        try {
            return this.booster.getFeatureScore(this.featureMapFile.getAbsolutePath());
        } catch (XGBoostError e) {
            throw new IllegalStateException((Throwable) e);
        }
    }

    public String toString() {
        conditionallyWriteFeatureMap();
        try {
            return Arrays.toString(this.booster.getModelDump(this.featureMapFile.getAbsolutePath(), true));
        } catch (XGBoostError e) {
            throw new IllegalStateException((Throwable) e);
        }
    }

    public static void main(String[] strArr) throws IOException {
        if (strArr.length != 1) {
            throw new IllegalArgumentException("First argument must be path to the model");
        }
        File file = new File(strArr[0]);
        if (!file.isFile()) {
            throw new IllegalArgumentException(file + " is not a file");
        }
        CollectionHelper.print(((XGBoostModel) FileHelper.deserialize(strArr[0])).getFeatureRanking().getAll());
    }
}
