package com.intel.analytics.bigdl.orca.tfpark;

import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.dllib.tensor.QuantizedType$;
import com.intel.analytics.bigdl.dllib.tensor.Storage;
import com.intel.analytics.bigdl.dllib.tensor.Storage$;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.Tensor$;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import scala.Array$;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.immutable.Nil$;
import scala.math.Numeric$IntIsIntegral$;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: TFModelBroadcast.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/orca/tfpark/Util$.class */
public final class Util$ {
    public static final Util$ MODULE$ = null;

    static {
        new Util$();
    }

    public <T> Tensor<T>[] getAndClearWeightBias(Tuple2<Tensor<T>[], Tensor<T>[]> tuple2, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        clearTensor((Tensor[]) tuple2._2(), classTag, tensorNumeric);
        return getAndClearParameters((Tensor[]) tuple2._1(), classTag, tensorNumeric);
    }

    public <T> Tensor<T>[] getAndClearExtraParameters(Tensor<T>[] tensorArr, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        return getAndClearParameters(tensorArr, classTag, tensorNumeric);
    }

    public <T> Tensor<T>[] getAndClearParameters(Tensor<T>[] tensorArr, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        if (tensorArr == null) {
            return null;
        }
        if (tensorArr.length == 0) {
            return (Tensor[]) Array$.MODULE$.apply(Nil$.MODULE$, ClassTag$.MODULE$.apply(Tensor.class));
        }
        int i = 0;
        Tensor<T>[] tensorArr2 = new Tensor[tensorArr.length];
        Storage apply = Storage$.MODULE$.apply(tensorArr[0].storage().array(), classTag);
        Tuple2 tuple2 = new Tuple2(BoxesRunTime.boxToBoolean(BoxesRunTime.unboxToInt(Predef$.MODULE$.intArrayOps((int[]) Predef$.MODULE$.refArrayOps(tensorArr).map(new Util$$anonfun$1(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int()))).sum(Numeric$IntIsIntegral$.MODULE$)) == apply.length()), apply);
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        Tuple2 tuple22 = new Tuple2(BoxesRunTime.boxToBoolean(tuple2._1$mcZ$sp()), (Storage) tuple2._2());
        boolean _1$mcZ$sp = tuple22._1$mcZ$sp();
        Storage storage = (Storage) tuple22._2();
        while (i < tensorArr.length) {
            if (tensorArr[i] != null) {
                Tensor<T> tensor = tensorArr[i];
                tensorArr2[i] = _1$mcZ$sp ? Tensor$.MODULE$.apply(storage, tensor.storageOffset(), tensor.size(), tensor.stride(), classTag, tensorNumeric) : Tensor$.MODULE$.apply(Storage$.MODULE$.apply(tensor.storage().array(), classTag), tensor.storageOffset(), tensor.size(), tensor.stride(), classTag, tensorNumeric);
                i++;
            }
        }
        clearTensor(tensorArr, classTag, tensorNumeric);
        return tensorArr2;
    }

