package org.apache.spark.ml;

import org.apache.spark.ml.DLEstimatorBase;
import org.apache.spark.ml.DLParams;
import org.apache.spark.ml.DLTransformerBase;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.shared.HasFeaturesCol;
import org.apache.spark.ml.param.shared.HasLabelCol;
import org.apache.spark.ml.param.shared.HasPredictionCol;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.FloatType$;
import org.apache.spark.sql.types.StructType;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.mutable.StringBuilder;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;

/* compiled from: DLEstimatorBase.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005-bAB\u0001\u0003\u0003\u0003\u0011!BA\bE\u0019\u0016\u001bH/[7bi>\u0014()Y:f\u0015\t\u0019A!\u0001\u0002nY*\u0011QAB\u0001\u0006gB\f'o\u001b\u0006\u0003\u000f!\ta!\u00199bG\",'\"A\u0005\u0002\u0007=\u0014x-F\u0002\f_I\u0019B\u0001\u0001\u0007 EA\u0019QB\u0004\t\u000e\u0003\tI!a\u0004\u0002\u0003\u0013\u0015\u001bH/[7bi>\u0014\bCA\t\u0013\u0019\u0001!Qa\u0005\u0001C\u0002U\u0011\u0011!T\u0002\u0001#\t1B\u0004\u0005\u0002\u001855\t\u0001DC\u0001\u001a\u0003\u0015\u00198-\u00197b\u0013\tY\u0002DA\u0004O_RD\u0017N\\4\u0011\u00075i\u0002#\u0003\u0002\u001f\u0005\t\tB\t\u0014+sC:\u001chm\u001c:nKJ\u0014\u0015m]3\u0011\u00055\u0001\u0013BA\u0011\u0003\u0005!!E\nU1sC6\u001c\bCA\u0012)\u001b\u0005!#BA\u0013'\u0003\u0019\u0019\b.\u0019:fI*\u0011qEA\u0001\u0006a\u0006\u0014\u0018-\\\u0005\u0003S\u0011\u00121\u0002S1t\u0019\u0006\u0014W\r\\\"pY\")1\u0006\u0001C\u0001Y\u00051A(\u001b8jiz\"\u0012!\f\t\u0005\u001b\u0001q\u0003\u0003\u0005\u0002\u0012_\u0011)\u0001\u0007\u0001b\u0001c\t9A*Z1s]\u0016\u0014\u0018C\u0001\f.\u0011\u0015\u0019\u0004\u0001\"\u00055\u0003A9W\r\u001e'bE\u0016d\u0017I\u001d:bs\u000e{G.F\u00016!\t1\u0014H\u0004\u0002\u0018o%\u0011\u0001\bG\u0001\u0007!J,G-\u001a4\n\u0005iZ$AB*ue&twM\u0003\u000291!)Q\b\u0001D\t}\u0005Y\u0011N\u001c;fe:\fGNR5u)\t\u0001r\bC\u0003Ay\u0001\u0007\u0011)A\bgK\u0006$XO]3B]\u0012d\u0015MY3m!\r\u0011UiR\u0007\u0002\u0007*\u0011A\tB\u0001\u0004e\u0012$\u0017B\u0001$D\u0005\r\u0011F\t\u0012\t\u0005/!S%*\u0003\u0002J1\t1A+\u001e9mKJ\u00022aS*W\u001d\ta\u0015K\u0004\u0002N!6\taJ\u0003\u0002P)\u00051AH]8pizJ\u0011!G\u0005\u0003%b\tq\u0001]1dW\u0006<W-\u0003\u0002U+\n\u00191+Z9\u000b\u0005IC\u0002CA\fX\u0013\tA\u0006D\u0001\u0004B]f4\u0016\r\u001c\u0005\u00065\u0002!\teW\u0001\u0004M&$HC\u0001\t]\u0011\u0015i\u0016\f1\u0001_\u0003\u001d!\u0017\r^1tKR\u0004$a\u00184\u0011\u0007\u0001\u001cW-D\u0001b\u0015\t\u0011G!A\u0002tc2L!\u0001Z1\u0003\u000f\u0011\u000bG/Y:fiB\u0011\u0011C\u001a\u0003\nOr\u000b\t\u0011!A\u0003\u0002!\u00141a\u0018\u00132#\t1\u0012\u000e\u0005\u0002\u0018U&\u00111\u000e\u0007\u0002\u0004\u0003:L\b\"B7\u0001\t#q\u0017a\u0003;p\u0003J\u0014\u0018-\u001f+za\u0016$\"!Q8\t\u000buc\u0007\u0019\u00019\u0011\u0005E\\hB\u0001:{\u001d\t\u0019\u0018P\u0004\u0002uq:\u0011Qo\u001e\b\u0003\u001bZL\u0011!C\u0005\u0003\u000f!I!!\u0002\u0004\n\u0005\t$\u0011B\u0001*b\u0013\taXPA\u0005ECR\fgI]1nK*\u0011!+\u0019\u0005\u0007\u007f\u0002!\t&!\u0001\u0002\u001dY\fG.\u001b3bi\u0016\u001c6\r[3nCR!\u00111AA\u0005!\r9\u0012QA\u0005\u0004\u0003\u000fA\"\u0001B+oSRDq!a\u0003\u007f\u0001\u0004\ti!\u0001\u0004tG\",W.\u0019\t\u0005\u0003\u001f\t)\"\u0004\u0002\u0002\u0012)\u0019\u00111C1\u0002\u000bQL\b/Z:\n\t\u0005]\u0011\u0011\u0003\u0002\u000b'R\u0014Xo\u0019;UsB,\u0007bBA\u000e\u0001\u0011\u0005\u0013QD\u0001\u0005G>\u0004\u0018\u0010F\u0002/\u0003?A\u0001\"!\t\u0002\u001a\u0001\u0007\u00111E\u0001\u0006Kb$(/\u0019\t\u0005\u0003K\t9#D\u0001'\u0013\r\tIC\n\u0002\t!\u0006\u0014\u0018-\\'ba\u0002")
/* loaded from: input_file:org/apache/spark/ml/DLEstimatorBase.class */
public abstract class DLEstimatorBase<Learner extends DLEstimatorBase<Learner, M>, M extends DLTransformerBase<M>> extends Estimator<M> implements DLParams, HasLabelCol {
    private final Param<String> labelCol;
    private final Param<String> predictionCol;
    private final Param<String> featuresCol;

