package org.apache.spark.ml.bundle.ops.classification;

import ml.combust.bundle.dsl.Attribute;
import ml.combust.bundle.dsl.Bundle$BuiltinOps$classification$;
import ml.combust.bundle.dsl.ReadableModel;
import ml.combust.bundle.dsl.ReadableNode;
import ml.combust.bundle.dsl.Shape;
import ml.combust.bundle.dsl.Shape$;
import ml.combust.bundle.dsl.Value$;
import ml.combust.bundle.dsl.WritableModel;
import ml.combust.bundle.op.OpModel;
import ml.combust.bundle.op.OpNode;
import ml.combust.bundle.serializer.BundleContext;
import org.apache.spark.ml.classification.GBTClassificationModel;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import scala.Array$;
import scala.Predef$;
import scala.collection.Seq$;
import scala.collection.TraversableOnce;
import scala.reflect.ClassTag$;
import scala.runtime.IntRef;

/* compiled from: GBTClassifierOp.scala */
/* loaded from: input_file:org/apache/spark/ml/bundle/ops/classification/GBTClassifierOp$.class */
public final class GBTClassifierOp$ implements OpNode<GBTClassificationModel, GBTClassificationModel> {
    public static final GBTClassifierOp$ MODULE$ = null;
    private final OpModel<GBTClassificationModel> Model;

    static {
        new GBTClassifierOp$();
    }

    public OpModel<GBTClassificationModel> Model() {
        return this.Model;
    }

    public String name(GBTClassificationModel gBTClassificationModel) {
        return gBTClassificationModel.uid();
    }

    public GBTClassificationModel model(GBTClassificationModel gBTClassificationModel) {
        return gBTClassificationModel;
    }

    public GBTClassificationModel load(BundleContext bundleContext, ReadableNode readableNode, GBTClassificationModel gBTClassificationModel) {
        return new GBTClassificationModel(readableNode.name(), gBTClassificationModel.trees(), gBTClassificationModel.treeWeights(), gBTClassificationModel.numFeatures());
    }

    public Shape shape(GBTClassificationModel gBTClassificationModel) {
        return new Shape(Shape$.MODULE$.apply$default$1()).withInput(gBTClassificationModel.getFeaturesCol(), "features").withOutput(gBTClassificationModel.getPredictionCol(), "prediction");
    }

    private GBTClassifierOp$() {
        MODULE$ = this;
        this.Model = new OpModel<GBTClassificationModel>() { // from class: org.apache.spark.ml.bundle.ops.classification.GBTClassifierOp$$anon$1
            public String opName() {
                return Bundle$BuiltinOps$classification$.MODULE$.gbt_classifier();
            }

            public WritableModel store(BundleContext bundleContext, WritableModel writableModel, GBTClassificationModel gBTClassificationModel) {
                return writableModel.withAttr(new Attribute("num_features", Value$.MODULE$.long(gBTClassificationModel.numFeatures()))).withAttr(new Attribute("num_classes", Value$.MODULE$.long(2L))).withAttr(new Attribute("tree_weights", Value$.MODULE$.doubleList(Predef$.MODULE$.wrapDoubleArray(gBTClassificationModel.treeWeights())))).withAttr(new Attribute("trees", Value$.MODULE$.stringList(Predef$.MODULE$.wrapRefArray((String[]) Predef$.MODULE$.refArrayOps(gBTClassificationModel.trees()).map(new GBTClassifierOp$$anon$1$$anonfun$1(this, bundleContext, new IntRef(0)), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(String.class)))))));
            }

            /* renamed from: load, reason: merged with bridge method [inline-methods] */
            public GBTClassificationModel m10load(BundleContext bundleContext, ReadableModel readableModel) {
                if (readableModel.value("num_classes").getLong() != 2) {
                    throw new Error("MLeap only supports binary logistic regression");
                }
                int i = (int) readableModel.value("num_features").getLong();
                return new GBTClassificationModel("", (DecisionTreeRegressionModel[]) ((TraversableOnce) readableModel.value("trees").getStringList().map(new GBTClassifierOp$$anon$1$$anonfun$2(this, bundleContext), Seq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.apply(DecisionTreeRegressionModel.class)), (double[]) readableModel.value("tree_weights").getDoubleList().toArray(ClassTag$.MODULE$.Double()), i);
            }
        };
    }
}
