package com.cloudera.oryx.app.rdf;

import com.cloudera.oryx.app.rdf.tree.DecisionForest;
import com.cloudera.oryx.app.schema.CategoricalValueEncodings;
import com.cloudera.oryx.app.schema.InputSchema;
import com.cloudera.oryx.common.OryxTest;
import com.cloudera.oryx.common.collection.Pair;
import com.cloudera.oryx.common.pmml.PMMLUtils;
import com.cloudera.oryx.common.settings.ConfigUtils;
import java.util.ArrayList;
import java.util.HashMap;
import org.dmg.pmml.Array;
import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.ScoreDistribution;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.SimpleSetPredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.Value;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.junit.Test;

/* loaded from: input_file:com/cloudera/oryx/app/rdf/RDFPMMLUtilsTest.class */
public final class RDFPMMLUtilsTest extends OryxTest {
    @Test
    public void testValidateClassification() {
        PMML buildDummyClassificationModel = buildDummyClassificationModel();
        HashMap hashMap = new HashMap();
        hashMap.put("oryx.input-schema.feature-names", "[\"color\",\"fruit\"]");
        hashMap.put("oryx.input-schema.numeric-features", "[]");
        hashMap.put("oryx.input-schema.target-feature", "fruit");
        RDFPMMLUtils.validatePMMLVsSchema(buildDummyClassificationModel, new InputSchema(ConfigUtils.overlayOn(hashMap, ConfigUtils.getDefault())));
    }

    @Test
    public void testValidateRegression() {
        PMML buildDummyRegressionModel = buildDummyRegressionModel();
        HashMap hashMap = new HashMap();
        hashMap.put("oryx.input-schema.feature-names", "[\"foo\",\"bar\"]");
        hashMap.put("oryx.input-schema.categorical-features", "[]");
        hashMap.put("oryx.input-schema.target-feature", "bar");
        RDFPMMLUtils.validatePMMLVsSchema(buildDummyRegressionModel, new InputSchema(ConfigUtils.overlayOn(hashMap, ConfigUtils.getDefault())));
    }

    @Test
    public void testReadClassification() {
        Pair read = RDFPMMLUtils.read(buildDummyClassificationModel());
        DecisionForest decisionForest = (DecisionForest) read.getFirst();
        assertEquals(1L, decisionForest.getTrees().length);
        assertArrayEquals(new double[]{1.0d}, decisionForest.getWeights());
        assertArrayEquals(new double[]{0.5d, 0.0d}, decisionForest.getFeatureImportances());
        CategoricalValueEncodings categoricalValueEncodings = (CategoricalValueEncodings) read.getSecond();
        assertEquals(2L, categoricalValueEncodings.getValueCount(0));
        assertEquals(2L, categoricalValueEncodings.getValueCount(1));
    }

    @Test
    public void testReadRegression() {
        assertEquals(0L, ((CategoricalValueEncodings) RDFPMMLUtils.read(buildDummyRegressionModel()).getSecond()).getCategoryCounts().size());
    }

    @Test
    public void testReadClassificationForest() {
        assertEquals(3L, ((DecisionForest) RDFPMMLUtils.read(buildDummyClassificationModel(3)).getFirst()).getTrees().length);
    }

    public static PMML buildDummyClassificationModel() {
        return buildDummyClassificationModel(1);
    }

