package com.feedzai.openml.provider.lightgbm;

import com.feedzai.openml.data.Dataset;
import com.feedzai.openml.data.schema.CategoricalValueSchema;
import com.feedzai.openml.data.schema.DatasetSchema;
import com.feedzai.openml.data.schema.FieldSchema;
import com.feedzai.openml.data.schema.StringValueSchema;
import com.feedzai.openml.model.MachineLearningModel;
import com.feedzai.openml.provider.descriptor.fieldtype.ParamValidationError;
import com.feedzai.openml.provider.exception.ModelLoadingException;
import com.feedzai.openml.provider.exception.ModelTrainingException;
import com.feedzai.openml.provider.model.MachineLearningModelTrainer;
import com.feedzai.openml.util.load.LoadSchemaUtils;
import com.feedzai.openml.util.validate.ValidationUtils;
import com.google.common.collect.ImmutableList;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.nio.file.attribute.FileAttribute;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/feedzai/openml/provider/lightgbm/LightGBMModelCreator.class */
public class LightGBMModelCreator implements MachineLearningModelTrainer<LightGBMBinaryClassificationModel> {
    private static final Logger logger = LoggerFactory.getLogger(LightGBMModelCreator.class);
    public static final String MODEL_BINARY_RESOURCE_FILE_NAME = "LightGBM_model.txt";
    static final String ERROR_MSG_CANNOT_LOAD_NON_BINARY_LIGHTGBM_MODEL = "Cannot load a non-binary LightGBM model.";
    static final String ERROR_MSG_SCHEMA_HAS_STRING_FIELDS = "Schema has string fields.";
    static final String ERROR_MSG_NON_BINARY_TARGET = "Target field must be binary.";
    static final String ERROR_MSG_PREFIX_CANNOT_FIND_MODEL_FILE = "Cannot find model file";
    static final String ERROR_MSG_SCHEMA_WITH_WRONG_PREDICTIVE_FIELDS_SIZE = "Received schema with wrong number of predictive fields.";
    static final String ERROR_MSG_SCHEMA_WITH_WRONG_PREDICTIVE_FIELD_NAMES = "Received schema with wrong predictive field names.";
    static final String ERROR_MSG_RANDOM_FOREST_REQUIRES_BAGGING = "Random Forest Boosting type requires bagging. Please see bagging parameters.";

    public LightGBMModelCreator() {
        LightGBMUtils.loadLibs();
    }

    public LightGBMBinaryClassificationModel fit(Dataset dataset, Random random, Map<String, String> map) {
        try {
            Path createTempFile = Files.createTempFile("pulse_lightgbm_model_", null, new FileAttribute[0]);
            try {
                try {
                    LightGBMBinaryClassificationModelTrainer.fit(dataset, map, createTempFile);
                    return m8loadModel(createTempFile, dataset.getSchema());
                } catch (Exception e) {
                    logger.error("Could not train the model.");
                    throw new RuntimeException(e);
                }
            } finally {
                try {
                    Files.delete(createTempFile);
                } catch (IOException e2) {
                    logger.error("Could not delete temporary model file: {}", e2.getMessage());
                }
            }
        } catch (IOException e3) {
            logger.error("Could not create temporary file.");
            throw new RuntimeException(e3);
        }
    }

    public List<ParamValidationError> validateForFit(Path path, DatasetSchema datasetSchema, Map<String, String> map) {
        ImmutableList.Builder builder = ImmutableList.builder();
        builder.addAll(ValidationUtils.validateModelPathToTrain(path)).addAll(validateSchema(datasetSchema)).addAll(validateFitParams(map));
        return builder.build();
    }

    private List<ParamValidationError> validateSchema(DatasetSchema datasetSchema) {
        ImmutableList.Builder builder = ImmutableList.builder();
        Optional validateCategoricalSchema = ValidationUtils.validateCategoricalSchema(datasetSchema);
        builder.getClass();
        validateCategoricalSchema.ifPresent((v1) -> {
            r1.add(v1);
        });
        if (schemaHasStringFields(datasetSchema)) {
            builder.add(new ParamValidationError(ERROR_MSG_SCHEMA_HAS_STRING_FIELDS));
        }
        if (getNumTargetClasses(datasetSchema).orElse(-1).intValue() != 2) {
            builder.add(new ParamValidationError(ERROR_MSG_NON_BINARY_TARGET));
        }
        return builder.build();
    }

    private List<ParamValidationError> validateFitParams(Map<String, String> map) {
        ImmutableList.Builder builder = ImmutableList.builder();
        builder.addAll(ValidationUtils.checkParams(LightGBMAlgorithms.LIGHTGBM_BINARY_CLASSIFIER.getAlgorithmDescriptor(), map));
        if (map.get(LightGBMDescriptorUtil.BOOSTING_TYPE_PARAMETER_NAME).equals("rf") && baggingDisabled(map)) {
            logger.warn("RF requires bagging. Set bagging fraction < 1 and bagging frequency > 0.");
            builder.add(new ParamValidationError(ERROR_MSG_RANDOM_FOREST_REQUIRES_BAGGING));
        }
        return builder.build();
    }

    private boolean baggingDisabled(Map<String, String> map) {
        return Math.abs(Double.parseDouble(map.get(LightGBMDescriptorUtil.BAGGING_FREQUENCY_PARAMETER_NAME)) - 0.0d) < 1.0E-60d || Math.abs(1.0d - Double.parseDouble(map.get(LightGBMDescriptorUtil.BAGGING_FRACTION_PARAMETER_NAME))) < 1.0E-60d;
    }

