package com.intel.analytics.bigdl.dllib.example.nnframes.xgboost;

import com.intel.analytics.bigdl.dllib.NNContext$;
import com.intel.analytics.bigdl.dllib.nnframes.XGBClassifier;
import ml.dmlc.xgboost4j.scala.spark.TrackerConf;
import org.apache.spark.SparkContext;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SQLContext$;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.StringType$;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructField$;
import org.apache.spark.sql.types.StructType;
import scala.Array$;
import scala.MatchError;
import scala.Option;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Tuple2;
import scala.Tuple4;
import scala.collection.SeqLike;
import scala.collection.immutable.StringOps;
import scala.sys.package$;

/* compiled from: xgbClassifierTrainingExample.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/dllib/example/nnframes/xgboost/xgbClassifierTrainingExample$.class */
public final class xgbClassifierTrainingExample$ {
    public static final xgbClassifierTrainingExample$ MODULE$ = null;

    static {
        new xgbClassifierTrainingExample$();
    }

    public void main(String[] strArr) {
        if (strArr.length < 4) {
            Predef$.MODULE$.println("Usage: program inputPath numThreads numRound modelsavePath");
            throw package$.MODULE$.exit(1);
        }
        SparkContext initNNContext = NNContext$.MODULE$.initNNContext();
        SQLContext orCreate = SQLContext$.MODULE$.getOrCreate(initNNContext);
        String str = strArr[0];
        int i = new StringOps(Predef$.MODULE$.augmentString(strArr[1])).toInt();
        int i2 = new StringOps(Predef$.MODULE$.augmentString(strArr[2])).toInt();
        String str2 = strArr[3];
        Dataset csv = orCreate.read().schema(new StructType(new StructField[]{new StructField("sepal length", DoubleType$.MODULE$, true, StructField$.MODULE$.apply$default$4()), new StructField("sepal width", DoubleType$.MODULE$, true, StructField$.MODULE$.apply$default$4()), new StructField("petal length", DoubleType$.MODULE$, true, StructField$.MODULE$.apply$default$4()), new StructField("petal width", DoubleType$.MODULE$, true, StructField$.MODULE$.apply$default$4()), new StructField("class", StringType$.MODULE$, true, StructField$.MODULE$.apply$default$4())})).csv(str);
        Dataset[] randomSplit = new VectorAssembler().setInputCols(new String[]{"sepal length", "sepal width", "petal length", "petal width"}).setOutputCol("features").transform(new StringIndexer().setInputCol("class").setOutputCol("classIndex").fit(csv).transform(csv).drop("class")).select("features", Predef$.MODULE$.wrapRefArray(new String[]{"classIndex"})).randomSplit(new double[]{0.6d, 0.2d, 0.1d, 0.1d});
        Option unapplySeq = Array$.MODULE$.unapplySeq(randomSplit);
        if (unapplySeq.isEmpty() || unapplySeq.get() == null || ((SeqLike) unapplySeq.get()).lengthCompare(4) != 0) {
            throw new MatchError(randomSplit);
        }
        Tuple4 tuple4 = new Tuple4((Dataset) ((SeqLike) unapplySeq.get()).apply(0), (Dataset) ((SeqLike) unapplySeq.get()).apply(1), (Dataset) ((SeqLike) unapplySeq.get()).apply(2), (Dataset) ((SeqLike) unapplySeq.get()).apply(3));
        Dataset<Row> dataset = (Dataset) tuple4._1();
        Dataset dataset2 = (Dataset) tuple4._2();
        Dataset dataset3 = (Dataset) tuple4._3();
        XGBClassifier xGBClassifier = new XGBClassifier(Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("tracker_conf"), new TrackerConf(3600L, "scala")), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("eval_sets"), Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("eval1"), dataset2), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("eval2"), dataset3)})))})));
        xGBClassifier.setFeaturesCol("features");
        xGBClassifier.setLabelCol("classIndex");
        xGBClassifier.setNumClass(3);
        xGBClassifier.setMaxDepth(2);
        xGBClassifier.setNumWorkers(1);
        xGBClassifier.setNthread(i);
        xGBClassifier.setNumRound(i2);
        xGBClassifier.setTreeMethod("auto");
        xGBClassifier.setObjective("multi:softprob");
        xGBClassifier.setTimeoutRequestWorkers(180000L);
        xGBClassifier.fit(dataset).save(str2);
        initNNContext.stop();
    }

    private xgbClassifierTrainingExample$() {
        MODULE$ = this;
    }
}