    private static PMML buildDummyClassificationModel(int i) {
        PMML buildSkeletonPMML = PMMLUtils.buildSkeletonPMML();
        ArrayList arrayList = new ArrayList();
        DataField dataField = new DataField(FieldName.create("color"), OpType.CATEGORICAL, DataType.STRING);
        dataField.addValues(new Value[]{new Value("yellow"), new Value("red")});
        arrayList.add(dataField);
        DataField dataField2 = new DataField(FieldName.create("fruit"), OpType.CATEGORICAL, DataType.STRING);
        dataField2.addValues(new Value[]{new Value("banana"), new Value("apple")});
        arrayList.add(dataField2);
        buildSkeletonPMML.setDataDictionary(new DataDictionary(arrayList).setNumberOfFields(Integer.valueOf(arrayList.size())));
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(new MiningField(FieldName.create("color")).setOpType(OpType.CATEGORICAL).setUsageType(MiningField.UsageType.ACTIVE).setImportance(Double.valueOf(0.5d)));
        arrayList2.add(new MiningField(FieldName.create("fruit")).setOpType(OpType.CATEGORICAL).setUsageType(MiningField.UsageType.PREDICTED));
        MiningSchema miningSchema = new MiningSchema(arrayList2);
        Node predicate = new Node().setId("r").setRecordCount(Double.valueOf(2.0d)).setPredicate(new True());
        double d = 2.0d / 2.0d;
        Node predicate2 = new Node().setId("r-").setRecordCount(Double.valueOf(d)).setPredicate(new True());
        predicate2.addScoreDistributions(new ScoreDistribution[]{new ScoreDistribution("apple", d)});
        Node predicate3 = new Node().setId("r+").setRecordCount(Double.valueOf(d)).setPredicate(new SimpleSetPredicate(FieldName.create("color"), SimpleSetPredicate.BooleanOperator.IS_NOT_IN, new Array(Array.Type.STRING, "red")));
        predicate3.addScoreDistributions(new ScoreDistribution[]{new ScoreDistribution("banana", d)});
        predicate.addNodes(new Node[]{predicate3, predicate2});
        Model missingValueStrategy = new TreeModel(MiningFunction.CLASSIFICATION, miningSchema, predicate).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT).setMissingValueStrategy(TreeModel.MissingValueStrategy.DEFAULT_CHILD);
        if (i > 1) {
            Model miningModel = new MiningModel(MiningFunction.CLASSIFICATION, miningSchema);
            ArrayList arrayList3 = new ArrayList();
            for (int i2 = 0; i2 < i; i2++) {
                arrayList3.add(new Segment().setId(Integer.toString(i2)).setPredicate(new True()).setModel(missingValueStrategy).setWeight(Double.valueOf(1.0d)));
            }
            miningModel.setSegmentation(new Segmentation(Segmentation.MultipleModelMethod.WEIGHTED_MAJORITY_VOTE, arrayList3));
            buildSkeletonPMML.addModels(new Model[]{miningModel});
        } else {
            buildSkeletonPMML.addModels(new Model[]{missingValueStrategy});
        }
        return buildSkeletonPMML;
    }

    public static PMML buildDummyRegressionModel() {
        PMML buildSkeletonPMML = PMMLUtils.buildSkeletonPMML();
        ArrayList arrayList = new ArrayList();
        arrayList.add(new DataField(FieldName.create("foo"), OpType.CONTINUOUS, DataType.DOUBLE));
        arrayList.add(new DataField(FieldName.create("bar"), OpType.CONTINUOUS, DataType.DOUBLE));
        buildSkeletonPMML.setDataDictionary(new DataDictionary(arrayList).setNumberOfFields(Integer.valueOf(arrayList.size())));
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(new MiningField(FieldName.create("foo")).setOpType(OpType.CONTINUOUS).setUsageType(MiningField.UsageType.ACTIVE).setImportance(Double.valueOf(0.5d)));
        arrayList2.add(new MiningField(FieldName.create("bar")).setOpType(OpType.CONTINUOUS).setUsageType(MiningField.UsageType.PREDICTED));
        MiningSchema miningSchema = new MiningSchema(arrayList2);
        Node predicate = new Node().setId("r").setRecordCount(Double.valueOf(2.0d)).setPredicate(new True());
        double d = 2.0d / 2.0d;
        predicate.addNodes(new Node[]{new Node().setId("r+").setRecordCount(Double.valueOf(d)).setPredicate(new SimplePredicate(FieldName.create("foo"), SimplePredicate.Operator.GREATER_THAN).setValue("3.14")).setScore("2.0"), new Node().setId("r-").setRecordCount(Double.valueOf(d)).setPredicate(new True()).setScore("-2.0")});
        buildSkeletonPMML.addModels(new Model[]{new TreeModel(MiningFunction.REGRESSION, miningSchema, predicate).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT).setMissingValueStrategy(TreeModel.MissingValueStrategy.DEFAULT_CHILD).setMiningSchema(miningSchema)});
        return buildSkeletonPMML;
    }
}
