package com.cloudera.oryx.app.als;

import com.cloudera.oryx.common.lang.AutoLock;
import com.cloudera.oryx.common.lang.AutoReadWriteLock;
import com.cloudera.oryx.common.lang.LoggingCallable;
import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Objects;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.function.Function;
import java.util.function.ToIntBiFunction;
import java.util.stream.Stream;
import org.eclipse.collections.impl.map.mutable.primitive.ObjectIntHashMap;

/* loaded from: input_file:com/cloudera/oryx/app/als/PartitionedFeatureVectors.class */
public final class PartitionedFeatureVectors implements FeatureVectors {
    private final FeatureVectorsPartition[] partitions;
    private final ToIntBiFunction<String, float[]> partitioner;
    private final ObjectIntHashMap<String> partitionMap;
    private final AutoReadWriteLock partitionMapLock;
    private final ExecutorService executor;

    public PartitionedFeatureVectors(int i, ExecutorService executorService) {
        this(i, executorService, (str, fArr) -> {
            return (str.hashCode() & Integer.MAX_VALUE) % i;
        });
    }

    public PartitionedFeatureVectors(int i, ExecutorService executorService, ToIntBiFunction<String, float[]> toIntBiFunction) {
        Preconditions.checkArgument(i > 0);
        Objects.requireNonNull(executorService);
        Objects.requireNonNull(toIntBiFunction);
        this.partitions = new FeatureVectorsPartition[i];
        for (int i2 = 0; i2 < i; i2++) {
            this.partitions[i2] = new FeatureVectorsPartition();
        }
        this.partitionMap = ObjectIntHashMap.newMap();
        this.partitionMapLock = new AutoReadWriteLock();
        this.partitioner = toIntBiFunction;
        this.executor = executorService;
    }

    @Override // com.cloudera.oryx.app.als.FeatureVectors
    public int size() {
        int i = 0;
        for (FeatureVectorsPartition featureVectorsPartition : this.partitions) {
            i += featureVectorsPartition.size();
        }
        return i;
    }

    public <T> Stream<T> mapPartitionsParallel(Function<FeatureVectorsPartition, Stream<T>> function, boolean z) {
        return mapPartitionsParallel(function, null, z);
    }

    public <T> Stream<T> mapPartitionsParallel(Function<FeatureVectorsPartition, Stream<T>> function, int[] iArr, boolean z) {
        ArrayList arrayList;
        Stream<T> stream;
        if (iArr == null) {
            arrayList = new ArrayList(this.partitions.length);
            for (FeatureVectorsPartition featureVectorsPartition : this.partitions) {
                if (featureVectorsPartition.size() > 0) {
                    arrayList.add(LoggingCallable.log(() -> {
                        return (Stream) function.apply(featureVectorsPartition);
                    }));
                }
            }
        } else {
            arrayList = new ArrayList(iArr.length);
            for (int i : iArr) {
                if (this.partitions[i].size() > 0) {
                    arrayList.add(LoggingCallable.log(() -> {
                        return (Stream) function.apply(this.partitions[i]);
                    }));
                }
            }
        }
        int size = arrayList.size();
        if (size == 0) {
            return Stream.empty();
        }
        if (size == 1) {
            try {
                stream = (Stream) ((Callable) arrayList.get(0)).call();
            } catch (Exception e) {
                throw new IllegalStateException(e);
            }
        } else if (z) {
            stream = (Stream) arrayList.stream().map(callable -> {
                try {
                    return (Stream) callable.call();
                } catch (Exception e2) {
                    throw new IllegalStateException(e2);
                }
            }).reduce(Stream::concat).orElse(null);
        } else {
            try {
                stream = (Stream) this.executor.invokeAll(arrayList).stream().map(future -> {
                    try {
                        return (Stream) future.get();
                    } catch (InterruptedException e2) {
                        throw new IllegalStateException(e2);
                    } catch (ExecutionException e3) {
                        throw new IllegalStateException(e3.getCause());
                    }
                }).reduce(Stream::concat).orElse(null);
            } catch (InterruptedException e2) {
                throw new IllegalStateException(e2);
            }
        }
        return stream;
    }

