package hivemall.smile.tools;

import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;
import hivemall.utils.lang.Counter;
import hivemall.utils.lang.Preconditions;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Map;
import javax.annotation.CheckForNull;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardStructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableDoubleObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.IntWritable;
import smile.math.Math;

@Description(name = "rf_ensemble", value = "_FUNC_(int yhat [, array<double> proba [, double model_weight=1.0]]) - Returns ensembled prediction results in <int label, double probability, array<double> probabilities>")
/* loaded from: input_file:hivemall/smile/tools/RandomForestEnsembleUDAF.class */
public final class RandomForestEnsembleUDAF extends AbstractGenericUDAFResolver {

    /* loaded from: input_file:hivemall/smile/tools/RandomForestEnsembleUDAF$RfAggregationBufferV1.class */
    public static final class RfAggregationBufferV1 extends GenericUDAFEvaluator.AbstractAggregationBuffer {

        @Nonnull
        private Counter<Integer> partial;

        public RfAggregationBufferV1() {
            reset();
        }

        void reset() {
            this.partial = new Counter<>();
        }

        void iterate(int i) {
            this.partial.increment(Integer.valueOf(i));
        }

        @Nonnull
        Map<Integer, Integer> terminatePartial() {
            return this.partial.getMap();
        }

        void merge(int i, int i2) {
            this.partial.increment(Integer.valueOf(i), i2);
        }

        @Nullable
        Object[] terminate() {
            Map<Integer, Integer> map = this.partial.getMap();
            int size = map.size();
            if (size == 0) {
                return null;
            }
            IntArrayList intArrayList = new IntArrayList(size);
            long j = 0;
            Integer num = null;
            int i = Integer.MIN_VALUE;
            for (Map.Entry<Integer, Integer> entry : map.entrySet()) {
                Integer key = entry.getKey();
                intArrayList.add(key.intValue());
                int intValue = entry.getValue().intValue();
                j += intValue;
                if (intValue >= i) {
                    i = intValue;
                    num = key;
                }
            }
            int[] array = intArrayList.toArray();
            Arrays.sort(array);
            int i2 = array[array.length - 1];
            double d = j;
            double[] dArr = new double[Math.max(2, i2 + 1)];
            int length = dArr.length;
            for (int i3 = 0; i3 < length; i3++) {
                if (map.get(Integer.valueOf(i3)) == null) {
                    dArr[i3] = 0.0d;
                } else {
                    dArr[i3] = r0.intValue() / d;
                }
            }
            return new Object[]{new IntWritable(num.intValue()), new DoubleWritable(i / d), WritableUtils.toWritableList(dArr)};
        }
    }

    /* loaded from: input_file:hivemall/smile/tools/RandomForestEnsembleUDAF$RfAggregationBufferV2.class */
    public static final class RfAggregationBufferV2 extends GenericUDAFEvaluator.AbstractAggregationBuffer {

        @Nullable
        private double[] _posteriori;
        private int _k;

        public RfAggregationBufferV2() {
            reset();
        }

        void reset() {
            this._posteriori = null;
            this._k = -1;
        }

        void iterate(int i, double d, @Nonnull double[] dArr) throws HiveException {
            if (this._posteriori == null) {
                this._k = dArr.length;
                this._posteriori = new double[this._k];
            }
            if (i >= this._k) {
                throw new HiveException("Predicted class " + i + " is out of bounds: " + this._k);
            }
            if (dArr.length != this._k) {
                throw new HiveException("Given |a posteriori| " + dArr.length + " is differs from expected one: " + this._k);
            }
            double[] dArr2 = this._posteriori;
            dArr2[i] = dArr2[i] + (dArr[i] * d);
        }

        void merge(int i, @Nonnull Object obj, @Nonnull StandardListObjectInspector standardListObjectInspector) throws HiveException {
            if (i != this._k) {
                if (this._k != -1) {
                    throw new HiveException("Mismatch in the number of elements: _k=" + this._k + ", size=" + i);
                }
                this._k = i;
                this._posteriori = new double[i];
            }
            double[] dArr = this._posteriori;
            WritableDoubleObjectInspector writableDoubleObjectInspector = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
            int i2 = this._k;
            for (int i3 = 0; i3 < i2; i3++) {
                int i4 = i3;
                dArr[i4] = dArr[i4] + writableDoubleObjectInspector.get(standardListObjectInspector.getListElement(obj, i3));
            }
        }

        public int estimate() {
            if (this._k == -1) {
                return 0;
            }
            return 4 + (this._k * 8);
        }
    }

    @Deprecated
    /* loaded from: input_file:hivemall/smile/tools/RandomForestEnsembleUDAF$RfEvaluatorV1.class */
    public static final class RfEvaluatorV1 extends GenericUDAFEvaluator {
        private PrimitiveObjectInspector yhatOI;
        private StandardMapObjectInspector internalMergeOI;
        private IntObjectInspector keyOI;
        private IntObjectInspector valueOI;

