package hivemall.xgboost;

import hivemall.UDTFWithOptions;
import hivemall.annotations.VisibleForTesting;
import hivemall.utils.collections.lists.FloatArrayList;
import hivemall.utils.hadoop.HadoopUtils;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.OptionUtils;
import hivemall.utils.math.MathUtils;
import hivemall.xgboost.utils.DMatrixBuilder;
import hivemall.xgboost.utils.DenseDMatrixBuilder;
import hivemall.xgboost.utils.NativeLibLoader;
import hivemall.xgboost.utils.SparseDMatrixBuilder;
import hivemall.xgboost.utils.XGBoostUtils;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import matrix4j.utils.lang.ArrayUtils;
import matrix4j.utils.lang.Primitives;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.io.Text;
import org.apache.lucene.util.packed.PackedInts;

@Description(name = "train_xgboost", value = "_FUNC_(array<string|double> features, <int|double> target, const string options) - Returns a relation consists of <string model_id, array<string> pred_model>", extended = "SELECT \n  train_xgboost(features, label, '-objective binary:logistic -iters 10') \n    as (model_id, model)\nfrom (\n  select features, label\n  from xgb_input\n  cluster by rand(43) -- shuffle\n) shuffled;")
/* loaded from: input_file:hivemall/xgboost/XGBoostTrainUDTF.class */
public class XGBoostTrainUDTF extends UDTFWithOptions {
    private static final Log logger = LogFactory.getLog(XGBoostTrainUDTF.class);
    private ListObjectInspector featureListOI;
    private PrimitiveObjectInspector featureElemOI;
    private PrimitiveObjectInspector targetOI;
    private boolean denseInput;
    private DMatrixBuilder matrixBuilder;
    private FloatArrayList labels;
    protected int numClass;

    @Nonnull
    protected final Map<String, Object> params = new HashMap();
    protected ObjectiveType objectiveType = null;

    /* loaded from: input_file:hivemall/xgboost/XGBoostTrainUDTF$ObjectiveType.class */
    public enum ObjectiveType {
        regression,
        binary,
        multiclass,
        rank,
        other;

        @Nonnull
        public static ObjectiveType resolve(@Nonnull String str) {
            return str.startsWith("reg:") ? regression : str.startsWith("binary:") ? binary : str.startsWith("multi:") ? multiclass : str.startsWith("rank:") ? rank : other;
        }
    }

