package com.intel.analytics.bigdl.ppml.utils;

import com.intel.analytics.bigdl.dllib.feature.dataset.AbstractDataSet;
import com.intel.analytics.bigdl.dllib.feature.dataset.DataSet$;
import com.intel.analytics.bigdl.dllib.feature.dataset.MiniBatch;
import com.intel.analytics.bigdl.dllib.feature.dataset.Sample;
import com.intel.analytics.bigdl.dllib.feature.dataset.Sample$;
import com.intel.analytics.bigdl.dllib.feature.dataset.SampleToMiniBatch$;
import com.intel.analytics.bigdl.dllib.tensor.Tensor$;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath$TensorNumeric$NumericFloat$;
import com.intel.analytics.bigdl.ppml.FLContext$;
import java.util.ArrayList;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.ArrayType$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.FloatType$;
import org.apache.spark.sql.types.MapType$;
import org.apache.spark.sql.types.StructType$;
import scala.MatchError;
import scala.Predef$;
import scala.collection.JavaConverters$;
import scala.collection.Seq;
import scala.collection.TraversableOnce;
import scala.collection.immutable.IndexedSeq$;
import scala.collection.immutable.List;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.ObjectRef;
import scala.runtime.RichInt$;

/* compiled from: DataFrameUtils.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/ppml/utils/DataFrameUtils$.class */
public final class DataFrameUtils$ {
    public static DataFrameUtils$ MODULE$;
    private final Logger logger;

    static {
        new DataFrameUtils$();
    }

    public Logger logger() {
        return this.logger;
    }

    public AbstractDataSet<MiniBatch<Object>, ?> dataFrameToMiniBatch(Dataset<Row> dataset, String[] strArr, String[] strArr2, boolean z, int i) {
        return DataSet$.MODULE$.array(dataFrameToSampleRDD(dataset, strArr, strArr2, z, i).collect()).$minus$greater(SampleToMiniBatch$.MODULE$.apply(i, SampleToMiniBatch$.MODULE$.apply$default$2(), SampleToMiniBatch$.MODULE$.apply$default$3(), SampleToMiniBatch$.MODULE$.apply$default$4(), false, ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$), ClassTag$.MODULE$.apply(MiniBatch.class));
    }

    public String[] dataFrameToMiniBatch$default$2() {
        return null;
    }

    public String[] dataFrameToMiniBatch$default$3() {
        return null;
    }

    public boolean dataFrameToMiniBatch$default$4() {
        return true;
    }

    public int dataFrameToMiniBatch$default$5() {
        return 4;
    }

