package com.cloudera.oryx.app.classreg.predict;

import com.cloudera.oryx.app.classreg.example.CategoricalFeature;
import com.cloudera.oryx.app.classreg.example.Example;
import com.cloudera.oryx.app.classreg.example.Feature;
import com.cloudera.oryx.app.classreg.example.FeatureType;
import com.cloudera.oryx.common.OryxTest;
import org.junit.Test;

/* loaded from: input_file:com/cloudera/oryx/app/classreg/predict/CategoricalPredictionTest.class */
public final class CategoricalPredictionTest extends OryxTest {
    @Test
    public void testConstruct() {
        int[] iArr = {0, 1, 3, 0, 4, 0};
        CategoricalPrediction categoricalPrediction = new CategoricalPrediction(iArr);
        assertEquals(FeatureType.CATEGORICAL, categoricalPrediction.getFeatureType());
        assertEquals(4L, categoricalPrediction.getMostProbableCategoryEncoding());
        assertArrayEquals(toDoubles(iArr), categoricalPrediction.getCategoryCounts());
        assertArrayEquals(new double[]{0.0d, 0.125d, 0.375d, 0.0d, 0.5d, 0.0d}, categoricalPrediction.getCategoryProbabilities());
    }

    @Test
    public void testConstructFromProbability() {
        double[] dArr = {0.0d, 0.125d, 0.375d, 0.0d, 0.5d, 0.0d};
        CategoricalPrediction categoricalPrediction = new CategoricalPrediction(dArr);
        assertEquals(FeatureType.CATEGORICAL, categoricalPrediction.getFeatureType());
        assertEquals(4L, categoricalPrediction.getMostProbableCategoryEncoding());
        assertArrayEquals(dArr, categoricalPrediction.getCategoryProbabilities());
    }

    @Test
    public void testUpdate() {
        int[] iArr = {0, 1, 3, 0, 4, 0};
        CategoricalPrediction categoricalPrediction = new CategoricalPrediction(iArr);
        Example example = new Example(CategoricalFeature.forEncoding(2), new Feature[0]);
        categoricalPrediction.update(example);
        categoricalPrediction.update(example);
        assertEquals(2L, categoricalPrediction.getMostProbableCategoryEncoding());
        iArr[2] = iArr[2] + 2;
        assertArrayEquals(toDoubles(iArr), categoricalPrediction.getCategoryCounts());
        assertArrayEquals(new double[]{0.0d, 0.1d, 0.5d, 0.0d, 0.4d, 0.0d}, categoricalPrediction.getCategoryProbabilities());
    }

    @Test
    public void testUpdate2() {
        CategoricalPrediction categoricalPrediction = new CategoricalPrediction(new int[]{0, 1, 3, 0, 4, 0});
        categoricalPrediction.update(0, 3);
        categoricalPrediction.update(1, 9);
        assertArrayEquals(new double[]{3.0d, 10.0d, 3.0d, 0.0d, 4.0d, 0.0d}, categoricalPrediction.getCategoryCounts());
        assertArrayEquals(new double[]{0.15d, 0.5d, 0.15d, 0.0d, 0.2d, 0.0d}, categoricalPrediction.getCategoryProbabilities());
    }

    @Test
    public void testHashCode() {
        assertEquals(566115137L, new CategoricalPrediction(new int[]{0, 1, 3, 0, 4, 0}).hashCode());
    }

    @Test
    public void testToString() {
        assertEquals(":[0.0, 0.125, 0.375, 0.0, 0.5, 0.0]", new CategoricalPrediction(new int[]{0, 1, 3, 0, 4, 0}).toString());
    }

    @Test
    public void testEquals() {
        assertNotEquals(new CategoricalPrediction(new int[]{0, 1, 3, 0, 4, 0}), new CategoricalPrediction(new int[]{1, 2, 4, 5, 6, 7}));
    }

    private static double[] toDoubles(int[] iArr) {
        double[] dArr = new double[iArr.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = iArr[i];
        }
        return dArr;
    }
}
