package com.cloudera.oryx.app.speed.als;

import com.cloudera.oryx.api.speed.AbstractSpeedModelManager;
import com.cloudera.oryx.app.als.ALSUtils;
import com.cloudera.oryx.app.common.fn.MLFunctions;
import com.cloudera.oryx.app.pmml.AppPMMLUtils;
import com.cloudera.oryx.common.lang.RateLimitCheck;
import com.cloudera.oryx.common.math.SingularMatrixSolverException;
import com.cloudera.oryx.common.math.Solver;
import com.cloudera.oryx.common.text.TextUtils;
import com.google.common.base.Preconditions;
import com.typesafe.config.Config;
import java.io.IOException;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.hadoop.conf.Configuration;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import oryx.org.dmg.pmml.PMML;
import scala.Tuple2;

/* loaded from: input_file:com/cloudera/oryx/app/speed/als/ALSSpeedModelManager.class */
public final class ALSSpeedModelManager extends AbstractSpeedModelManager<String, String, String> {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) ALSSpeedModelManager.class);
    private ALSSpeedModel model;
    private final boolean noKnownItems;
    private final double minModelLoadFraction;
    private final RateLimitCheck logRateLimit;

    public ALSSpeedModelManager(Config config) {
        this.noKnownItems = config.getBoolean("oryx.als.no-known-items");
        this.minModelLoadFraction = config.getDouble("oryx.speed.min-model-load-fraction");
        Preconditions.checkArgument(this.minModelLoadFraction >= CMAESOptimizer.DEFAULT_STOPFITNESS && this.minModelLoadFraction <= 1.0d);
        this.logRateLimit = new RateLimitCheck(1L, TimeUnit.MINUTES);
    }

    @Override // com.cloudera.oryx.api.speed.AbstractSpeedModelManager
    public void consumeKeyMessage(String str, String str2, Configuration configuration) throws IOException {
        boolean z = -1;
        switch (str.hashCode()) {
            case 2715:
                if (str.equals("UP")) {
                    z = false;
                    break;
                }
                break;
            case 73532169:
                if (str.equals("MODEL")) {
                    z = true;
                    break;
                }
                break;
            case 775751599:
                if (str.equals("MODEL-REF")) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (this.model == null) {
                    return;
                }
                List list = (List) TextUtils.readJSON(str2, List.class);
                String obj = list.get(1).toString();
                float[] fArr = (float[]) TextUtils.convertViaJSON(list.get(2), float[].class);
                String obj2 = list.get(0).toString();
                boolean z2 = -1;
                switch (obj2.hashCode()) {
                    case 88:
                        if (obj2.equals("X")) {
                            z2 = false;
                            break;
                        }
                        break;
                    case 89:
                        if (obj2.equals("Y")) {
                            z2 = true;
                            break;
                        }
                        break;
                }
                switch (z2) {
                    case false:
                        this.model.setUserVector(obj, fArr);
                        break;
                    case true:
                        this.model.setItemVector(obj, fArr);
                        break;
                    default:
                        throw new IllegalArgumentException("Bad message: " + str2);
                }
                if (this.logRateLimit.test()) {
                    log.info("{}", this.model);
                    return;
                }
                return;
            case true:
            case true:
                log.info("Loading new model");
                PMML readPMMLFromUpdateKeyMessage = AppPMMLUtils.readPMMLFromUpdateKeyMessage(str, str2, configuration);
                if (readPMMLFromUpdateKeyMessage == null) {
                    return;
                }
                int parseInt = Integer.parseInt(AppPMMLUtils.getExtensionValue(readPMMLFromUpdateKeyMessage, "features"));
                boolean parseBoolean = Boolean.parseBoolean(AppPMMLUtils.getExtensionValue(readPMMLFromUpdateKeyMessage, "implicit"));
                boolean parseBoolean2 = Boolean.parseBoolean(AppPMMLUtils.getExtensionValue(readPMMLFromUpdateKeyMessage, "logStrength"));
                double parseDouble = parseBoolean2 ? Double.parseDouble(AppPMMLUtils.getExtensionValue(readPMMLFromUpdateKeyMessage, "epsilon")) : Double.NaN;
                if (this.model == null || parseInt != this.model.getFeatures()) {
                    log.warn("No previous model, or # features has changed; creating new one");
                    this.model = new ALSSpeedModel(parseInt, parseBoolean, parseBoolean2, parseDouble);
                }
                log.info("Updating model");
                HashSet hashSet = new HashSet(AppPMMLUtils.getExtensionContent(readPMMLFromUpdateKeyMessage, "XIDs"));
                HashSet hashSet2 = new HashSet(AppPMMLUtils.getExtensionContent(readPMMLFromUpdateKeyMessage, "YIDs"));
                this.model.retainRecentAndUserIDs(hashSet);
                this.model.retainRecentAndItemIDs(hashSet2);
                log.info("Model updated: {}", this.model);
                return;
            default:
                throw new IllegalArgumentException("Bad key: " + str);
        }
    }

    @Override // com.cloudera.oryx.api.speed.SpeedModelManager
    public Iterable<String> buildUpdates(JavaPairRDD<String, String> javaPairRDD) {
        JavaRDD map;
        if (this.model == null || this.model.getFractionLoaded() < this.minModelLoadFraction) {
            return Collections.emptyList();
        }
        this.model.precomputeSolvers();
        JavaPairRDD mapToPair = javaPairRDD.values().sortBy(MLFunctions.TO_TIMESTAMP_FN, true, javaPairRDD.partitions().size()).mapToPair(str -> {
            try {
                String[] strArr = (String[]) MLFunctions.PARSE_FN.call(str);
                return new Tuple2(new Tuple2(strArr[0], strArr[1]), Double.valueOf(strArr[2].isEmpty() ? Double.NaN : Double.valueOf(strArr[2]).doubleValue()));
            } catch (ArrayIndexOutOfBoundsException | NumberFormatException e) {
                log.warn("Bad input: {}", str);
                throw e;
            }
        });
        JavaPairRDD filter = (this.model.isImplicit() ? mapToPair.groupByKey().mapValues(MLFunctions.SUM_WITH_NAN) : mapToPair.foldByKey(Double.valueOf(Double.NaN), (d, d2) -> {
            return d2;
        })).filter(tuple2 -> {
            return Boolean.valueOf(!Double.isNaN(((Double) tuple2._2()).doubleValue()));
        });
        if (this.model.isLogStrength()) {
            double epsilon = this.model.getEpsilon();
            map = filter.map(tuple22 -> {
                return new UserItemStrength((String) ((Tuple2) tuple22._1())._1(), (String) ((Tuple2) tuple22._1())._2(), (float) Math.log1p(((Double) tuple22._2()).doubleValue() / epsilon));
            });
        } else {
            map = filter.map(tuple23 -> {
                return new UserItemStrength((String) ((Tuple2) tuple23._1())._1(), (String) ((Tuple2) tuple23._1())._2(), ((Double) tuple23._2()).floatValue());
            });
        }
        List collect = map.collect();
        try {
            Solver xTXSolver = this.model.getXTXSolver();
            Solver yTYSolver = this.model.getYTYSolver();
            if (xTXSolver != null && yTYSolver != null) {
                return (Iterable) collect.parallelStream().flatMap(userItemStrength -> {
                    String user = userItemStrength.getUser();
                    String item = userItemStrength.getItem();
                    double strength = userItemStrength.getStrength();
                    float[] userVector = this.model.getUserVector(user);
                    float[] itemVector = this.model.getItemVector(item);
                    float[] computeUpdatedXu = ALSUtils.computeUpdatedXu(yTYSolver, strength, userVector, itemVector, this.model.isImplicit());
                    float[] computeUpdatedXu2 = ALSUtils.computeUpdatedXu(xTXSolver, strength, itemVector, userVector, this.model.isImplicit());
                    ArrayList arrayList = new ArrayList(2);
                    if (computeUpdatedXu != null) {
                        arrayList.add(toUpdateJSON("X", user, computeUpdatedXu, item));
                    }
                    if (computeUpdatedXu2 != null) {
                        arrayList.add(toUpdateJSON("Y", item, computeUpdatedXu2, user));
                    }
                    return arrayList.stream();
                }).collect(Collectors.toList());
            }
            log.info("No solver available yet for model; skipping inputs");
            return Collections.emptyList();
        } catch (SingularMatrixSolverException e) {
            log.info("Not enough data for solver yet ({}); skipping inputs", e.getMessage());
            return Collections.emptyList();
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private String toUpdateJSON(String str, String str2, float[] fArr, String str3) {
        return TextUtils.joinJSON(this.noKnownItems ? Arrays.asList(str, str2, fArr) : Arrays.asList(str, str2, fArr, Collections.singletonList(str3)));
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1433566811:
                if (implMethodName.equals("lambda$buildUpdates$39757afe$1")) {
                    z = 3;
                    break;
                }
                break;
            case -797517454:
                if (implMethodName.equals("lambda$buildUpdates$b11940c$1")) {
                    z = false;
                    break;
                }
                break;
            case -736304167:
                if (implMethodName.equals("lambda$buildUpdates$e2ffef2e$1")) {
                    z = 2;
                    break;
                }
                break;
            case -444964769:
                if (implMethodName.equals("lambda$buildUpdates$7cfdf8e0$1")) {
                    z = true;
                    break;
                }
                break;
            case 1322734572:
                if (implMethodName.equals("lambda$buildUpdates$9c6584b1$1")) {
                    z = 4;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/cloudera/oryx/app/speed/als/ALSSpeedModelManager") && serializedLambda.getImplMethodSignature().equals("(DLscala/Tuple2;)Lcom/cloudera/oryx/app/speed/als/UserItemStrength;")) {
                    double doubleValue = ((Double) serializedLambda.getCapturedArg(0)).doubleValue();
                    return tuple22 -> {
                        return new UserItemStrength((String) ((Tuple2) tuple22._1())._1(), (String) ((Tuple2) tuple22._1())._2(), (float) Math.log1p(((Double) tuple22._2()).doubleValue() / doubleValue));
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function2") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/cloudera/oryx/app/speed/als/ALSSpeedModelManager") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Double;Ljava/lang/Double;)Ljava/lang/Double;")) {
                    return (d, d2) -> {
                        return d2;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/cloudera/oryx/app/speed/als/ALSSpeedModelManager") && serializedLambda.getImplMethodSignature().equals("(Lscala/Tuple2;)Lcom/cloudera/oryx/app/speed/als/UserItemStrength;")) {
                    return tuple23 -> {
                        return new UserItemStrength((String) ((Tuple2) tuple23._1())._1(), (String) ((Tuple2) tuple23._1())._2(), ((Double) tuple23._2()).floatValue());
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/PairFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Lscala/Tuple2;") && serializedLambda.getImplClass().equals("com/cloudera/oryx/app/speed/als/ALSSpeedModelManager") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/String;)Lscala/Tuple2;")) {
                    return str -> {
                        try {
                            String[] strArr = (String[]) MLFunctions.PARSE_FN.call(str);
                            return new Tuple2(new Tuple2(strArr[0], strArr[1]), Double.valueOf(strArr[2].isEmpty() ? Double.NaN : Double.valueOf(strArr[2]).doubleValue()));
                        } catch (ArrayIndexOutOfBoundsException | NumberFormatException e) {
                            log.warn("Bad input: {}", str);
                            throw e;
                        }
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/cloudera/oryx/app/speed/als/ALSSpeedModelManager") && serializedLambda.getImplMethodSignature().equals("(Lscala/Tuple2;)Ljava/lang/Boolean;")) {
                    return tuple2 -> {
                        return Boolean.valueOf(!Double.isNaN(((Double) tuple2._2()).doubleValue()));
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
