package com.cloudera.oryx.app.batch.mllib.kmeans;

import com.cloudera.oryx.app.common.fn.MLFunctions;
import com.cloudera.oryx.app.kmeans.KMeansPMMLUtils;
import com.cloudera.oryx.app.kmeans.KMeansUtils;
import com.cloudera.oryx.app.pmml.AppPMMLUtils;
import com.cloudera.oryx.app.schema.CategoricalValueEncodings;
import com.cloudera.oryx.app.schema.InputSchema;
import com.cloudera.oryx.common.pmml.PMMLUtils;
import com.cloudera.oryx.ml.MLUpdate;
import com.cloudera.oryx.ml.param.HyperParamValues;
import com.cloudera.oryx.ml.param.HyperParams;
import com.google.common.base.Preconditions;
import com.typesafe.config.Config;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.apache.hadoop.fs.Path;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.clustering.KMeans;
import org.apache.spark.mllib.clustering.KMeansModel;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.dmg.pmml.ComparisonMeasure;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMML;
import org.dmg.pmml.SquaredEuclidean;
import org.dmg.pmml.clustering.Cluster;
import org.dmg.pmml.clustering.ClusteringField;
import org.dmg.pmml.clustering.ClusteringModel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/cloudera/oryx/app/batch/mllib/kmeans/KMeansUpdate.class */
public final class KMeansUpdate extends MLUpdate<String> {
    private static final Logger log = LoggerFactory.getLogger(KMeansUpdate.class);
    private final String initializationStrategy;
    private final int maxIterations;
    private final int numberOfRuns;
    private final List<HyperParamValues<?>> hyperParamValues;
    private final InputSchema inputSchema;
    private final KMeansEvalStrategy evaluationStrategy;

    public KMeansUpdate(Config config) {
        super(config);
        this.initializationStrategy = config.getString("oryx.kmeans.initialization-strategy");
        this.evaluationStrategy = (KMeansEvalStrategy) Enum.valueOf(KMeansEvalStrategy.class, config.getString("oryx.kmeans.evaluation-strategy"));
        this.numberOfRuns = config.getInt("oryx.kmeans.runs");
        this.maxIterations = config.getInt("oryx.kmeans.iterations");
        this.hyperParamValues = new ArrayList();
        this.hyperParamValues.add(HyperParams.fromConfig(config, "oryx.kmeans.hyperparams.k"));
        this.inputSchema = new InputSchema(config);
        Preconditions.checkArgument(this.maxIterations > 0);
        Preconditions.checkArgument(this.numberOfRuns > 0);
        Preconditions.checkArgument(this.initializationStrategy.equals(KMeans.K_MEANS_PARALLEL()) || this.initializationStrategy.equals(KMeans.RANDOM()));
        Preconditions.checkArgument(!this.inputSchema.hasTarget());
        for (int i = 0; i < this.inputSchema.getNumFeatures(); i++) {
            Preconditions.checkArgument(!this.inputSchema.isCategorical(i));
        }
    }

    public List<HyperParamValues<?>> getHyperParameterValues() {
        return this.hyperParamValues;
    }

    public PMML buildModel(JavaSparkContext javaSparkContext, JavaRDD<String> javaRDD, List<?> list, Path path) {
        int intValue = ((Integer) list.get(0)).intValue();
        Preconditions.checkArgument(intValue > 1);
        log.info("Building KMeans Model with {} clusters", Integer.valueOf(intValue));
        JavaRDD<Vector> parsedToVectorRDD = parsedToVectorRDD(javaRDD.map(MLFunctions.PARSE_FN));
        KMeansModel train = KMeans.train(parsedToVectorRDD.rdd(), intValue, this.maxIterations, this.numberOfRuns, this.initializationStrategy);
        return kMeansModelToPMML(train, fetchClusterCountsFromModel(parsedToVectorRDD, train));
    }

    private static Map<Integer, Long> fetchClusterCountsFromModel(JavaRDD<Vector> javaRDD, KMeansModel kMeansModel) {
        kMeansModel.getClass();
        return javaRDD.map(kMeansModel::predict).countByValue();
    }

