package com.cloudera.oryx.ml;

import com.cloudera.oryx.api.TopicProducer;
import com.cloudera.oryx.api.batch.BatchLayerUpdate;
import com.cloudera.oryx.common.collection.Pair;
import com.cloudera.oryx.common.lang.ExecUtils;
import com.cloudera.oryx.common.pmml.PMMLUtils;
import com.cloudera.oryx.common.random.RandomManager;
import com.cloudera.oryx.common.settings.ConfigUtils;
import com.cloudera.oryx.ml.param.HyperParamValues;
import com.cloudera.oryx.ml.param.HyperParams;
import com.google.common.base.Preconditions;
import com.typesafe.config.Config;
import java.io.IOException;
import java.lang.invoke.SerializedLambda;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.rdd.RDD;
import org.dmg.pmml.PMML;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/cloudera/oryx/ml/MLUpdate.class */
public abstract class MLUpdate<M> implements BatchLayerUpdate<Object, M, String> {
    private static final Logger log = LoggerFactory.getLogger(MLUpdate.class);
    public static final String MODEL_FILE_NAME = "model.pmml";
    private final double testFraction;
    private final int candidates;
    private final String hyperParamSearch;
    private final int evalParallelism;
    private final Double threshold;
    private final int maxMessageSize;

    protected MLUpdate(Config config) {
        this.testFraction = config.getDouble("oryx.ml.eval.test-fraction");
        int i = config.getInt("oryx.ml.eval.candidates");
        this.evalParallelism = config.getInt("oryx.ml.eval.parallelism");
        this.threshold = ConfigUtils.getOptionalDouble(config, "oryx.ml.eval.threshold");
        this.maxMessageSize = config.getInt("oryx.update-topic.message.max-size");
        Preconditions.checkArgument(this.testFraction >= 0.0d && this.testFraction <= 1.0d);
        Preconditions.checkArgument(i > 0);
        Preconditions.checkArgument(this.evalParallelism > 0);
        Preconditions.checkArgument(this.maxMessageSize > 0);
        if (this.testFraction == 0.0d && i > 1) {
            log.info("Eval is disabled (test fraction = 0) so candidates is overridden to 1");
            i = 1;
        }
        this.candidates = i;
        this.hyperParamSearch = config.getString("oryx.ml.eval.hyperparam-search");
    }

    protected final double getTestFraction() {
        return this.testFraction;
    }

    public List<HyperParamValues<?>> getHyperParameterValues() {
        return Collections.emptyList();
    }

    public abstract PMML buildModel(JavaSparkContext javaSparkContext, JavaRDD<M> javaRDD, List<?> list, Path path);

    public boolean canPublishAdditionalModelData() {
        return false;
    }

    public void publishAdditionalModelData(JavaSparkContext javaSparkContext, PMML pmml, JavaRDD<M> javaRDD, JavaRDD<M> javaRDD2, Path path, TopicProducer<String, String> topicProducer) {
    }

    public abstract double evaluate(JavaSparkContext javaSparkContext, PMML pmml, Path path, JavaRDD<M> javaRDD, JavaRDD<M> javaRDD2);