        public ObjectInspector init(@Nonnull GenericUDAFEvaluator.Mode mode, @Nonnull ObjectInspector[] objectInspectorArr) throws HiveException {
            StandardStructObjectInspector standardMapObjectInspector;
            super.init(mode, objectInspectorArr);
            if (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.COMPLETE) {
                this.yhatOI = HiveUtils.asIntegerOI(objectInspectorArr, 0);
            } else {
                this.internalMergeOI = (StandardMapObjectInspector) objectInspectorArr[0];
                this.keyOI = HiveUtils.asIntOI(this.internalMergeOI.getMapKeyObjectInspector());
                this.valueOI = HiveUtils.asIntOI(this.internalMergeOI.getMapValueObjectInspector());
            }
            if (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.PARTIAL2) {
                standardMapObjectInspector = ObjectInspectorFactory.getStandardMapObjectInspector(PrimitiveObjectInspectorFactory.javaIntObjectInspector, PrimitiveObjectInspectorFactory.javaIntObjectInspector);
            } else {
                ArrayList arrayList = new ArrayList(3);
                ArrayList arrayList2 = new ArrayList(3);
                arrayList.add("label");
                arrayList2.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
                arrayList.add("probability");
                arrayList2.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
                arrayList.add("probabilities");
                arrayList2.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
                standardMapObjectInspector = ObjectInspectorFactory.getStandardStructObjectInspector(arrayList, arrayList2);
            }
            return standardMapObjectInspector;
        }

        /* renamed from: getNewAggregationBuffer, reason: merged with bridge method [inline-methods] */
        public RfAggregationBufferV1 m260getNewAggregationBuffer() throws HiveException {
            RfAggregationBufferV1 rfAggregationBufferV1 = new RfAggregationBufferV1();
            rfAggregationBufferV1.reset();
            return rfAggregationBufferV1;
        }

        public void reset(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            ((RfAggregationBufferV1) aggregationBuffer).reset();
        }

        public void iterate(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer, Object[] objArr) throws HiveException {
            Preconditions.checkNotNull(objArr[0]);
            ((RfAggregationBufferV1) aggregationBuffer).iterate(PrimitiveObjectInspectorUtils.getInt(objArr[0], this.yhatOI));
        }

        public Object terminatePartial(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            return ((RfAggregationBufferV1) aggregationBuffer).terminatePartial();
        }

        public void merge(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer, Object obj) throws HiveException {
            RfAggregationBufferV1 rfAggregationBufferV1 = (RfAggregationBufferV1) aggregationBuffer;
            for (Map.Entry entry : this.internalMergeOI.getMap(obj).entrySet()) {
                putIntoMap(entry.getKey(), entry.getValue(), rfAggregationBufferV1);
            }
        }

        private void putIntoMap(@CheckForNull Object obj, @CheckForNull Object obj2, @Nonnull RfAggregationBufferV1 rfAggregationBufferV1) {
            Preconditions.checkNotNull(obj);
            Preconditions.checkNotNull(obj2);
            rfAggregationBufferV1.merge(this.keyOI.get(obj), this.valueOI.get(obj2));
        }

        public Object terminate(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            return ((RfAggregationBufferV1) aggregationBuffer).terminate();
        }
    }

    /* loaded from: input_file:hivemall/smile/tools/RandomForestEnsembleUDAF$RfEvaluatorV2.class */
    public static final class RfEvaluatorV2 extends GenericUDAFEvaluator {
        private PrimitiveObjectInspector yhatOI;
        private ListObjectInspector posterioriOI;
        private PrimitiveObjectInspector posterioriElemOI;

        @Nullable
        private PrimitiveObjectInspector weightOI;
        private StructObjectInspector internalMergeOI;
        private StructField sizeField;
        private StructField posterioriField;
        private IntObjectInspector sizeFieldOI;
        private StandardListObjectInspector posterioriFieldOI;

        public ObjectInspector init(@Nonnull GenericUDAFEvaluator.Mode mode, @Nonnull ObjectInspector[] objectInspectorArr) throws HiveException {
            StandardStructObjectInspector standardStructObjectInspector;
            super.init(mode, objectInspectorArr);
            if (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.COMPLETE) {
                this.yhatOI = HiveUtils.asIntegerOI(objectInspectorArr[0]);
                this.posterioriOI = HiveUtils.asListOI(objectInspectorArr[1]);
                this.posterioriElemOI = HiveUtils.asDoubleCompatibleOI(this.posterioriOI.getListElementObjectInspector());
                if (objectInspectorArr.length == 3) {
                    this.weightOI = HiveUtils.asDoubleCompatibleOI(objectInspectorArr[2]);
                }
            } else {
                StructObjectInspector structObjectInspector = (StructObjectInspector) objectInspectorArr[0];
                this.internalMergeOI = structObjectInspector;
                this.sizeField = structObjectInspector.getStructFieldRef("size");
                this.posterioriField = structObjectInspector.getStructFieldRef("posteriori");
                this.sizeFieldOI = PrimitiveObjectInspectorFactory.writableIntObjectInspector;
                this.posterioriFieldOI = ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
            }
            if (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.PARTIAL2) {
                ArrayList arrayList = new ArrayList(3);
                ArrayList arrayList2 = new ArrayList(3);
                arrayList.add("size");
                arrayList2.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
                arrayList.add("posteriori");
                arrayList2.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
                standardStructObjectInspector = ObjectInspectorFactory.getStandardStructObjectInspector(arrayList, arrayList2);
            } else {
                ArrayList arrayList3 = new ArrayList(3);
                ArrayList arrayList4 = new ArrayList(3);
                arrayList3.add("label");
                arrayList4.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
                arrayList3.add("probability");
                arrayList4.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
                arrayList3.add("probabilities");
                arrayList4.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
                standardStructObjectInspector = ObjectInspectorFactory.getStandardStructObjectInspector(arrayList3, arrayList4);
            }
            return standardStructObjectInspector;
        }

