package hivemall.tools.array;

import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import javax.annotation.Nonnull;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableDoubleObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableLongObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.IntWritable;
import org.apache.lucene.util.packed.PackedInts;

@Description(name = "array_avg", value = "_FUNC_(array<number>) - Returns an array<double> in which each element is the mean of a set of numbers", extended = "WITH input as (\n  select array(1.0, 2.0, 3.0) as nums\n  UNION ALL\n  select array(2.0, 3.0, 4.0) as nums\n)\nselect\n  array_avg(nums)\nfrom\n  input;\n\n[\"1.5\",\"2.5\",\"3.5\"]")
/* loaded from: input_file:hivemall/tools/array/ArrayAvgGenericUDAF.class */
public final class ArrayAvgGenericUDAF extends AbstractGenericUDAFResolver {

    @GenericUDAFEvaluator.AggregationType(estimable = true)
    /* loaded from: input_file:hivemall/tools/array/ArrayAvgGenericUDAF$ArrayAvgAggregationBuffer.class */
    public static final class ArrayAvgAggregationBuffer extends GenericUDAFEvaluator.AbstractAggregationBuffer {
        int _size;
        double[] _sum;
        long[] _count;
        static final /* synthetic */ boolean $assertionsDisabled;

        void reset() {
            this._size = -1;
            this._sum = null;
            this._count = null;
        }

        void init(int i) throws HiveException {
            if (!$assertionsDisabled && i <= 0) {
                throw new AssertionError(i);
            }
            this._size = i;
            this._sum = new double[i];
            this._count = new long[i];
        }

        void doIterate(@Nonnull Object obj, @Nonnull ListObjectInspector listObjectInspector, @Nonnull PrimitiveObjectInspector primitiveObjectInspector) throws HiveException {
            int listLength = listObjectInspector.getListLength(obj);
            if (this._size == -1) {
                init(listLength);
            }
            if (listLength != this._size) {
                throw new HiveException("Mismatch in the number of elements at tuple: " + obj.toString());
            }
            double[] dArr = this._sum;
            long[] jArr = this._count;
            for (int i = 0; i < listLength; i++) {
                Object listElement = listObjectInspector.getListElement(obj, i);
                if (listElement != null) {
                    int i2 = i;
                    dArr[i2] = dArr[i2] + PrimitiveObjectInspectorUtils.getDouble(listElement, primitiveObjectInspector);
                    int i3 = i;
                    jArr[i3] = jArr[i3] + 1;
                }
            }
        }

        void merge(int i, @Nonnull Object obj, @Nonnull Object obj2, @Nonnull StandardListObjectInspector standardListObjectInspector, @Nonnull StandardListObjectInspector standardListObjectInspector2) throws HiveException {
            WritableDoubleObjectInspector writableDoubleObjectInspector = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
            WritableLongObjectInspector writableLongObjectInspector = PrimitiveObjectInspectorFactory.writableLongObjectInspector;
            if (i != this._size) {
                if (this._size != -1) {
                    throw new HiveException("Mismatch in the number of elements");
                }
                init(i);
            }
            double[] dArr = this._sum;
            long[] jArr = this._count;
            int i2 = this._size;
            for (int i3 = 0; i3 < i2; i3++) {
                int i4 = i3;
                dArr[i4] = dArr[i4] + writableDoubleObjectInspector.get(standardListObjectInspector.getListElement(obj, i3));
                int i5 = i3;
                jArr[i5] = jArr[i5] + writableLongObjectInspector.get(standardListObjectInspector2.getListElement(obj2, i3));
            }
        }

        public int estimate() {
            if (this._size == -1) {
                return 8;
            }
            return 4 + (2 * (32 + (8 * this._size)));
        }

        static {
            $assertionsDisabled = !ArrayAvgGenericUDAF.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:hivemall/tools/array/ArrayAvgGenericUDAF$Evaluator.class */
    public static class Evaluator extends GenericUDAFEvaluator {
        private ListObjectInspector inputListOI;
        private PrimitiveObjectInspector inputListElemOI;
        private StructObjectInspector internalMergeOI;
        private StructField sizeField;
        private StructField sumField;
        private StructField countField;
        private WritableIntObjectInspector sizeOI;
        private StandardListObjectInspector sumOI;
        private StandardListObjectInspector countOI;
        static final /* synthetic */ boolean $assertionsDisabled;

        public ObjectInspector init(GenericUDAFEvaluator.Mode mode, ObjectInspector[] objectInspectorArr) throws HiveException {
            if (!$assertionsDisabled && objectInspectorArr.length != 1) {
                throw new AssertionError();
            }
            super.init(mode, objectInspectorArr);
            if (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.COMPLETE) {
                this.inputListOI = (ListObjectInspector) objectInspectorArr[0];
                this.inputListElemOI = HiveUtils.asDoubleCompatibleOI(this.inputListOI.getListElementObjectInspector());
            } else {
                StructObjectInspector structObjectInspector = (StructObjectInspector) objectInspectorArr[0];
                this.internalMergeOI = structObjectInspector;
                this.sizeField = structObjectInspector.getStructFieldRef("size");
                this.sumField = structObjectInspector.getStructFieldRef("sum");
                this.countField = structObjectInspector.getStructFieldRef("count");
                this.sizeOI = PrimitiveObjectInspectorFactory.writableIntObjectInspector;
                this.sumOI = ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
                this.countOI = ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
            }
            return (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.PARTIAL2) ? internalMergeOI() : ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
        }

        private static StructObjectInspector internalMergeOI() {
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            arrayList.add("size");
            arrayList2.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
            arrayList.add("sum");
            arrayList2.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
            arrayList.add("count");
            arrayList2.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableLongObjectInspector));
            return ObjectInspectorFactory.getStandardStructObjectInspector(arrayList, arrayList2);
        }

