package com.feedzai.openml.provider.lightgbm;

import com.feedzai.openml.data.Instance;
import com.feedzai.openml.data.schema.DatasetSchema;
import com.feedzai.openml.provider.exception.ModelLoadingException;
import com.microsoft.ml.lightgbm.lightgbmlibJNI;
import java.nio.file.Path;
import java.util.Arrays;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:com/feedzai/openml/provider/lightgbm/LightGBMSWIG.class */
public class LightGBMSWIG {
    private static final Logger logger = LoggerFactory.getLogger(LightGBMSWIG.class);
    private final int schemaTargetIndex;
    private final int schemaNumFields;
    private int boosterNumClasses;
    private final SWIGResources swigResources;

    public LightGBMSWIG(String str, DatasetSchema datasetSchema, String str2) throws ModelLoadingException {
        this.schemaNumFields = datasetSchema.getFieldSchemas().size();
        this.schemaTargetIndex = ((Integer) datasetSchema.getTargetIndex().orElse(-1)).intValue();
        this.swigResources = new SWIGResources(str, str2);
        initBoosterNumClasses();
    }

    private void copyDataToSWIGInstance(Instance instance) {
        int i = 0;
        for (int i2 = 0; i2 < this.schemaNumFields; i2++) {
            if (i2 == this.schemaTargetIndex) {
                i = -1;
            } else {
                lightgbmlibJNI.doubleArray_setitem(this.swigResources.swigInstancePtr.longValue(), i2 + i, instance.getValue(i2));
            }
        }
    }

    public double[] getBinaryClassDistribution(Instance instance) {
        double[] dArr;
        synchronized (this.swigResources) {
            copyDataToSWIGInstance(instance);
            if (lightgbmlibJNI.LGBM_BoosterPredictForMatSingleRowFast(this.swigResources.swigFastConfigHandle.longValue(), this.swigResources.swigInstancePtr.longValue(), this.swigResources.swigOutLengthInt64Ptr.longValue(), this.swigResources.swigOutScoresPtr.longValue()) == -1) {
                throw new LightGBMException();
            }
            double doubleArray_getitem = lightgbmlibJNI.doubleArray_getitem(this.swigResources.swigOutScoresPtr.longValue(), 0L);
            logger.trace("Prediction: {}", Double.valueOf(doubleArray_getitem));
            dArr = new double[]{1.0d - doubleArray_getitem, doubleArray_getitem};
        }
        return dArr;
    }

    public double[] getFeaturesContributions(Instance instance) {
        double[] dArr;
        synchronized (this.swigResources) {
            copyDataToSWIGInstance(instance);
            if (lightgbmlibJNI.LGBM_BoosterPredictForMatSingleRowFast(this.swigResources.swigFastConfigContributionsHandle.longValue(), this.swigResources.swigInstancePtr.longValue(), this.swigResources.swigOutLengthInt64Ptr.longValue(), this.swigResources.swigOutContributionsPtr.longValue()) == -1) {
                throw new LightGBMException();
            }
            dArr = new double[this.schemaNumFields];
            for (int i = 0; i < this.schemaNumFields; i++) {
                dArr[i] = lightgbmlibJNI.doubleArray_getitem(this.swigResources.swigOutContributionsPtr.longValue(), i);
            }
            if (logger.isTraceEnabled()) {
                logger.trace("Features Contributions: {}", Arrays.toString(dArr));
            }
        }
        return dArr;
    }

    public int getBoosterNumFeatures() {
        return this.swigResources.getBoosterNumFeatures();
    }

    public String[] getBoosterFeatureNames() {
        return this.swigResources.getBoosterFeatureNames();
    }

    private void initBoosterNumClasses() throws LightGBMException {
        if (lightgbmlibJNI.LGBM_BoosterGetNumClasses(this.swigResources.swigBoosterHandle.longValue(), this.swigResources.swigOutIntPtr.longValue()) == -1) {
            throw new LightGBMException();
        }
        this.boosterNumClasses = lightgbmlibJNI.intp_value(this.swigResources.swigOutIntPtr.longValue());
    }

    public int getBoosterNumClasses() {
        return this.boosterNumClasses;
    }

    public boolean isModelBinary() {
        return this.boosterNumClasses == 1;
    }

    public int getBoosterNumIterations() {
        return this.swigResources.getBoosterNumIterations();
    }

    public void saveModelToDisk(Path path) {
        logger.info("Saving model to disk.");
        logger.debug("Saving model to disk @ {}.", path);
        if (lightgbmlibJNI.LGBM_BoosterSaveModel(this.swigResources.swigBoosterHandle.longValue(), 0, -1, lightgbmlibJNI.C_API_FEATURE_IMPORTANCE_GAIN_get(), path.toAbsolutePath().toString()) == -1) {
            logger.error("Could not save model to disk.");
            throw new LightGBMException();
        }
    }
}
