package com.intel.analytics.bigdl.dllib.utils;

import com.intel.analytics.bigdl.dllib.nn.Container;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.dllib.nn.quantized.DescParams;
import com.intel.analytics.bigdl.dllib.optim.DistriOptimizer;
import com.intel.analytics.bigdl.dllib.tensor.QuantizedTensor;
import com.intel.analytics.bigdl.dllib.tensor.QuantizedTensor$;
import com.intel.analytics.bigdl.dllib.tensor.QuantizedType$;
import com.intel.analytics.bigdl.dllib.tensor.Storage;
import com.intel.analytics.bigdl.dllib.tensor.Storage$;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.Tensor$;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.dllib.tensor.TensorType;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import org.apache.commons.io.serialization.ValidatingObjectInputStream;
import org.apache.spark.rdd.RDD;
import scala.Array$;
import scala.MatchError;
import scala.Predef$;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.immutable.Map;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.ArrayBuffer;
import scala.collection.mutable.ArrayBuffer$;
import scala.collection.mutable.StringBuilder;
import scala.math.Numeric$IntIsIntegral$;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.reflect.package$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;
import scala.runtime.ObjectRef;
import scala.runtime.RichInt$;
import scala.runtime.ScalaRunTime$;
import scala.util.Try$;

/* compiled from: Util.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/dllib/utils/Util$.class */
public final class Util$ {
    public static final Util$ MODULE$ = null;

    static {
        new Util$();
    }

    public long kthLargest(long[] jArr, int i, int i2, int i3) {
        while (i3 != 0) {
            int randomPartition = randomPartition(jArr, i, i2);
            if (randomPartition - i == i3 - 1) {
                return jArr[randomPartition];
            }
            if (randomPartition - i > i3 - 1) {
                return kthLargest(jArr, i, randomPartition - 1, i3);
            }
            i3 = ((i3 - randomPartition) + i) - 1;
            i2 = i2;
            i = randomPartition + 1;
            jArr = jArr;
        }
        return Long.MAX_VALUE;
    }

    public void swap(long[] jArr, int i, int i2) {
        long j = jArr[i];
        jArr[i] = jArr[i2];
        jArr[i2] = j;
    }

    private int partition(long[] jArr, int i, int i2) {
        long j = jArr[i2];
        IntRef create = IntRef.create(i);
        RichInt$.MODULE$.to$extension0(Predef$.MODULE$.intWrapper(i), i2 - 1).foreach$mVc$sp(new Util$$anonfun$partition$1(jArr, j, create));
        swap(jArr, create.elem, i2);
        return create.elem;
    }

    private int randomPartition(long[] jArr, int i, int i2) {
        swap(jArr, i + ((int) (Math.random() % ((i2 - i) + 1))), i2);
        return partition(jArr, i, i2);
    }