        /* renamed from: getNewAggregationBuffer, reason: merged with bridge method [inline-methods] */
        public ArrayAvgAggregationBuffer m288getNewAggregationBuffer() throws HiveException {
            ArrayAvgAggregationBuffer arrayAvgAggregationBuffer = new ArrayAvgAggregationBuffer();
            reset(arrayAvgAggregationBuffer);
            return arrayAvgAggregationBuffer;
        }

        public void reset(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            ((ArrayAvgAggregationBuffer) aggregationBuffer).reset();
        }

        public void iterate(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer, Object[] objArr) throws HiveException {
            ArrayAvgAggregationBuffer arrayAvgAggregationBuffer = (ArrayAvgAggregationBuffer) aggregationBuffer;
            Object obj = objArr[0];
            if (obj != null) {
                arrayAvgAggregationBuffer.doIterate(obj, this.inputListOI, this.inputListElemOI);
            }
        }

        public Object terminatePartial(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            ArrayAvgAggregationBuffer arrayAvgAggregationBuffer = (ArrayAvgAggregationBuffer) aggregationBuffer;
            if (arrayAvgAggregationBuffer._size == -1) {
                return null;
            }
            return new Object[]{new IntWritable(arrayAvgAggregationBuffer._size), WritableUtils.toWritableList(arrayAvgAggregationBuffer._sum), WritableUtils.toWritableList(arrayAvgAggregationBuffer._count)};
        }

        public void merge(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer, Object obj) throws HiveException {
            if (obj != null) {
                ArrayAvgAggregationBuffer arrayAvgAggregationBuffer = (ArrayAvgAggregationBuffer) aggregationBuffer;
                int i = this.sizeOI.get(this.internalMergeOI.getStructFieldData(obj, this.sizeField));
                if (!$assertionsDisabled && i == -1) {
                    throw new AssertionError();
                }
                Object structFieldData = this.internalMergeOI.getStructFieldData(obj, this.sumField);
                Object structFieldData2 = this.internalMergeOI.getStructFieldData(obj, this.countField);
                if (structFieldData instanceof LazyBinaryArray) {
                    structFieldData = ((LazyBinaryArray) structFieldData).getList();
                }
                if (structFieldData2 instanceof LazyBinaryArray) {
                    structFieldData2 = ((LazyBinaryArray) structFieldData2).getList();
                }
                arrayAvgAggregationBuffer.merge(i, structFieldData, structFieldData2, this.sumOI, this.countOI);
            }
        }

        /* renamed from: terminate, reason: merged with bridge method [inline-methods] */
        public List<DoubleWritable> m287terminate(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            ArrayAvgAggregationBuffer arrayAvgAggregationBuffer = (ArrayAvgAggregationBuffer) aggregationBuffer;
            int i = arrayAvgAggregationBuffer._size;
            if (i == -1) {
                return null;
            }
            double[] dArr = arrayAvgAggregationBuffer._sum;
            long[] jArr = arrayAvgAggregationBuffer._count;
            DoubleWritable[] doubleWritableArr = new DoubleWritable[i];
            for (int i2 = 0; i2 < i; i2++) {
                doubleWritableArr[i2] = new DoubleWritable(jArr[i2] == 0 ? PackedInts.COMPACT : (float) (dArr[i2] / r0));
            }
            return Arrays.asList(doubleWritableArr);
        }

        static {
            $assertionsDisabled = !ArrayAvgGenericUDAF.class.desiredAssertionStatus();
        }
    }

    public GenericUDAFEvaluator getEvaluator(TypeInfo[] typeInfoArr) throws SemanticException {
        if (typeInfoArr.length != 1) {
            throw new UDFArgumentTypeException(typeInfoArr.length - 1, "One argument is expected, taking an array as an argument");
        }
        if (typeInfoArr[0].getCategory().equals(ObjectInspector.Category.LIST)) {
            return new Evaluator();
        }
        throw new UDFArgumentTypeException(typeInfoArr.length - 1, "One argument is expected, taking an array as an argument");
    }
}
