package edu.columbia.tjw.item.spark;

import edu.columbia.tjw.item.ItemParameters;
import edu.columbia.tjw.item.ItemSettings;
import edu.columbia.tjw.item.base.SimpleRegressor;
import edu.columbia.tjw.item.base.SimpleStatus;
import edu.columbia.tjw.item.base.StandardCurveType;
import edu.columbia.tjw.item.base.raw.RawFittingGrid;
import edu.columbia.tjw.item.data.ItemStatusGrid;
import edu.columbia.tjw.item.fit.FitResult;
import edu.columbia.tjw.item.fit.GradientResult;
import edu.columbia.tjw.item.fit.ItemFitter;
import edu.columbia.tjw.item.optimize.ConvergenceException;
import edu.columbia.tjw.item.util.random.RandomTool;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.classification.ProbabilisticClassifier;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/* loaded from: input_file:edu/columbia/tjw/item/spark/ItemClassifier.class */
public class ItemClassifier extends ProbabilisticClassifier<Vector, ItemClassifier, ItemClassificationModel> implements Cloneable {
    private static final long serialVersionUID = 8990051165227355116L;
    private final ItemClassifierSettings _settings;
    private final ItemParameters<SimpleStatus, SimpleRegressor, StandardCurveType> _startingParams;
    private String _uid;

    public ItemClassifier(ItemClassifierSettings itemClassifierSettings) {
        this(itemClassifierSettings, null);
    }

    public ItemClassifier(ItemClassifierSettings itemClassifierSettings, ItemParameters<SimpleStatus, SimpleRegressor, StandardCurveType> itemParameters) {
        if (null == itemClassifierSettings) {
            throw new NullPointerException("Settings cannot be null.");
        }
        this._settings = itemClassifierSettings;
        this._startingParams = itemParameters;
    }

