package com.cloudera.oryx.app.als;

import com.cloudera.oryx.common.OryxTest;
import com.cloudera.oryx.common.lang.ExecUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.Test;

/* loaded from: input_file:com/cloudera/oryx/app/als/FeatureVectorsPartitionTest.class */
public final class FeatureVectorsPartitionTest extends OryxTest {
    @Test
    public void testGetSet() {
        FeatureVectorsPartition featureVectorsPartition = new FeatureVectorsPartition();
        assertEquals(0L, featureVectorsPartition.size());
        featureVectorsPartition.setVector("foo", new float[]{1.0f});
        assertEquals(1L, featureVectorsPartition.size());
        assertArrayEquals(new float[]{1.0f}, featureVectorsPartition.getVector("foo"));
        featureVectorsPartition.removeVector("foo");
        assertEquals(0L, featureVectorsPartition.size());
        assertNull(featureVectorsPartition.getVector("foo"));
    }

    @Test
    public void testVTV() {
        FeatureVectorsPartition featureVectorsPartition = new FeatureVectorsPartition();
        featureVectorsPartition.setVector("foo", new float[]{1.0f, 2.0f, 4.0f});
        featureVectorsPartition.setVector("bar", new float[]{1.5f, -1.0f, 0.0f});
        double[] dArr = {3.25d, 0.5d, 4.0d, 5.0d, 8.0d, 16.0d};
        assertArrayEquals(dArr, featureVectorsPartition.getVTV(false));
        assertArrayEquals(dArr, featureVectorsPartition.getVTV(true));
    }

    @Test
    public void testForEach() {
        FeatureVectorsPartition featureVectorsPartition = new FeatureVectorsPartition();
        featureVectorsPartition.setVector("foo", new float[]{1.0f, 2.0f, 4.0f});
        featureVectorsPartition.setVector("bar", new float[]{1.5f, -1.0f, 0.0f});
        ArrayList arrayList = new ArrayList();
        featureVectorsPartition.forEach((str, fArr) -> {
            arrayList.add(str + fArr[0]);
        });
        assertEquals(featureVectorsPartition.size(), arrayList.size());
        assertContains(arrayList, "foo1.0");
        assertContains(arrayList, "bar1.5");
    }

    @Test
    public void testRetainRecent() {
        FeatureVectorsPartition featureVectorsPartition = new FeatureVectorsPartition();
        featureVectorsPartition.setVector("foo", new float[]{1.0f});
        featureVectorsPartition.retainRecentAndIDs(Collections.singleton("foo"));
        assertEquals(1L, featureVectorsPartition.size());
        featureVectorsPartition.retainRecentAndIDs(Collections.singleton("bar"));
        assertEquals(0L, featureVectorsPartition.size());
    }

    @Test
    public void testAllIDs() {
        FeatureVectorsPartition featureVectorsPartition = new FeatureVectorsPartition();
        featureVectorsPartition.setVector("foo", new float[]{1.0f});
        HashSet hashSet = new HashSet();
        featureVectorsPartition.addAllIDsTo(hashSet);
        assertEquals(Collections.singleton("foo"), hashSet);
        featureVectorsPartition.removeAllIDsFrom(hashSet);
        assertEquals(0L, hashSet.size());
    }

    @Test
    public void testRecent() {
        FeatureVectorsPartition featureVectorsPartition = new FeatureVectorsPartition();
        featureVectorsPartition.setVector("foo", new float[]{1.0f});
        HashSet hashSet = new HashSet();
        featureVectorsPartition.addAllRecentTo(hashSet);
        assertEquals(Collections.singleton("foo"), hashSet);
        featureVectorsPartition.retainRecentAndIDs(Collections.singleton("foo"));
        hashSet.clear();
        featureVectorsPartition.addAllRecentTo(hashSet);
        assertEquals(0L, hashSet.size());
    }

    @Test
    public void testConcurrent() throws Exception {
        FeatureVectorsPartition featureVectorsPartition = new FeatureVectorsPartition();
        AtomicInteger atomicInteger = new AtomicInteger();
        int i = 10000;
        ExecUtils.doInParallel(16, num -> {
            for (int i2 = 0; i2 < i; i2++) {
                int andIncrement = atomicInteger.getAndIncrement();
                featureVectorsPartition.setVector(Integer.toString(andIncrement), new float[]{andIncrement});
            }
        });
        assertEquals(10000 * 16, featureVectorsPartition.size());
        assertEquals(10000 * 16, atomicInteger.get());
        ExecUtils.doInParallel(16, num2 -> {
            for (int i2 = 0; i2 < i; i2++) {
                featureVectorsPartition.removeVector(Integer.toString(atomicInteger.decrementAndGet()));
            }
        });
        assertEquals(0L, featureVectorsPartition.size());
    }

    @Test
    public void testToString() {
        FeatureVectorsPartition featureVectorsPartition = new FeatureVectorsPartition();
        featureVectorsPartition.setVector("A", new float[]{1.0f, 3.0f, 6.0f});
        assertEquals("FeatureVectors[size:1]", featureVectorsPartition.toString());
    }
}
