package com.cloudera.oryx.app.serving.als.model;

import com.cloudera.oryx.api.serving.AbstractServingModelManager;
import com.cloudera.oryx.app.als.MultiRescorerProvider;
import com.cloudera.oryx.app.als.RescorerProvider;
import com.cloudera.oryx.app.pmml.AppPMMLUtils;
import com.cloudera.oryx.common.lang.RateLimitCheck;
import com.cloudera.oryx.common.settings.ConfigUtils;
import com.cloudera.oryx.common.text.TextUtils;
import com.google.common.base.Preconditions;
import com.typesafe.config.Config;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.concurrent.TimeUnit;
import org.apache.hadoop.conf.Configuration;
import org.dmg.pmml.PMML;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/cloudera/oryx/app/serving/als/model/ALSServingModelManager.class */
public final class ALSServingModelManager extends AbstractServingModelManager<String> {
    private static final Logger log = LoggerFactory.getLogger(ALSServingModelManager.class);
    private ALSServingModel model;
    private boolean triggeredSolver;
    private final double sampleRate;
    private final double minModelLoadFraction;
    private final RescorerProvider rescorerProvider;
    private final RateLimitCheck logRateLimit;

    public ALSServingModelManager(Config config) {
        super(config);
        this.rescorerProvider = loadRescorerProviders(ConfigUtils.getOptionalString(config, "oryx.als.rescorer-provider-class"));
        this.sampleRate = config.getDouble("oryx.als.sample-rate");
        this.minModelLoadFraction = config.getDouble("oryx.serving.min-model-load-fraction");
        Preconditions.checkArgument(this.sampleRate > 0.0d && this.sampleRate <= 1.0d);
        Preconditions.checkArgument(this.minModelLoadFraction >= 0.0d && this.minModelLoadFraction <= 1.0d);
        this.logRateLimit = new RateLimitCheck(1L, TimeUnit.MINUTES);
    }

    public void consumeKeyMessage(String str, String str2, Configuration configuration) throws IOException {
        boolean z = -1;
        switch (str.hashCode()) {
            case 2715:
                if (str.equals("UP")) {
                    z = false;
                    break;
                }
                break;
            case 73532169:
                if (str.equals("MODEL")) {
                    z = true;
                    break;
                }
                break;
            case 775751599:
                if (str.equals("MODEL-REF")) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (this.model == null) {
                    return;
                }
                List list = (List) TextUtils.readJSON(str2, List.class);
                String obj = list.get(1).toString();
                float[] fArr = (float[]) TextUtils.convertViaJSON(list.get(2), float[].class);
                String obj2 = list.get(0).toString();
                boolean z2 = -1;
                switch (obj2.hashCode()) {
                    case 88:
                        if (obj2.equals("X")) {
                            z2 = false;
                            break;
                        }
                        break;
                    case 89:
                        if (obj2.equals("Y")) {
                            z2 = true;
                            break;
                        }
                        break;
                }
                switch (z2) {
                    case false:
                        this.model.setUserVector(obj, fArr);
                        if (list.size() > 3) {
                            this.model.addKnownItems(obj, (Collection) list.get(3));
                            break;
                        }
                        break;
                    case true:
                        this.model.setItemVector(obj, fArr);
                        break;
                    default:
                        throw new IllegalArgumentException("Bad message: " + str2);
                }
                if (this.logRateLimit.test()) {
                    log.info("{}", this.model);
                    if (this.triggeredSolver || this.model.getFractionLoaded() < this.minModelLoadFraction) {
                        return;
                    }
                    this.triggeredSolver = true;
                    this.model.precomputeSolvers();
                    return;
                }
                return;
            case true:
            case true:
                log.info("Loading new model");
                PMML readPMMLFromUpdateKeyMessage = AppPMMLUtils.readPMMLFromUpdateKeyMessage(str, str2, configuration);
                if (readPMMLFromUpdateKeyMessage == null) {
                    return;
                }
                int parseInt = Integer.parseInt(AppPMMLUtils.getExtensionValue(readPMMLFromUpdateKeyMessage, "features"));
                boolean booleanValue = Boolean.valueOf(AppPMMLUtils.getExtensionValue(readPMMLFromUpdateKeyMessage, "implicit")).booleanValue();
                if (this.model == null || parseInt != this.model.getFeatures()) {
                    log.warn("No previous model, or # features has changed; creating new one");
                    this.model = new ALSServingModel(parseInt, booleanValue, this.sampleRate, this.rescorerProvider);
                }
                log.info("Updating model");
                HashSet hashSet = new HashSet(AppPMMLUtils.getExtensionContent(readPMMLFromUpdateKeyMessage, "XIDs"));
                HashSet hashSet2 = new HashSet(AppPMMLUtils.getExtensionContent(readPMMLFromUpdateKeyMessage, "YIDs"));
                this.model.retainRecentAndKnownItems(hashSet, hashSet2);
                this.model.retainRecentAndUserIDs(hashSet);
                this.model.retainRecentAndItemIDs(hashSet2);
                log.info("Model updated: {}", this.model);
                return;
            default:
                throw new IllegalArgumentException("Bad key: " + str);
        }
    }

    /* renamed from: getModel, reason: merged with bridge method [inline-methods] */
    public ALSServingModel m3getModel() {
        return this.model;
    }

    static RescorerProvider loadRescorerProviders(String str) {
        if (str == null || str.isEmpty()) {
            return null;
        }
        String[] split = str.split(",");
        if (split.length == 1) {
            return loadInstanceOf(split[0]);
        }
        RescorerProvider[] rescorerProviderArr = new RescorerProvider[split.length];
        for (int i = 0; i < split.length; i++) {
            rescorerProviderArr[i] = loadInstanceOf(split[i]);
        }
        return MultiRescorerProvider.of(rescorerProviderArr);
    }

    private static RescorerProvider loadInstanceOf(String str) {
        try {
            return (RescorerProvider) Class.forName(str, true, RescorerProvider.class.getClassLoader()).asSubclass(RescorerProvider.class).getConstructor(new Class[0]).newInstance(new Object[0]);
        } catch (ClassNotFoundException e) {
            throw new IllegalArgumentException("Could not load " + str + " due to exception", e);
        } catch (IllegalAccessException | InstantiationException | NoSuchMethodException e2) {
            throw new IllegalArgumentException("Could not instantiate " + str + " due to exception", e2);
        } catch (InvocationTargetException e3) {
            throw new IllegalStateException("Could not instantiate " + str + " due to exception", e3.getCause());
        }
    }
}
