package com.cloudera.oryx.app.serving.als.model;

import com.cloudera.oryx.app.als.RescorerProvider;
import com.cloudera.oryx.app.serving.als.CosineAverageFunction;
import com.cloudera.oryx.app.serving.als.DotsFunction;
import com.cloudera.oryx.common.OryxTest;
import com.cloudera.oryx.common.collection.Pair;
import com.cloudera.oryx.common.math.VectorMath;
import com.cloudera.oryx.common.random.RandomManager;
import com.koloboke.function.ObjDoubleToDoubleFunction;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.apache.commons.math3.distribution.PoissonDistribution;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.stat.descriptive.moment.Mean;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/cloudera/oryx/app/serving/als/model/ALSServingModelTest.class */
public final class ALSServingModelTest extends OryxTest {
    private static final Logger log = LoggerFactory.getLogger(ALSServingModelTest.class);

    @Test
    public void testUserItemVector() {
        ALSServingModel aLSServingModel = new ALSServingModel(2, true, 1.0d, (RescorerProvider) null);
        assertEquals(2L, aLSServingModel.getFeatures());
        assertTrue(aLSServingModel.isImplicit());
        assertNull(aLSServingModel.getRescorerProvider());
        aLSServingModel.setUserVector("U1", new float[]{1.5f, -2.5f});
        assertArrayEquals(new float[]{1.5f, -2.5f}, aLSServingModel.getUserVector("U1"));
        aLSServingModel.setItemVector("I0", new float[]{0.5f, 0.0f});
        assertArrayEquals(new float[]{0.5f, 0.0f}, aLSServingModel.getItemVector("I0"));
        assertContainsSame(Arrays.asList("U1"), aLSServingModel.getAllUserIDs());
        assertContainsSame(Arrays.asList("I0"), aLSServingModel.getAllItemIDs());
    }

    @Test
    public void testKnownItems() {
        ALSServingModel aLSServingModel = new ALSServingModel(2, true, 1.0d, (RescorerProvider) null);
        populateKnownItems(aLSServingModel);
        assertContainsSame(Arrays.asList("I0", "I1"), aLSServingModel.getKnownItems("U0"));
        assertContainsSame(Arrays.asList("I0", "I1", "I2"), aLSServingModel.getKnownItems("U1"));
        assertContainsSame(Arrays.asList("I8", "I9"), aLSServingModel.getKnownItems("U9"));
        Map userCounts = aLSServingModel.getUserCounts();
        assertEquals(2L, ((Integer) userCounts.get("U0")).intValue());
        assertEquals(3L, ((Integer) userCounts.get("U1")).intValue());
        assertEquals(2L, ((Integer) userCounts.get("U9")).intValue());
        Map itemCounts = aLSServingModel.getItemCounts();
        assertEquals(2L, ((Integer) itemCounts.get("I0")).intValue());
        assertEquals(3L, ((Integer) itemCounts.get("I1")).intValue());
        assertEquals(2L, ((Integer) itemCounts.get("I9")).intValue());
    }

    @Test
    public void testRetainUsersItems() {
        ALSServingModel aLSServingModel = new ALSServingModel(2, true, 1.0d, (RescorerProvider) null);
        aLSServingModel.setUserVector("U0", new float[]{1.0f, 1.0f});
        aLSServingModel.retainRecentAndUserIDs(Collections.emptyList());
        assertNotNull(aLSServingModel.getUserVector("U0"));
        aLSServingModel.retainRecentAndUserIDs(Collections.emptyList());
        assertNull(aLSServingModel.getUserVector("U0"));
        aLSServingModel.setUserVector("U0", new float[]{1.0f, 1.0f});
        aLSServingModel.retainRecentAndUserIDs(Arrays.asList("U0"));
        assertNotNull(aLSServingModel.getUserVector("U0"));
        aLSServingModel.retainRecentAndUserIDs(Arrays.asList("U0"));
        assertNotNull(aLSServingModel.getUserVector("U0"));
        aLSServingModel.setItemVector("I0", new float[]{1.0f, 1.0f});
        aLSServingModel.retainRecentAndItemIDs(Collections.emptyList());
        assertNotNull(aLSServingModel.getItemVector("I0"));
        aLSServingModel.retainRecentAndItemIDs(Collections.emptyList());
        assertNull(aLSServingModel.getItemVector("I0"));
        aLSServingModel.setItemVector("I0", new float[]{1.0f, 1.0f});
        aLSServingModel.retainRecentAndItemIDs(Arrays.asList("I0"));
        assertNotNull(aLSServingModel.getItemVector("I0"));
        aLSServingModel.retainRecentAndItemIDs(Arrays.asList("I0"));
        assertNotNull(aLSServingModel.getItemVector("I0"));
    }

    @Test
    public void testRetainKnown() {
        ALSServingModel aLSServingModel = new ALSServingModel(2, true, 1.0d, (RescorerProvider) null);
        populateKnownItems(aLSServingModel);
        for (int i = 0; i < 10; i++) {
            aLSServingModel.setUserVector("U" + i, new float[]{0.0f, 0.0f});
            aLSServingModel.setItemVector("I" + i, new float[]{0.0f, 0.0f});
        }
        aLSServingModel.retainRecentAndKnownItems(Arrays.asList("U4", "U5", "U6"), Arrays.asList("I4", "I5", "I6"));
        assertContains(aLSServingModel.getKnownItems("U3"), "I4");
        assertContains(aLSServingModel.getKnownItems("U4"), "I4");
        assertContains(aLSServingModel.getKnownItems("U6"), "I6");
        assertContains(aLSServingModel.getKnownItems("U6"), "I7");
        assertContains(aLSServingModel.getKnownItems("U2"), "I2");
        aLSServingModel.retainRecentAndUserIDs(Collections.emptyList());
        aLSServingModel.retainRecentAndItemIDs(Collections.emptyList());
        aLSServingModel.retainRecentAndKnownItems(Arrays.asList("U4", "U5", "U6"), Arrays.asList("I4", "I5", "I6"));
        assertEquals(0L, aLSServingModel.getKnownItems("U3").size());
        assertContains(aLSServingModel.getKnownItems("U4"), "I4");
        assertContains(aLSServingModel.getKnownItems("U6"), "I6");
        assertNotContains(aLSServingModel.getKnownItems("U6"), "I7");
        assertEquals(0L, aLSServingModel.getKnownItems("U2").size());
    }