    public RDD<Sample<Object>> dataFrameToSampleRDD(Dataset<Row> dataset, String[] strArr, String[] strArr2, boolean z, int i) {
        FLContext$.MODULE$.getSparkSession();
        ObjectRef create = ObjectRef.create(dataset);
        if (strArr != null) {
            List list = new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(strArr)).toList();
            create.elem = ((Dataset) create.elem).select((String) list.head(), (Seq) list.tail());
        }
        new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(dataset.columns())).foreach(str -> {
            $anonfun$dataFrameToSampleRDD$1(dataset, create, str);
            return BoxedUnit.UNIT;
        });
        return ((Dataset) create.elem).rdd().map(row -> {
            int length;
            float[] fArr;
            int i2 = 0;
            ArrayList arrayList = new ArrayList();
            if (strArr == null && strArr2 == null) {
                MODULE$.logger().warn("featureColumn and labelColumn are not provided, would take the lastcolumn as label column, and others would be feature columns");
                if (z) {
                    length = row.size() - 1;
                    i2 = 1;
                } else {
                    length = row.size();
                }
                fArr = (float[]) ((TraversableOnce) RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), row.size()).map(i3 -> {
                    return BoxesRunTime.unboxToFloat(row.getAs(i3));
                }, IndexedSeq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.Float());
            } else {
                new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(strArr)).foreach(str2 -> {
                    return BoxesRunTime.boxToBoolean($anonfun$dataFrameToSampleRDD$3(arrayList, row, str2));
                });
                if (z) {
                    Predef$.MODULE$.require((strArr == null || strArr2 == null) ? false : true, () -> {
                        return "You must provide both featureColumn and labelColumn or neither in training or evaluation.\nIf neither, the last would be used as label and the rest are the features";
                    });
                    new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(strArr2)).foreach(str3 -> {
                        return BoxesRunTime.boxToBoolean($anonfun$dataFrameToSampleRDD$5(arrayList, row, str3));
                    });
                } else {
                    Predef$.MODULE$.require(strArr != null, () -> {
                        return "You must provide featureColumn in predict";
                    });
                }
                length = strArr.length;
                i2 = strArr2.length;
                fArr = (float[]) ((TraversableOnce) JavaConverters$.MODULE$.asScalaBufferConverter(arrayList).asScala()).toArray(ClassTag$.MODULE$.Float());
            }
            float[] fArr2 = fArr;
            if (z) {
                Predef$.MODULE$.require(length + i2 == row.size(), () -> {
                    return "size mismatch";
                });
                return Sample$.MODULE$.apply(Tensor$.MODULE$.apply(new ArrayOps.ofFloat(Predef$.MODULE$.floatArrayOps(fArr2)).slice(0, length), new int[]{length}, ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$), Tensor$.MODULE$.apply(new ArrayOps.ofFloat(Predef$.MODULE$.floatArrayOps(fArr2)).slice(length, row.size()), new int[]{i2}, ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
            }
            int size = row.size();
            return Sample$.MODULE$.apply(Tensor$.MODULE$.apply(new ArrayOps.ofFloat(Predef$.MODULE$.floatArrayOps(fArr2)).slice(0, size), new int[]{size}, ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
        }, ClassTag$.MODULE$.apply(Sample.class));
    }

    public String[] dataFrameToSampleRDD$default$2() {
        return null;
    }

    public String[] dataFrameToSampleRDD$default$3() {
        return null;
    }

    public boolean dataFrameToSampleRDD$default$4() {
        return true;
    }

    public int dataFrameToSampleRDD$default$5() {
        return 4;
    }

    public String getGenericType(DataType dataType) {
        String str;
        ArrayType$ arrayType$ = ArrayType$.MODULE$;
        if (dataType != null ? !dataType.equals(arrayType$) : arrayType$ != null) {
            StructType$ structType$ = StructType$.MODULE$;
            if (dataType != null ? !dataType.equals(structType$) : structType$ != null) {
                MapType$ mapType$ = MapType$.MODULE$;
                if (dataType != null ? !dataType.equals(mapType$) : mapType$ != null) {
                    str = "scalar";
                    return str;
                }
            }
        }
        str = "complex";
        return str;
    }

    public static final /* synthetic */ void $anonfun$dataFrameToSampleRDD$1(Dataset dataset, ObjectRef objectRef, String str) {
        String genericType = MODULE$.getGenericType(dataset.schema().apply(str).dataType());
        if ("scalar".equals(genericType)) {
            objectRef.elem = ((Dataset) objectRef.elem).withColumn(str, dataset.col(str).cast(FloatType$.MODULE$));
        } else {
            if (!"complex".equals(genericType)) {
                throw new MatchError(genericType);
            }
            throw new Error("not implemented");
        }
    }

    public static final /* synthetic */ boolean $anonfun$dataFrameToSampleRDD$3(ArrayList arrayList, Row row, String str) {
        return arrayList.add(row.getAs(str));
    }

    public static final /* synthetic */ boolean $anonfun$dataFrameToSampleRDD$5(ArrayList arrayList, Row row, String str) {
        return arrayList.add(row.getAs(str));
    }

    private DataFrameUtils$() {
        MODULE$ = this;
        this.logger = LogManager.getLogger(getClass());
    }
}