    public final Param<String> labelCol() {
        return this.labelCol;
    }

    public final void org$apache$spark$ml$param$shared$HasLabelCol$_setter_$labelCol_$eq(Param param) {
        this.labelCol = param;
    }

    public final String getLabelCol() {
        return HasLabelCol.class.getLabelCol(this);
    }

    @Override // org.apache.spark.ml.DLParams
    public Seq<Object> supportedTypesToSeq(Row row, DataType dataType, int i) {
        return DLParams.Cclass.supportedTypesToSeq(this, row, dataType, i);
    }

    @Override // org.apache.spark.ml.DLParams
    public String getFeatureArrayCol() {
        return DLParams.Cclass.getFeatureArrayCol(this);
    }

    public final Param<String> predictionCol() {
        return this.predictionCol;
    }

    public final void org$apache$spark$ml$param$shared$HasPredictionCol$_setter_$predictionCol_$eq(Param param) {
        this.predictionCol = param;
    }

    public final String getPredictionCol() {
        return HasPredictionCol.class.getPredictionCol(this);
    }

    public final Param<String> featuresCol() {
        return this.featuresCol;
    }

    public final void org$apache$spark$ml$param$shared$HasFeaturesCol$_setter_$featuresCol_$eq(Param param) {
        this.featuresCol = param;
    }

    public final String getFeaturesCol() {
        return HasFeaturesCol.class.getFeaturesCol(this);
    }

    public String getLabelArrayCol() {
        return new StringBuilder().append((String) $(labelCol())).append("_Array").toString();
    }

    public abstract M internalFit(RDD<Tuple2<Seq<Object>, Seq<Object>>> rdd);

    public M fit(Dataset<?> dataset) {
        transformSchema(dataset.schema(), true);
        return internalFit(toArrayType(dataset.toDF()));
    }

    public RDD<Tuple2<Seq<Object>, Seq<Object>>> toArrayType(Dataset<Row> dataset) {
        return dataset.rdd().map(new DLEstimatorBase$$anonfun$toArrayType$1(this, dataset.schema().apply((String) $(featuresCol())).dataType(), dataset.schema().fieldIndex((String) $(featuresCol())), dataset.schema().apply((String) $(labelCol())).dataType(), dataset.schema().fieldIndex((String) $(labelCol()))), ClassTag$.MODULE$.apply(Tuple2.class));
    }

    @Override // org.apache.spark.ml.DLParams
    public void validateSchema(StructType structType) {
        DLParams.Cclass.validateSchema(this, structType);
        Seq apply = Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new DataType[]{new ArrayType(DoubleType$.MODULE$, false), new ArrayType(FloatType$.MODULE$, false), new VectorUDT(), DoubleType$.MODULE$}));
        DataType dataType = structType.apply((String) $(labelCol())).dataType();
        Predef$.MODULE$.require(apply.exists(new DLEstimatorBase$$anonfun$validateSchema$4(this, dataType)), new DLEstimatorBase$$anonfun$validateSchema$3(this, apply, dataType));
    }

    /* renamed from: copy, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public Learner m2copy(ParamMap paramMap) {
        return defaultCopy(paramMap);
    }

    /* renamed from: fit, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Model m3fit(Dataset dataset) {
        return fit((Dataset<?>) dataset);
    }

    public DLEstimatorBase() {
        HasFeaturesCol.class.$init$(this);
        HasPredictionCol.class.$init$(this);
        DLParams.Cclass.$init$(this);
        HasLabelCol.class.$init$(this);
    }
}
