package hivemall.classifier;

import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.Preconditions;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;

@Description(name = "kpa_predict", value = "_FUNC_(@Nonnull double xh, @Nonnull double xk, @Nullable float w0, @Nonnull float w1, @Nonnull float w2, @Nullable float w3) - Returns a prediction value in Double")
/* loaded from: input_file:hivemall/classifier/KPAPredictUDAF.class */
public final class KPAPredictUDAF extends AbstractGenericUDAFResolver {

    /* JADX INFO: Access modifiers changed from: package-private */
    @GenericUDAFEvaluator.AggregationType(estimable = true)
    /* loaded from: input_file:hivemall/classifier/KPAPredictUDAF$AggrBuffer.class */
    public static class AggrBuffer extends GenericUDAFEvaluator.AbstractAggregationBuffer {
        double score;

        AggrBuffer() {
            reset();
        }

        public int estimate() {
            return 8;
        }

        void reset() {
            this.score = CMAESOptimizer.DEFAULT_STOPFITNESS;
        }

        double get() {
            return this.score;
        }

        void addW0(@Nonnull double d) {
            this.score += d;
        }

        void addW1W2(double d, double d2, double d3) {
            this.score += (d2 * d) + (d3 * d * d);
        }

        void addW3(double d, double d2, double d3) {
            this.score += d3 * d * d2;
        }

        void merge(double d) {
            this.score += d;
        }
    }

    /* loaded from: input_file:hivemall/classifier/KPAPredictUDAF$Evaluator.class */
    public static class Evaluator extends GenericUDAFEvaluator {

        @Nullable
        private transient PrimitiveObjectInspector xhOI;

        @Nullable
        private transient PrimitiveObjectInspector xkOI;

        @Nullable
        private transient PrimitiveObjectInspector w0OI;

        @Nullable
        private transient PrimitiveObjectInspector w1OI;

        @Nullable
        private transient PrimitiveObjectInspector w2OI;

        @Nullable
        private transient PrimitiveObjectInspector w3OI;

        public ObjectInspector init(GenericUDAFEvaluator.Mode mode, ObjectInspector[] objectInspectorArr) throws HiveException {
            super.init(mode, objectInspectorArr);
            if (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.COMPLETE) {
                this.xhOI = HiveUtils.asNumberOI(objectInspectorArr[0]);
                this.xkOI = HiveUtils.asNumberOI(objectInspectorArr[1]);
                this.w0OI = HiveUtils.asNumberOI(objectInspectorArr[2]);
                this.w1OI = HiveUtils.asNumberOI(objectInspectorArr[3]);
                this.w2OI = HiveUtils.asNumberOI(objectInspectorArr[4]);
                this.w3OI = HiveUtils.asNumberOI(objectInspectorArr[5]);
            }
            return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
        }

        /* renamed from: getNewAggregationBuffer, reason: merged with bridge method [inline-methods] */
        public AggrBuffer m38getNewAggregationBuffer() throws HiveException {
            return new AggrBuffer();
        }

        public void reset(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            ((AggrBuffer) aggregationBuffer).reset();
        }

        public void iterate(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer, Object[] objArr) throws HiveException {
            Preconditions.checkArgument(objArr.length == 6, HiveException.class);
            AggrBuffer aggrBuffer = (AggrBuffer) aggregationBuffer;
            if (objArr[0] == null) {
                if (objArr[2] != null) {
                    aggrBuffer.addW0(HiveUtils.getDouble(objArr[2], this.w0OI));
                    return;
                }
                return;
            }
            double d = HiveUtils.getDouble(objArr[0], this.xhOI);
            if (objArr[1] != null) {
                if (objArr[5] == null) {
                    return;
                }
                aggrBuffer.addW3(d, HiveUtils.getDouble(objArr[1], this.xkOI), HiveUtils.getDouble(objArr[5], this.w3OI));
            } else {
                if (objArr[3] == null) {
                    return;
                }
                Preconditions.checkNotNull(objArr[4], HiveException.class);
                aggrBuffer.addW1W2(d, HiveUtils.getDouble(objArr[3], this.w1OI), HiveUtils.getDouble(objArr[4], this.w2OI));
            }
        }

        public Object terminatePartial(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            return new DoubleWritable(((AggrBuffer) aggregationBuffer).get());
        }

        public void merge(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer, Object obj) throws HiveException {
            if (obj == null) {
                return;
            }
            ((AggrBuffer) aggregationBuffer).merge(((DoubleWritable) obj).get());
        }

        /* renamed from: terminate, reason: merged with bridge method [inline-methods] */
        public DoubleWritable m37terminate(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            return new DoubleWritable(((AggrBuffer) aggregationBuffer).get());
        }
    }

    public GenericUDAFEvaluator getEvaluator(TypeInfo[] typeInfoArr) throws SemanticException {
        if (typeInfoArr.length != 6) {
            throw new UDFArgumentException("_FUNC_(double xh, double xk, float w0, float w1, float w2, float w3) takes exactly 6 arguments but got: " + typeInfoArr.length);
        }
        if (!HiveUtils.isNumberTypeInfo(typeInfoArr[0])) {
            throw new UDFArgumentTypeException(0, "Number type is expected for xh (1st argument): " + typeInfoArr[0].getTypeName());
        }
        if (!HiveUtils.isNumberTypeInfo(typeInfoArr[1])) {
            throw new UDFArgumentTypeException(1, "Number type is expected for xk (2nd argument): " + typeInfoArr[1].getTypeName());
        }
        if (!HiveUtils.isNumberTypeInfo(typeInfoArr[2])) {
            throw new UDFArgumentTypeException(2, "Number type is expected for w0 (3rd argument): " + typeInfoArr[2].getTypeName());
        }
        if (!HiveUtils.isNumberTypeInfo(typeInfoArr[3])) {
            throw new UDFArgumentTypeException(3, "Number type is expected for w1 (4th argument): " + typeInfoArr[3].getTypeName());
        }
        if (!HiveUtils.isNumberTypeInfo(typeInfoArr[4])) {
            throw new UDFArgumentTypeException(4, "Number type is expected for w2 (5th argument): " + typeInfoArr[4].getTypeName());
        }
        if (HiveUtils.isNumberTypeInfo(typeInfoArr[5])) {
            return new Evaluator();
        }
        throw new UDFArgumentTypeException(5, "Number type is expected for w3 (6th argument): " + typeInfoArr[5].getTypeName());
    }
}
