package com.cloudera.oryx.app.speed.rdf;

import com.cloudera.oryx.api.speed.AbstractSpeedModelManager;
import com.cloudera.oryx.app.classreg.example.CategoricalFeature;
import com.cloudera.oryx.app.classreg.example.ExampleUtils;
import com.cloudera.oryx.app.classreg.example.Feature;
import com.cloudera.oryx.app.classreg.example.NumericFeature;
import com.cloudera.oryx.app.common.fn.MLFunctions;
import com.cloudera.oryx.app.pmml.AppPMMLUtils;
import com.cloudera.oryx.app.rdf.RDFPMMLUtils;
import com.cloudera.oryx.app.rdf.tree.DecisionForest;
import com.cloudera.oryx.app.rdf.tree.DecisionTree;
import com.cloudera.oryx.app.schema.CategoricalValueEncodings;
import com.cloudera.oryx.app.schema.InputSchema;
import com.cloudera.oryx.common.collection.Pair;
import com.cloudera.oryx.common.text.TextUtils;
import com.typesafe.config.Config;
import java.io.IOException;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.DoubleSummaryStatistics;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import org.apache.hadoop.conf.Configuration;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import oryx.org.dmg.pmml.PMML;
import scala.Tuple2;

/* loaded from: input_file:com/cloudera/oryx/app/speed/rdf/RDFSpeedModelManager.class */
public final class RDFSpeedModelManager extends AbstractSpeedModelManager<String, String, String> {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) RDFSpeedModelManager.class);
    private final InputSchema inputSchema;
    private RDFSpeedModel model;

    public RDFSpeedModelManager(Config config) {
        this.inputSchema = new InputSchema(config);
    }

    @Override // com.cloudera.oryx.api.speed.AbstractSpeedModelManager
    public void consumeKeyMessage(String str, String str2, Configuration configuration) throws IOException {
        boolean z = -1;
        switch (str.hashCode()) {
            case 2715:
                if (str.equals("UP")) {
                    z = false;
                    break;
                }
                break;
            case 73532169:
                if (str.equals("MODEL")) {
                    z = true;
                    break;
                }
                break;
            case 775751599:
                if (str.equals("MODEL-REF")) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return;
            case true:
            case true:
                log.info("Loading new model");
                PMML readPMMLFromUpdateKeyMessage = AppPMMLUtils.readPMMLFromUpdateKeyMessage(str, str2, configuration);
                if (readPMMLFromUpdateKeyMessage == null) {
                    return;
                }
                RDFPMMLUtils.validatePMMLVsSchema(readPMMLFromUpdateKeyMessage, this.inputSchema);
                Pair<DecisionForest, CategoricalValueEncodings> read = RDFPMMLUtils.read(readPMMLFromUpdateKeyMessage);
                this.model = new RDFSpeedModel(read.getFirst(), read.getSecond());
                log.info("New model loaded: {}", this.model);
                return;
            default:
                throw new IllegalArgumentException("Bad key: " + str);
        }
    }

    @Override // com.cloudera.oryx.api.speed.SpeedModelManager
    public Iterable<String> buildUpdates(JavaPairRDD<String, String> javaPairRDD) {
        if (this.model == null) {
            return Collections.emptyList();
        }
        InputSchema inputSchema = this.inputSchema;
        CategoricalValueEncodings encodings = this.model.getEncodings();
        JavaRDD map = javaPairRDD.values().map(MLFunctions.PARSE_FN).map(strArr -> {
            return ExampleUtils.dataToExample(strArr, inputSchema, encodings);
        });
        DecisionForest forest = this.model.getForest();
        JavaPairRDD groupByKey = map.flatMapToPair(example -> {
            Feature target = example.getTarget();
            DecisionTree[] trees = forest.getTrees();
            ArrayList arrayList = new ArrayList(trees.length);
            for (int i = 0; i < trees.length; i++) {
                arrayList.add(new Tuple2(new Pair(Integer.valueOf(i), trees[i].findTerminal(example).getID()), target));
            }
            return arrayList.iterator();
        }).groupByKey();
        return inputSchema.isClassification() ? (Iterable) groupByKey.mapValues(iterable -> {
            return (Map) StreamSupport.stream(iterable.spliterator(), false).collect(Collectors.groupingBy(feature -> {
                return Integer.valueOf(((CategoricalFeature) feature).getEncoding());
            }, Collectors.counting()));
        }).collect().stream().map(tuple2 -> {
            return TextUtils.joinJSON(Arrays.asList((Integer) ((Pair) tuple2._1()).getFirst(), (String) ((Pair) tuple2._1()).getSecond(), tuple2._2()));
        }).collect(Collectors.toList()) : (Iterable) groupByKey.mapValues(iterable2 -> {
            return (DoubleSummaryStatistics) StreamSupport.stream(iterable2.spliterator(), false).collect(Collectors.summarizingDouble(feature -> {
                return ((NumericFeature) feature).getValue();
            }));
        }).collect().stream().map(tuple22 -> {
            Integer num = (Integer) ((Pair) tuple22._1()).getFirst();
            String str = (String) ((Pair) tuple22._1()).getSecond();
            DoubleSummaryStatistics doubleSummaryStatistics = (DoubleSummaryStatistics) tuple22._2();
            return TextUtils.joinJSON(Arrays.asList(num, str, Double.valueOf(doubleSummaryStatistics.getAverage()), Long.valueOf(doubleSummaryStatistics.getCount())));
        }).collect(Collectors.toList());
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -736304167:
                if (implMethodName.equals("lambda$buildUpdates$e2ffef2e$1")) {
                    z = true;
                    break;
                }
                break;
            case -736304166:
                if (implMethodName.equals("lambda$buildUpdates$e2ffef2e$2")) {
                    z = 2;
                    break;
                }
                break;
            case -319688507:
                if (implMethodName.equals("lambda$buildUpdates$cdbf5838$1")) {
                    z = false;
                    break;
                }
                break;
            case 1977420322:
                if (implMethodName.equals("lambda$buildUpdates$4f375e37$1")) {
                    z = 3;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/cloudera/oryx/app/speed/rdf/RDFSpeedModelManager") && serializedLambda.getImplMethodSignature().equals("(Lcom/cloudera/oryx/app/schema/InputSchema;Lcom/cloudera/oryx/app/schema/CategoricalValueEncodings;[Ljava/lang/String;)Lcom/cloudera/oryx/app/classreg/example/Example;")) {
                    InputSchema inputSchema = (InputSchema) serializedLambda.getCapturedArg(0);
                    CategoricalValueEncodings categoricalValueEncodings = (CategoricalValueEncodings) serializedLambda.getCapturedArg(1);
                    return strArr -> {
                        return ExampleUtils.dataToExample(strArr, inputSchema, categoricalValueEncodings);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/cloudera/oryx/app/speed/rdf/RDFSpeedModelManager") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Iterable;)Ljava/util/Map;")) {
                    return iterable -> {
                        return (Map) StreamSupport.stream(iterable.spliterator(), false).collect(Collectors.groupingBy(feature -> {
                            return Integer.valueOf(((CategoricalFeature) feature).getEncoding());
                        }, Collectors.counting()));
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/cloudera/oryx/app/speed/rdf/RDFSpeedModelManager") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Iterable;)Ljava/util/DoubleSummaryStatistics;")) {
                    return iterable2 -> {
                        return (DoubleSummaryStatistics) StreamSupport.stream(iterable2.spliterator(), false).collect(Collectors.summarizingDouble(feature -> {
                            return ((NumericFeature) feature).getValue();
                        }));
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/PairFlatMapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/util/Iterator;") && serializedLambda.getImplClass().equals("com/cloudera/oryx/app/speed/rdf/RDFSpeedModelManager") && serializedLambda.getImplMethodSignature().equals("(Lcom/cloudera/oryx/app/rdf/tree/DecisionForest;Lcom/cloudera/oryx/app/classreg/example/Example;)Ljava/util/Iterator;")) {
                    DecisionForest decisionForest = (DecisionForest) serializedLambda.getCapturedArg(0);
                    return example -> {
                        Feature target = example.getTarget();
                        DecisionTree[] trees = decisionForest.getTrees();
                        ArrayList arrayList = new ArrayList(trees.length);
                        for (int i = 0; i < trees.length; i++) {
                            arrayList.add(new Tuple2(new Pair(Integer.valueOf(i), trees[i].findTerminal(example).getID()), target));
                        }
                        return arrayList.iterator();
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