    public double evaluate(JavaSparkContext javaSparkContext, PMML pmml, Path path, JavaRDD<String> javaRDD, JavaRDD<String> javaRDD2) {
        double d;
        KMeansPMMLUtils.validatePMMLVsSchema(pmml, this.inputSchema);
        JavaRDD<Vector> parsedToVectorRDD = parsedToVectorRDD(javaRDD2.union(javaRDD).map(MLFunctions.PARSE_FN));
        List read = KMeansPMMLUtils.read(pmml);
        log.info("Evaluation Strategy is {}", this.evaluationStrategy);
        switch (this.evaluationStrategy) {
            case DAVIES_BOULDIN:
                double evaluate = new DaviesBouldinIndex(read).evaluate(parsedToVectorRDD);
                log.info("Davies-Bouldin index: {}", Double.valueOf(evaluate));
                d = -evaluate;
                break;
            case DUNN:
                double evaluate2 = new DunnIndex(read).evaluate(parsedToVectorRDD);
                log.info("Dunn index: {}", Double.valueOf(evaluate2));
                d = evaluate2;
                break;
            case SILHOUETTE:
                double evaluate3 = new SilhouetteCoefficient(read).evaluate(parsedToVectorRDD);
                log.info("Silhouette Coefficient: {}", Double.valueOf(evaluate3));
                d = evaluate3;
                break;
            case SSE:
                double evaluate4 = new SumSquaredError(read).evaluate(parsedToVectorRDD);
                log.info("Sum squared error: {}", Double.valueOf(evaluate4));
                d = -evaluate4;
                break;
            default:
                throw new IllegalArgumentException("Unknown evaluation strategy " + this.evaluationStrategy);
        }
        return d;
    }

    private PMML kMeansModelToPMML(KMeansModel kMeansModel, Map<Integer, Long> map) {
        Model pmmlClusteringModel = pmmlClusteringModel(kMeansModel, map);
        PMML buildSkeletonPMML = PMMLUtils.buildSkeletonPMML();
        buildSkeletonPMML.setDataDictionary(AppPMMLUtils.buildDataDictionary(this.inputSchema, (CategoricalValueEncodings) null));
        buildSkeletonPMML.addModels(new Model[]{pmmlClusteringModel});
        return buildSkeletonPMML;
    }

    private ClusteringModel pmmlClusteringModel(KMeansModel kMeansModel, Map<Integer, Long> map) {
        Vector[] clusterCenters = kMeansModel.clusterCenters();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.inputSchema.getNumFeatures(); i++) {
            if (this.inputSchema.isActive(i)) {
                arrayList.add(new ClusteringField(FieldName.create((String) this.inputSchema.getFeatureNames().get(i))).setCenterField(ClusteringField.CenterField.TRUE));
            }
        }
        ArrayList arrayList2 = new ArrayList(clusterCenters.length);
        for (int i2 = 0; i2 < clusterCenters.length; i2++) {
            arrayList2.add(new Cluster().setId(Integer.toString(i2)).setSize(Integer.valueOf(map.get(Integer.valueOf(i2)).intValue())).setArray(AppPMMLUtils.toArray(clusterCenters[i2].toArray())));
        }
        return new ClusteringModel(MiningFunction.CLUSTERING, ClusteringModel.ModelClass.CENTER_BASED, arrayList2.size(), AppPMMLUtils.buildMiningSchema(this.inputSchema), new ComparisonMeasure(ComparisonMeasure.Kind.DISTANCE).setMeasure(new SquaredEuclidean()), arrayList, arrayList2);
    }

    private JavaRDD<Vector> parsedToVectorRDD(JavaRDD<String[]> javaRDD) {
        return javaRDD.map(strArr -> {
            try {
                return Vectors.dense(KMeansUtils.featuresFromTokens(strArr, this.inputSchema));
            } catch (ArrayIndexOutOfBoundsException | NumberFormatException e) {
                log.warn("Bad input: {}", Arrays.toString(strArr));
                throw e;
            }
        });
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1105536192:
                if (implMethodName.equals("lambda$parsedToVectorRDD$f8310de6$1")) {
                    z = true;
                    break;
                }
                break;
            case -318720807:
                if (implMethodName.equals("predict")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 5 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/spark/mllib/clustering/KMeansModel") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/spark/mllib/linalg/Vector;)I")) {
                    KMeansModel kMeansModel = (KMeansModel) serializedLambda.getCapturedArg(0);
                    return kMeansModel::predict;
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/cloudera/oryx/app/batch/mllib/kmeans/KMeansUpdate") && serializedLambda.getImplMethodSignature().equals("([Ljava/lang/String;)Lorg/apache/spark/mllib/linalg/Vector;")) {
                    KMeansUpdate kMeansUpdate = (KMeansUpdate) serializedLambda.getCapturedArg(0);
                    return strArr -> {
                        try {
                            return Vectors.dense(KMeansUtils.featuresFromTokens(strArr, this.inputSchema));
                        } catch (ArrayIndexOutOfBoundsException | NumberFormatException e) {
                            log.warn("Bad input: {}", Arrays.toString(strArr));
                            throw e;
                        }
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