    public void runUpdate(JavaSparkContext javaSparkContext, long j, JavaPairRDD<Object, M> javaPairRDD, JavaPairRDD<Object, M> javaPairRDD2, String str, TopicProducer<String, String> topicProducer) throws IOException, InterruptedException {
        Objects.requireNonNull(javaPairRDD);
        JavaRDD<M> values = javaPairRDD.values();
        JavaRDD<M> values2 = javaPairRDD2 == null ? null : javaPairRDD2.values();
        if (values != null) {
            values.cache();
            values.foreachPartition(it -> {
            });
        }
        if (values2 != null) {
            values2.cache();
            values2.foreachPartition(it2 -> {
            });
        }
        List<List<?>> chooseHyperParameterCombos = HyperParams.chooseHyperParameterCombos(getHyperParameterValues(), this.hyperParamSearch, this.candidates);
        Path path = new Path(str);
        Path path2 = new Path(new Path(path, ".temporary"), Long.toString(System.currentTimeMillis()));
        FileSystem fileSystem = FileSystem.get(path.toUri(), javaSparkContext.hadoopConfiguration());
        fileSystem.mkdirs(path2);
        Path findBestCandidatePath = findBestCandidatePath(javaSparkContext, values, values2, chooseHyperParameterCombos, path2);
        Path path3 = new Path(path, Long.toString(System.currentTimeMillis()));
        if (findBestCandidatePath == null) {
            log.info("Unable to build any model");
        } else {
            fileSystem.rename(findBestCandidatePath, path3);
        }
        fileSystem.delete(path2, true);
        if (topicProducer == null) {
            log.info("No update topic configured, not publishing models to a topic");
        } else {
            Path path4 = new Path(path3, MODEL_FILE_NAME);
            if (fileSystem.exists(path4)) {
                FileStatus fileStatus = fileSystem.getFileStatus(path4);
                PMML pmml = null;
                boolean canPublishAdditionalModelData = canPublishAdditionalModelData();
                boolean z = fileStatus.getLen() <= ((long) this.maxMessageSize);
                if (canPublishAdditionalModelData || z) {
                    FSDataInputStream open = fileSystem.open(path4);
                    Throwable th = null;
                    try {
                        try {
                            pmml = PMMLUtils.read(open);
                            if (open != null) {
                                if (0 != 0) {
                                    try {
                                        open.close();
                                    } catch (Throwable th2) {
                                        th.addSuppressed(th2);
                                    }
                                } else {
                                    open.close();
                                }
                            }
                        } catch (Throwable th3) {
                            th = th3;
                            throw th3;
                        }
                    } catch (Throwable th4) {
                        if (open != null) {
                            if (th != null) {
                                try {
                                    open.close();
                                } catch (Throwable th5) {
                                    th.addSuppressed(th5);
                                }
                            } else {
                                open.close();
                            }
                        }
                        throw th4;
                    }
                }
                if (z) {
                    topicProducer.send("MODEL", PMMLUtils.toString(pmml));
                } else {
                    topicProducer.send("MODEL-REF", fileSystem.makeQualified(path4).toString());
                }
                if (canPublishAdditionalModelData) {
                    publishAdditionalModelData(javaSparkContext, pmml, values, values2, path3, topicProducer);
                }
            }
        }
        if (values != null) {
            values.unpersist();
        }
        if (values2 != null) {
            values2.unpersist();
        }
    }

    private Path findBestCandidatePath(JavaSparkContext javaSparkContext, JavaRDD<M> javaRDD, JavaRDD<M> javaRDD2, List<List<?>> list, Path path) throws IOException {
        FileSystem fileSystem = null;
        Path path2 = null;
        double d = Double.NEGATIVE_INFINITY;
        for (Map.Entry entry : ((Map) ExecUtils.collectInParallel(this.candidates, Math.min(this.evalParallelism, this.candidates), true, num -> {
            return buildAndEval(num.intValue(), list, javaSparkContext, javaRDD, javaRDD2, path);
        }, Collectors.toMap((v0) -> {
            return v0.getFirst();
        }, (v0) -> {
            return v0.getSecond();
        }))).entrySet()) {
            Path path3 = (Path) entry.getKey();
            if (fileSystem == null) {
                fileSystem = FileSystem.get(path3.toUri(), javaSparkContext.hadoopConfiguration());
            }
            if (path3 != null && fileSystem.exists(path3)) {
                Double d2 = (Double) entry.getValue();
                if (Double.isNaN(d2.doubleValue())) {
                    if (path2 == null && this.testFraction == 0.0d) {
                        path2 = path3;
                    }
                } else if (d2.doubleValue() > d) {
                    log.info("Best eval / model path is now {} / {}", d2, path3);
                    d = d2.doubleValue();
                    path2 = path3;
                }
            }
        }
        if (this.threshold != null && d < this.threshold.doubleValue()) {
            log.info("Best model at {} had eval {}, but did not exceed threshold {}; discarding model", new Object[]{path2, Double.valueOf(d), this.threshold});
            path2 = null;
        }
        return path2;
    }

