package com.cloudera.oryx.app.als;

import com.cloudera.oryx.common.lang.LoggingCallable;
import com.cloudera.oryx.common.math.LinearSystemSolver;
import com.cloudera.oryx.common.math.Solver;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/cloudera/oryx/app/als/SolverCache.class */
public final class SolverCache {
    private static final Logger log = LoggerFactory.getLogger(SolverCache.class);
    private final AtomicReference<Solver> solver = new AtomicReference<>();
    private final AtomicBoolean solverDirty = new AtomicBoolean(true);
    private final AtomicBoolean solverUpdating = new AtomicBoolean(false);
    private final CountDownLatch solverInitialized = new CountDownLatch(1);
    private final Executor executor;
    private final FeatureVectors vectorPartitions;

    public SolverCache(Executor executor, FeatureVectors featureVectors) {
        this.executor = executor;
        this.vectorPartitions = featureVectors;
    }

    public void setDirty() {
        this.solverDirty.set(true);
    }

    public void compute() {
        if (this.solverUpdating.compareAndSet(false, true)) {
            this.executor.execute(LoggingCallable.log(() -> {
                try {
                    log.info("Computing cached solver");
                    Solver solver = LinearSystemSolver.getSolver(this.vectorPartitions.getVTV(this.solver.get() != null));
                    if (solver != null) {
                        log.info("Computed new solver {}", solver);
                        this.solver.set(solver);
                    }
                } finally {
                    this.solverInitialized.countDown();
                    this.solverUpdating.set(false);
                }
            }).asRunnable());
        }
    }

    public Solver get(boolean z) {
        if (this.solverDirty.getAndSet(false)) {
            compute();
        }
        if (z && this.solverInitialized.getCount() > 0) {
            try {
                this.solverInitialized.await();
            } catch (InterruptedException e) {
                log.warn("Interrupted while waiting for model", e);
            }
        }
        return this.solver.get();
    }
}
