package com.cloudera.oryx.app.serving.als.model;

import com.cloudera.oryx.common.math.VectorMath;
import com.cloudera.oryx.common.random.RandomManager;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.util.CombinatoricsUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/cloudera/oryx/app/serving/als/model/LocalitySensitiveHash.class */
final class LocalitySensitiveHash {
    static final int MAX_HASHES = 16;
    private static final Logger log = LoggerFactory.getLogger(LocalitySensitiveHash.class);
    private final float[][] hashVectors;
    private final int maxBitsDiffering;
    private final int[] candidateIndicesPrototype;
    private final int[] allIndices;

    /* JADX INFO: Access modifiers changed from: package-private */
    public LocalitySensitiveHash(double d, int i) {
        this(d, i, Runtime.getRuntime().availableProcessors());
    }

    /* JADX WARN: Type inference failed for: r1v4, types: [float[], float[][]] */
    LocalitySensitiveHash(double d, int i, int i2) {
        long j;
        int i3 = 0;
        int i4 = 0;
        while (i3 < MAX_HASHES) {
            i4 = 0;
            long j2 = 1;
            while (true) {
                j = j2;
                if (i4 >= i3 || j >= i2) {
                    break;
                }
                i4++;
                j2 = j + CombinatoricsUtils.binomialCoefficient(i3, i4);
            }
            if ((i4 != i3 || j >= i2) && j <= d * (1 << i3)) {
                break;
            } else {
                i3++;
            }
        }
        log.info("LSH with {} hashes, querying partitions with up to {} bits differing", Integer.valueOf(i3), Integer.valueOf(i4));
        this.maxBitsDiffering = i4;
        this.hashVectors = new float[i3];
        RandomGenerator random = RandomManager.getRandom();
        for (int i5 = 0; i5 < i3; i5++) {
            double d2 = Double.POSITIVE_INFINITY;
            float[] fArr = null;
            int i6 = 0;
            while (i6 < 1000) {
                float[] randomVectorF = VectorMath.randomVectorF(i, random);
                double d3 = totalAbsCos(this.hashVectors, i5, randomVectorF);
                if (d3 < d2) {
                    fArr = randomVectorF;
                    if (d3 == 0.0d) {
                        break;
                    }
                    d2 = d3;
                    i6 = 0;
                } else {
                    i6++;
                }
            }
            this.hashVectors[i5] = fArr;
        }
        log.info("Chose {} random hash vectors", Integer.valueOf(this.hashVectors.length));
        this.candidateIndicesPrototype = new int[1 << i3];
        int[] iArr = new int[i3 + 1];
        for (int i7 = 1; i7 <= i3; i7++) {
            iArr[i7] = iArr[i7 - 1] + ((int) CombinatoricsUtils.binomialCoefficient(i3, i7 - 1));
        }
        for (int i8 = 0; i8 < this.candidateIndicesPrototype.length; i8++) {
            int[] iArr2 = this.candidateIndicesPrototype;
            int bitCount = Integer.bitCount(i8);
            int i9 = iArr[bitCount];
            iArr[bitCount] = i9 + 1;
            iArr2[i9] = i8;
        }
        this.allIndices = new int[1 << i3];
        for (int i10 = 0; i10 < this.allIndices.length; i10++) {
            this.allIndices[i10] = i10;
        }
    }

    int getNumHashes() {
        return this.hashVectors.length;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public int getNumPartitions() {
        return 1 << getNumHashes();
    }

    int getMaxBitsDiffering() {
        return this.maxBitsDiffering;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public int getIndexFor(float[] fArr) {
        int i = 0;
        for (int i2 = 0; i2 < this.hashVectors.length; i2++) {
            if (VectorMath.dot(this.hashVectors[i2], fArr) > 0.0d) {
                i |= 1 << i2;
            }
        }
        return i;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public int[] getCandidateIndices(float[] fArr) {
        int indexFor = getIndexFor(fArr);
        int numHashes = getNumHashes();
        if (numHashes == this.maxBitsDiffering) {
            return this.allIndices;
        }
        if (this.maxBitsDiffering == 0) {
            return new int[]{indexFor};
        }
        int i = 0;
        for (int i2 = 0; i2 <= this.maxBitsDiffering; i2++) {
            i += (int) CombinatoricsUtils.binomialCoefficient(numHashes, i2);
        }
        int[] iArr = new int[i];
        System.arraycopy(this.candidateIndicesPrototype, 0, iArr, 0, i);
        for (int i3 = 0; i3 < i; i3++) {
            int i4 = i3;
            iArr[i4] = iArr[i4] ^ indexFor;
        }
        return iArr;
    }

    private static double totalAbsCos(float[][] fArr, int i, float[] fArr2) {
        double norm = VectorMath.norm(fArr2);
        double d = 0.0d;
        for (int i2 = 0; i2 < i; i2++) {
            d += Math.abs(VectorMath.cosineSimilarity(fArr[i2], fArr2, norm));
        }
        return d;
    }
}
