package edu.columbia.tjw.item.spark;

import java.util.Arrays;
import org.apache.spark.ml.classification.ClassificationModel;
import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel;
import org.apache.spark.ml.classification.ProbabilisticClassificationModel;
import org.apache.spark.ml.classification.ProbabilisticClassifier;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.functions;

/* loaded from: input_file:edu/columbia/tjw/item/spark/ClassificationModelEvaluator.class */
public class ClassificationModelEvaluator {

    /* loaded from: input_file:edu/columbia/tjw/item/spark/ClassificationModelEvaluator$EntropyResult.class */
    public static final class EntropyResult {
        private final long _calcTime;
        private final long _rowCount;
        private final double _crossEntropy;
        private final double _distEntropy;

        public EntropyResult(long j, long j2, double d, double d2) {
            this._calcTime = j;
            this._rowCount = j2;
            this._crossEntropy = d;
            this._distEntropy = d2;
        }

        public long getCalcTime() {
            return this._calcTime;
        }

        public long getRowCount() {
            return this._rowCount;
        }

        public double getCrossEntropy() {
            return this._crossEntropy;
        }

        public double getDistEntropy() {
            return this._distEntropy;
        }
    }

    /* loaded from: input_file:edu/columbia/tjw/item/spark/ClassificationModelEvaluator$EvaluationResult.class */
    public static final class EvaluationResult {
        private final String _label;
        private final long _prngSeed;
        private final String _layerString;
        private final MultilayerPerceptronClassificationModel _model;
        private final long _fittingTime;
        private final EntropyResult _fittingEntropy;
        private final EntropyResult _testingEntropy;

        public EvaluationResult(String str, long j, String str2, MultilayerPerceptronClassificationModel multilayerPerceptronClassificationModel, long j2, EntropyResult entropyResult, EntropyResult entropyResult2) {
            this._label = str;
            this._prngSeed = j;
            this._layerString = str2;
            this._model = multilayerPerceptronClassificationModel;
            this._fittingTime = j2;
            this._fittingEntropy = entropyResult;
            this._testingEntropy = entropyResult2;
        }

        public String getLabel() {
            return this._label;
        }

        public MultilayerPerceptronClassificationModel getModel() {
            return this._model;
        }

        public long getFittingTime() {
            return this._fittingTime;
        }

        public EntropyResult getFittingEntropy() {
            return this._fittingEntropy;
        }

        public EntropyResult getTestingEntropy() {
            return this._testingEntropy;
        }

        public int getParamCount() {
            return this._model.weights().size();
        }

        public long getPrngSeed() {
            return this._prngSeed;
        }

        public String getLayerString() {
            return this._layerString;
        }
    }

    public static <W extends ProbabilisticClassificationModel<Vector, W>, M extends ProbabilisticClassifier<Vector, M, W>> EvaluationResult evaluate(M m, String str, Dataset<Row> dataset, Dataset<Row> dataset2, long j, int[] iArr) {
        long currentTimeMillis = System.currentTimeMillis();
        MultilayerPerceptronClassificationModel fit = m.fit(dataset);
        return new EvaluationResult(str, j, Arrays.toString(iArr), fit, System.currentTimeMillis() - currentTimeMillis, computeEntropy(dataset, fit), computeEntropy(dataset2, fit));
    }

    private static EntropyResult computeEntropy(Dataset<Row> dataset, ClassificationModel classificationModel) {
        long currentTimeMillis = System.currentTimeMillis();
        Row row = (Row) classificationModel.transform(dataset).withColumn("prob_array", functions.expr("toArrayLambda(probability)")).withColumn("prob_p", functions.expr("prob_array[0]")).withColumn("prob_c", functions.expr("prob_array[1]")).withColumn("prob_3", functions.expr("prob_array[2]")).withColumn("distEntropy", functions.expr("-1.0 * ((prob_p*log(prob_p))  + (prob_c*log(prob_c)) + (prob_3*log(prob_3)))")).withColumn("crossEntropy", functions.expr(" -1.0 * log(prob_array[next_status])")).select(new Column[]{functions.expr("count(*)"), functions.expr("sum(crossEntropy)/count(*)"), functions.expr("sum(distEntropy)/count(*)")}).toLocalIterator().next();
        return new EntropyResult(System.currentTimeMillis() - currentTimeMillis, row.getLong(0), row.getDouble(1), row.getDouble(2));
    }
}
