package com.cloudera.oryx.app.serving.als.model;

import com.cloudera.oryx.api.serving.ServingModel;
import com.cloudera.oryx.app.als.FeatureVectorsPartition;
import com.cloudera.oryx.app.als.PartitionedFeatureVectors;
import com.cloudera.oryx.app.als.RescorerProvider;
import com.cloudera.oryx.app.als.SolverCache;
import com.cloudera.oryx.app.serving.als.CosineDistanceSensitiveFunction;
import com.cloudera.oryx.common.collection.Pair;
import com.cloudera.oryx.common.collection.Pairs;
import com.cloudera.oryx.common.lang.AutoLock;
import com.cloudera.oryx.common.lang.AutoReadWriteLock;
import com.cloudera.oryx.common.math.Solver;
import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.koloboke.collect.ObjCursor;
import com.koloboke.collect.map.ObjObjMap;
import com.koloboke.collect.map.hash.HashObjIntMap;
import com.koloboke.collect.map.hash.HashObjIntMaps;
import com.koloboke.collect.map.hash.HashObjObjMaps;
import com.koloboke.collect.set.ObjSet;
import com.koloboke.collect.set.hash.HashObjSet;
import com.koloboke.collect.set.hash.HashObjSets;
import com.koloboke.function.ObjDoubleToDoubleFunction;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.function.Predicate;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/cloudera/oryx/app/serving/als/model/ALSServingModel.class */
public final class ALSServingModel implements ServingModel {
    private static final Logger log = LoggerFactory.getLogger(ALSServingModel.class);
    private static final ExecutorService executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors(), new ThreadFactoryBuilder().setDaemon(true).setNameFormat("ALSServingModel-%d").build());
    private final LocalitySensitiveHash lsh;
    private final FeatureVectorsPartition X;
    private final PartitionedFeatureVectors Y;
    private final ObjObjMap<String, ObjSet<String>> knownItems;
    private final AutoReadWriteLock knownItemsLock;
    private final ObjSet<String> expectedUserIDs;
    private final AutoReadWriteLock expectedUserIDsLock;
    private final ObjSet<String> expectedItemIDs;
    private final AutoReadWriteLock expectedItemIDsLock;
    private final SolverCache cachedYTYSolver;
    private final int features;
    private final boolean implicit;
    private final RescorerProvider rescorerProvider;

    /* JADX INFO: Access modifiers changed from: package-private */
    public ALSServingModel(int i, boolean z, double d, RescorerProvider rescorerProvider) {
        Preconditions.checkArgument(i > 0);
        Preconditions.checkArgument(d > 0.0d && d <= 1.0d);
        this.lsh = new LocalitySensitiveHash(d, i);
        this.X = new FeatureVectorsPartition();
        this.Y = new PartitionedFeatureVectors(this.lsh.getNumPartitions(), executor, (str, fArr) -> {
            return this.lsh.getIndexFor(fArr);
        });
        this.knownItems = HashObjObjMaps.newMutableMap();
        this.knownItemsLock = new AutoReadWriteLock();
        this.expectedUserIDs = HashObjSets.newMutableSet();
        this.expectedUserIDsLock = new AutoReadWriteLock();
        this.expectedItemIDs = HashObjSets.newMutableSet();
        this.expectedItemIDsLock = new AutoReadWriteLock();
        this.cachedYTYSolver = new SolverCache(executor, this.Y);
        this.features = i;
        this.implicit = z;
        this.rescorerProvider = rescorerProvider;
    }

    public int getFeatures() {
        return this.features;
    }

    public boolean isImplicit() {
        return this.implicit;
    }

    public RescorerProvider getRescorerProvider() {
        return this.rescorerProvider;
    }

    public float[] getUserVector(String str) {
        return this.X.getVector(str);
    }

    public float[] getItemVector(String str) {
        return this.Y.getVector(str);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void setUserVector(String str, float[] fArr) {
        Preconditions.checkArgument(fArr.length == this.features);
        this.X.setVector(str, fArr);
        AutoLock autoWriteLock = this.expectedUserIDsLock.autoWriteLock();
        Throwable th = null;
        try {
            try {
                this.expectedUserIDs.remove(str);
                if (autoWriteLock != null) {
                    if (0 == 0) {
                        autoWriteLock.close();
                        return;
                    }
                    try {
                        autoWriteLock.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (Throwable th4) {
            if (autoWriteLock != null) {
                if (th != null) {
                    try {
                        autoWriteLock.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    autoWriteLock.close();
                }
            }
            throw th4;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void setItemVector(String str, float[] fArr) {
        Preconditions.checkArgument(fArr.length == this.features);
        this.Y.setVector(str, fArr);
        AutoLock autoWriteLock = this.expectedItemIDsLock.autoWriteLock();
        Throwable th = null;
        try {
            try {
                this.expectedItemIDs.remove(str);
                if (autoWriteLock != null) {
                    if (0 != 0) {
                        try {
                            autoWriteLock.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        autoWriteLock.close();
                    }
                }
                this.cachedYTYSolver.setDirty();
            } finally {
            }
        } catch (Throwable th3) {
            if (autoWriteLock != null) {
                if (th != null) {
                    try {
                        autoWriteLock.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    autoWriteLock.close();
                }
            }
            throw th3;
        }
    }

    public Set<String> getKnownItems(String str) {
        ObjSet<String> doGetKnownItems = doGetKnownItems(str);
        if (doGetKnownItems == null) {
            return Collections.emptySet();
        }
        synchronized (doGetKnownItems) {
            if (doGetKnownItems.isEmpty()) {
                return Collections.emptySet();
            }
            return HashObjSets.newImmutableSet(doGetKnownItems);
        }
    }

    private ObjSet<String> doGetKnownItems(String str) {
        AutoLock autoReadLock = this.knownItemsLock.autoReadLock();
        Throwable th = null;
        try {
            try {
                ObjSet<String> objSet = (ObjSet) this.knownItems.get(str);
                if (autoReadLock != null) {
                    if (0 != 0) {
                        try {
                            autoReadLock.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        autoReadLock.close();
                    }
                }
                return objSet;
            } finally {
            }
        } catch (Throwable th3) {
            if (autoReadLock != null) {
                if (th != null) {
                    try {
                        autoReadLock.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    autoReadLock.close();
                }
            }
            throw th3;
        }
    }

    public Map<String, Integer> getUserCounts() {
        HashObjIntMap newUpdatableMap = HashObjIntMaps.newUpdatableMap();
        AutoLock autoReadLock = this.knownItemsLock.autoReadLock();
        Throwable th = null;
        try {
            try {
                this.knownItems.forEach((str, objSet) -> {
                    int size;
                    synchronized (objSet) {
                        size = objSet.size();
                    }
                    newUpdatableMap.addValue(str, size);
                });
                if (autoReadLock != null) {
                    if (0 != 0) {
                        try {
                            autoReadLock.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        autoReadLock.close();
                    }
                }
                return newUpdatableMap;
            } finally {
            }
        } catch (Throwable th3) {
            if (autoReadLock != null) {
                if (th != null) {
                    try {
                        autoReadLock.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    autoReadLock.close();
                }
            }
            throw th3;
        }
    }

    public Map<String, Integer> getItemCounts() {
        HashObjIntMap newUpdatableMap = HashObjIntMaps.newUpdatableMap();
        AutoLock autoReadLock = this.knownItemsLock.autoReadLock();
        Throwable th = null;
        try {
            try {
                this.knownItems.values().forEach(objSet -> {
                    synchronized (objSet) {
                        objSet.forEach(str -> {
                            newUpdatableMap.addValue(str, 1);
                        });
                    }
                });
                if (autoReadLock != null) {
                    if (0 != 0) {
                        try {
                            autoReadLock.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        autoReadLock.close();
                    }
                }
                return newUpdatableMap;
            } finally {
            }
        } catch (Throwable th3) {
            if (autoReadLock != null) {
                if (th != null) {
                    try {
                        autoReadLock.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    autoReadLock.close();
                }
            }
            throw th3;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void addKnownItems(String str, Collection<String> collection) {
        if (collection.isEmpty()) {
            return;
        }
        ObjSet<String> doGetKnownItems = doGetKnownItems(str);
        if (doGetKnownItems == null) {
            AutoLock autoWriteLock = this.knownItemsLock.autoWriteLock();
            Throwable th = null;
            try {
                doGetKnownItems = (ObjSet) this.knownItems.computeIfAbsent(str, str2 -> {
                    return HashObjSets.newMutableSet();
                });
                if (autoWriteLock != null) {
                    if (0 != 0) {
                        try {
                            autoWriteLock.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        autoWriteLock.close();
                    }
                }
            } catch (Throwable th3) {
                if (autoWriteLock != null) {
                    if (0 != 0) {
                        try {
                            autoWriteLock.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        autoWriteLock.close();
                    }
                }
                throw th3;
            }
        }
        synchronized (doGetKnownItems) {
            doGetKnownItems.addAll(collection);
        }
    }

    public List<Pair<String, float[]>> getKnownItemVectorsForUser(String str) {
        ObjSet<String> doGetKnownItems;
        if (getUserVector(str) == null || (doGetKnownItems = doGetKnownItems(str)) == null) {
            return null;
        }
        synchronized (doGetKnownItems) {
            int size = doGetKnownItems.size();
            if (size == 0) {
                return null;
            }
            ArrayList arrayList = new ArrayList(size);
            for (String str2 : doGetKnownItems) {
                float[] itemVector = getItemVector(str2);
                if (itemVector != null) {
                    arrayList.add(new Pair(str2, itemVector));
                }
            }
            return arrayList.isEmpty() ? null : arrayList;
        }
    }

    public Stream<Pair<String, Double>> topN(CosineDistanceSensitiveFunction cosineDistanceSensitiveFunction, ObjDoubleToDoubleFunction<String> objDoubleToDoubleFunction, int i, Predicate<String> predicate) {
        return this.Y.mapPartitionsParallel(featureVectorsPartition -> {
            TopNConsumer topNConsumer = new TopNConsumer(i, cosineDistanceSensitiveFunction, objDoubleToDoubleFunction, predicate);
            featureVectorsPartition.forEach(topNConsumer);
            return topNConsumer.getTopN();
        }, this.lsh.getCandidateIndices(cosineDistanceSensitiveFunction.getTargetVector()), false).sorted(Pairs.orderBySecond(Pairs.SortOrder.DESCENDING)).limit(i);
    }

    public Collection<String> getAllUserIDs() {
        HashObjSet newMutableSet = HashObjSets.newMutableSet();
        this.X.addAllIDsTo(newMutableSet);
        return newMutableSet;
    }

    public Collection<String> getAllItemIDs() {
        HashObjSet newMutableSet = HashObjSets.newMutableSet();
        this.Y.addAllIDsTo(newMutableSet);
        return newMutableSet;
    }

    public Solver getYTYSolver() {
        return this.cachedYTYSolver.get(true);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void precomputeSolvers() {
        this.cachedYTYSolver.compute();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void retainRecentAndUserIDs(Collection<String> collection) {
        this.X.retainRecentAndIDs(collection);
        AutoLock autoWriteLock = this.expectedUserIDsLock.autoWriteLock();
        Throwable th = null;
        try {
            try {
                this.expectedUserIDs.clear();
                this.expectedUserIDs.addAll(collection);
                this.X.removeAllIDsFrom(this.expectedUserIDs);
                if (autoWriteLock != null) {
                    if (0 == 0) {
                        autoWriteLock.close();
                        return;
                    }
                    try {
                        autoWriteLock.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (Throwable th4) {
            if (autoWriteLock != null) {
                if (th != null) {
                    try {
                        autoWriteLock.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    autoWriteLock.close();
                }
            }
            throw th4;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void retainRecentAndItemIDs(Collection<String> collection) {
        this.Y.retainRecentAndIDs(collection);
        AutoLock autoWriteLock = this.expectedItemIDsLock.autoWriteLock();
        Throwable th = null;
        try {
            try {
                this.expectedItemIDs.clear();
                this.expectedItemIDs.addAll(collection);
                this.Y.removeAllIDsFrom(this.expectedItemIDs);
                if (autoWriteLock != null) {
                    if (0 == 0) {
                        autoWriteLock.close();
                        return;
                    }
                    try {
                        autoWriteLock.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (Throwable th4) {
            if (autoWriteLock != null) {
                if (th != null) {
                    try {
                        autoWriteLock.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    autoWriteLock.close();
                }
            }
            throw th4;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void retainRecentAndKnownItems(Collection<String> collection, Collection<String> collection2) {
        Predicate predicate;
        AutoLock autoReadLock;
        Throwable th;
        HashObjSet newMutableSet = HashObjSets.newMutableSet();
        this.X.addAllRecentTo(newMutableSet);
        AutoLock autoWriteLock = this.knownItemsLock.autoWriteLock();
        Throwable th2 = null;
        try {
            try {
                this.knownItems.removeIf((str, objSet) -> {
                    return (collection.contains(str) || newMutableSet.contains(str)) ? false : true;
                });
                if (autoWriteLock != null) {
                    if (0 != 0) {
                        try {
                            autoWriteLock.close();
                        } catch (Throwable th3) {
                            th2.addSuppressed(th3);
                        }
                    } else {
                        autoWriteLock.close();
                    }
                }
                HashObjSet newMutableSet2 = HashObjSets.newMutableSet();
                this.Y.addAllRecentTo(newMutableSet2);
                predicate = str2 -> {
                    return (collection2.contains(str2) || newMutableSet2.contains(str2)) ? false : true;
                };
                autoReadLock = this.knownItemsLock.autoReadLock();
                th = null;
            } catch (Throwable th4) {
                th2 = th4;
                throw th4;
            }
            try {
                try {
                    this.knownItems.values().forEach(objSet2 -> {
                        synchronized (objSet2) {
                            ObjCursor cursor = objSet2.cursor();
                            while (cursor.moveNext()) {
                                Object elem = cursor.elem();
                                if (!(elem instanceof String)) {
                                    log.warn("Found non-String collection: {}", elem);
                                    cursor.remove();
                                } else if (predicate.test((String) elem)) {
                                    cursor.remove();
                                }
                            }
                        }
                    });
                    if (autoReadLock != null) {
                        if (0 == 0) {
                            autoReadLock.close();
                            return;
                        }
                        try {
                            autoReadLock.close();
                        } catch (Throwable th5) {
                            th.addSuppressed(th5);
                        }
                    }
                } catch (Throwable th6) {
                    th = th6;
                    throw th6;
                }
            } catch (Throwable th7) {
                if (autoReadLock != null) {
                    if (th != null) {
                        try {
                            autoReadLock.close();
                        } catch (Throwable th8) {
                            th.addSuppressed(th8);
                        }
                    } else {
                        autoReadLock.close();
                    }
                }
                throw th7;
            }
        } catch (Throwable th9) {
            if (autoWriteLock != null) {
                if (th2 != null) {
                    try {
                        autoWriteLock.close();
                    } catch (Throwable th10) {
                        th2.addSuppressed(th10);
                    }
                } else {
                    autoWriteLock.close();
                }
            }
            throw th9;
        }
    }

    public int getNumUsers() {
        return this.X.size();
    }

    public int getNumItems() {
        return this.Y.size();
    }

    public float getFractionLoaded() {
        AutoLock autoReadLock = this.expectedUserIDsLock.autoReadLock();
        Throwable th = null;
        try {
            int size = 0 + this.expectedUserIDs.size();
            if (autoReadLock != null) {
                if (0 != 0) {
                    try {
                        autoReadLock.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                } else {
                    autoReadLock.close();
                }
            }
            AutoLock autoReadLock2 = this.expectedItemIDsLock.autoReadLock();
            Throwable th3 = null;
            try {
                try {
                    int size2 = size + this.expectedItemIDs.size();
                    if (autoReadLock2 != null) {
                        if (0 != 0) {
                            try {
                                autoReadLock2.close();
                            } catch (Throwable th4) {
                                th3.addSuppressed(th4);
                            }
                        } else {
                            autoReadLock2.close();
                        }
                    }
                    if (size2 == 0) {
                        return 1.0f;
                    }
                    float numUsers = getNumUsers() + getNumItems();
                    return numUsers / (numUsers + size2);
                } finally {
                }
            } catch (Throwable th5) {
                if (autoReadLock2 != null) {
                    if (th3 != null) {
                        try {
                            autoReadLock2.close();
                        } catch (Throwable th6) {
                            th3.addSuppressed(th6);
                        }
                    } else {
                        autoReadLock2.close();
                    }
                }
                throw th5;
            }
        } catch (Throwable th7) {
            if (autoReadLock != null) {
                if (0 != 0) {
                    try {
                        autoReadLock.close();
                    } catch (Throwable th8) {
                        th.addSuppressed(th8);
                    }
                } else {
                    autoReadLock.close();
                }
            }
            throw th7;
        }
    }

    public String toString() {
        return "ALSServingModel[features:" + this.features + ", implicit:" + this.implicit + ", X:(" + getNumUsers() + " users), Y:(" + getNumItems() + " items, partitions: " + this.Y + "...), fractionLoaded:" + getFractionLoaded() + "]";
    }
}