    @Override // hivemall.UDTFWithOptions
    protected Options getOptions() {
        Options options = new Options();
        options.addOption("num_round", "iters", true, "Number of boosting iterations [default: 10]");
        options.addOption("maximize_evaluation_metrics", true, "Maximize evaluation metrics [default: false]");
        options.addOption("num_early_stopping_rounds", true, "Minimum rounds required for early stopping [default: 0]");
        options.addOption("validation_ratio", true, "Validation ratio in range [0.0,1.0] [default: 0.2]");
        options.addOption("booster", true, "Set a booster to use, gbtree or gblinear or dart. [default: gbree]");
        options.addOption("silent", true, "Deprecated. Please use verbosity instead. 0 means printing running messages, 1 means silent mode [default: 1]");
        options.addOption("verbosity", true, "Verbosity of printing messages. Choices: 0 (silent), 1 (warning), 2 (info), 3 (debug). [default: 0]");
        options.addOption("disable_default_eval_metric", true, "NFlag to disable default metric. Set to >0 to disable. [default: 0]");
        options.addOption("num_pbuffer", true, "Size of prediction buffer [default: set automatically by xgboost]");
        options.addOption("num_feature", true, "Feature dimension used in boosting [default: set automatically by xgboost]");
        options.addOption("lambda", "reg_lambda", true, "L2 regularization term on weights. Increasing this value will make model more conservative. [default: 1.0 for gbtree, 0.0 for gblinear]");
        options.addOption("alpha", "reg_alpha", true, "L1 regularization term on weights. Increasing this value will make model more conservative. [default: 0.0]");
        options.addOption("updater", true, "A comma-separated string that defines the sequence of tree updaters to run. For a full list of valid inputs, please refer to XGBoost Parameters. [default: 'grow_colmaker,prune' for gbtree, 'shotgun' for gblinear]");
        options.addOption("eta", "learning_rate", true, "Step size shrinkage used in update to prevents overfitting [default: 0.3]");
        options.addOption("gamma", "min_split_loss", true, "Minimum loss reduction required to make a further partition on a leaf node of the tree. [default: 0.0]");
        options.addOption("max_depth", true, "Max depth of decision tree [default: 6]");
        options.addOption("min_child_weight", true, "Minimum sum of instance weight (hessian) needed in a child [default: 1.0]");
        options.addOption("max_delta_step", true, "Maximum delta step we allow each tree's weight estimation to be [default: 0]");
        options.addOption("subsample", true, "Subsample ratio of the training instance in range (0.0,1.0] [default: 1.0]");
        options.addOption("colsample_bytree", true, "Subsample ratio of columns when constructing each tree [default: 1.0]");
        options.addOption("colsample_bylevel", true, "Subsample ratio of columns for each level [default: 1.0]");
        options.addOption("colsample_bynode", true, "Subsample ratio of columns for each node [default: 1.0]");
        options.addOption("tree_method", true, "The tree construction algorithm used in XGBoost. [default: auto, Choices: auto, exact, approx, hist]");
        options.addOption("sketch_eps", true, "This roughly translates into O(1 / sketch_eps) number of bins. \nCompared to directly select number of bins, this comes with theoretical guarantee with sketch accuracy.\nOnly used for tree_method=approx. Usually user does not have to tune this.  [default: 0.03]");
        options.addOption("scale_pos_weight", true, "ontrol the balance of positive and negative weights, useful for unbalanced classes. A typical value to consider: sum(negative instances) / sum(positive instances) [default: 1.0]");
        options.addOption("refresh_leaf", true, "This is a parameter of the refresh updater plugin. When this flag is 1, tree leafs as well as tree nodes’ stats are updated. When it is 0, only node stats are updated. [default: 1]");
        options.addOption("process_type", true, "A type of boosting process to run. [Choices: default, update]");
        options.addOption("grow_policy", true, "Controls a way new nodes are added to the tree. Currently supported only if tree_method is set to hist. [default: depthwise, Choices: depthwise, lossguide]");
        options.addOption("max_leaves", true, "Maximum number of nodes to be added. Only relevant when grow_policy=lossguide is set. [default: 0]");
        options.addOption("max_bin", true, "Maximum number of discrete bins to bucket continuous features. Only used if tree_method is set to hist. [default: 256]");
        options.addOption("num_parallel_tree", true, "Number of parallel trees constructed during each iteration. This option is used to support boosted random forest. Usually no need to tune (default 1 is enough) for gradient boosting trees. [default: 1]");
        options.addOption("sample_type", true, "Type of sampling algorithm. [Choices: uniform (default), weighted]");
        options.addOption("normalize_type", true, "Type of normalization algorithm. [Choices: tree (default), forest]");
        options.addOption("rate_drop", true, "Dropout rate in range [0.0, 1.0]. [default: 0.0]");
        options.addOption("one_drop", true, "When this flag is enabled, at least one tree is always dropped during the dropout. 0 or 1. [default: 0]");
        options.addOption("skip_drop", true, "Probability of skipping the dropout procedure during a boosting iteration in range [0.0, 1.0]. [default: 0.0]");
        options.addOption("lambda_bias", true, "L2 regularization term on bias [default: 0.0]");
        options.addOption("feature_selector", true, "Feature selection and ordering method. [Choices: cyclic (default), shuffle, random, greedy, thrifty]");
        options.addOption("top_k", true, "The number of top features to select in greedy and thrifty feature selector. The value of 0 means using all the features. [default: 0]");
        options.addOption("tweedie_variance_power", true, "Parameter that controls the variance of the Tweedie distribution in range [1.0, 2.0]. [default: 1.5]");
        options.addOption("objective", true, "Specifies the learning task and the corresponding learning objective. Examples: reg:linear, reg:logistic, multi:softmax. For a full list of valid inputs, refer to XGBoost Parameters. [default: reg:linear]");
        options.addOption("base_score", true, "Initial prediction score of all instances, global bias [default: 0.5]");
        options.addOption("eval_metric", true, "Evaluation metrics for validation data. A default metric is assigned according to the objective:\n- rmse: for regression\n- error: for classification\n- map: for ranking\nFor a list of valid inputs, see XGBoost Parameters.");
        options.addOption("seed", true, "Random number seed. [default: 43]");
        options.addOption("num_class", true, "Number of classes to classify");
        return options;
    }

