package com.flipkart.fdp.ml.adapter;

import com.flipkart.fdp.ml.modelinfo.DecisionTreeModelInfo;
import java.util.Stack;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.tree.model.Node;
import org.apache.spark.mllib.tree.model.Split;
import org.apache.spark.sql.DataFrame;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/flipkart/fdp/ml/adapter/DecisionTreeModelInfoAdapter.class */
public class DecisionTreeModelInfoAdapter extends AbstractModelInfoAdapter<DecisionTreeModel, DecisionTreeModelInfo> {
    private static final Logger log = LoggerFactory.getLogger(DecisionTreeModelInfoAdapter.class);

    private void visit(Node node, Stack<Node> stack, DecisionTreeModelInfo decisionTreeModelInfo) {
        DecisionTreeModelInfo.DecisionNode decisionNode = new DecisionTreeModelInfo.DecisionNode();
        decisionNode.setId(node.id());
        decisionNode.setLeaf(node.isLeaf());
        if (node.split().nonEmpty()) {
            Split split = (Split) node.split().get();
            decisionNode.setFeature(split.feature());
            decisionNode.setThreshold(split.threshold());
            decisionNode.setFeatureType(split.featureType().toString());
        }
        decisionNode.setPredict(node.predict().predict());
        decisionNode.setProbability(node.predict().prob());
        decisionTreeModelInfo.getNodeInfo().put(Integer.valueOf(node.id()), decisionNode);
        if (node.rightNode().nonEmpty()) {
            Node node2 = (Node) node.rightNode().get();
            decisionTreeModelInfo.getRightChildMap().put(Integer.valueOf(node.id()), Integer.valueOf(node2.id()));
            stack.push(node2);
        }
        if (node.leftNode().nonEmpty()) {
            Node node3 = (Node) node.leftNode().get();
            decisionTreeModelInfo.getLeftChildMap().put(Integer.valueOf(node.id()), Integer.valueOf(node3.id()));
            stack.push(node3);
        }
    }

    @Override // com.flipkart.fdp.ml.adapter.AbstractModelInfoAdapter
    public DecisionTreeModelInfo getModelInfo(DecisionTreeModel decisionTreeModel, DataFrame dataFrame) {
        DecisionTreeModelInfo decisionTreeModelInfo = new DecisionTreeModelInfo();
        Node node = decisionTreeModel.topNode();
        decisionTreeModelInfo.setRoot(node.id());
        Stack<Node> stack = new Stack<>();
        stack.push(node);
        while (!stack.empty()) {
            visit(stack.pop(), stack, decisionTreeModelInfo);
        }
        return decisionTreeModelInfo;
    }

    @Override // com.flipkart.fdp.ml.adapter.ModelInfoAdapter
    public Class<DecisionTreeModel> getSource() {
        return DecisionTreeModel.class;
    }

    @Override // com.flipkart.fdp.ml.adapter.ModelInfoAdapter
    public Class<DecisionTreeModelInfo> getTarget() {
        return DecisionTreeModelInfo.class;
    }
}