    /* renamed from: copy, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public ItemClassifier m8copy(ParamMap paramMap) {
        return defaultCopy(paramMap);
    }

    public ItemClassifierSettings getSettings() {
        return this._settings;
    }

    public RawFittingGrid<SimpleStatus, SimpleRegressor> generateMaterializedGrid(Dataset<?> dataset) {
        return new RawFittingGrid<>(generateFitter(dataset).getGrid());
    }

    private ItemStatusGrid<SimpleStatus, SimpleRegressor> generateGrid(Dataset<?> dataset) {
        return new SparkGridAdapter(dataset, getLabelCol(), getFeaturesCol(), this._settings.getRegressors(), this._settings.getFromStatus(), this._settings.getRegressorFamily());
    }

    private ItemFitter<SimpleStatus, SimpleRegressor, StandardCurveType> generateFitter(Dataset<?> dataset) {
        return new ItemFitter<>(this._settings.getFactory(), this._settings.getRegressorFamily(), this._settings.getFromStatus(), generateGrid(dataset), this._settings.getSettings());
    }

    public static Dataset<Row> prepareData(Dataset<?> dataset, ItemClassifierSettings itemClassifierSettings, String str) {
        List<SimpleRegressor> regressors = itemClassifierSettings.getRegressors();
        String[] strArr = new String[regressors.size()];
        int i = 0;
        Iterator<SimpleRegressor> it = regressors.iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            strArr[i2] = it.next().name();
        }
        VectorAssembler vectorAssembler = new VectorAssembler();
        vectorAssembler.setInputCols(strArr);
        vectorAssembler.setOutputCol(str);
        return vectorAssembler.transform(dataset);
    }

    public static ItemClassifierSettings prepareSettings(Dataset<?> dataset, String str, List<String> list, Set<String> set, int i) {
        return prepareSettings(dataset, str, list, set, i, new ItemSettings());
    }

    public static ItemClassifierSettings prepareSettings(Dataset<?> dataset, String str, List<String> list, Set<String> set, int i, ItemSettings itemSettings) {
        Iterator localIterator = dataset.select(str, new String[0]).distinct().toLocalIterator();
        TreeMap treeMap = new TreeMap();
        while (localIterator.hasNext()) {
            Object obj = ((Row) localIterator.next()).get(0);
            if (null != obj) {
                Integer valueOf = Integer.valueOf(((Number) obj).intValue());
                treeMap.putIfAbsent(valueOf, 0);
                treeMap.put(valueOf, Integer.valueOf(((Integer) treeMap.get(valueOf)).intValue() + 1));
            }
        }
        int i2 = -1;
        int i3 = 0;
        for (Map.Entry entry : treeMap.entrySet()) {
            int intValue = ((Integer) entry.getValue()).intValue();
            if (intValue > i3) {
                i2 = ((Integer) entry.getKey()).intValue();
                i3 = intValue;
            }
        }
        ArrayList arrayList = new ArrayList();
        int i4 = -1;
        for (Integer num : treeMap.keySet()) {
            arrayList.add(num.toString());
            if (num.intValue() == i2) {
                i4 = arrayList.size();
            }
        }
        SimpleStatus fromOrdinal = SimpleStatus.generateFamily(arrayList).getFromOrdinal(i4);
        HashSet hashSet = new HashSet(list);
        if (hashSet.size() != list.size()) {
            throw new RuntimeException("Non distinct features: " + list.size());
        }
        if (hashSet.containsAll(set)) {
            return new ItemClassifierSettings(itemSettings, fromOrdinal, i, list, set);
        }
        throw new RuntimeException("All curve regressors must also be in the feature list.");
    }

    public GradientResult computeGradients(Dataset<?> dataset, ItemClassificationModel itemClassificationModel) {
        return generateFitter(dataset).getCalculator().computeGradients(itemClassificationModel.getParams());
    }

    public FitResult<SimpleStatus, SimpleRegressor, StandardCurveType> computeFitResult(Dataset<?> dataset, ItemClassificationModel itemClassificationModel) {
        return generateFitter(dataset).getCalculator().computeFitResult(itemClassificationModel.getParams(), (FitResult) null);
    }

    public ItemClassificationModel runAnnealing(Dataset<?> dataset, ItemClassificationModel itemClassificationModel) {
        ItemFitter<SimpleStatus, SimpleRegressor, StandardCurveType> generateFitter = generateFitter(dataset);
        try {
            generateFitter.pushParameters("PrevModel", itemClassificationModel.getParams());
            generateFitter.runAnnealingByEntry(this._settings.getCurveRegressors(), true);
            return new ItemClassificationModel(generateFitter.getChain().getLatestResults(), this._settings);
        } catch (ConvergenceException e) {
            throw new RuntimeException((Throwable) e);
        }
    }

    public ItemClassificationModel retrainModel(Dataset<?> dataset, ItemClassificationModel itemClassificationModel) {
        ItemFitter<SimpleStatus, SimpleRegressor, StandardCurveType> generateFitter = generateFitter(dataset);
        generateFitter.pushParameters("PrevModel", itemClassificationModel.getParams());
        try {
            generateFitter.fitModel(this._settings.getNonCurveRegressors(), this._settings.getCurveRegressors(), this._settings.getMaxParamCount() - generateFitter.getBestParameters().getEffectiveParamCount(), false);
            return new ItemClassificationModel(generateFitter.getChain().getLatestResults(), this._settings);
        } catch (ConvergenceException e) {
            throw new RuntimeException((Throwable) e);
        }
    }

    public ItemClassificationModel train(Dataset<?> dataset) {
        ItemFitter<SimpleStatus, SimpleRegressor, StandardCurveType> generateFitter = generateFitter(dataset);
        if (null != this._startingParams) {
            generateFitter.pushParameters("InitialParams", this._startingParams);
        }
        try {
            generateFitter.fitModel(this._settings.getNonCurveRegressors(), this._settings.getCurveRegressors(), this._settings.getMaxParamCount() - generateFitter.getBestParameters().getEffectiveParamCount(), false);
            return new ItemClassificationModel(generateFitter.getChain().getLatestResults(), this._settings);
        } catch (ConvergenceException e) {
            throw new RuntimeException((Throwable) e);
        }
    }

    public synchronized String uid() {
        if (null == this._uid) {
            this._uid = RandomTool.randomString(64);
        }
        return this._uid;
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ PredictionModel m5train(Dataset dataset) {
        return train((Dataset<?>) dataset);
    }
}
