package org.apache.spark.ml.bundle.ops.classification;

import ml.combust.bundle.dsl.Attribute;
import ml.combust.bundle.dsl.Bundle$BuiltinOps$classification$;
import ml.combust.bundle.dsl.ReadableModel;
import ml.combust.bundle.dsl.ReadableNode;
import ml.combust.bundle.dsl.Shape;
import ml.combust.bundle.dsl.Shape$;
import ml.combust.bundle.dsl.Value$;
import ml.combust.bundle.dsl.WritableModel;
import ml.combust.bundle.op.OpModel;
import ml.combust.bundle.op.OpNode;
import ml.combust.bundle.serializer.BundleContext;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.linalg.Vectors$;
import scala.Predef$;
import scala.reflect.ClassTag$;

/* compiled from: LogisticRegressionOp.scala */
/* loaded from: input_file:org/apache/spark/ml/bundle/ops/classification/LogisticRegressionOp$.class */
public final class LogisticRegressionOp$ implements OpNode<LogisticRegressionModel, LogisticRegressionModel> {
    public static final LogisticRegressionOp$ MODULE$ = null;
    private final OpModel<LogisticRegressionModel> Model;

    static {
        new LogisticRegressionOp$();
    }

    public OpModel<LogisticRegressionModel> Model() {
        return this.Model;
    }

    public String name(LogisticRegressionModel logisticRegressionModel) {
        return logisticRegressionModel.uid();
    }

    public LogisticRegressionModel model(LogisticRegressionModel logisticRegressionModel) {
        return logisticRegressionModel;
    }

    public LogisticRegressionModel load(BundleContext bundleContext, ReadableNode readableNode, LogisticRegressionModel logisticRegressionModel) {
        LogisticRegressionModel predictionCol = new LogisticRegressionModel(readableNode.name(), logisticRegressionModel.coefficients(), logisticRegressionModel.intercept()).copy(logisticRegressionModel.extractParamMap()).setFeaturesCol(readableNode.shape().input("features").name()).setPredictionCol(readableNode.shape().output("prediction").name());
        return (LogisticRegressionModel) readableNode.shape().getOutput("probability").map(new LogisticRegressionOp$$anonfun$load$3(predictionCol)).getOrElse(new LogisticRegressionOp$$anonfun$load$4(predictionCol));
    }

    public Shape shape(LogisticRegressionModel logisticRegressionModel) {
        Shape withOutput = new Shape(Shape$.MODULE$.apply$default$1()).withInput(logisticRegressionModel.getFeaturesCol(), "features").withOutput(logisticRegressionModel.getPredictionCol(), "prediction");
        return (Shape) logisticRegressionModel.get(logisticRegressionModel.probabilityCol()).map(new LogisticRegressionOp$$anonfun$shape$1(withOutput)).getOrElse(new LogisticRegressionOp$$anonfun$shape$2(withOutput));
    }

    private LogisticRegressionOp$() {
        MODULE$ = this;
        this.Model = new OpModel<LogisticRegressionModel>() { // from class: org.apache.spark.ml.bundle.ops.classification.LogisticRegressionOp$$anon$1
            public String opName() {
                return Bundle$BuiltinOps$classification$.MODULE$.logistic_regression();
            }

            public WritableModel store(BundleContext bundleContext, WritableModel writableModel, LogisticRegressionModel logisticRegressionModel) {
                WritableModel withAttr = writableModel.withAttr(new Attribute("coefficients", Value$.MODULE$.doubleVector(Predef$.MODULE$.wrapDoubleArray(logisticRegressionModel.coefficients().toArray())))).withAttr(new Attribute("intercept", Value$.MODULE$.double(logisticRegressionModel.intercept()))).withAttr(new Attribute("num_classes", Value$.MODULE$.long(logisticRegressionModel.numClasses())));
                return (WritableModel) logisticRegressionModel.get(logisticRegressionModel.threshold()).map(new LogisticRegressionOp$$anon$1$$anonfun$store$1(this, withAttr)).getOrElse(new LogisticRegressionOp$$anon$1$$anonfun$store$2(this, withAttr));
            }

            /* renamed from: load, reason: merged with bridge method [inline-methods] */
            public LogisticRegressionModel m14load(BundleContext bundleContext, ReadableModel readableModel) {
                if (readableModel.value("num_classes").getLong() != 2) {
                    throw new Error("Only binary logistic regression supported in Spark");
                }
                LogisticRegressionModel logisticRegressionModel = new LogisticRegressionModel("", Vectors$.MODULE$.dense((double[]) readableModel.value("coefficients").getDoubleVector().toArray(ClassTag$.MODULE$.Double())), readableModel.value("intercept").getDouble());
                return (LogisticRegressionModel) readableModel.getValue("threshold").map(new LogisticRegressionOp$$anon$1$$anonfun$load$1(this, logisticRegressionModel)).getOrElse(new LogisticRegressionOp$$anon$1$$anonfun$load$2(this, logisticRegressionModel));
            }
        };
    }
}
