package com.cloudera.oryx.app.serving.rdf.model;

import com.cloudera.oryx.app.classreg.predict.CategoricalPrediction;
import com.cloudera.oryx.app.rdf.tree.DecisionForest;
import com.cloudera.oryx.app.rdf.tree.DecisionNode;
import com.cloudera.oryx.app.rdf.tree.DecisionTree;
import com.cloudera.oryx.app.rdf.tree.TerminalNode;
import com.cloudera.oryx.app.schema.CategoricalValueEncodings;
import com.cloudera.oryx.app.speed.rdf.MockRDFClassificationModelGenerator;
import com.cloudera.oryx.common.settings.ConfigUtils;
import com.cloudera.oryx.lambda.serving.AbstractServingIT;
import com.typesafe.config.Config;
import java.util.HashMap;
import java.util.Map;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/cloudera/oryx/app/serving/rdf/model/RDFServingModelManagerIT.class */
public final class RDFServingModelManagerIT extends AbstractServingIT {
    private static final Logger log = LoggerFactory.getLogger(RDFServingModelManagerIT.class);

    @Test
    public void testRDFServingModel() throws Exception {
        HashMap hashMap = new HashMap();
        hashMap.put("oryx.serving.application-resources", "\"com.cloudera.oryx.app.serving,com.cloudera.oryx.app.serving.classreg,com.cloudera.oryx.app.serving.rdf\"");
        hashMap.put("oryx.serving.model-manager-class", RDFServingModelManager.class.getName());
        hashMap.put("oryx.input-schema.feature-names", "[\"color\",\"fruit\"]");
        hashMap.put("oryx.input-schema.numeric-features", "[]");
        hashMap.put("oryx.input-schema.target-feature", "fruit");
        Config overlayOn = ConfigUtils.overlayOn(hashMap, getConfig());
        startMessaging();
        startServer(overlayOn);
        startUpdateTopics(new MockRDFClassificationModelGenerator(), 5);
        sleepSeconds(1);
        RDFServingModelManager rDFServingModelManager = (RDFServingModelManager) getServingLayer().getContext().getServletContext().getAttribute("com.cloudera.oryx.lambda.serving.ModelManagerListener.ModelManager");
        assertNotNull("Manager must initialize in web context", rDFServingModelManager);
        RDFServingModel model = rDFServingModelManager.getModel();
        log.debug("{}", model);
        CategoricalValueEncodings encodings = model.getEncodings();
        assertEquals(2L, encodings.getValueCount(0));
        assertEquals(2L, encodings.getValueCount(1));
        Map encodingValueMap = encodings.getEncodingValueMap(0);
        assertEquals("yellow", encodingValueMap.get(0));
        assertEquals("red", encodingValueMap.get(1));
        Map encodingValueMap2 = encodings.getEncodingValueMap(1);
        assertEquals("banana", encodingValueMap2.get(0));
        assertEquals("apple", encodingValueMap2.get(1));
        DecisionForest forest = model.getForest();
        DecisionTree[] trees = forest.getTrees();
        assertEquals(1L, trees.length);
        assertArrayEquals(new double[]{1.0d}, forest.getWeights());
        assertEquals(2L, model.getInputSchema().getNumFeatures());
        DecisionTree decisionTree = trees[0];
        DecisionNode findByID = decisionTree.findByID("r");
        TerminalNode findByID2 = decisionTree.findByID("r-");
        TerminalNode findByID3 = decisionTree.findByID("r+");
        assertSame(findByID.getLeft(), findByID2);
        assertSame(findByID.getRight(), findByID3);
        assertEquals(7L, findByID2.getCount());
        assertEquals(7L, findByID3.getCount());
        CategoricalPrediction prediction = findByID2.getPrediction();
        CategoricalPrediction prediction2 = findByID3.getPrediction();
        assertEquals(2.0d, prediction.getCategoryCounts()[0]);
        assertEquals(5.0d, prediction.getCategoryCounts()[1]);
        assertEquals(3.0d, prediction2.getCategoryCounts()[0]);
        assertEquals(4.0d, prediction2.getCategoryCounts()[1]);
    }
}