    /* renamed from: loadModel, reason: merged with bridge method [inline-methods] */
    public LightGBMBinaryClassificationModel m8loadModel(Path path, DatasetSchema datasetSchema) throws ModelLoadingException {
        Path path2 = getPath(path);
        logger.info("Loading LightGBM model from " + path2.toAbsolutePath());
        LightGBMBinaryClassificationModel lightGBMBinaryClassificationModel = new LightGBMBinaryClassificationModel(path2, datasetSchema);
        if (!lightGBMBinaryClassificationModel.isModelBinary()) {
            throw new ModelLoadingException(ERROR_MSG_CANNOT_LOAD_NON_BINARY_LIGHTGBM_MODEL);
        }
        if (lightGBMBinaryClassificationModel.getBoosterNumFeatures() != datasetSchema.getPredictiveFields().size()) {
            throw new ModelLoadingException(ERROR_MSG_SCHEMA_WITH_WRONG_PREDICTIVE_FIELDS_SIZE);
        }
        if (schemaMatchAllFeatures(datasetSchema, lightGBMBinaryClassificationModel.getBoosterFeatureNames())) {
            return lightGBMBinaryClassificationModel;
        }
        throw new ModelLoadingException(ERROR_MSG_SCHEMA_WITH_WRONG_PREDICTIVE_FIELD_NAMES);
    }

    private Path getPath(Path path) {
        return Files.isDirectory(path, new LinkOption[0]) ? path.resolve(MODEL_BINARY_RESOURCE_FILE_NAME) : path;
    }

    public List<ParamValidationError> validateForLoad(Path path, DatasetSchema datasetSchema, Map<String, String> map) {
        ImmutableList.Builder builder = ImmutableList.builder();
        builder.addAll(ValidationUtils.baseLoadValidations(datasetSchema, map));
        Optional validateCategoricalSchema = ValidationUtils.validateCategoricalSchema(datasetSchema);
        builder.getClass();
        validateCategoricalSchema.ifPresent((v1) -> {
            r1.add(v1);
        });
        if (schemaHasStringFields(datasetSchema)) {
            builder.add(new ParamValidationError(ERROR_MSG_SCHEMA_HAS_STRING_FIELDS));
        }
        if (getNumTargetClasses(datasetSchema).orElse(-1).intValue() != 2) {
            builder.add(new ParamValidationError(ERROR_MSG_NON_BINARY_TARGET));
        }
        if (!Files.exists(path, new LinkOption[0])) {
            logger.error("Cannot find model file in filesystem ({}).", path);
            builder.add(new ParamValidationError("Cannot find model file in filesystem."));
            return builder.build();
        }
        if (Files.isDirectory(path, new LinkOption[0]) && !Files.exists(path.resolve(MODEL_BINARY_RESOURCE_FILE_NAME), new LinkOption[0])) {
            logger.error("Error loading model from directory ({}). File {} not found.", path, MODEL_BINARY_RESOURCE_FILE_NAME);
            builder.add(new ParamValidationError(String.format("%s %s inside folder.", ERROR_MSG_PREFIX_CANNOT_FIND_MODEL_FILE, MODEL_BINARY_RESOURCE_FILE_NAME)));
        }
        return builder.build();
    }

    public DatasetSchema loadSchema(Path path) throws ModelLoadingException {
        return LoadSchemaUtils.datasetSchemaFromJson(path);
    }

    private static boolean schemaHasStringFields(DatasetSchema datasetSchema) {
        return datasetSchema.getFieldSchemas().stream().anyMatch(fieldSchema -> {
            return fieldSchema.getValueSchema() instanceof StringValueSchema;
        });
    }

    private static Optional<Integer> getNumTargetClasses(DatasetSchema datasetSchema) {
        if (!datasetSchema.getTargetFieldSchema().isPresent()) {
            return Optional.empty();
        }
        CategoricalValueSchema valueSchema = ((FieldSchema) datasetSchema.getTargetFieldSchema().get()).getValueSchema();
        return valueSchema instanceof CategoricalValueSchema ? Optional.of(Integer.valueOf(valueSchema.getNominalValues().size())) : Optional.empty();
    }

    private static String[] getFeatureNamesFrom(DatasetSchema datasetSchema) {
        return (String[]) datasetSchema.getPredictiveFields().stream().map((v0) -> {
            return v0.getFieldName();
        }).map(str -> {
            return str.replace(" ", "_");
        }).toArray(i -> {
            return new String[i];
        });
    }

    private boolean schemaMatchAllFeatures(DatasetSchema datasetSchema, String[] strArr) {
        String[] featureNamesFrom = getFeatureNamesFrom(datasetSchema);
        boolean z = true;
        for (int i = 0; i < strArr.length; i++) {
            if (!featureNamesFrom[i].equals(strArr[i])) {
                logger.error("Schema with wrong predictive field name at index {}: '{}' Expected: '{}'", new Object[]{Integer.valueOf(i), featureNamesFrom[i], strArr[i]});
                z = false;
            }
        }
        if (!z) {
            logger.error("Schema with wrong predictive field names: '{}' - Expected: '{}'", String.join(", ", featureNamesFrom), String.join(", ", strArr));
        }
        return z;
    }

    /* renamed from: fit, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ MachineLearningModel m7fit(Dataset dataset, Random random, Map map) throws ModelTrainingException {
        return fit(dataset, random, (Map<String, String>) map);
    }
}
