package com.cloudera.oryx.app.serving.als.model;

import com.cloudera.oryx.common.OryxTest;
import com.cloudera.oryx.common.math.VectorMath;
import com.cloudera.oryx.common.random.RandomManager;
import java.util.Arrays;
import java.util.IntSummaryStatistics;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.util.CombinatoricsUtils;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/cloudera/oryx/app/serving/als/model/LocalitySensitiveHashTest.class */
public final class LocalitySensitiveHashTest extends OryxTest {
    private static final Logger log = LoggerFactory.getLogger(LocalitySensitiveHashTest.class);

    @Test
    public void testOneCore() {
        doTestHashesBits(1.0d, 1, 0, 0);
        doTestHashesBits(0.5d, 1, 1, 0);
        doTestHashesBits(0.1d, 1, 4, 0);
    }

    @Test
    public void testTwoCores() {
        doTestHashesBits(1.0d, 2, 1, 1);
        doTestHashesBits(0.75d, 3, 2, 1);
    }

    @Test
    public void testManyCores() {
        doTestHashesBits(0.75d, 3, 2, 1);
        doTestHashesBits(0.5d, 3, 3, 1);
        doTestHashesBits(0.1d, 8, 7, 1);
        doTestHashesBits(0.01d, 8, 11, 1);
        doTestHashesBits(0.001d, 8, 14, 1);
        doTestHashesBits(1.0E-4d, 8, 16, 1);
        doTestHashesBits(1.0E-5d, 8, 16, 1);
    }

    @Test
    public void testHashDistribution() {
        doTestHashDistribution(200, 1.0d, 16);
        doTestHashDistribution(200, 0.1d, 16);
        doTestHashDistribution(40, 1.0d, 8);
        doTestHashDistribution(40, 0.1d, 8);
        doTestHashDistribution(40, 1.0d, 1);
        doTestHashDistribution(40, 0.1d, 1);
        doTestHashDistribution(10, 1.0d, 1);
        doTestHashDistribution(10, 0.1d, 1);
    }

    @Test
    public void testCandidateIndicesNoSample() {
        LocalitySensitiveHash localitySensitiveHash = new LocalitySensitiveHash(1.0d, 10, 8);
        int[] candidateIndices = localitySensitiveHash.getCandidateIndices(new float[10]);
        int numHashes = 1 << localitySensitiveHash.getNumHashes();
        assertEquals(numHashes, candidateIndices.length);
        for (int i = 0; i < numHashes; i++) {
            assertEquals(i, candidateIndices[i]);
        }
    }

    @Test
    public void testCandidateIndicesOneBit() {
        LocalitySensitiveHash localitySensitiveHash = new LocalitySensitiveHash(0.1d, 10, 8);
        assertEquals(1L, localitySensitiveHash.getMaxBitsDiffering());
        int[] candidateIndices = localitySensitiveHash.getCandidateIndices(new float[10]);
        assertEquals(1 + localitySensitiveHash.getNumHashes(), candidateIndices.length);
        assertEquals(0L, candidateIndices[0]);
        for (int i = 1; i < candidateIndices.length; i++) {
            assertEquals(1 << (i - 1), candidateIndices[i]);
        }
        float[] fArr = new float[10];
        Arrays.fill(fArr, 1.0f);
        int[] candidateIndices2 = localitySensitiveHash.getCandidateIndices(fArr);
        for (int i2 = 1; i2 < candidateIndices2.length; i2++) {
            assertEquals(candidateIndices2[0] ^ (1 << (i2 - 1)), candidateIndices2[i2]);
        }
    }

    @Test
    public void testCandidateIndices() {
        LocalitySensitiveHash localitySensitiveHash = new LocalitySensitiveHash(0.5d, 10, 32);
        assertEquals(3L, localitySensitiveHash.getMaxBitsDiffering());
        assertEquals(7L, localitySensitiveHash.getNumHashes());
        float[] fArr = new float[10];
        Arrays.fill(fArr, 1.0f);
        int[] candidateIndices = localitySensitiveHash.getCandidateIndices(fArr);
        assertEquals(64L, candidateIndices.length);
        for (int i = 1; i < 8; i++) {
            assertEquals(1L, Integer.bitCount(candidateIndices[0] ^ candidateIndices[i]));
        }
        for (int i2 = 8; i2 < 29; i2++) {
            assertEquals(2L, Integer.bitCount(candidateIndices[0] ^ candidateIndices[i2]));
        }
        for (int i3 = 29; i3 < 64; i3++) {
            assertEquals(3L, Integer.bitCount(candidateIndices[0] ^ candidateIndices[i3]));
        }
    }

    private static void doTestHashDistribution(int i, double d, int i2) {
        LocalitySensitiveHash localitySensitiveHash = new LocalitySensitiveHash(d, i, i2);
        int numHashes = localitySensitiveHash.getNumHashes();
        RandomGenerator random = RandomManager.getRandom();
        int[] iArr = new int[1 << numHashes];
        for (int i3 = 0; i3 < 100000; i3++) {
            int indexFor = localitySensitiveHash.getIndexFor(VectorMath.randomVectorF(i, random));
            iArr[indexFor] = iArr[indexFor] + 1;
        }
        log.info("{}", Arrays.toString(iArr));
        IntSummaryStatistics summaryStatistics = Arrays.stream(iArr).summaryStatistics();
        log.info("Total {} / Max {} / Min {}", new Object[]{Long.valueOf(summaryStatistics.getSum()), Integer.valueOf(summaryStatistics.getMax()), Integer.valueOf(summaryStatistics.getMin())});
        assertEquals(100000, summaryStatistics.getSum());
        assertLessOrEqual(summaryStatistics.getMax(), 2 * summaryStatistics.getMin());
    }

    private static void doTestHashesBits(double d, int i, int i2, int i3) {
        LocalitySensitiveHash localitySensitiveHash = new LocalitySensitiveHash(d, 10, i);
        assertEquals(i2, localitySensitiveHash.getNumHashes());
        assertEquals(1 << i2, localitySensitiveHash.getNumPartitions());
        assertEquals(i3, localitySensitiveHash.getMaxBitsDiffering());
        if (d == 1.0d) {
            assertEquals(localitySensitiveHash.getMaxBitsDiffering(), localitySensitiveHash.getNumHashes());
        }
        long j = 0;
        for (int i4 = 0; i4 <= i3; i4++) {
            j += CombinatoricsUtils.binomialCoefficient(i2, i4);
        }
        if (i2 < 16) {
            assertLessOrEqual(j / (1 << i2), d);
        }
    }
}
