package hivemall.factorization.mf;

import java.util.List;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDF;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.io.FloatWritable;
import org.apache.lucene.util.packed.PackedInts;

@UDFType(deterministic = true, stateful = false)
@Description(name = "bprmf_predict", value = "_FUNC_(List<Float> Pu, List<Float> Qi[, double Bi]) - Returns the prediction value")
/* loaded from: input_file:hivemall/factorization/mf/BPRMFPredictionUDF.class */
public final class BPRMFPredictionUDF extends UDF {
    public FloatWritable evaluate(List<Float> list, List<Float> list2) throws HiveException {
        return evaluate(list, list2, CMAESOptimizer.DEFAULT_STOPFITNESS);
    }

    public FloatWritable evaluate(List<Float> list, List<Float> list2, double d) throws HiveException {
        if (list == null && list2 == null) {
            return new FloatWritable(PackedInts.COMPACT);
        }
        if (list == null) {
            return new FloatWritable((float) d);
        }
        if (list2 == null) {
            return new FloatWritable(PackedInts.COMPACT);
        }
        int size = list.size();
        int size2 = list2.size();
        if (size == 0) {
            return size2 == 0 ? new FloatWritable(PackedInts.COMPACT) : new FloatWritable((float) d);
        }
        if (size2 == 0) {
            return new FloatWritable(PackedInts.COMPACT);
        }
        if (size2 != size) {
            throw new HiveException("|Pu| " + size + " was not equal to |Qi| " + size2);
        }
        float f = (float) d;
        for (int i = 0; i < size; i++) {
            f += list.get(i).floatValue() * list2.get(i).floatValue();
        }
        return new FloatWritable(f);
    }
}