    @Override // hivemall.UDTFWithOptions
    @Nonnull
    protected CommandLine processOptions(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        CommandLine parseOptions = objectInspectorArr.length >= 3 ? parseOptions(HiveUtils.getConstString(objectInspectorArr, 2)) : parseOptions("");
        String optionValue = parseOptions.getOptionValue("objective");
        if (optionValue == null) {
            showHelp("Please provide \"-objective XXX\" option in the 3rd argument.\n\nHere is the list of supported objectives: \n - Regression:\n {reg:squarederror, reg:logistic, reg:gamma, reg:tweedie}\n - Binary classification: {binary:logistic, binary:logitraw, binary:hinge}\n - Multiclass classification:\n {multi:softmax, multi:softprob}\n - Ranking:\n {rank:pairwise, rank:ndcg, rank:map}\n - Other:\n {count:poisson, survival:cox}");
        }
        if (optionValue.equals("reg:squarederror")) {
            optionValue = "reg:linear";
        }
        String optionValue2 = parseOptions.getOptionValue("booster", "gbtree");
        this.params.put("num_round", Integer.valueOf(Primitives.parseInt(parseOptions.getOptionValue("num_round"), 10)));
        this.params.put("maximize_evaluation_metrics", Boolean.valueOf(Primitives.parseBoolean(parseOptions.getOptionValue("maximize_evaluation_metrics"), false)));
        this.params.put("num_early_stopping_rounds", Integer.valueOf(Primitives.parseInt(parseOptions.getOptionValue("num_early_stopping_rounds"), 0)));
        double parseDouble = Primitives.parseDouble(parseOptions.getOptionValue("validation_ratio"), 0.2d);
        if (parseDouble < CMAESOptimizer.DEFAULT_STOPFITNESS || parseDouble >= 1.0d) {
            throw new UDFArgumentException("Invalid validation_ratio=" + parseDouble);
        }
        this.params.put("validation_ratio", Double.valueOf(parseDouble));
        this.params.put("booster", optionValue2);
        this.params.put("silent", Integer.valueOf(Primitives.parseInt(parseOptions.getOptionValue("silent"), 1)));
        this.params.put("verbosity", Integer.valueOf(Primitives.parseInt(parseOptions.getOptionValue("verbosity"), 0)));
        this.params.put("nthread", Integer.valueOf(Primitives.parseInt(parseOptions.getOptionValue("nthread"), 1)));
        this.params.put("disable_default_eval_metric", Integer.valueOf(Primitives.parseInt(parseOptions.getOptionValue("disable_default_eval_metric"), 0)));
        if (parseOptions.hasOption("num_pbuffer")) {
            this.params.put("num_pbuffer", Integer.valueOf(parseOptions.getOptionValue("num_pbuffer")));
        }
        if (parseOptions.hasOption("num_feature")) {
            this.params.put("num_feature", Integer.valueOf(parseOptions.getOptionValue("num_feature")));
        }
        if (optionValue2.equals("gbtree")) {
            this.params.put("eta", Double.valueOf(Primitives.parseDouble(parseOptions.getOptionValue("eta"), 0.3d)));
            this.params.put("gamma", Double.valueOf(Primitives.parseDouble(parseOptions.getOptionValue("gamma"), CMAESOptimizer.DEFAULT_STOPFITNESS)));
            this.params.put("max_depth", Integer.valueOf(Primitives.parseInt(parseOptions.getOptionValue("max_depth"), 6)));
            this.params.put("min_child_weight", Double.valueOf(Primitives.parseDouble(parseOptions.getOptionValue("min_child_weight"), 1.0d)));
            this.params.put("max_delta_step", Double.valueOf(Primitives.parseDouble(parseOptions.getOptionValue("max_delta_step"), CMAESOptimizer.DEFAULT_STOPFITNESS)));
            this.params.put("subsample", Double.valueOf(Primitives.parseDouble(parseOptions.getOptionValue("subsample"), 1.0d)));
            this.params.put("colsamle_bytree", Double.valueOf(Primitives.parseDouble(parseOptions.getOptionValue("colsample_bytree"), 1.0d)));
            this.params.put("colsamle_bylevel", Double.valueOf(Primitives.parseDouble(parseOptions.getOptionValue("colsamle_bylevel"), 1.0d)));
            this.params.put("colsamle_bynode", Double.valueOf(Primitives.parseDouble(parseOptions.getOptionValue("colsamle_bynode"), 1.0d)));
            this.params.put("lambda", Double.valueOf(Primitives.parseDouble(parseOptions.getOptionValue("lambda"), 1.0d)));
            this.params.put("alpha", Double.valueOf(Primitives.parseDouble(parseOptions.getOptionValue("alpha"), CMAESOptimizer.DEFAULT_STOPFITNESS)));
            this.params.put("tree_method", parseOptions.getOptionValue("tree_method", "auto"));
            this.params.put("sketch_eps", Double.valueOf(Primitives.parseDouble(parseOptions.getOptionValue("sketch_eps"), 0.03d)));
            this.params.put("scale_pos_weight", Double.valueOf(Primitives.parseDouble(parseOptions.getOptionValue("scale_pos_weight"), 1.0d)));
            this.params.put("updater", parseOptions.getOptionValue("updater", "grow_colmaker,prune"));
            this.params.put("refresh_leaf", Integer.valueOf(Primitives.parseInt(parseOptions.getOptionValue("refresh_leaf"), 1)));
            this.params.put("process_type", parseOptions.getOptionValue("process_type", "default"));
            this.params.put("grow_policy", parseOptions.getOptionValue("grow_policy", "depthwise"));
            this.params.put("max_leaves", Integer.valueOf(Primitives.parseInt(parseOptions.getOptionValue("max_leaves"), 0)));
            this.params.put("max_bin", Integer.valueOf(Primitives.parseInt(parseOptions.getOptionValue("max_bin"), 256)));
            this.params.put("num_parallel_tree", Integer.valueOf(Primitives.parseInt(parseOptions.getOptionValue("num_parallel_tree"), 1)));
        }
        if (optionValue2.equals("dart")) {
            this.params.put("sample_type", parseOptions.getOptionValue("sample_type", "uniform"));
            this.params.put("normalize_type", parseOptions.getOptionValue("normalize_type", "tree"));
            this.params.put("rate_drop", Double.valueOf(Primitives.parseDouble(parseOptions.getOptionValue("rate_drop"), CMAESOptimizer.DEFAULT_STOPFITNESS)));
            this.params.put("one_drop", Integer.valueOf(Primitives.parseInt(parseOptions.getOptionValue("one_drop"), 0)));
            this.params.put("skip_drop", Double.valueOf(Primitives.parseDouble(parseOptions.getOptionValue("skip_drop"), CMAESOptimizer.DEFAULT_STOPFITNESS)));
        }
        if (optionValue2.equals("gblinear")) {
            this.params.put("lambda", Double.valueOf(Primitives.parseDouble(parseOptions.getOptionValue("lambda"), CMAESOptimizer.DEFAULT_STOPFITNESS)));
            this.params.put("lambda_bias", Double.valueOf(Primitives.parseDouble(parseOptions.getOptionValue("lambda_bias"), CMAESOptimizer.DEFAULT_STOPFITNESS)));
            this.params.put("alpha", Double.valueOf(Primitives.parseDouble(parseOptions.getOptionValue("alpha"), CMAESOptimizer.DEFAULT_STOPFITNESS)));
            this.params.put("updater", parseOptions.getOptionValue("updater", "shotgun"));
            this.params.put("feature_selector", parseOptions.getOptionValue("feature_selector", "cyclic"));
            this.params.put("top_k", Integer.valueOf(Primitives.parseInt(parseOptions.getOptionValue("top_k"), 0)));
        }
        if (optionValue.equals("reg:tweedie")) {
            this.params.put("tweedie_variance_power", Double.valueOf(Primitives.parseDouble(parseOptions.getOptionValue("tweedie_variance_power"), 1.5d)));
        }
        if (optionValue.equals("count:poisson")) {
            this.params.put("max_delta_step", Double.valueOf(Primitives.parseDouble(parseOptions.getOptionValue("max_delta_step"), 0.7d)));
        }
        this.params.put("objective", optionValue);
        this.params.put("base_score", Double.valueOf(Primitives.parseDouble(parseOptions.getOptionValue("base_score"), 0.5d)));
        if (parseOptions.hasOption("eval_metric")) {
            this.params.put("eval_metric", parseOptions.getOptionValue("eval_metric"));
        }
        this.params.put("seed", Long.valueOf(Primitives.parseLong(parseOptions.getOptionValue("seed"), 43L)));
        if (parseOptions.hasOption("num_class")) {
            this.numClass = Integer.parseInt(parseOptions.getOptionValue("num_class"));
            this.params.put("num_class", Integer.valueOf(this.numClass));
        } else if (optionValue.startsWith("multi:")) {
            throw new UDFArgumentException("-num_class is required for multiclass classification");
        }
        if (logger.isInfoEnabled()) {
            logger.info("XGboost training hyperparameters: " + this.params.toString());
        }
        this.objectiveType = ObjectiveType.resolve(optionValue);
        return parseOptions;
    }

