package com.cloudera.oryx.app.rdf;

import com.cloudera.oryx.app.classreg.predict.CategoricalPrediction;
import com.cloudera.oryx.app.classreg.predict.NumericPrediction;
import com.cloudera.oryx.app.classreg.predict.Prediction;
import com.cloudera.oryx.app.pmml.AppPMMLUtils;
import com.cloudera.oryx.app.rdf.decision.CategoricalDecision;
import com.cloudera.oryx.app.rdf.decision.Decision;
import com.cloudera.oryx.app.rdf.decision.NumericDecision;
import com.cloudera.oryx.app.rdf.tree.DecisionForest;
import com.cloudera.oryx.app.rdf.tree.DecisionNode;
import com.cloudera.oryx.app.rdf.tree.DecisionTree;
import com.cloudera.oryx.app.rdf.tree.TerminalNode;
import com.cloudera.oryx.app.rdf.tree.TreeNode;
import com.cloudera.oryx.app.schema.CategoricalValueEncodings;
import com.cloudera.oryx.app.schema.InputSchema;
import com.cloudera.oryx.common.collection.Pair;
import com.cloudera.oryx.common.text.TextUtils;
import com.google.common.base.Preconditions;
import java.util.BitSet;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMML;
import org.dmg.pmml.ScoreDistribution;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.SimpleSetPredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;

/* loaded from: input_file:com/cloudera/oryx/app/rdf/RDFPMMLUtils.class */
public final class RDFPMMLUtils {
    private RDFPMMLUtils() {
    }

    public static void validatePMMLVsSchema(PMML pmml, InputSchema inputSchema) {
        List models = pmml.getModels();
        Preconditions.checkArgument(models.size() == 1, "Should have exactly one model, but had %s", new Object[]{Integer.valueOf(models.size())});
        Model model = (Model) models.get(0);
        MiningFunction miningFunction = model.getMiningFunction();
        if (inputSchema.isClassification()) {
            Preconditions.checkArgument(miningFunction == MiningFunction.CLASSIFICATION, "Expected classification function type but got %s", new Object[]{miningFunction});
        } else {
            Preconditions.checkArgument(miningFunction == MiningFunction.REGRESSION, "Expected regression function type but got %s", new Object[]{miningFunction});
        }
        Preconditions.checkArgument(inputSchema.getFeatureNames().equals(AppPMMLUtils.getFeatureNames(pmml.getDataDictionary())), "Feature names in schema don't match names in PMML");
        MiningSchema miningSchema = model.getMiningSchema();
        Preconditions.checkArgument(inputSchema.getFeatureNames().equals(AppPMMLUtils.getFeatureNames(miningSchema)));
        Integer findTargetIndex = AppPMMLUtils.findTargetIndex(miningSchema);
        if (!inputSchema.hasTarget()) {
            Preconditions.checkArgument(findTargetIndex == null);
        } else {
            int targetFeatureIndex = inputSchema.getTargetFeatureIndex();
            Preconditions.checkArgument(findTargetIndex != null && targetFeatureIndex == findTargetIndex.intValue(), "Configured schema expects target at index %s, but PMML has target at index %s", new Object[]{Integer.valueOf(targetFeatureIndex), findTargetIndex});
        }
    }

    public static Pair<DecisionForest, CategoricalValueEncodings> read(PMML pmml) {
        DecisionTree[] decisionTreeArr;
        double[] dArr;
        DataDictionary dataDictionary = pmml.getDataDictionary();
        List<String> featureNames = AppPMMLUtils.getFeatureNames(dataDictionary);
        CategoricalValueEncodings buildCategoricalValueEncodings = AppPMMLUtils.buildCategoricalValueEncodings(dataDictionary);
        MiningModel miningModel = (Model) pmml.getModels().get(0);
        MiningSchema miningSchema = miningModel.getMiningSchema();
        int intValue = ((Integer) Objects.requireNonNull(AppPMMLUtils.findTargetIndex(miningSchema))).intValue();
        if (miningModel instanceof MiningModel) {
            Segmentation segmentation = miningModel.getSegmentation();
            Preconditions.checkArgument(segmentation.getMultipleModelMethod() == Segmentation.MultipleModelMethod.WEIGHTED_AVERAGE || segmentation.getMultipleModelMethod() == Segmentation.MultipleModelMethod.WEIGHTED_MAJORITY_VOTE);
            List segments = segmentation.getSegments();
            Preconditions.checkArgument(!segments.isEmpty());
            decisionTreeArr = new DecisionTree[segments.size()];
            dArr = new double[decisionTreeArr.length];
            for (int i = 0; i < decisionTreeArr.length; i++) {
                Segment segment = (Segment) segments.get(i);
                Preconditions.checkArgument(segment.getPredicate() instanceof True);
                dArr[i] = segment.getWeight().doubleValue();
                decisionTreeArr[i] = new DecisionTree(translateFromPMML(segment.getModel().getNode(), buildCategoricalValueEncodings, featureNames, intValue));
            }
        } else {
            decisionTreeArr = new DecisionTree[]{new DecisionTree(translateFromPMML(((TreeModel) miningModel).getNode(), buildCategoricalValueEncodings, featureNames, intValue))};
            dArr = new double[]{1.0d};
        }
        List miningFields = miningSchema.getMiningFields();
        double[] dArr2 = new double[featureNames.size()];
        for (int i2 = 0; i2 < miningFields.size(); i2++) {
            Double importance = ((MiningField) miningFields.get(i2)).getImportance();
            if (importance != null) {
                dArr2[i2] = importance.doubleValue();
            }
        }
        return new Pair<>(new DecisionForest(decisionTreeArr, dArr, dArr2), buildCategoricalValueEncodings);
    }

