package com.cloudera.oryx.app.speed.kmeans;

import com.cloudera.oryx.api.KeyMessage;
import com.cloudera.oryx.app.kmeans.ClusterInfo;
import com.cloudera.oryx.common.math.VectorMath;
import com.cloudera.oryx.common.pmml.PMMLUtils;
import com.cloudera.oryx.common.settings.ConfigUtils;
import com.cloudera.oryx.common.text.TextUtils;
import com.cloudera.oryx.lambda.speed.AbstractSpeedIT;
import com.typesafe.config.Config;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import org.dmg.pmml.Model;
import org.dmg.pmml.clustering.Cluster;
import org.dmg.pmml.clustering.ClusteringModel;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/cloudera/oryx/app/speed/kmeans/KMeansSpeedIT.class */
public final class KMeansSpeedIT extends AbstractSpeedIT {
    private static final Logger log = LoggerFactory.getLogger(KMeansSpeedIT.class);
    private static final int NUM_CLUSTERS = 3;

    @Test
    public void testKMeansSpeed() throws Exception {
        HashMap hashMap = new HashMap();
        hashMap.put("oryx.speed.model-manager-class", KMeansSpeedModelManager.class.getName());
        hashMap.put("oryx.speed.streaming.generation-interval-sec", 6);
        hashMap.put("oryx.input-schema.feature-names", "[\"x\",\"y\"]");
        hashMap.put("oryx.input-schema.categorical-features", "[]");
        Config overlayOn = ConfigUtils.overlayOn(hashMap, getConfig());
        startMessaging();
        List startServerProduceConsumeTopics = startServerProduceConsumeTopics(overlayOn, new MockKMeansInputGenerator(), new MockKMeansModelGenerator(), 300, 1);
        startServerProduceConsumeTopics.forEach(keyMessage -> {
            log.info("{}", keyMessage);
        });
        int size = startServerProduceConsumeTopics.size();
        assertGreaterOrEqual(startServerProduceConsumeTopics.size(), 4.0d);
        assertEquals("MODEL", ((KeyMessage) startServerProduceConsumeTopics.get(0)).getKey());
        ClusteringModel clusteringModel = (Model) PMMLUtils.fromString((String) ((KeyMessage) startServerProduceConsumeTopics.get(0)).getMessage()).getModels().get(0);
        assertInstanceOf(clusteringModel, ClusteringModel.class);
        assertEquals(3L, r0.getNumberOfClusters());
        List clusters = clusteringModel.getClusters();
        HashMap hashMap2 = new HashMap();
        for (int i = 1; i < size; i++) {
            KeyMessage keyMessage2 = (KeyMessage) startServerProduceConsumeTopics.get(i);
            assertEquals("UP", keyMessage2.getKey());
            List list = (List) TextUtils.readJSON((String) keyMessage2.getMessage(), List.class);
            int intValue = ((Integer) list.get(0)).intValue();
            hashMap2.put(Integer.valueOf(intValue), new ClusterInfo(intValue, (double[]) TextUtils.convertViaJSON(list.get(1), double[].class), ((Integer) list.get(2)).intValue()));
        }
        assertEquals(3L, hashMap2.size());
        for (ClusterInfo clusterInfo : hashMap2.values()) {
            int id = clusterInfo.getID();
            double[] parseVector = VectorMath.parseVector(TextUtils.parseDelimited(((Cluster) clusters.get(id)).getArray().getValue(), ' '));
            double[] center = clusterInfo.getCenter();
            assertEquals(r0.length, parseVector.length);
            assertFalse(Arrays.equals(parseVector, center));
            assertArrayEquals(center, MockKMeansInputGenerator.UPDATE_POINTS[id], 0.1d);
            long count = clusterInfo.getCount();
            assertGreater(count, r0.getSize().intValue());
            assertEquals(100 + r0.getSize().intValue(), count);
        }
    }
}