    public StructObjectInspector initialize(@Nonnull ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        if (objectInspectorArr.length != 2 && objectInspectorArr.length != 3) {
            showHelp("Invalid argment length=" + objectInspectorArr.length);
        }
        processOptions(objectInspectorArr);
        ListObjectInspector asListOI = HiveUtils.asListOI(objectInspectorArr, 0);
        ObjectInspector listElementObjectInspector = asListOI.getListElementObjectInspector();
        this.featureListOI = asListOI;
        if (HiveUtils.isNumberOI(listElementObjectInspector)) {
            this.featureElemOI = HiveUtils.asDoubleCompatibleOI(listElementObjectInspector);
            this.denseInput = true;
            this.matrixBuilder = new DenseDMatrixBuilder(8192);
        } else {
            if (!HiveUtils.isStringOI(listElementObjectInspector)) {
                throw new UDFArgumentException("train_xgboost takes array<double> or array<string> for the first argument: " + asListOI.getTypeName());
            }
            this.featureElemOI = HiveUtils.asStringOI(listElementObjectInspector);
            this.denseInput = false;
            this.matrixBuilder = new SparseDMatrixBuilder(8192);
        }
        this.targetOI = HiveUtils.asDoubleCompatibleOI(objectInspectorArr, 1);
        this.labels = new FloatArrayList(1024);
        ArrayList arrayList = new ArrayList(2);
        ArrayList arrayList2 = new ArrayList(2);
        arrayList.add("model_id");
        arrayList2.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector);
        arrayList.add("model");
        arrayList2.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
        return ObjectInspectorFactory.getStandardStructObjectInspector(arrayList, arrayList2);
    }

    protected float processTargetValue(float f) throws HiveException {
        switch (this.objectiveType) {
            case binary:
                if (f != -1.0f && f != PackedInts.COMPACT && f != 1.0f) {
                    throw new UDFArgumentException("Invalid label value for classification: " + f);
                }
                if (f > PackedInts.COMPACT) {
                    return 1.0f;
                }
                return PackedInts.COMPACT;
            case multiclass:
                int i = (int) f;
                if (i != f) {
                    throw new UDFArgumentException("Invalid target value for class label: " + f);
                }
                if (i < 0 || i >= this.numClass) {
                    throw new UDFArgumentException("target must be {0.0, ..., " + String.format("%.1f", Double.valueOf(this.numClass - 1.0d)) + "}: " + f);
                }
                return f;
            default:
                return f;
        }
    }

    public void process(@Nonnull Object[] objArr) throws HiveException {
        if (objArr[0] == null) {
            throw new HiveException("array<double> features was null");
        }
        parseFeatures(objArr[0], this.matrixBuilder);
        this.labels.add(processTargetValue(PrimitiveObjectInspectorUtils.getFloat(objArr[1], this.targetOI)));
    }

    private void parseFeatures(@Nonnull Object obj, @Nonnull DMatrixBuilder dMatrixBuilder) {
        if (this.denseInput) {
            int listLength = this.featureListOI.getListLength(obj);
            for (int i = 0; i < listLength; i++) {
                Object listElement = this.featureListOI.getListElement(obj, i);
                if (listElement != null) {
                    dMatrixBuilder.nextColumn(i, PrimitiveObjectInspectorUtils.getFloat(listElement, this.featureElemOI));
                }
            }
        } else {
            int listLength2 = this.featureListOI.getListLength(obj);
            for (int i2 = 0; i2 < listLength2; i2++) {
                Object listElement2 = this.featureListOI.getListElement(obj, i2);
                if (listElement2 != null) {
                    dMatrixBuilder.nextColumn(listElement2.toString());
                }
            }
        }
        dMatrixBuilder.nextRow();
    }

    /* JADX WARN: Finally extract failed */
    public void close() throws HiveException {
        Booster train;
        try {
            try {
                DMatrix buildMatrix = this.matrixBuilder.buildMatrix(this.labels.toArray(true));
                this.matrixBuilder = null;
                this.labels = null;
                int i = OptionUtils.getInt(this.params, "num_round");
                int i2 = OptionUtils.getInt(this.params, "num_early_stopping_rounds");
                if (i2 > 0) {
                    double d = OptionUtils.getDouble(this.params, "validation_ratio");
                    long j = OptionUtils.getLong(this.params, "seed");
                    int rowNum = (int) buildMatrix.rowNum();
                    int[] permutation = MathUtils.permutation(rowNum);
                    ArrayUtils.shuffle(permutation, new Random(j));
                    int i3 = (int) (rowNum * d);
                    DMatrix dMatrix = null;
                    DMatrix dMatrix2 = null;
                    try {
                        dMatrix2 = buildMatrix.slice(Arrays.copyOf(permutation, i3));
                        dMatrix = buildMatrix.slice(Arrays.copyOfRange(permutation, i3, permutation.length));
                        train = train(dMatrix, dMatrix2, i, i2, this.params);
                        XGBoostUtils.close(dMatrix);
                        XGBoostUtils.close(dMatrix2);
                    } catch (Throwable th) {
                        XGBoostUtils.close(dMatrix);
                        XGBoostUtils.close(dMatrix2);
                        throw th;
                    }
                } else {
                    train = train(buildMatrix, i, this.params);
                }
                onFinishTraining(train);
                String generateUniqueModelId = generateUniqueModelId();
                Text serializeBooster = XGBoostUtils.serializeBooster(train);
                logger.info("model_id:" + generateUniqueModelId.toString() + ", size:" + serializeBooster.getLength());
                forward(new Object[]{generateUniqueModelId, serializeBooster});
                XGBoostUtils.close(buildMatrix);
                XGBoostUtils.close(train);
            } catch (Throwable th2) {
                throw new HiveException(th2);
            }
        } catch (Throwable th3) {
            XGBoostUtils.close((DMatrix) null);
            XGBoostUtils.close((Booster) null);
            throw th3;
        }
    }

    @VisibleForTesting
    protected void onFinishTraining(@Nonnull Booster booster) {
    }

    @Nonnull
    private static Booster train(@Nonnull DMatrix dMatrix, @Nonnegative int i, @Nonnull Map<String, Object> map) throws NoSuchMethodException, IllegalAccessException, InvocationTargetException, InstantiationException, XGBoostError {
        Booster createBooster = XGBoostUtils.createBooster(dMatrix, map);
        for (int i2 = 0; i2 < i; i2++) {
            createBooster.update(dMatrix, i2);
        }
        return createBooster;
    }

    @Nonnull
    private static Booster train(@Nonnull DMatrix dMatrix, @Nonnull DMatrix dMatrix2, @Nonnegative int i, @Nonnegative int i2, @Nonnull Map<String, Object> map) throws NoSuchMethodException, IllegalAccessException, InvocationTargetException, InstantiationException, XGBoostError {
        Booster createBooster = XGBoostUtils.createBooster(dMatrix, map);
        boolean z = OptionUtils.getBoolean(map, "maximize_evaluation_metrics");
        float f = z ? -3.4028235E38f : Float.MAX_VALUE;
        int i3 = 0;
        float[] fArr = new float[1];
        int i4 = 0;
        while (true) {
            if (i4 >= i) {
                break;
            }
            createBooster.update(dMatrix, i4);
            logger.info(createBooster.evalSet(new DMatrix[]{dMatrix2}, new String[]{"test"}, i4, fArr));
            float f2 = fArr[0];
            if (z) {
                if (f2 > f) {
                    f = f2;
                    i3 = i4;
                }
            } else if (f2 < f) {
                f = f2;
                i3 = i4;
            }
            if (shouldEarlyStop(i2, i4, i3)) {
                logger.info(String.format("early stopping after %d rounds away from the best iteration", Integer.valueOf(i2)));
                break;
            }
            i4++;
        }
        return createBooster;
    }

    private static boolean shouldEarlyStop(int i, int i2, int i3) {
        return i2 - i3 >= i;
    }

    @Nonnull
    private static String generateUniqueModelId() {
        return "xgbmodel-" + HadoopUtils.getUniqueTaskIdString();
    }

    static {
        NativeLibLoader.initXGBoost();
    }
}
