package hivemall.knn.lsh;

import hivemall.HivemallConstants;
import hivemall.UDTFWithOptions;
import hivemall.model.FeatureValue;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hashing.HashFunction;
import hivemall.utils.hashing.HashFunctionFactory;
import java.util.ArrayList;
import java.util.List;
import java.util.PriorityQueue;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.lucene.util.packed.PackedInts;

@UDFType(deterministic = true, stateful = false)
@Description(name = "minhash", value = "_FUNC_(ANY item, array<int|bigint|string> features [, constant string options]) - Returns n different k-depth signatures (i.e., clusterid) for each item <clusterid, item>")
/* loaded from: input_file:hivemall/knn/lsh/MinHashUDTF.class */
public final class MinHashUDTF extends UDTFWithOptions {
    private ObjectInspector itemOI;
    private ListObjectInspector featureListOI;
    private boolean parseFeature;
    private Object[] forwardObjs;
    private int num_hashes = 5;
    private int num_keygroups = 2;
    private HashFunction[] hashFuncs;

    public StructObjectInspector initialize(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        if (objectInspectorArr.length < 2) {
            throw new UDFArgumentException("_FUNC_ takes more than 2 arguments: ANY item, Array<Int|BigInt|Text> features [, constant String options]");
        }
        this.itemOI = objectInspectorArr[0];
        this.featureListOI = (ListObjectInspector) objectInspectorArr[1];
        String typeName = this.featureListOI.getListElementObjectInspector().getTypeName();
        if (!HivemallConstants.STRING_TYPE_NAME.equals(typeName) && !HivemallConstants.INT_TYPE_NAME.equals(typeName) && !HivemallConstants.BIGINT_TYPE_NAME.equals(typeName)) {
            throw new UDFArgumentTypeException(0, "1st argument must be Map of key type [Int|BitInt|Text]: " + typeName);
        }
        this.parseFeature = HivemallConstants.STRING_TYPE_NAME.equals(typeName);
        this.forwardObjs = new Object[2];
        processOptions(objectInspectorArr);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        arrayList.add("clusterid");
        arrayList2.add(PrimitiveObjectInspectorFactory.javaIntObjectInspector);
        arrayList.add("item");
        arrayList2.add(this.itemOI);
        return ObjectInspectorFactory.getStandardStructObjectInspector(arrayList, arrayList2);
    }

    @Override // hivemall.UDTFWithOptions
    protected Options getOptions() {
        Options options = new Options();
        options.addOption("n", "hashes", true, "Generate N sets of minhash values for each row (DEFAULT: 5)");
        options.addOption("k", "keygroups", true, "Use K minhash value (DEFAULT: 2)");
        return options;
    }

    @Override // hivemall.UDTFWithOptions
    protected CommandLine processOptions(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        CommandLine commandLine = null;
        if (objectInspectorArr.length >= 3) {
            commandLine = parseOptions(HiveUtils.getConstString(objectInspectorArr[2]));
            String optionValue = commandLine.getOptionValue("hashes");
            if (optionValue != null) {
                this.num_hashes = Integer.parseInt(optionValue);
            }
            String optionValue2 = commandLine.getOptionValue("keygroups");
            if (optionValue2 != null) {
                this.num_keygroups = Integer.parseInt(optionValue2);
            }
        }
        this.hashFuncs = HashFunctionFactory.create(this.num_hashes);
        return commandLine;
    }

    public void process(Object[] objArr) throws HiveException {
        Object[] objArr2 = this.forwardObjs;
        objArr2[1] = objArr[0];
        computeAndForwardSignatures(parseFeatures(this.featureListOI.getList(objArr[1]), this.featureListOI.getListElementObjectInspector(), this.parseFeature), objArr2);
    }

    private void computeAndForwardSignatures(List<FeatureValue> list, Object[] objArr) throws HiveException {
        PriorityQueue priorityQueue = new PriorityQueue();
        for (int i = 0; i < this.num_hashes; i++) {
            float f = Float.MAX_VALUE;
            for (FeatureValue featureValue : list) {
                int abs = Math.abs(this.hashFuncs[i].hash(featureValue.getFeature()));
                float calcWeightedHashValue = calcWeightedHashValue(abs, featureValue.getValueAsFloat());
                if (calcWeightedHashValue < f) {
                    f = calcWeightedHashValue;
                    priorityQueue.offer(Integer.valueOf(abs));
                }
            }
            objArr[0] = Integer.valueOf(getSignature(priorityQueue, this.num_keygroups));
            forward(objArr);
            priorityQueue.clear();
        }
    }

    private static float calcWeightedHashValue(int i, float f) throws HiveException {
        if (f < PackedInts.COMPACT) {
            throw new HiveException("Non-negative value is not accepted for a feature weight");
        }
        if (f == PackedInts.COMPACT) {
            return Float.MAX_VALUE;
        }
        return i / f;
    }

    private static int getSignature(PriorityQueue<Integer> priorityQueue, int i) {
        int size = priorityQueue.size();
        if (size == 0) {
            return 0;
        }
        int min = Math.min(size, i);
        int i2 = 1;
        for (int i3 = 0; i3 < min; i3++) {
            i2 = (31 * i2) + priorityQueue.poll().intValue();
        }
        return i2 & Integer.MAX_VALUE;
    }

    public void close() throws HiveException {
    }
}
