package hivemall.xgboost;

import java.util.ArrayList;
import javax.annotation.Nonnull;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.Writable;

@Description(name = "xgboost_predict_triple", value = "_FUNC_(PRIMITIVE rowid, array<string|double> features, string model_id, array<string> pred_model [, string options]) - Returns a prediction result as (string rowid, string label, double probability)", extended = "select\n  rowid,\n  label,\n  avg(prob) as prob\nfrom (\n  select\n    xgboost_predict_triple(rowid, features, model_id, model) as (rowid, label, prob)\n  from\n    xgb_model l\n    LEFT OUTER JOIN xgb_input r\n) t\ngroup by rowid, label;")
/* loaded from: input_file:hivemall/xgboost/XGBoostPredictTripleUDTF.class */
public final class XGBoostPredictTripleUDTF extends XGBoostOnlinePredictUDTF {
    public XGBoostPredictTripleUDTF() {
        super(new Object[3]);
    }

    @Override // hivemall.xgboost.XGBoostOnlinePredictUDTF
    protected StructObjectInspector getReturnOI(@Nonnull PrimitiveObjectInspector primitiveObjectInspector) {
        ArrayList arrayList = new ArrayList(3);
        ArrayList arrayList2 = new ArrayList(3);
        arrayList.add("rowid");
        arrayList2.add(PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(primitiveObjectInspector.getPrimitiveCategory()));
        arrayList.add("label");
        arrayList2.add(PrimitiveObjectInspectorFactory.javaIntObjectInspector);
        arrayList.add("proba");
        arrayList2.add(PrimitiveObjectInspectorFactory.javaDoubleObjectInspector);
        return ObjectInspectorFactory.getStandardStructObjectInspector(arrayList, arrayList2);
    }

    @Override // hivemall.xgboost.XGBoostOnlinePredictUDTF
    protected void forwardPredicted(@Nonnull Writable writable, @Nonnull double[] dArr) throws HiveException {
        Object[] objArr = this._forwardObj;
        objArr[0] = writable;
        int length = dArr.length;
        for (int i = 0; i < length; i++) {
            objArr[1] = Integer.valueOf(i);
            objArr[2] = Double.valueOf(dArr[i]);
            forward(objArr);
        }
    }
}