    public <B> Object shift(Object obj, int i, int i2) {
        Log4Error$.MODULE$.unKnowExceptionError(i < ScalaRunTime$.MODULE$.array_length(obj) && i >= 0, new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"invalid from ", " array length is ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(i), BoxesRunTime.boxToInteger(ScalaRunTime$.MODULE$.array_length(obj))})), Log4Error$.MODULE$.unKnowExceptionError$default$3(), Log4Error$.MODULE$.unKnowExceptionError$default$4());
        Log4Error$.MODULE$.unKnowExceptionError(i2 < ScalaRunTime$.MODULE$.array_length(obj) && i2 >= 0, new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"invalid to ", " array length is ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(i2), BoxesRunTime.boxToInteger(ScalaRunTime$.MODULE$.array_length(obj))})), Log4Error$.MODULE$.unKnowExceptionError$default$3(), Log4Error$.MODULE$.unKnowExceptionError$default$4());
        if (i == i2) {
            return obj;
        }
        if (i < i2) {
            int i3 = i;
            while (true) {
                int i4 = i3;
                if (i4 >= i2) {
                    return obj;
                }
                Object array_apply = ScalaRunTime$.MODULE$.array_apply(obj, i4);
                ScalaRunTime$.MODULE$.array_update(obj, i4, ScalaRunTime$.MODULE$.array_apply(obj, i4 + 1));
                ScalaRunTime$.MODULE$.array_update(obj, i4 + 1, array_apply);
                i3 = i4 + 1;
            }
        } else {
            int i5 = i;
            while (true) {
                int i6 = i5;
                if (i6 <= i2) {
                    return obj;
                }
                Object array_apply2 = ScalaRunTime$.MODULE$.array_apply(obj, i6);
                ScalaRunTime$.MODULE$.array_update(obj, i6, ScalaRunTime$.MODULE$.array_apply(obj, i6 - 1));
                ScalaRunTime$.MODULE$.array_update(obj, i6 - 1, array_apply2);
                i5 = i6 - 1;
            }
        }
    }

    public <T> Tensor<T>[] getAndClearWeightBias(Tuple2<Tensor<T>[], Tensor<T>[]> tuple2, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Tuple2 tuple22;
        BoxedUnit boxedUnit;
        if (((Tensor[]) tuple2._1()).length == 0) {
            return (Tensor[]) Array$.MODULE$.apply(Nil$.MODULE$, ClassTag$.MODULE$.apply(Tensor.class));
        }
        int i = 0;
        Tensor<T>[] tensorArr = new Tensor[((Tensor[]) tuple2._1()).length];
        if (Predef$.MODULE$.refArrayOps((Object[]) tuple2._1()).exists(new Util$$anonfun$2())) {
            tuple22 = new Tuple2(BoxesRunTime.boxToBoolean(false), (Object) null);
        } else {
            Storage<T> apply = Storage$.MODULE$.apply(((Tensor[]) tuple2._1())[0].storage().array(), classTag);
            tuple22 = new Tuple2(BoxesRunTime.boxToBoolean(BoxesRunTime.unboxToInt(Predef$.MODULE$.intArrayOps((int[]) Predef$.MODULE$.refArrayOps((Object[]) tuple2._1()).map(new Util$$anonfun$3(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int()))).sum(Numeric$IntIsIntegral$.MODULE$)) == apply.length()), apply);
        }
        Tuple2 tuple23 = tuple22;
        if (tuple23 == null) {
            throw new MatchError(tuple23);
        }
        Tuple2 tuple24 = new Tuple2(BoxesRunTime.boxToBoolean(tuple23._1$mcZ$sp()), (Storage) tuple23._2());
        boolean _1$mcZ$sp = tuple24._1$mcZ$sp();
        Storage<T> storage = (Storage) tuple24._2();
        while (i < ((Tensor[]) tuple2._1()).length) {
            if (((Tensor[]) tuple2._1())[i] != null) {
                Tensor tensor = ((Tensor[]) tuple2._1())[i];
                if (QuantizedType$.MODULE$.equals(tensor.getTensorType())) {
                    QuantizedTensor quantizedTensor = (QuantizedTensor) tensor;
                    tensorArr[i] = QuantizedTensor$.MODULE$.apply(quantizedTensor.getStorage(), quantizedTensor.maxOfRow(), quantizedTensor.minOfRow(), quantizedTensor.sumOfRow(), quantizedTensor.size(), quantizedTensor.params(), classTag, tensorNumeric);
                    boxedUnit = BoxedUnit.UNIT;
                } else {
                    tensorArr[i] = _1$mcZ$sp ? Tensor$.MODULE$.apply(storage, tensor.storageOffset(), tensor.size(), tensor.stride(), classTag, tensorNumeric) : Tensor$.MODULE$.apply(Storage$.MODULE$.apply(tensor.storage().array(), classTag), tensor.storageOffset(), tensor.size(), tensor.stride(), classTag, tensorNumeric);
                    boxedUnit = BoxedUnit.UNIT;
                }
                i++;
            }
        }
        clearTensor((Tensor[]) tuple2._1(), classTag, tensorNumeric);
        clearTensor((Tensor[]) tuple2._2(), classTag, tensorNumeric);
        return tensorArr;
    }

    public <T> Map<String, Tensor<?>> getAndClearConsts(Container<?, ?, T> container, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        ArrayBuffer arrayBuffer = (ArrayBuffer) ((TraversableLike) container.findModules("Const").map(new Util$$anonfun$4(), ArrayBuffer$.MODULE$.canBuildFrom())).map(new Util$$anonfun$5(), ArrayBuffer$.MODULE$.canBuildFrom());
        arrayBuffer.foreach(new Util$$anonfun$getAndClearConsts$1());
        Map<String, Tensor<?>> map = ((TraversableOnce) arrayBuffer.map(new Util$$anonfun$6(), ArrayBuffer$.MODULE$.canBuildFrom())).toMap(Predef$.MODULE$.$conforms());
        Log4Error$.MODULE$.unKnowExceptionError(map.size() == arrayBuffer.length(), new StringBuilder().append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", "'s Const node's name is duplicated,"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{container}))).append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"please check your model."})).s(Nil$.MODULE$)).toString(), Log4Error$.MODULE$.unKnowExceptionError$default$3(), Log4Error$.MODULE$.unKnowExceptionError$default$4());
        return map;
    }

    public <T> void putConsts(Container<?, ?, T> container, Map<String, Tensor<?>> map, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        ((ArrayBuffer) container.findModules("Const").map(new Util$$anonfun$7(), ArrayBuffer$.MODULE$.canBuildFrom())).foreach(new Util$$anonfun$putConsts$1(map));
    }

    public <T> void clearTensor(Tensor<T>[] tensorArr, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 >= tensorArr.length) {
                return;
            }
            if (tensorArr[i2] == null) {
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else {
                TensorType tensorType = tensorArr[i2].getTensorType();
                QuantizedType$ quantizedType$ = QuantizedType$.MODULE$;
                if (tensorType != null ? !tensorType.equals(quantizedType$) : quantizedType$ != null) {
                    BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
                } else {
                    tensorArr[i2].mo1996toQuantizedTensor().release();
                }
                tensorArr[i2].set();
            }
            i = i2 + 1;
        }
    }

    public <T> void putWeightBias(Tensor<T>[] tensorArr, AbstractModule<Activity, Activity, T> abstractModule, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Tensor[] tensorArr2 = (Tensor[]) abstractModule.parameters()._1();
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 >= tensorArr2.length) {
                return;
            }
            if (tensorArr2[i2] != null) {
                clearAndSet$1(tensorArr2[i2], tensorArr[i2]);
            }
            i = i2 + 1;
        }
    }