    private Pair<Path, Double> buildAndEval(int i, List<List<?>> list, JavaSparkContext javaSparkContext, JavaRDD<M> javaRDD, JavaRDD<M> javaRDD2, Path path) {
        List<?> list2 = list.get(i % list.size());
        Path path2 = new Path(path, Integer.toString(i));
        log.info("Building candidate {} with params {}", Integer.valueOf(i), list2);
        Pair<JavaRDD<M>, JavaRDD<M>> splitTrainTest = splitTrainTest(javaRDD, javaRDD2);
        JavaRDD<M> javaRDD3 = (JavaRDD) splitTrainTest.getFirst();
        JavaRDD<M> javaRDD4 = (JavaRDD) splitTrainTest.getSecond();
        Double valueOf = Double.valueOf(Double.NaN);
        if (empty(javaRDD3)) {
            log.info("No train data to build a model");
        } else {
            PMML buildModel = buildModel(javaSparkContext, javaRDD3, list2, path2);
            if (buildModel == null) {
                log.info("Unable to build a model");
            } else {
                Path path3 = new Path(path2, MODEL_FILE_NAME);
                log.info("Writing model to {}", path3);
                try {
                    FileSystem fileSystem = FileSystem.get(path2.toUri(), javaSparkContext.hadoopConfiguration());
                    fileSystem.mkdirs(path2);
                    FSDataOutputStream create = fileSystem.create(path3);
                    Throwable th = null;
                    try {
                        try {
                            PMMLUtils.write(buildModel, create);
                            if (create != null) {
                                if (0 != 0) {
                                    try {
                                        create.close();
                                    } catch (Throwable th2) {
                                        th.addSuppressed(th2);
                                    }
                                } else {
                                    create.close();
                                }
                            }
                            if (empty(javaRDD4)) {
                                log.info("No test data available to evaluate model");
                            } else {
                                log.info("Evaluating model");
                                valueOf = Double.valueOf(evaluate(javaSparkContext, buildModel, path2, javaRDD4, javaRDD3));
                            }
                        } finally {
                        }
                    } finally {
                    }
                } catch (IOException e) {
                    throw new IllegalStateException(e);
                }
            }
        }
        log.info("Model eval for params {}: {} ({})", new Object[]{list2, valueOf, path2});
        return new Pair<>(path2, valueOf);
    }

    private Pair<JavaRDD<M>, JavaRDD<M>> splitTrainTest(JavaRDD<M> javaRDD, JavaRDD<M> javaRDD2) {
        Objects.requireNonNull(javaRDD);
        if (this.testFraction <= 0.0d) {
            return new Pair<>(javaRDD2 == null ? javaRDD : javaRDD.union(javaRDD2), (Object) null);
        }
        if (this.testFraction >= 1.0d) {
            return new Pair<>(javaRDD2, javaRDD);
        }
        if (empty(javaRDD)) {
            return new Pair<>(javaRDD2, (Object) null);
        }
        Pair<JavaRDD<M>, JavaRDD<M>> splitNewDataToTrainTest = splitNewDataToTrainTest(javaRDD);
        JavaRDD javaRDD3 = (JavaRDD) splitNewDataToTrainTest.getFirst();
        return new Pair<>(javaRDD2 == null ? javaRDD3 : javaRDD3.union(javaRDD2), splitNewDataToTrainTest.getSecond());
    }

    private static boolean empty(JavaRDD<?> javaRDD) {
        return javaRDD == null || javaRDD.isEmpty();
    }

    protected Pair<JavaRDD<M>, JavaRDD<M>> splitNewDataToTrainTest(JavaRDD<M> javaRDD) {
        RDD[] randomSplit = javaRDD.rdd().randomSplit(new double[]{1.0d - this.testFraction, this.testFraction}, RandomManager.getRandom().nextLong());
        return new Pair<>(javaRDD.wrapRDD(randomSplit[0]), javaRDD.wrapRDD(randomSplit[1]));
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 2056017689:
                if (implMethodName.equals("lambda$runUpdate$305dc787$1")) {
                    z = false;
                    break;
                }
                break;
            case 2056017690:
                if (implMethodName.equals("lambda$runUpdate$305dc787$2")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/VoidFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)V") && serializedLambda.getImplClass().equals("com/cloudera/oryx/ml/MLUpdate") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/Iterator;)V")) {
                    return it -> {
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/VoidFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)V") && serializedLambda.getImplClass().equals("com/cloudera/oryx/ml/MLUpdate") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/Iterator;)V")) {
                    return it2 -> {
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