    @Test
    public void testToString() {
        String aLSServingModel = new ALSServingModel(2, true, 1.0d, (RescorerProvider) null).toString();
        assertContains(aLSServingModel, "ALSServingModel");
        assertContains(aLSServingModel, "features:2");
        assertContains(aLSServingModel, "implicit:true");
    }

    private static void populateKnownItems(ALSServingModel aLSServingModel) {
        for (int i = 0; i < 10; i++) {
            String str = "U" + i;
            for (int i2 = 0; i2 < 10; i2++) {
                if (Math.abs(i - i2) <= 1) {
                    aLSServingModel.addKnownItems(str, Collections.singleton("I" + i2));
                }
            }
        }
    }

    @Test
    public void testLSHEffect() {
        RandomGenerator random = RandomManager.getRandom();
        PoissonDistribution poissonDistribution = new PoissonDistribution(random, 20.0d, 1.0E-12d, 10000000);
        ALSServingModel aLSServingModel = new ALSServingModel(20, true, 1.0d, (RescorerProvider) null);
        ALSServingModel aLSServingModel2 = new ALSServingModel(20, true, 0.5d, (RescorerProvider) null);
        for (int i = 0; i < 20000; i++) {
            String str = "U" + i;
            float[] randomVectorF = VectorMath.randomVectorF(20, random);
            aLSServingModel.setUserVector(str, randomVectorF);
            aLSServingModel2.setUserVector(str, randomVectorF);
            int sample = poissonDistribution.sample();
            ArrayList arrayList = new ArrayList(sample);
            for (int i2 = 0; i2 < sample; i2++) {
                arrayList.add("I" + random.nextInt(20000));
            }
            aLSServingModel.addKnownItems(str, arrayList);
            aLSServingModel2.addKnownItems(str, arrayList);
        }
        for (int i3 = 0; i3 < 20000; i3++) {
            String str2 = "I" + i3;
            float[] randomVectorF2 = VectorMath.randomVectorF(20, random);
            aLSServingModel.setItemVector(str2, randomVectorF2);
            aLSServingModel2.setItemVector(str2, randomVectorF2);
        }
        Mean mean = new Mean();
        for (int i4 = 0; i4 < 20000; i4++) {
            String str3 = "U" + i4;
            List list = (List) aLSServingModel.topN(new DotsFunction(aLSServingModel.getUserVector(str3)), (ObjDoubleToDoubleFunction) null, 10, (Predicate) null).collect(Collectors.toList());
            List list2 = (List) aLSServingModel2.topN(new DotsFunction(aLSServingModel2.getUserVector(str3)), (ObjDoubleToDoubleFunction) null, 10, (Predicate) null).collect(Collectors.toList());
            int i5 = 0;
            while (i5 < list2.size() && i5 < list.size() && ((Pair) list2.get(i5)).equals(list.get(i5))) {
                i5++;
            }
            mean.increment(i5);
        }
        log.info("Mean matching prefix: {}", Double.valueOf(mean.getResult()));
        assertGreaterOrEqual(mean.getResult(), 4.0d);
        mean.clear();
        for (int i6 = 0; i6 < 20000; i6++) {
            String str4 = "I" + i6;
            List list3 = (List) aLSServingModel.topN(new CosineAverageFunction(aLSServingModel.getItemVector(str4)), (ObjDoubleToDoubleFunction) null, 10, (Predicate) null).collect(Collectors.toList());
            List list4 = (List) aLSServingModel2.topN(new CosineAverageFunction(aLSServingModel2.getItemVector(str4)), (ObjDoubleToDoubleFunction) null, 10, (Predicate) null).collect(Collectors.toList());
            int i7 = 0;
            while (i7 < list4.size() && i7 < list3.size() && ((Pair) list4.get(i7)).equals(list3.get(i7))) {
                i7++;
            }
            mean.increment(i7);
        }
        log.info("Mean matching prefix: {}", Double.valueOf(mean.getResult()));
        assertGreaterOrEqual(mean.getResult(), 5.0d);
    }

    @Test
    public void testFractionLoaded() {
        assertEquals(1.0f, new ALSServingModel(2, true, 1.0d, (RescorerProvider) null).getFractionLoaded());
        ALSServingModel aLSServingModel = new ALSServingModel(2, true, 1.0d, (RescorerProvider) null);
        assertNotNull(aLSServingModel.toString());
        aLSServingModel.retainRecentAndUserIDs(Collections.singleton("U1"));
        aLSServingModel.retainRecentAndItemIDs(Collections.singleton("I0"));
        assertEquals(0.0f, aLSServingModel.getFractionLoaded());
        aLSServingModel.setUserVector("U1", new float[]{1.5f, -2.5f});
        assertEquals(0.5f, aLSServingModel.getFractionLoaded());
        aLSServingModel.setItemVector("I0", new float[]{0.5f, 0.0f});
        assertEquals(1.0f, aLSServingModel.getFractionLoaded());
    }
}
