package com.cloudera.oryx.app.speed.rdf;

import com.cloudera.oryx.api.KeyMessage;
import com.cloudera.oryx.app.pmml.AppPMMLUtils;
import com.cloudera.oryx.app.schema.CategoricalValueEncodings;
import com.cloudera.oryx.common.pmml.PMMLUtils;
import com.cloudera.oryx.common.random.RandomManager;
import com.cloudera.oryx.common.settings.ConfigUtils;
import com.cloudera.oryx.common.text.TextUtils;
import com.cloudera.oryx.lambda.speed.AbstractSpeedIT;
import com.typesafe.config.Config;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.distribution.BinomialDistribution;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/cloudera/oryx/app/speed/rdf/RDFSpeedIT.class */
public final class RDFSpeedIT extends AbstractSpeedIT {
    private static final Logger log = LoggerFactory.getLogger(RDFSpeedIT.class);
    private static final int NUM_INPUT = 500;

    @Test
    public void testRDFSpeedRegression() throws Exception {
        HashMap hashMap = new HashMap();
        hashMap.put("oryx.speed.model-manager-class", RDFSpeedModelManager.class.getName());
        hashMap.put("oryx.speed.streaming.generation-interval-sec", 10);
        hashMap.put("oryx.input-schema.feature-names", "[\"foo\",\"bar\"]");
        hashMap.put("oryx.input-schema.categorical-features", "[]");
        hashMap.put("oryx.input-schema.target-feature", "bar");
        Config overlayOn = ConfigUtils.overlayOn(hashMap, getConfig());
        startMessaging();
        List startServerProduceConsumeTopics = startServerProduceConsumeTopics(overlayOn, new MockRDFRegressionInputGenerator(), new MockRDFRegressionModelGenerator(), NUM_INPUT, 1);
        if (log.isDebugEnabled()) {
            startServerProduceConsumeTopics.forEach(keyMessage -> {
                log.debug("{}", keyMessage);
            });
        }
        int size = startServerProduceConsumeTopics.size();
        assertGreaterOrEqual(size, 3.0d);
        assertNotEquals(0L, size % 2);
        assertEquals("MODEL", ((KeyMessage) startServerProduceConsumeTopics.get(0)).getKey());
        for (int i = 1; i < size; i++) {
            KeyMessage keyMessage2 = (KeyMessage) startServerProduceConsumeTopics.get(i);
            assertEquals("UP", keyMessage2.getKey());
            List list = (List) TextUtils.readJSON((String) keyMessage2.getMessage(), List.class);
            int intValue = ((Integer) list.get(0)).intValue();
            String obj = list.get(1).toString();
            double doubleValue = ((Double) list.get(2)).doubleValue();
            int intValue2 = ((Integer) list.get(3)).intValue();
            assertEquals(0L, intValue);
            assertContains(Arrays.asList("r-", "r+"), obj);
            double[] minMaxExpectedMean = minMaxExpectedMean(intValue2, "r+".equals(obj));
            assertRange(doubleValue, minMaxExpectedMean[0] - 1.0E-12d, minMaxExpectedMean[1] + 1.0E-12d);
        }
        for (int i2 = 1; i2 < size; i2 += 2) {
            KeyMessage keyMessage3 = (KeyMessage) startServerProduceConsumeTopics.get(i2);
            KeyMessage keyMessage4 = (KeyMessage) startServerProduceConsumeTopics.get(i2 + 1);
            List list2 = (List) TextUtils.readJSON((String) keyMessage3.getMessage(), List.class);
            List list3 = (List) TextUtils.readJSON((String) keyMessage4.getMessage(), List.class);
            assertLessOrEqual(Math.abs(((Integer) list2.get(3)).intValue() - ((Integer) list3.get(3)).intValue()), 1.0d);
            String obj2 = list2.get(1).toString();
            String obj3 = list3.get(1).toString();
            if ("r-".equals(obj2)) {
                assertEquals("r+", obj3);
            } else {
                assertEquals("r-", obj3);
            }
        }
    }

    private static double[] minMaxExpectedMean(int i, boolean z) {
        double d;
        int i2;
        double d2 = 0.0d;
        double d3 = 0.0d;
        int i3 = 5 - (i % 5);
        for (int i4 = 0; i4 < i; i4++) {
            if (z) {
                d2 += 1 + (2 * (i4 % 5));
                d = d3;
                i2 = 1 + (2 * ((i4 + i3) % 5));
            } else {
                d2 += (-2) * ((i4 + i3) % 5);
                d = d3;
                i2 = (-2) * (i4 % 5);
            }
            d3 = d + i2;
        }
        return new double[]{d2 / i, d3 / i};
    }

    @Test
    public void testRDFSpeedClassification() throws Exception {
        HashMap hashMap = new HashMap();
        hashMap.put("oryx.speed.model-manager-class", RDFSpeedModelManager.class.getName());
        hashMap.put("oryx.speed.streaming.generation-interval-sec", 5);
        hashMap.put("oryx.input-schema.feature-names", "[\"color\",\"fruit\"]");
        hashMap.put("oryx.input-schema.numeric-features", "[]");
        hashMap.put("oryx.input-schema.target-feature", "fruit");
        Config overlayOn = ConfigUtils.overlayOn(hashMap, getConfig());
        startMessaging();
        List startServerProduceConsumeTopics = startServerProduceConsumeTopics(overlayOn, new MockRDFClassificationInputGenerator(), new MockRDFClassificationModelGenerator(), NUM_INPUT, 1);
        if (log.isDebugEnabled()) {
            startServerProduceConsumeTopics.forEach(keyMessage -> {
                log.debug("{}", keyMessage);
            });
        }
        int size = startServerProduceConsumeTopics.size();
        assertGreaterOrEqual(size, 3.0d);
        assertNotEquals(0L, size % 2);
        assertEquals("MODEL", ((KeyMessage) startServerProduceConsumeTopics.get(0)).getKey());
        CategoricalValueEncodings buildCategoricalValueEncodings = AppPMMLUtils.buildCategoricalValueEncodings(PMMLUtils.fromString((String) ((KeyMessage) startServerProduceConsumeTopics.get(0)).getMessage()).getDataDictionary());
        log.info("{}", buildCategoricalValueEncodings);
        Map valueEncodingMap = buildCategoricalValueEncodings.getValueEncodingMap(0);
        String num = Integer.toString(((Integer) valueEncodingMap.get("red")).intValue());
        String num2 = Integer.toString(((Integer) valueEncodingMap.get("yellow")).intValue());
        for (int i = 1; i < size; i++) {
            KeyMessage keyMessage2 = (KeyMessage) startServerProduceConsumeTopics.get(i);
            assertEquals("UP", keyMessage2.getKey());
            List list = (List) TextUtils.readJSON((String) keyMessage2.getMessage(), List.class);
            int intValue = ((Integer) list.get(0)).intValue();
            String obj = list.get(1).toString();
            Map map = (Map) list.get(2);
            assertEquals(0L, intValue);
            assertContains(Arrays.asList("r-", "r+"), obj);
            int intValue2 = ((Integer) map.getOrDefault(num2, 0)).intValue();
            int intValue3 = ((Integer) map.getOrDefault(num, 0)).intValue();
            int i2 = intValue2 + intValue3;
            assertGreater(i2, 0.0d);
            BinomialDistribution binomialDistribution = new BinomialDistribution(RandomManager.getRandom(), i2, 0.9d);
            if ("r+".equals(obj)) {
                checkDiscreteProbability(intValue2, binomialDistribution);
            } else {
                checkDiscreteProbability(intValue3, binomialDistribution);
            }
        }
    }
}
