package hivemall.xgboost.utils;

import biz.k11i.xgboost.Predictor;
import biz.k11i.xgboost.util.FVec;
import hivemall.utils.io.FastByteArrayInputStream;
import hivemall.utils.io.IOUtils;
import hivemall.xgboost.XGBoostBatchPredictUDTF;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.io.Text;

/* loaded from: input_file:hivemall/xgboost/utils/XGBoostUtils.class */
public final class XGBoostUtils {
    private XGBoostUtils() {
    }

    @Nonnull
    public static String getVersion() throws HiveException {
        Properties properties = new Properties();
        try {
            InputStream resourceAsStream = Thread.currentThread().getContextClassLoader().getResourceAsStream("xgboost4j-version.properties");
            Throwable th = null;
            try {
                try {
                    properties.load(resourceAsStream);
                    if (resourceAsStream != null) {
                        if (0 != 0) {
                            try {
                                resourceAsStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            resourceAsStream.close();
                        }
                    }
                    return properties.getProperty("version", "<unknown>");
                } finally {
                }
            } finally {
            }
        } catch (IOException e) {
            throw new HiveException("Failed to load xgboost4j-version.properties", e);
        }
    }

    @Nonnull
    public static DMatrix createDMatrix(@Nonnull List<XGBoostBatchPredictUDTF.LabeledPointWithRowId> list) throws XGBoostError {
        ArrayList arrayList = new ArrayList(list.size());
        Iterator<XGBoostBatchPredictUDTF.LabeledPointWithRowId> it2 = list.iterator();
        while (it2.hasNext()) {
            arrayList.add(it2.next());
        }
        return new DMatrix(arrayList.iterator(), "");
    }

    @Nonnull
    public static Booster createBooster(@Nonnull DMatrix dMatrix, @Nonnull Map<String, Object> map) throws NoSuchMethodException, XGBoostError, IllegalAccessException, InvocationTargetException, InstantiationException {
        Constructor declaredConstructor = Booster.class.getDeclaredConstructor(Map.class, DMatrix[].class);
        declaredConstructor.setAccessible(true);
        return (Booster) declaredConstructor.newInstance(map, new DMatrix[]{dMatrix});
    }

    public static void close(@Nullable DMatrix dMatrix) {
        if (dMatrix == null) {
            return;
        }
        try {
            dMatrix.dispose();
        } catch (Throwable th) {
        }
    }

    public static void close(@Nullable Booster booster) {
        if (booster == null) {
            return;
        }
        try {
            booster.dispose();
        } catch (Throwable th) {
        }
    }

    @Nonnull
    public static Text serializeBooster(@Nonnull Booster booster) throws HiveException {
        try {
            return new Text(IOUtils.toCompressedText(booster.toByteArray()));
        } catch (Throwable th) {
            throw new HiveException("Failed to serialize a booster", th);
        }
    }

    @Nonnull
    public static Booster deserializeBooster(@Nonnull Text text) throws HiveException {
        try {
            return XGBoost.loadModel(new FastByteArrayInputStream(IOUtils.fromCompressedText(text.getBytes(), text.getLength())));
        } catch (Throwable th) {
            throw new HiveException("Failed to deserialize a booster", th);
        }
    }

    @Nonnull
    public static Predictor loadPredictor(@Nonnull Text text) throws HiveException {
        try {
            return new Predictor(new FastByteArrayInputStream(IOUtils.fromCompressedText(text.getBytes(), text.getLength())));
        } catch (Throwable th) {
            throw new HiveException("Failed to create a predictor", th);
        }
    }

    @Nonnull
    public static FVec parseRowAsFVec(@Nonnull String[] strArr, int i, int i2) {
        HashMap hashMap = new HashMap((int) (strArr.length * 1.5d));
        for (int i3 = i; i3 < i2; i3++) {
            String str = strArr[i3];
            if (str != null) {
                String str2 = str.toString();
                int indexOf = str2.indexOf(58);
                if (indexOf < 1) {
                    throw new IllegalArgumentException("Invalid feature format: " + str2);
                }
                try {
                    hashMap.put(Integer.valueOf(Integer.parseInt(str2.substring(0, indexOf))), Float.valueOf(Float.parseFloat(str2.substring(indexOf + 1))));
                } catch (NumberFormatException e) {
                    throw new IllegalArgumentException("Failed to parse a feature value: " + str2);
                }
            }
        }
        return FVec.Transformer.fromMap(hashMap);
    }
}
