package com.cloudera.oryx.app.rdf.tree;

import com.cloudera.oryx.app.classreg.example.Example;
import com.cloudera.oryx.app.classreg.example.Feature;
import com.cloudera.oryx.app.classreg.example.NumericFeature;
import com.cloudera.oryx.app.classreg.predict.NumericPrediction;
import com.cloudera.oryx.app.rdf.decision.NumericDecision;
import com.cloudera.oryx.common.OryxTest;
import org.junit.Test;

/* loaded from: input_file:com/cloudera/oryx/app/rdf/tree/DecisionTreeTest.class */
public final class DecisionTreeTest extends OryxTest {
    /* JADX INFO: Access modifiers changed from: package-private */
    public static DecisionTree buildTestTree() {
        return new DecisionTree(new DecisionNode("r", new NumericDecision(0, 1.0d, false), new DecisionNode("r-", new NumericDecision(0, -1.0d, false), new TerminalNode("r--", new NumericPrediction(0.0d, 1)), new TerminalNode("r-+", new NumericPrediction(1.0d, 1))), new TerminalNode("r+", new NumericPrediction(2.0d, 1))));
    }

    @Test
    public void testPredict() {
        assertEquals(1.0d, buildTestTree().predict(new Example((Feature) null, new Feature[]{NumericFeature.forValue(0.5d)})).getPrediction());
    }

    @Test
    public void testFindTerminal() {
        assertEquals(1.0d, buildTestTree().findTerminal(new Example((Feature) null, new Feature[]{NumericFeature.forValue(0.5d)})).getPrediction().getPrediction());
    }

    @Test
    public void testFindByID() {
        assertEquals(1.0d, buildTestTree().findByID("r-+").getPrediction().getPrediction());
    }

    @Test
    public void testToString() {
        String decisionTree = buildTestTree().toString();
        assertTrue(decisionTree.startsWith("(#0 >= 1.0)"));
        assertContains(decisionTree, "(#0 >= -1.0)");
    }
}