    @Override // com.cloudera.oryx.app.als.FeatureVectors
    public float[] getVector(String str) {
        AutoLock autoReadLock = this.partitionMapLock.autoReadLock();
        Throwable th = null;
        try {
            int ifAbsent = this.partitionMap.getIfAbsent(str, Integer.MIN_VALUE);
            if (autoReadLock != null) {
                if (0 != 0) {
                    try {
                        autoReadLock.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                } else {
                    autoReadLock.close();
                }
            }
            if (ifAbsent < 0) {
                return null;
            }
            return this.partitions[ifAbsent].getVector(str);
        } catch (Throwable th3) {
            if (autoReadLock != null) {
                if (0 != 0) {
                    try {
                        autoReadLock.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    autoReadLock.close();
                }
            }
            throw th3;
        }
    }

    @Override // com.cloudera.oryx.app.als.FeatureVectors
    public void setVector(String str, float[] fArr) {
        int applyAsInt = this.partitioner.applyAsInt(str, fArr);
        AutoLock autoWriteLock = this.partitionMapLock.autoWriteLock();
        Throwable th = null;
        try {
            try {
                int ifAbsent = this.partitionMap.getIfAbsent(str, Integer.MIN_VALUE);
                if (ifAbsent >= 0 && ifAbsent != applyAsInt) {
                    this.partitions[ifAbsent].removeVector(str);
                }
                this.partitions[applyAsInt].setVector(str, fArr);
                this.partitionMap.put(str, applyAsInt);
                if (autoWriteLock != null) {
                    if (0 == 0) {
                        autoWriteLock.close();
                        return;
                    }
                    try {
                        autoWriteLock.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (Throwable th4) {
            if (autoWriteLock != null) {
                if (th != null) {
                    try {
                        autoWriteLock.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    autoWriteLock.close();
                }
            }
            throw th4;
        }
    }

    @Override // com.cloudera.oryx.app.als.FeatureVectors
    public void addAllIDsTo(Collection<String> collection) {
        for (FeatureVectorsPartition featureVectorsPartition : this.partitions) {
            featureVectorsPartition.addAllIDsTo(collection);
        }
    }

    @Override // com.cloudera.oryx.app.als.FeatureVectors
    public void removeAllIDsFrom(Collection<String> collection) {
        for (FeatureVectorsPartition featureVectorsPartition : this.partitions) {
            featureVectorsPartition.removeAllIDsFrom(collection);
        }
    }

    @Override // com.cloudera.oryx.app.als.FeatureVectors
    public void addAllRecentTo(Collection<String> collection) {
        for (FeatureVectorsPartition featureVectorsPartition : this.partitions) {
            featureVectorsPartition.addAllRecentTo(collection);
        }
    }

    @Override // com.cloudera.oryx.app.als.FeatureVectors
    public void retainRecentAndIDs(Collection<String> collection) {
        for (FeatureVectorsPartition featureVectorsPartition : this.partitions) {
            featureVectorsPartition.retainRecentAndIDs(collection);
        }
    }

    @Override // com.cloudera.oryx.app.als.FeatureVectors
    public double[] getVTV(boolean z) {
        return (double[]) mapPartitionsParallel(featureVectorsPartition -> {
            return Stream.of(featureVectorsPartition.getVTV(z));
        }, z).reduce((dArr, dArr2) -> {
            for (int i = 0; i < dArr.length; i++) {
                int i2 = i;
                dArr[i2] = dArr[i2] + dArr2[i];
            }
            return dArr;
        }).orElse(null);
    }

    public String toString() {
        ArrayList arrayList = new ArrayList(16);
        int i = 0;
        while (true) {
            if (i >= this.partitions.length) {
                break;
            }
            int size = this.partitions[i].size();
            if (size > 0) {
                arrayList.add(i + ":" + size);
                if (arrayList.size() == 16) {
                    arrayList.add("...");
                    break;
                }
            }
            i++;
        }
        return arrayList.toString();
    }
}