    private static TreeNode translateFromPMML(Node node, CategoricalValueEncodings categoricalValueEncodings, List<String> list, int i) {
        Node node2;
        Node node3;
        Decision categoricalDecision;
        Prediction numericPrediction;
        String id = node.getId();
        List nodes = node.getNodes();
        if (nodes.isEmpty()) {
            List<ScoreDistribution> scoreDistributions = node.getScoreDistributions();
            if (scoreDistributions == null || scoreDistributions.isEmpty()) {
                numericPrediction = new NumericPrediction(Double.parseDouble(node.getScore()), (int) Math.round(node.getRecordCount().doubleValue()));
            } else {
                Map<String, Integer> valueEncodingMap = categoricalValueEncodings.getValueEncodingMap(i);
                double[] dArr = new double[valueEncodingMap.size()];
                for (ScoreDistribution scoreDistribution : scoreDistributions) {
                    dArr[valueEncodingMap.get(scoreDistribution.getValue()).intValue()] = scoreDistribution.getRecordCount();
                }
                numericPrediction = new CategoricalPrediction(dArr);
            }
            return new TerminalNode(id, numericPrediction);
        }
        Preconditions.checkArgument(nodes.size() == 2);
        Node node4 = (Node) nodes.get(0);
        Node node5 = (Node) nodes.get(1);
        if (node4.getPredicate() instanceof True) {
            node2 = node4;
            node3 = node5;
        } else {
            Preconditions.checkArgument(node5.getPredicate() instanceof True);
            node2 = node5;
            node3 = node4;
        }
        SimplePredicate predicate = node3.getPredicate();
        boolean equals = node3.getId().equals(node.getDefaultChild());
        if (predicate instanceof SimplePredicate) {
            SimplePredicate simplePredicate = predicate;
            SimplePredicate.Operator operator = simplePredicate.getOperator();
            Preconditions.checkArgument(operator == SimplePredicate.Operator.GREATER_OR_EQUAL || operator == SimplePredicate.Operator.GREATER_THAN);
            double parseDouble = Double.parseDouble(simplePredicate.getValue());
            if (operator == SimplePredicate.Operator.GREATER_THAN) {
                parseDouble += Math.ulp(parseDouble);
            }
            categoricalDecision = new NumericDecision(list.indexOf(simplePredicate.getField().getValue()), parseDouble, equals);
        } else {
            Preconditions.checkArgument(predicate instanceof SimpleSetPredicate);
            SimpleSetPredicate simpleSetPredicate = (SimpleSetPredicate) predicate;
            SimpleSetPredicate.BooleanOperator booleanOperator = simpleSetPredicate.getBooleanOperator();
            Preconditions.checkArgument(booleanOperator == SimpleSetPredicate.BooleanOperator.IS_IN || booleanOperator == SimpleSetPredicate.BooleanOperator.IS_NOT_IN);
            int indexOf = list.indexOf(simpleSetPredicate.getField().getValue());
            Map<String, Integer> valueEncodingMap2 = categoricalValueEncodings.getValueEncodingMap(indexOf);
            String[] parseDelimited = TextUtils.parseDelimited(simpleSetPredicate.getArray().getValue(), ' ');
            BitSet bitSet = new BitSet(valueEncodingMap2.size());
            if (booleanOperator == SimpleSetPredicate.BooleanOperator.IS_IN) {
                for (String str : parseDelimited) {
                    bitSet.set(valueEncodingMap2.get(str).intValue());
                }
            } else {
                Collection<Integer> values = valueEncodingMap2.values();
                bitSet.getClass();
                values.forEach((v1) -> {
                    r1.set(v1);
                });
                for (String str2 : parseDelimited) {
                    bitSet.clear(valueEncodingMap2.get(str2).intValue());
                }
            }
            categoricalDecision = new CategoricalDecision(indexOf, bitSet, equals);
        }
        return new DecisionNode(id, categoricalDecision, translateFromPMML(node2, categoricalValueEncodings, list, i), translateFromPMML(node3, categoricalValueEncodings, list, i));
    }
}
