package com.feedzai.openml.provider.lightgbm;

import com.feedzai.openml.data.Dataset;
import com.feedzai.openml.data.Instance;
import com.feedzai.openml.data.schema.CategoricalValueSchema;
import com.feedzai.openml.data.schema.DatasetSchema;
import com.feedzai.openml.data.schema.FieldSchema;
import com.google.common.collect.ImmutableSet;
import com.microsoft.ml.lightgbm.SWIGTYPE_p_float;
import com.microsoft.ml.lightgbm.SWIGTYPE_p_int;
import com.microsoft.ml.lightgbm.SWIGTYPE_p_void;
import com.microsoft.ml.lightgbm.lightgbmlib;
import com.microsoft.ml.lightgbm.lightgbmlibConstants;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:com/feedzai/openml/provider/lightgbm/LightGBMBinaryClassificationModelTrainer.class */
public final class LightGBMBinaryClassificationModelTrainer {
    private static final Logger logger;
    static final long DEFAULT_TRAIN_DATA_CHUNK_INSTANCES_SIZE = 200000;
    static final /* synthetic */ boolean $assertionsDisabled;

    private LightGBMBinaryClassificationModelTrainer() {
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void fit(Dataset dataset, Map<String, String> map, Path path) {
        fit(dataset, map, path, DEFAULT_TRAIN_DATA_CHUNK_INSTANCES_SIZE);
    }

    static void fit(Dataset dataset, Map<String, String> map, Path path, long j) {
        DatasetSchema schema = dataset.getSchema();
        int size = schema.getPredictiveFields().size();
        String lightGBMTrainParamsString = getLightGBMTrainParamsString(map, schema);
        int parseInt = Integer.parseInt(map.get(LightGBMDescriptorUtil.NUM_ITERATIONS_PARAMETER_NAME));
        logger.debug("LightGBM model trainParams: {}", lightGBMTrainParamsString);
        SWIGTrainData sWIGTrainData = new SWIGTrainData(size, j, FairGBMParamParserUtil.isFairnessConstrained(map));
        SWIGTrainBooster sWIGTrainBooster = new SWIGTrainBooster();
        createTrainDataset(dataset, size, lightGBMTrainParamsString, FairGBMParamParserUtil.getConstraintGroupColumnIndex(map, schema).orElse(-1).intValue(), sWIGTrainData);
        createBoosterStructure(sWIGTrainBooster, sWIGTrainData, lightGBMTrainParamsString);
        trainBooster(sWIGTrainBooster.swigBoosterHandle, parseInt);
        saveModelFileToDisk(sWIGTrainBooster.swigBoosterHandle, path);
        sWIGTrainBooster.close();
    }

    private static List<Integer> getCategoricalFeaturesIndicesWithoutLabel(DatasetSchema datasetSchema) {
        List predictiveFields = datasetSchema.getPredictiveFields();
        int intValue = ((Integer) datasetSchema.getTargetIndex().get()).intValue();
        return (List) predictiveFields.stream().filter(fieldSchema -> {
            return fieldSchema.getValueSchema() instanceof CategoricalValueSchema;
        }).map(fieldSchema2 -> {
            int fieldIndex = fieldSchema2.getFieldIndex();
            return Integer.valueOf(fieldIndex + (fieldIndex > intValue ? -1 : 0));
        }).collect(Collectors.toList());
    }

    private static String[] getFieldNames(List<FieldSchema> list) {
        return (String[]) list.stream().map((v0) -> {
            return v0.getFieldName();
        }).toArray(i -> {
            return new String[i];
        });
    }

    private static void createTrainDataset(Dataset dataset, int i, String str, int i2, SWIGTrainData sWIGTrainData) {
        logger.info("Creating LightGBM dataset");
        logger.debug("Copying train data through SWIG.");
        copyTrainDataToSWIGArrays(dataset, sWIGTrainData, i2);
        initializeLightGBMTrainDatasetFeatures(sWIGTrainData, i, str);
        setLightGBMDatasetLabelData(sWIGTrainData);
        if (i2 != -1) {
            setLightGBMDatasetConstraintGroupData(sWIGTrainData);
        }
        setLightGBMDatasetFeatureNames(sWIGTrainData.swigDatasetHandle, dataset.getSchema());
        logger.info("Created LightGBM dataset.");
    }

    private static void initializeLightGBMTrainDatasetFeatures(SWIGTrainData sWIGTrainData, int i, String str) {
        logger.debug("Initializing LightGBM in-memory structure and setting feature data.");
        SWIGTYPE_p_int genSWIGFeatureChunkSizesArray = genSWIGFeatureChunkSizesArray(sWIGTrainData, i);
        logger.debug("Creating LGBM_Dataset from chunked data...");
        if (lightgbmlib.LGBM_DatasetCreateFromMats((int) sWIGTrainData.swigFeaturesChunkedArray.get_chunks_count(), sWIGTrainData.swigFeaturesChunkedArray.data_as_void(), lightgbmlibConstants.C_API_DTYPE_FLOAT64, genSWIGFeatureChunkSizesArray, i, 1, str, null, sWIGTrainData.swigOutDatasetHandlePtr) == -1) {
            logger.error("Could not create LightGBM dataset.");
            throw new LightGBMException();
        }
        sWIGTrainData.initSwigDatasetHandle();
        sWIGTrainData.releaseSwigTrainFeaturesChunkedArray();
        lightgbmlib.delete_intArray(genSWIGFeatureChunkSizesArray);
    }

    private static SWIGTYPE_p_int genSWIGFeatureChunkSizesArray(SWIGTrainData sWIGTrainData, int i) {
        logger.debug("Retrieving chunked data block sizes...");
        long j = sWIGTrainData.swigFeaturesChunkedArray.get_chunks_count();
        long numInstancesChunk = sWIGTrainData.getNumInstancesChunk();
        SWIGTYPE_p_int new_intArray = lightgbmlib.new_intArray(j);
        for (int i2 = 0; i2 < j - 1; i2++) {
            lightgbmlib.intArray_setitem(new_intArray, i2, (int) numInstancesChunk);
            logger.debug("FTL: chunk-size report: chunk #{} is full-chunk of size {}", Integer.valueOf(i2), Integer.valueOf((int) numInstancesChunk));
        }
        lightgbmlib.intArray_setitem(new_intArray, j - 1, ((int) sWIGTrainData.swigFeaturesChunkedArray.get_last_chunk_add_count()) / i);
        logger.debug("FTL: chunk-size report: chunk #{} is partial-chunk of size {}", Long.valueOf(j - 1), Integer.valueOf(((int) sWIGTrainData.swigFeaturesChunkedArray.get_last_chunk_add_count()) / i));
        return new_intArray;
    }

    private static void setLightGBMDatasetLabelData(SWIGTrainData sWIGTrainData) {
        long j = sWIGTrainData.swigLabelsChunkedArray.get_add_count();
        SWIGTYPE_p_float coalesceChunkedSwigTrainLabelDataArray = sWIGTrainData.coalesceChunkedSwigTrainLabelDataArray();
        logger.debug("FTL: #labels={}", Long.valueOf(j));
        logger.debug("Setting label data.");
        if (lightgbmlib.LGBM_DatasetSetField(sWIGTrainData.swigDatasetHandle, "label", lightgbmlib.float_to_voidp_ptr(coalesceChunkedSwigTrainLabelDataArray), (int) j, lightgbmlibConstants.C_API_DTYPE_FLOAT32) == -1) {
            logger.error("Could not set label.");
            throw new LightGBMException();
        }
        sWIGTrainData.destroySwigTrainLabelDataArray();
    }

    private static void setLightGBMDatasetConstraintGroupData(SWIGTrainData sWIGTrainData) {
        long j = sWIGTrainData.swigConstraintGroupChunkedArray.get_add_count();
        SWIGTYPE_p_int coalesceChunkedSwigConstraintGroupDataArray = sWIGTrainData.coalesceChunkedSwigConstraintGroupDataArray();
        logger.debug("FTL: #labels={}", Long.valueOf(j));
        logger.debug("Setting constraint group data.");
        if (lightgbmlib.LGBM_DatasetSetField(sWIGTrainData.swigDatasetHandle, "constraint_group", lightgbmlib.int_to_voidp_ptr(coalesceChunkedSwigConstraintGroupDataArray), (int) j, lightgbmlibConstants.C_API_DTYPE_INT32) == -1) {
            logger.error("Could not set constraint group data.");
            throw new LightGBMException();
        }
        sWIGTrainData.destroySwigConstraintGroupDataArray();
    }

    private static void setLightGBMDatasetFeatureNames(SWIGTYPE_p_void sWIGTYPE_p_void, DatasetSchema datasetSchema) {
        int size = datasetSchema.getPredictiveFields().size();
        String[] fieldNames = getFieldNames(datasetSchema.getPredictiveFields());
        logger.debug("featureNames {}", Arrays.toString(fieldNames));
        if (lightgbmlib.LGBM_DatasetSetFeatureNames(sWIGTYPE_p_void, fieldNames, size) == -1) {
            logger.error("Could not set feature names.");
            throw new LightGBMException();
        }
    }

    static void createBoosterStructure(SWIGTrainBooster sWIGTrainBooster, SWIGTrainData sWIGTrainData, String str) {
        logger.debug("Initializing LightGBM model structure.");
        if (lightgbmlib.LGBM_BoosterCreate(sWIGTrainData.swigDatasetHandle, str, sWIGTrainBooster.swigOutBoosterHandlePtr) == -1) {
            logger.error("LightGBM model structure creation failed.");
            throw new LightGBMException();
        }
        sWIGTrainBooster.initSwigBoosterHandle();
    }

    private static void trainBooster(SWIGTYPE_p_void sWIGTYPE_p_void, int i) {
        logger.info("Training LightGBM model.");
        SWIGTYPE_p_int new_intp = lightgbmlib.new_intp();
        int i2 = 0;
        while (true) {
            if (i2 >= i) {
                break;
            }
            try {
                logger.debug("Starting model training iteration #{}/{}.", Integer.valueOf(i2 + 1), Integer.valueOf(i));
                if (lightgbmlib.LGBM_BoosterUpdateOneIter(sWIGTYPE_p_void, new_intp) == -1) {
                    logger.error("Failed to train model!");
                    throw new LightGBMException();
                }
                if (lightgbmlib.intp_value(new_intp) == 1) {
                    logger.info("LightGBM backend signalled the end of the model train.");
                    break;
                }
                i2++;
            } finally {
                lightgbmlib.delete_intp(new_intp);
            }
        }
        logger.info("Finished model training.");
    }

    static void saveModelFileToDisk(SWIGTYPE_p_void sWIGTYPE_p_void, Path path) {
        logger.debug("Saving trained model to disk at {}.", path);
        if (lightgbmlib.LGBM_BoosterSaveModel(sWIGTYPE_p_void, 0, -1, lightgbmlib.C_API_FEATURE_IMPORTANCE_GAIN, path.toAbsolutePath().toString()) == -1) {
            logger.error("Could not save model to disk.");
            throw new LightGBMException();
        }
        logger.info("Saved model to disk");
    }

    private static void copyTrainDataToSWIGArrays(Dataset dataset, SWIGTrainData sWIGTrainData) {
        copyTrainDataToSWIGArrays(dataset, sWIGTrainData, -1);
    }

    private static void copyTrainDataToSWIGArrays(Dataset dataset, SWIGTrainData sWIGTrainData, int i) {
        DatasetSchema schema = dataset.getSchema();
        int size = schema.getFieldSchemas().size();
        int intValue = ((Integer) schema.getTargetIndex().get()).intValue();
        Iterator instances = dataset.getInstances();
        while (instances.hasNext()) {
            Instance instance = (Instance) instances.next();
            sWIGTrainData.addLabelValue((float) instance.getValue(intValue));
            if (i != -1) {
                sWIGTrainData.addConstraintGroupValue((int) instance.getValue(i));
            }
            for (int i2 = 0; i2 < size; i2++) {
                if (i2 != intValue) {
                    sWIGTrainData.addFeatureValue(instance.getValue(i2));
                }
            }
        }
        if (!$assertionsDisabled && sWIGTrainData.swigLabelsChunkedArray.get_add_count() != sWIGTrainData.swigFeaturesChunkedArray.get_add_count() / sWIGTrainData.numFeatures) {
            throw new AssertionError();
        }
        if (sWIGTrainData.fairnessConstrained && !$assertionsDisabled && sWIGTrainData.swigConstraintGroupChunkedArray.get_add_count() != sWIGTrainData.swigLabelsChunkedArray.get_add_count()) {
            throw new AssertionError();
        }
        logger.debug("Copied train data of size {} into {} chunks.", Long.valueOf(sWIGTrainData.swigLabelsChunkedArray.get_add_count()), Long.valueOf(sWIGTrainData.swigLabelsChunkedArray.get_chunks_count()));
        if (sWIGTrainData.swigLabelsChunkedArray.get_add_count() == 0) {
            logger.error("Received empty train dataset!");
            throw new IllegalArgumentException("Received empty train dataset for LightGBM!");
        }
    }

    private static String getLightGBMTrainParamsString(Map<String, String> map, DatasetSchema datasetSchema) {
        HashMap hashMap = new HashMap();
        hashMap.put("categorical_feature", StringUtils.join(getCategoricalFeaturesIndicesWithoutLabel(datasetSchema), ","));
        if (!getLightGBMObjective(map).isPresent()) {
            hashMap.put("objective", "binary");
        }
        FairGBMParamParserUtil.getConstraintGroupColumnIndexWithoutLabel(map, datasetSchema).ifPresent(num -> {
        });
        hashMap.getClass();
        map.forEach((v1, v2) -> {
            r1.putIfAbsent(v1, v2);
        });
        StringBuilder sb = new StringBuilder();
        hashMap.forEach((str, str2) -> {
            sb.append(String.format("%s=%s ", str, str2));
        });
        return sb.toString();
    }

    public static Optional<String> getLightGBMObjective(Map<String, String> map) {
        ImmutableSet of = ImmutableSet.of("objective", "objective_type", "app", "application", "loss");
        Map.Entry<String, String> orElse = map.entrySet().stream().filter(entry -> {
            return of.contains(entry.getKey());
        }).findFirst().orElse(null);
        return orElse != null ? Optional.of(orElse.getValue()) : Optional.empty();
    }

    static {
        $assertionsDisabled = !LightGBMBinaryClassificationModelTrainer.class.desiredAssertionStatus();
        logger = LoggerFactory.getLogger(LightGBMBinaryClassificationModelTrainer.class);
    }
}
