package ai.tripl.arc.transform;

import ai.tripl.arc.api.API;
import ai.tripl.arc.util.log.logger.Logger;
import java.net.URI;
import java.util.HashMap;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.tuning.CrossValidatorModel;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import org.apache.spark.sql.functions$;
import scala.Array$;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Option$;
import scala.Predef$;
import scala.Predef$DummyImplicit$;
import scala.Serializable;
import scala.Some;
import scala.Tuple12;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.immutable.List;
import scala.collection.immutable.List$;
import scala.collection.immutable.Map;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.ArrayOps;
import scala.math.Ordering$Double$;
import scala.reflect.ClassTag$;
import scala.reflect.api.Mirror;
import scala.reflect.api.TypeCreator;
import scala.reflect.api.Types;
import scala.reflect.api.Universe;
import scala.reflect.runtime.package$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.ObjectRef;
import scala.util.Either;
import scala.util.Left;
import scala.util.Right;

/* compiled from: MLTransform.scala */
/* loaded from: input_file:ai/tripl/arc/transform/MLTransformStage$.class */
public final class MLTransformStage$ implements Serializable {
    public static MLTransformStage$ MODULE$;

    static {
        new MLTransformStage$();
    }

    public Option<Dataset<Row>> execute(MLTransformStage mLTransformStage, SparkSession sparkSession, Logger logger, API.ARCContext aRCContext) {
        CrossValidatorModel crossValidatorModel;
        PipelineModel pipelineModel;
        Dataset repartition;
        Dataset table = sparkSession.table(mLTransformStage.inputView());
        Right model = mLTransformStage.model();
        if (model instanceof Right) {
            crossValidatorModel = (CrossValidatorModel) model.value();
        } else {
            if (!(model instanceof Left)) {
                throw new MatchError(model);
            }
            crossValidatorModel = (PipelineModel) ((Left) model).value();
        }
        CrossValidatorModel crossValidatorModel2 = crossValidatorModel;
        try {
            Right model2 = mLTransformStage.model();
            if (model2 instanceof Right) {
                pipelineModel = (PipelineModel) ((CrossValidatorModel) model2.value()).bestModel();
            } else {
                if (!(model2 instanceof Left)) {
                    throw new MatchError(model2);
                }
                pipelineModel = (PipelineModel) ((Left) model2).value();
            }
            Seq<Transformer> modelStages = modelStages(pipelineModel);
            try {
                Dataset transform = crossValidatorModel2.transform(table);
                Column[] columnArr = (Column[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(table.schema().fields())).map(structField -> {
                    return functions$.MODULE$.col(structField.name());
                }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Column.class)));
                Seq seq = (Seq) ((TraversableLike) ((TraversableLike) ((TraversableLike) ((TraversableLike) modelStages.filter(transformer -> {
                    return BoxesRunTime.boxToBoolean(transformer.hasParam("predictionCol"));
                })).map(transformer2 -> {
                    return transformer2.get(transformer2.getParam("predictionCol"));
                }, Seq$.MODULE$.canBuildFrom())).map(option -> {
                    return option.getOrElse(() -> {
                        return "prediction";
                    });
                }, Seq$.MODULE$.canBuildFrom())).map(obj -> {
                    return obj.toString();
                }, Seq$.MODULE$.canBuildFrom())).map(str -> {
                    return functions$.MODULE$.col(str);
                }, Seq$.MODULE$.canBuildFrom());
                Seq seq2 = (Seq) ((TraversableLike) ((TraversableLike) ((TraversableLike) ((TraversableLike) modelStages.filter(transformer3 -> {
                    return BoxesRunTime.boxToBoolean(transformer3.hasParam("probabilityCol"));
                })).map(transformer4 -> {
                    return transformer4.get(transformer4.getParam("probabilityCol"));
                }, Seq$.MODULE$.canBuildFrom())).map(option2 -> {
                    return option2.getOrElse(() -> {
                        return "probability";
                    });
                }, Seq$.MODULE$.canBuildFrom())).map(obj2 -> {
                    return obj2.toString();
                }, Seq$.MODULE$.canBuildFrom())).map(str2 -> {
                    return functions$.MODULE$.col(str2);
                }, Seq$.MODULE$.canBuildFrom());
                ObjectRef create = ObjectRef.create(seq.isEmpty() ? transform : transform.select(Predef$.MODULE$.wrapRefArray((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(columnArr)).$plus$plus(seq, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Column.class))))).$plus$plus(seq2, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Column.class))))));
                UserDefinedFunction udf = functions$.MODULE$.udf(vector -> {
                    return BoxesRunTime.boxToDouble($anonfun$execute$14(vector));
                }, package$.MODULE$.universe().TypeTag().Double(), package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(getClass().getClassLoader()), new TypeCreator() { // from class: ai.tripl.arc.transform.MLTransformStage$$typecreator1$1
                    public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                        mirror.universe();
                        return mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
                    }
                }));
                seq2.foreach(column -> {
                    $anonfun$execute$15(create, udf, column);
                    return BoxedUnit.UNIT;
                });
                List<String> partitionBy = mLTransformStage.partitionBy();
                if (Nil$.MODULE$.equals(partitionBy)) {
                    Some numPartitions = mLTransformStage.numPartitions();
                    if (numPartitions instanceof Some) {
                        repartition = ((Dataset) create.elem).repartition(BoxesRunTime.unboxToInt(numPartitions.value()));
                    } else {
                        if (!None$.MODULE$.equals(numPartitions)) {
                            throw new MatchError(numPartitions);
                        }
                        repartition = (Dataset) create.elem;
                    }
                } else {
                    List list = (List) partitionBy.map(str3 -> {
                        return ((Dataset) create.elem).apply(str3);
                    }, List$.MODULE$.canBuildFrom());
                    Some numPartitions2 = mLTransformStage.numPartitions();
                    if (numPartitions2 instanceof Some) {
                        repartition = ((Dataset) create.elem).repartition(BoxesRunTime.unboxToInt(numPartitions2.value()), list);
                    } else {
                        if (!None$.MODULE$.equals(numPartitions2)) {
                            throw new MatchError(numPartitions2);
                        }
                        repartition = ((Dataset) create.elem).repartition(list);
                    }
                }
                Dataset dataset = repartition;
                if (aRCContext.immutableViews()) {
                    dataset.createTempView(mLTransformStage.outputView());
                } else {
                    dataset.createOrReplaceTempView(mLTransformStage.outputView());
                }
                if (dataset.isStreaming()) {
                    BoxedUnit boxedUnit = BoxedUnit.UNIT;
                } else {
                    mLTransformStage.stageDetail().put("outputColumns", Integer.valueOf(dataset.schema().length()));
                    mLTransformStage.stageDetail().put("numPartitions", Integer.valueOf(dataset.rdd().partitions().length));
                    if (mLTransformStage.persist()) {
                        sparkSession.catalog().cacheTable(mLTransformStage.outputView(), aRCContext.storageLevel());
                        mLTransformStage.stageDetail().put("records", Long.valueOf(dataset.count()));
                        HashMap hashMap = new HashMap();
                        seq2.foreach(column2 -> {
                            return (Double[]) hashMap.put(column2.toString(), new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(dataset.stat().approxQuantile(column2.toString(), new double[]{0.1d, 0.2d, 0.3d, 0.4d, 0.5d, 0.6d, 0.7d, 0.8d, 0.9d, 1.0d}, 0.1d))).map(obj3 -> {
                                return Double.valueOf(BoxesRunTime.unboxToDouble(obj3));
                            }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Double.class))));
                        });
                        if (hashMap.size() > 0) {
                            mLTransformStage.stageDetail().put("percentiles", hashMap);
                        } else {
                            BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
                        }
                    } else {
                        BoxedUnit boxedUnit3 = BoxedUnit.UNIT;
                    }
                }
                return Option$.MODULE$.apply(dataset);
            } catch (Exception e) {
                throw new MLTransformStage$$anon$2(e, mLTransformStage);
            }
        } catch (Exception e2) {
            throw new MLTransformStage$$anon$1(e2, mLTransformStage);
        }
    }

    public Seq<Transformer> modelStages(PipelineModel pipelineModel) {
        return (Seq) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(pipelineModel.stages())).flatMap(transformer -> {
            if (!(transformer instanceof PipelineModel)) {
                return Nil$.MODULE$.$colon$colon(transformer);
            }
            return MODULE$.modelStages((PipelineModel) transformer);
        }, Array$.MODULE$.fallbackCanBuildFrom(Predef$DummyImplicit$.MODULE$.dummyImplicit()));
    }

    public MLTransformStage apply(MLTransform mLTransform, Option<String> option, String str, Option<String> option2, URI uri, Either<PipelineModel, CrossValidatorModel> either, String str2, String str3, Map<String, String> map, boolean z, Option<Object> option3, List<String> list) {
        return new MLTransformStage(mLTransform, option, str, option2, uri, either, str2, str3, map, z, option3, list);
    }

    public Option<Tuple12<MLTransform, Option<String>, String, Option<String>, URI, Either<PipelineModel, CrossValidatorModel>, String, String, Map<String, String>, Object, Option<Object>, List<String>>> unapply(MLTransformStage mLTransformStage) {
        return mLTransformStage == null ? None$.MODULE$ : new Some(new Tuple12(mLTransformStage.plugin(), mLTransformStage.id(), mLTransformStage.name(), mLTransformStage.description(), mLTransformStage.inputURI(), mLTransformStage.model(), mLTransformStage.inputView(), mLTransformStage.outputView(), mLTransformStage.params(), BoxesRunTime.boxToBoolean(mLTransformStage.persist()), mLTransformStage.numPartitions(), mLTransformStage.partitionBy()));
    }

    private Object readResolve() {
        return MODULE$;
    }

    public static final /* synthetic */ double $anonfun$execute$14(Vector vector) {
        return BoxesRunTime.unboxToDouble(new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(vector.toArray())).max(Ordering$Double$.MODULE$));
    }

    public static final /* synthetic */ void $anonfun$execute$15(ObjectRef objectRef, UserDefinedFunction userDefinedFunction, Column column) {
        objectRef.elem = ((Dataset) objectRef.elem).withColumn(String.valueOf(column), userDefinedFunction.apply(Predef$.MODULE$.wrapRefArray(new Column[]{column})));
    }

    private MLTransformStage$() {
        MODULE$ = this;
    }
}
