package org.apache.spark.ml.bundle.ops.classification;

import ml.combust.bundle.dsl.Attribute;
import ml.combust.bundle.dsl.Bundle$BuiltinOps$classification$;
import ml.combust.bundle.dsl.ReadableModel;
import ml.combust.bundle.dsl.ReadableNode;
import ml.combust.bundle.dsl.Shape;
import ml.combust.bundle.dsl.Shape$;
import ml.combust.bundle.dsl.Value$;
import ml.combust.bundle.dsl.WritableModel;
import ml.combust.bundle.op.OpModel;
import ml.combust.bundle.op.OpNode;
import ml.combust.bundle.serializer.BundleContext;
import ml.combust.bundle.tree.TreeSerializer;
import org.apache.spark.ml.bundle.tree.SparkNodeWrapper$;
import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
import org.apache.spark.ml.tree.Node;

/* compiled from: DecisionTreeClassifierOp.scala */
/* loaded from: input_file:org/apache/spark/ml/bundle/ops/classification/DecisionTreeClassifierOp$.class */
public final class DecisionTreeClassifierOp$ implements OpNode<DecisionTreeClassificationModel, DecisionTreeClassificationModel> {
    public static final DecisionTreeClassifierOp$ MODULE$ = null;
    private final SparkNodeWrapper$ nodeWrapper;
    private final OpModel<DecisionTreeClassificationModel> Model;

    static {
        new DecisionTreeClassifierOp$();
    }

    public SparkNodeWrapper$ nodeWrapper() {
        return this.nodeWrapper;
    }

    public OpModel<DecisionTreeClassificationModel> Model() {
        return this.Model;
    }

    public String name(DecisionTreeClassificationModel decisionTreeClassificationModel) {
        return decisionTreeClassificationModel.uid();
    }

    public DecisionTreeClassificationModel model(DecisionTreeClassificationModel decisionTreeClassificationModel) {
        return decisionTreeClassificationModel;
    }

    public DecisionTreeClassificationModel load(BundleContext bundleContext, ReadableNode readableNode, DecisionTreeClassificationModel decisionTreeClassificationModel) {
        return new DecisionTreeClassificationModel(readableNode.name(), decisionTreeClassificationModel.rootNode(), decisionTreeClassificationModel.numFeatures(), decisionTreeClassificationModel.numClasses()).setFeaturesCol(readableNode.shape().input("features").name()).setPredictionCol(readableNode.shape().output("prediction").name());
    }

    public Shape shape(DecisionTreeClassificationModel decisionTreeClassificationModel) {
        return new Shape(Shape$.MODULE$.apply$default$1()).withInput(decisionTreeClassificationModel.getFeaturesCol(), "features").withOutput(decisionTreeClassificationModel.getPredictionCol(), "prediction");
    }

    private DecisionTreeClassifierOp$() {
        MODULE$ = this;
        this.nodeWrapper = SparkNodeWrapper$.MODULE$;
        this.Model = new OpModel<DecisionTreeClassificationModel>() { // from class: org.apache.spark.ml.bundle.ops.classification.DecisionTreeClassifierOp$$anon$1
            public String opName() {
                return Bundle$BuiltinOps$classification$.MODULE$.decision_tree_classifier();
            }

            public WritableModel store(BundleContext bundleContext, WritableModel writableModel, DecisionTreeClassificationModel decisionTreeClassificationModel) {
                new TreeSerializer(bundleContext.file("nodes"), true, DecisionTreeClassifierOp$.MODULE$.nodeWrapper()).write(decisionTreeClassificationModel.rootNode());
                return writableModel.withAttr(new Attribute("num_features", Value$.MODULE$.long(decisionTreeClassificationModel.numFeatures()))).withAttr(new Attribute("num_classes", Value$.MODULE$.long(decisionTreeClassificationModel.numClasses())));
            }

            /* renamed from: load, reason: merged with bridge method [inline-methods] */
            public DecisionTreeClassificationModel m8load(BundleContext bundleContext, ReadableModel readableModel) {
                return new DecisionTreeClassificationModel("", (Node) new TreeSerializer(bundleContext.file("nodes"), true, DecisionTreeClassifierOp$.MODULE$.nodeWrapper()).read(), (int) readableModel.value("num_features").getLong(), (int) readableModel.value("num_classes").getLong());
            }
        };
    }
}