        /* renamed from: getNewAggregationBuffer, reason: merged with bridge method [inline-methods] */
        public RfAggregationBufferV2 m261getNewAggregationBuffer() throws HiveException {
            RfAggregationBufferV2 rfAggregationBufferV2 = new RfAggregationBufferV2();
            reset(rfAggregationBufferV2);
            return rfAggregationBufferV2;
        }

        public void reset(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            ((RfAggregationBufferV2) aggregationBuffer).reset();
        }

        public void iterate(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer, Object[] objArr) throws HiveException {
            RfAggregationBufferV2 rfAggregationBufferV2 = (RfAggregationBufferV2) aggregationBuffer;
            Preconditions.checkNotNull(objArr[0]);
            int i = PrimitiveObjectInspectorUtils.getInt(objArr[0], this.yhatOI);
            Preconditions.checkNotNull(objArr[1]);
            double[] asDoubleArray = HiveUtils.asDoubleArray(objArr[1], this.posterioriOI, this.posterioriElemOI);
            double d = 1.0d;
            if (objArr.length == 3) {
                Preconditions.checkNotNull(objArr[2]);
                d = PrimitiveObjectInspectorUtils.getDouble(objArr[2], this.weightOI);
            }
            rfAggregationBufferV2.iterate(i, d, asDoubleArray);
        }

        public Object terminatePartial(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            RfAggregationBufferV2 rfAggregationBufferV2 = (RfAggregationBufferV2) aggregationBuffer;
            if (rfAggregationBufferV2._k == -1) {
                return null;
            }
            return new Object[]{new IntWritable(rfAggregationBufferV2._k), WritableUtils.toWritableList(rfAggregationBufferV2._posteriori)};
        }

        public void merge(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer, Object obj) throws HiveException {
            if (obj == null) {
                return;
            }
            RfAggregationBufferV2 rfAggregationBufferV2 = (RfAggregationBufferV2) aggregationBuffer;
            int i = this.sizeFieldOI.get(this.internalMergeOI.getStructFieldData(obj, this.sizeField));
            Object structFieldData = this.internalMergeOI.getStructFieldData(obj, this.posterioriField);
            if (structFieldData instanceof LazyBinaryArray) {
                structFieldData = ((LazyBinaryArray) structFieldData).getList();
            }
            rfAggregationBufferV2.merge(i, structFieldData, this.posterioriFieldOI);
        }

        public Object terminate(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            RfAggregationBufferV2 rfAggregationBufferV2 = (RfAggregationBufferV2) aggregationBuffer;
            if (rfAggregationBufferV2._k == -1) {
                return null;
            }
            double[] dArr = rfAggregationBufferV2._posteriori;
            int whichMax = Math.whichMax(dArr);
            Math.unitize1(dArr);
            return new Object[]{new IntWritable(whichMax), new DoubleWritable(dArr[whichMax]), WritableUtils.toWritableList(dArr)};
        }
    }

    public GenericUDAFEvaluator getEvaluator(@Nonnull TypeInfo[] typeInfoArr) throws SemanticException {
        switch (typeInfoArr.length) {
            case 1:
                if (HiveUtils.isIntegerTypeInfo(typeInfoArr[0])) {
                    return new RfEvaluatorV1();
                }
                throw new UDFArgumentTypeException(0, "Expected INT for yhat: " + typeInfoArr[0]);
            case 2:
                break;
            case 3:
                if (!HiveUtils.isFloatingPointTypeInfo(typeInfoArr[2])) {
                    throw new UDFArgumentTypeException(2, "Expected DOUBLE or FLOAT for model_weight: " + typeInfoArr[2]);
                }
                break;
            default:
                throw new UDFArgumentLengthException("Expected 1~3 arguments but got " + typeInfoArr.length);
        }
        if (!HiveUtils.isIntegerTypeInfo(typeInfoArr[0])) {
            throw new UDFArgumentTypeException(0, "Expected INT for yhat: " + typeInfoArr[0]);
        }
        if (HiveUtils.isFloatingPointListTypeInfo(typeInfoArr[1])) {
            return new RfEvaluatorV2();
        }
        throw new UDFArgumentTypeException(1, "ARRAY<double> is expected for a posteriori: " + typeInfoArr[1]);
    }
}
