package hivemall.classifier.multiclass;

import hivemall.model.FeatureValue;
import hivemall.model.Margin;
import javax.annotation.Nonnull;
import org.apache.commons.cli.CommandLine;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.lucene.analysis.wikipedia.WikipediaTokenizer;
import org.apache.lucene.util.packed.PackedInts;

@Description(name = "train_multiclass_pa", value = "_FUNC_(list<string|int|bigint> features, {int|string} label [, const string options]) - Returns a relation consists of <{int|string} label, {string|int|bigint} feature, float weight>", extended = "Build a prediction model by Passive-Aggressive (PA) multiclass classifier")
/* loaded from: input_file:hivemall/classifier/multiclass/MulticlassPassiveAggressiveUDTF.class */
public class MulticlassPassiveAggressiveUDTF extends MulticlassOnlineClassifierUDTF {

    @Description(name = "train_multiclass_pa1", value = "_FUNC_(list<string|int|bigint> features, {int|string} label [, const string options]) - Returns a relation consists of <{int|string} label, {string|int|bigint} feature, float weight>", extended = "Build a prediction model by Passive-Aggressive 1 (PA-1) multiclass classifier")
    /* loaded from: input_file:hivemall/classifier/multiclass/MulticlassPassiveAggressiveUDTF$PA1.class */
    public static class PA1 extends MulticlassPassiveAggressiveUDTF {
        protected float c;

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // hivemall.LearnerBaseUDTF, hivemall.UDTFWithOptions
        public CommandLine processOptions(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
            String optionValue;
            CommandLine processOptions = super.processOptions(objectInspectorArr);
            float f = 1.0f;
            if (processOptions != null && (optionValue = processOptions.getOptionValue(WikipediaTokenizer.CATEGORY)) != null) {
                f = Float.parseFloat(optionValue);
                if (f <= PackedInts.COMPACT) {
                    throw new UDFArgumentException("Aggressiveness parameter C must be C > 0: " + f);
                }
            }
            this.c = f;
            return processOptions;
        }

        @Override // hivemall.classifier.multiclass.MulticlassPassiveAggressiveUDTF
        protected float eta(float f, float f2) {
            return Math.min(this.c, f / (2.0f * f2));
        }
    }

    @Description(name = "train_multiclass_pa2", value = "_FUNC_(list<string|int|bigint> features, {int|string} label [, const string options]) - Returns a relation consists of <{int|string} label, {string|int|bigint} feature, float weight>", extended = "Build a prediction model by Passive-Aggressive 2 (PA-2) multiclass classifier")
    /* loaded from: input_file:hivemall/classifier/multiclass/MulticlassPassiveAggressiveUDTF$PA2.class */
    public static final class PA2 extends PA1 {
        @Override // hivemall.classifier.multiclass.MulticlassPassiveAggressiveUDTF.PA1, hivemall.classifier.multiclass.MulticlassPassiveAggressiveUDTF
        protected float eta(float f, float f2) {
            return f / ((2.0f * f2) + (0.5f / this.c));
        }
    }

    @Override // hivemall.classifier.multiclass.MulticlassOnlineClassifierUDTF
    public StructObjectInspector initialize(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        int length = objectInspectorArr.length;
        if (length == 2 || length == 3) {
            return super.initialize(objectInspectorArr);
        }
        throw new UDFArgumentException("MulticlassPassiveAggressiveUDTF takes 2 or 3 arguments: List<Text|Int|BitInt> features, {Int|Text} label [, constant text options]");
    }

    @Override // hivemall.classifier.multiclass.MulticlassOnlineClassifierUDTF
    protected void train(@Nonnull FeatureValue[] featureValueArr, @Nonnull Object obj) {
        Margin margin = getMargin(featureValueArr, obj);
        float loss = loss(margin);
        if (loss > PackedInts.COMPACT) {
            float squaredNorm = squaredNorm(featureValueArr);
            if (squaredNorm == PackedInts.COMPACT) {
                return;
            }
            update(featureValueArr, eta(loss, squaredNorm), obj, margin.getMaxIncorrectLabel());
        }
    }

    protected float loss(Margin margin) {
        return 1.0f - margin.get();
    }

    protected float eta(float f, float f2) {
        return f / (2.0f * f2);
    }
}
