package hivemall.classifier;

import hivemall.model.FeatureValue;
import hivemall.model.IWeightValue;
import hivemall.model.PredictionResult;
import hivemall.model.WeightValue;
import hivemall.utils.math.StatsUtils;
import javax.annotation.Nonnull;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.lucene.util.packed.PackedInts;

@Description(name = "train_cw", value = "_FUNC_(list<string|int|bigint> features, int label [, const string options]) - Returns a relation consists of <string|int|bigint feature, float weight, float covar>", extended = "Build a prediction model by Confidence-Weighted (CW) binary classifier")
/* loaded from: input_file:hivemall/classifier/ConfidenceWeightedUDTF.class */
public final class ConfidenceWeightedUDTF extends BinaryOnlineClassifierUDTF {
    protected float phi;

    @Override // hivemall.classifier.BinaryOnlineClassifierUDTF
    public StructObjectInspector initialize(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        int length = objectInspectorArr.length;
        if (length != 2 && length != 3) {
            showHelp("_FUNC_ takes 2 or 3 arguments: List<String|Int|BitInt> features, Int label [, constant String options]");
        }
        return super.initialize(objectInspectorArr);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hivemall.LearnerBaseUDTF
    public boolean useCovariance() {
        return true;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hivemall.LearnerBaseUDTF, hivemall.UDTFWithOptions
    public Options getOptions() {
        Options options = super.getOptions();
        options.addOption("phi", "confidence", true, "Confidence parameter [default 1.0]");
        options.addOption("eta", "hyper_c", true, "Confidence hyperparameter eta in range (0.5, 1] [default 0.85]");
        return options;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hivemall.LearnerBaseUDTF, hivemall.UDTFWithOptions
    public CommandLine processOptions(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        CommandLine processOptions = super.processOptions(objectInspectorArr);
        float f = 1.0f;
        if (processOptions != null) {
            String optionValue = processOptions.getOptionValue("phi");
            if (optionValue == null) {
                String optionValue2 = processOptions.getOptionValue("eta");
                if (optionValue2 != null) {
                    double parseDouble = Double.parseDouble(optionValue2);
                    if (parseDouble <= 0.5d || parseDouble > 1.0d) {
                        throw new UDFArgumentException("Confidence hyperparameter eta must be in range (0.5, 1]: " + optionValue2);
                    }
                    f = (float) StatsUtils.probit(parseDouble, 5.0d);
                }
            } else {
                f = Float.parseFloat(optionValue);
            }
        }
        this.phi = f;
        return processOptions;
    }

    @Override // hivemall.classifier.BinaryOnlineClassifierUDTF
    protected void train(@Nonnull FeatureValue[] featureValueArr, int i) {
        int i2 = i > 0 ? 1 : -1;
        float gamma = getGamma(calcScoreAndVariance(featureValueArr), i2);
        if (gamma > PackedInts.COMPACT) {
            update(featureValueArr, gamma * i2, gamma);
        }
    }

    protected final float getGamma(PredictionResult predictionResult, int i) {
        float score = predictionResult.getScore() * i;
        float variance = predictionResult.getVariance();
        float sqrt = (-(1.0f + (2.0f * this.phi * score))) + ((float) Math.sqrt((r0 * r0) - ((8.0f * this.phi) * (score - (this.phi * variance)))));
        float f = 4.0f * this.phi * variance;
        return f == PackedInts.COMPACT ? PackedInts.COMPACT : sqrt / f;
    }

    @Override // hivemall.classifier.BinaryOnlineClassifierUDTF
    protected void update(@Nonnull FeatureValue[] featureValueArr, float f, float f2) {
        for (FeatureValue featureValue : featureValueArr) {
            if (featureValue != null) {
                Object feature = featureValue.getFeature();
                this.model.set(feature, getNewWeight(this.model.get(feature), featureValue.getValueAsFloat(), f, f2, this.phi));
            }
        }
    }

    private static IWeightValue getNewWeight(IWeightValue iWeightValue, float f, float f2, float f3, float f4) {
        float f5;
        float covariance;
        if (iWeightValue == null) {
            f5 = 0.0f;
            covariance = 1.0f;
        } else {
            f5 = iWeightValue.get();
            covariance = iWeightValue.getCovariance();
        }
        return new WeightValue.WeightValueWithCovar(f5 + (f2 * covariance * f), 1.0f / ((1.0f / covariance) + ((((2.0f * f3) * f4) * f) * f)));
    }
}
