package hivemall.xgboost;

import biz.k11i.xgboost.Predictor;
import biz.k11i.xgboost.util.FVec;
import hivemall.UDTFWithOptions;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;
import hivemall.xgboost.utils.XGBoostUtils;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
import org.apache.hadoop.io.Writable;

@Description(name = "xgboost_predict", value = "_FUNC_(PRIMITIVE rowid, array<string|double> features, string model_id, array<string> pred_model [, string options]) - Returns a prediction result as (string rowid, array<double> predicted)", extended = "select\n  rowid, \n  array_avg(predicted) as predicted,\n  avg(predicted[0]) as predicted0\nfrom (\n  select\n    xgboost_predict(rowid, features, model_id, model) as (rowid, predicted)\n  from\n    xgb_model l\n    LEFT OUTER JOIN xgb_input r\n) t\ngroup by rowid;")
/* loaded from: input_file:hivemall/xgboost/XGBoostOnlinePredictUDTF.class */
public class XGBoostOnlinePredictUDTF extends UDTFWithOptions {
    private PrimitiveObjectInspector rowIdOI;
    private ListObjectInspector featureListOI;
    private boolean denseFeatures;

    @Nullable
    private PrimitiveObjectInspector featureElemOI;
    private StringObjectInspector modelIdOI;
    private StringObjectInspector modelOI;

    @Nullable
    private transient Map<String, Predictor> mapToModel;

    @Nonnull
    protected final transient Object[] _forwardObj;

    @Nullable
    protected transient List<DoubleWritable> _predictedCache;

    public XGBoostOnlinePredictUDTF() {
        this(new Object[2]);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public XGBoostOnlinePredictUDTF(@Nonnull Object[] objArr) {
        this._forwardObj = objArr;
    }

    @Override // hivemall.UDTFWithOptions
    protected Options getOptions() {
        return new Options();
    }

    @Override // hivemall.UDTFWithOptions
    protected CommandLine processOptions(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        CommandLine commandLine = null;
        if (objectInspectorArr.length >= 5) {
            commandLine = parseOptions(HiveUtils.getConstString(objectInspectorArr, 4));
        }
        return commandLine;
    }

    public StructObjectInspector initialize(@Nonnull ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        if (objectInspectorArr.length != 4 && objectInspectorArr.length != 5) {
            showHelp("Invalid argment length=" + objectInspectorArr.length);
        }
        processOptions(objectInspectorArr);
        this.rowIdOI = HiveUtils.asPrimitiveObjectInspector(objectInspectorArr, 0);
        ListObjectInspector asListOI = HiveUtils.asListOI(objectInspectorArr, 1);
        this.featureListOI = asListOI;
        ObjectInspector listElementObjectInspector = asListOI.getListElementObjectInspector();
        if (HiveUtils.isNumberOI(listElementObjectInspector)) {
            this.featureElemOI = HiveUtils.asDoubleCompatibleOI(listElementObjectInspector);
            this.denseFeatures = true;
        } else {
            if (!HiveUtils.isStringOI(listElementObjectInspector)) {
                throw new UDFArgumentException("Expected array<string|double> for the 2nd argment but got an unexpected features type: " + asListOI.getTypeName());
            }
            this.denseFeatures = false;
        }
        this.modelIdOI = HiveUtils.asStringOI(objectInspectorArr, 2);
        this.modelOI = HiveUtils.asStringOI(objectInspectorArr, 3);
        return getReturnOI(this.rowIdOI);
    }

    @Nonnull
    protected StructObjectInspector getReturnOI(@Nonnull PrimitiveObjectInspector primitiveObjectInspector) {
        ArrayList arrayList = new ArrayList(2);
        ArrayList arrayList2 = new ArrayList(2);
        arrayList.add("rowid");
        arrayList2.add(PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(primitiveObjectInspector.getPrimitiveCategory()));
        arrayList.add("predicted");
        arrayList2.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
        return ObjectInspectorFactory.getStandardStructObjectInspector(arrayList, arrayList2);
    }

    public void process(Object[] objArr) throws HiveException {
        if (this.mapToModel == null) {
            this.mapToModel = new HashMap();
        }
        if (objArr[1] == null) {
            return;
        }
        String string = PrimitiveObjectInspectorUtils.getString(nonNullArgument(objArr, 2), this.modelIdOI);
        Predictor predictor = this.mapToModel.get(string);
        if (predictor == null) {
            predictor = XGBoostUtils.loadPredictor(this.modelOI.getPrimitiveWritableObject(nonNullArgument(objArr, 3)));
            this.mapToModel.put(string, predictor);
        }
        predictAndForward(predictor, HiveUtils.copyToWritable(nonNullArgument(objArr, 0), this.rowIdOI), this.denseFeatures ? parseDenseFeatures(objArr[1]) : parseSparseFeatures(this.featureListOI.getList(objArr[1])));
    }

    @Nonnull
    private FVec parseDenseFeatures(@Nonnull Object obj) throws UDFArgumentException {
        int listLength = this.featureListOI.getListLength(obj);
        double[] dArr = new double[listLength];
        for (int i = 0; i < listLength; i++) {
            Object listElement = this.featureListOI.getListElement(obj, i);
            dArr[i] = listElement == null ? Double.NaN : PrimitiveObjectInspectorUtils.getDouble(listElement, this.featureElemOI);
        }
        return FVec.Transformer.fromArray(dArr, false);
    }

    @Nonnull
    private static FVec parseSparseFeatures(@Nonnull List<?> list) throws UDFArgumentException {
        HashMap hashMap = new HashMap((int) (list.size() * 1.5d));
        for (Object obj : list) {
            if (obj != null) {
                String obj2 = obj.toString();
                int indexOf = obj2.indexOf(58);
                if (indexOf < 1) {
                    throw new UDFArgumentException("Invalid feature format: " + obj2);
                }
                try {
                    hashMap.put(Integer.valueOf(Integer.parseInt(obj2.substring(0, indexOf))), Double.valueOf(Double.parseDouble(obj2.substring(indexOf + 1))));
                } catch (NumberFormatException e) {
                    throw new UDFArgumentException("Failed to parse a feature value: " + obj2);
                }
            }
        }
        return FVec.Transformer.fromMap(hashMap);
    }

    private void predictAndForward(@Nonnull Predictor predictor, @Nonnull Writable writable, @Nonnull FVec fVec) throws HiveException {
        forwardPredicted(writable, predictor.predict(fVec));
    }

    protected void forwardPredicted(@Nonnull Writable writable, @Nonnull double[] dArr) throws HiveException {
        List<DoubleWritable> writableList = WritableUtils.toWritableList(dArr, this._predictedCache);
        this._predictedCache = writableList;
        Object[] objArr = this._forwardObj;
        objArr[0] = writable;
        objArr[1] = writableList;
        forward(objArr);
    }

    public void close() throws HiveException {
        this.mapToModel = null;
    }
}