    private <T> void clearTensor(Tensor<T>[] tensorArr, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        if (tensorArr == null) {
            return;
        }
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 >= tensorArr.length) {
                return;
            }
            if (tensorArr[i2] == null) {
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else {
                tensorArr[i2].set();
            }
            i = i2 + 1;
        }
    }

    public <T> void putWeightBias(Tensor<T>[] tensorArr, AbstractModule<Activity, Activity, T> abstractModule, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Tensor[] tensorArr2 = (Tensor[]) abstractModule.parameters()._1();
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 >= tensorArr2.length) {
                return;
            }
            if (tensorArr2[i2] != null) {
                clearAndSet$1(tensorArr2[i2], tensorArr[i2]);
            }
            i = i2 + 1;
        }
    }

    public <T> void putExtraParams(Tensor<T>[] tensorArr, AbstractModule<Activity, Activity, T> abstractModule, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Tensor[] extraParameter = abstractModule.getExtraParameter();
        if (extraParameter == null) {
            return;
        }
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 >= extraParameter.length) {
                return;
            }
            if (extraParameter[i2] == null) {
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else {
                extraParameter[i2].set(tensorArr[i2]);
            }
            i = i2 + 1;
        }
    }

    public <T> void initGradWeightBias(Tensor<T>[] tensorArr, AbstractModule<Activity, Activity, T> abstractModule, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Tuple2 parameters = abstractModule.parameters();
        if (parameters == null) {
            throw new MatchError(parameters);
        }
        Tuple2 tuple2 = new Tuple2((Tensor[]) parameters._1(), (Tensor[]) parameters._2());
        Tensor[] tensorArr2 = (Tensor[]) tuple2._1();
        Tensor[] tensorArr3 = (Tensor[]) tuple2._2();
        Storage apply = Storage$.MODULE$.apply(BoxesRunTime.unboxToInt(Predef$.MODULE$.intArrayOps((int[]) Predef$.MODULE$.refArrayOps(tensorArr3).map(new Util$$anonfun$2(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int()))).sum(Numeric$IntIsIntegral$.MODULE$)), classTag);
        Predef$.MODULE$.refArrayOps(tensorArr).exists(new Util$$anonfun$3());
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 >= tensorArr2.length) {
                return;
            }
            if (tensorArr2[i2] == null) {
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else {
                Tensor<T> tensor = tensorArr[i2];
                Tensor tensor2 = QuantizedType$.MODULE$.equals(tensor.getTensorType()) ? tensorArr3[i2].set(Tensor$.MODULE$.apply(1, classTag, tensorNumeric)) : tensorArr3[i2].set(apply, tensor.storageOffset(), tensor.size(), tensor.stride());
            }
            i = i2 + 1;
        }
    }

    public <T> Tensor<T>[] cloneParameters(Tensor<T>[] tensorArr, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Storage storage;
        if (tensorArr == null) {
            return null;
        }
        if (tensorArr.length == 0) {
            return (Tensor[]) Array$.MODULE$.apply(Nil$.MODULE$, ClassTag$.MODULE$.apply(Tensor.class));
        }
        int i = 0;
        Tensor<T>[] tensorArr2 = new Tensor[tensorArr.length];
        Storage apply = Storage$.MODULE$.apply(tensorArr[0].storage().array(), classTag);
        Tuple2 tuple2 = new Tuple2(BoxesRunTime.boxToBoolean(BoxesRunTime.unboxToInt(Predef$.MODULE$.intArrayOps((int[]) Predef$.MODULE$.refArrayOps(tensorArr).map(new Util$$anonfun$4(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int()))).sum(Numeric$IntIsIntegral$.MODULE$)) == apply.length()), apply);
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        Tuple2 tuple22 = new Tuple2(BoxesRunTime.boxToBoolean(tuple2._1$mcZ$sp()), (Storage) tuple2._2());
        boolean _1$mcZ$sp = tuple22._1$mcZ$sp();
        Storage storage2 = (Storage) tuple22._2();
        if (_1$mcZ$sp) {
            Storage apply2 = Storage$.MODULE$.apply(storage2.length(), classTag);
            System.arraycopy(storage2.array(), tensorArr[0].storageOffset() - 1, apply2.array(), 0, storage2.length());
            storage = apply2;
        } else {
            storage = null;
        }
        Storage storage3 = storage;
        while (i < tensorArr.length) {
            if (tensorArr[i] != null) {
                Tensor<T> tensor = tensorArr[i];
                tensorArr2[i] = _1$mcZ$sp ? Tensor$.MODULE$.apply(storage3, tensor.storageOffset(), tensor.size(), tensor.stride(), classTag, tensorNumeric) : tensor.clone();
                i++;
            }
        }
        return tensorArr2;
    }

    private final void clearAndSet$1(Tensor tensor, Tensor tensor2) {
        tensor.set(tensor2);
    }

    private Util$() {
        MODULE$ = this;
    }
}