    public <T> void initGradWeightBias(Tensor<T>[] tensorArr, AbstractModule<Activity, Activity, T> abstractModule, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Tuple2<Tensor<T>[], Tensor<T>[]> parameters = abstractModule.parameters();
        if (parameters == null) {
            throw new MatchError(parameters);
        }
        Tuple2 tuple2 = new Tuple2((Tensor[]) parameters._1(), (Tensor[]) parameters._2());
        Tensor[] tensorArr2 = (Tensor[]) tuple2._1();
        Tensor[] tensorArr3 = (Tensor[]) tuple2._2();
        Storage<T> apply = Storage$.MODULE$.apply(BoxesRunTime.unboxToInt(Predef$.MODULE$.intArrayOps((int[]) Predef$.MODULE$.refArrayOps(tensorArr3).map(new Util$$anonfun$8(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int()))).sum(Numeric$IntIsIntegral$.MODULE$)), classTag);
        Predef$.MODULE$.refArrayOps(tensorArr).exists(new Util$$anonfun$9());
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 >= tensorArr2.length) {
                return;
            }
            if (tensorArr2[i2] == null) {
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else {
                Tensor<T> tensor = tensorArr[i2];
                Tensor<T> tensor2 = QuantizedType$.MODULE$.equals(tensor.getTensorType()) ? tensorArr3[i2].set(Tensor$.MODULE$.apply(1, classTag, tensorNumeric)) : tensorArr3[i2].set(apply, tensor.storageOffset(), tensor.size(), tensor.stride());
            }
            i = i2 + 1;
        }
    }

    public <T> T deserialize(byte[] bArr, ClassTag<T> classTag) {
        if (bArr == null) {
            Log4Error$.MODULE$.invalidOperationError(false, "The byte[] must not be null", Log4Error$.MODULE$.invalidOperationError$default$3(), Log4Error$.MODULE$.invalidOperationError$default$4());
        }
        return (T) deserialize(new ByteArrayInputStream(bArr), classTag);
    }

    public <T> T deserialize(InputStream inputStream, ClassTag<T> classTag) {
        T t;
        Object obj;
        if (inputStream == null) {
            Log4Error$.MODULE$.invalidOperationError(false, "The InputStream must not be null", Log4Error$.MODULE$.invalidOperationError$default$3(), Log4Error$.MODULE$.invalidOperationError$default$4());
        }
        ObjectRef create = ObjectRef.create((Object) null);
        try {
            try {
                try {
                    create.elem = new Util$$anon$1(inputStream);
                    ((ValidatingObjectInputStream) create.elem).accept(new Class[]{package$.MODULE$.classTag(classTag).runtimeClass()});
                    t = (T) ((ValidatingObjectInputStream) create.elem).readObject();
                } catch (ClassNotFoundException e) {
                    Log4Error$.MODULE$.unKnowExceptionError(false, "class not found", Log4Error$.MODULE$.unKnowExceptionError$default$3(), e);
                    t = (T) BoxesRunTime.boxToInteger(0);
                }
            } catch (IOException e2) {
                Log4Error$.MODULE$.unKnowExceptionError(false, "io exception", Log4Error$.MODULE$.unKnowExceptionError$default$3(), e2);
                t = (T) BoxesRunTime.boxToInteger(0);
            } catch (ClassCastException e3) {
                Log4Error$.MODULE$.unKnowExceptionError(false, "class cast error", Log4Error$.MODULE$.unKnowExceptionError$default$3(), e3);
                t = (T) BoxesRunTime.boxToInteger(0);
            }
            if (((ValidatingObjectInputStream) obj) != null) {
                Try$.MODULE$.apply(new Util$$anonfun$deserialize$1(create));
            }
            return t;
        } finally {
            if (((ValidatingObjectInputStream) create.elem) != null) {
                Try$.MODULE$.apply(new Util$$anonfun$deserialize$1(create));
            }
        }
    }

    public <T> Tensor<T>[] cloneParameters(Tensor<T>[] tensorArr, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Tuple2 tuple2;
        Storage<T> storage;
        if (tensorArr == null) {
            return null;
        }
        if (tensorArr.length == 0) {
            return (Tensor[]) Array$.MODULE$.apply(Nil$.MODULE$, ClassTag$.MODULE$.apply(Tensor.class));
        }
        Tensor<T>[] tensorArr2 = new Tensor[tensorArr.length];
        if (Predef$.MODULE$.refArrayOps(tensorArr).exists(new Util$$anonfun$10())) {
            tuple2 = new Tuple2(BoxesRunTime.boxToBoolean(false), (Object) null);
        } else {
            Storage<T> apply = Storage$.MODULE$.apply(tensorArr[0].storage().array(), classTag);
            tuple2 = new Tuple2(BoxesRunTime.boxToBoolean(BoxesRunTime.unboxToInt(Predef$.MODULE$.intArrayOps((int[]) Predef$.MODULE$.refArrayOps(tensorArr).map(new Util$$anonfun$11(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int()))).sum(Numeric$IntIsIntegral$.MODULE$)) == apply.length()), apply);
        }
        Tuple2 tuple22 = tuple2;
        if (tuple22 == null) {
            throw new MatchError(tuple22);
        }
        Tuple2 tuple23 = new Tuple2(BoxesRunTime.boxToBoolean(tuple22._1$mcZ$sp()), (Storage) tuple22._2());
        boolean _1$mcZ$sp = tuple23._1$mcZ$sp();
        Storage storage2 = (Storage) tuple23._2();
        if (_1$mcZ$sp) {
            Storage<T> apply2 = Storage$.MODULE$.apply(storage2.length(), classTag);
            System.arraycopy(storage2.array(), tensorArr[0].storageOffset() - 1, apply2.array(), 0, storage2.length());
            storage = apply2;
        } else {
            storage = null;
        }
        Storage<T> storage3 = storage;
        for (int i = 0; i < tensorArr.length; i++) {
            if (tensorArr[i] != null) {
                Tensor<T> tensor = tensorArr[i];
                if (QuantizedType$.MODULE$.equals(tensor.getTensorType())) {
                    QuantizedTensor quantizedTensor = (QuantizedTensor) tensor;
                    byte[] bArr = new byte[quantizedTensor.nElement()];
                    System.arraycopy(quantizedTensor.getStorage(), 0, bArr, 0, quantizedTensor.nElement());
                    int length = quantizedTensor.size().length;
                    int[] iArr = new int[length];
                    System.arraycopy(quantizedTensor.size(), 0, iArr, 0, length);
                    DescParams copy = quantizedTensor.params().copy();
                    int array_length = ScalaRunTime$.MODULE$.array_length(quantizedTensor.maxOfRow());
                    Object newArray = classTag.newArray(array_length);
                    System.arraycopy(quantizedTensor.maxOfRow(), 0, newArray, 0, array_length);
                    Object newArray2 = classTag.newArray(array_length);
                    System.arraycopy(quantizedTensor.minOfRow(), 0, newArray2, 0, array_length);
                    Object newArray3 = classTag.newArray(array_length);
                    System.arraycopy(quantizedTensor.sumOfRow(), 0, newArray3, 0, array_length);
                    tensorArr2[i] = QuantizedTensor$.MODULE$.apply(bArr, newArray, newArray2, newArray3, iArr, copy, classTag, tensorNumeric);
                    BoxedUnit boxedUnit = BoxedUnit.UNIT;
                } else {
                    tensorArr2[i] = _1$mcZ$sp ? Tensor$.MODULE$.apply(storage3, tensor.storageOffset(), tensor.size(), tensor.stride(), classTag, tensorNumeric) : tensor.m1995clone();
                    BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
                }
            }
        }
        return tensorArr2;
    }

    public <T> void setExtraParametersFromModelRDD(RDD<DistriOptimizer.Cache<T>> rdd, AbstractModule<Activity, Activity, T> abstractModule, int i, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Tensor<T>[] tensorArr;
        if (abstractModule.getExtraParameter() == null || abstractModule.getExtraParameter().length <= 0) {
            return;
        }
        if (BoxesRunTime.unboxToInt(rdd.map(new Util$$anonfun$12(), ClassTag$.MODULE$.Int()).first()) < i) {
            tensorArr = (Tensor[]) rdd.map(new Util$$anonfun$13(), ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Tensor.class))).first();
        } else {
            int[] iArr = (int[]) rdd.map(new Util$$anonfun$14(), ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Integer.TYPE))).first();
            int length = iArr.length;
            Tensor<T>[] tensorArr2 = new Tensor[length];
            RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), length).foreach$mVc$sp(new Util$$anonfun$1(rdd, abstractModule, i, classTag, tensorNumeric, iArr, tensorArr2));
            tensorArr = tensorArr2;
        }
        abstractModule.setExtraParameter(tensorArr);
    }

    private final void clearAndSet$1(Tensor tensor, Tensor tensor2) {
        TensorType tensorType = tensor.getTensorType();
        QuantizedType$ quantizedType$ = QuantizedType$.MODULE$;
        if (tensorType != null ? tensorType.equals(quantizedType$) : quantizedType$ == null) {
            TensorType tensorType2 = tensor2.getTensorType();
            QuantizedType$ quantizedType$2 = QuantizedType$.MODULE$;
            if (tensorType2 != null ? tensorType2.equals(quantizedType$2) : quantizedType$2 == null) {
                QuantizedTensor quantizedTensor = (QuantizedTensor) tensor;
                if (quantizedTensor.getNativeStorage() != ((QuantizedTensor) tensor2).getNativeStorage()) {
                    quantizedTensor.release();
                } else {
                    BoxedUnit boxedUnit = BoxedUnit.UNIT;
                }
                tensor.set(tensor2);
            }
        }
        BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
        tensor.set(tensor2);
    }

    private Util$() {
        MODULE$ = this;
    }
}
