package com.intel.analytics.bigdl.dllib.keras.layers.utils;

import caffe.Caffe;
import com.intel.analytics.bigdl.dllib.keras.Net;
import com.intel.analytics.bigdl.dllib.keras.layers.SoftMax$;
import com.intel.analytics.bigdl.dllib.keras.metrics.BinaryAccuracy;
import com.intel.analytics.bigdl.dllib.keras.metrics.CategoricalAccuracy;
import com.intel.analytics.bigdl.dllib.keras.metrics.SparseCategoricalAccuracy;
import com.intel.analytics.bigdl.dllib.keras.models.KerasNet;
import com.intel.analytics.bigdl.dllib.keras.objectives.BinaryCrossEntropy$;
import com.intel.analytics.bigdl.dllib.keras.objectives.CategoricalCrossEntropy$;
import com.intel.analytics.bigdl.dllib.keras.objectives.CosineProximity$;
import com.intel.analytics.bigdl.dllib.keras.objectives.Hinge$;
import com.intel.analytics.bigdl.dllib.keras.objectives.KullbackLeiblerDivergence$;
import com.intel.analytics.bigdl.dllib.keras.objectives.MeanAbsoluteError$;
import com.intel.analytics.bigdl.dllib.keras.objectives.MeanAbsolutePercentageError$;
import com.intel.analytics.bigdl.dllib.keras.objectives.MeanSquaredError$;
import com.intel.analytics.bigdl.dllib.keras.objectives.MeanSquaredLogarithmicError$;
import com.intel.analytics.bigdl.dllib.keras.objectives.Poisson$;
import com.intel.analytics.bigdl.dllib.keras.objectives.RankHinge$;
import com.intel.analytics.bigdl.dllib.keras.objectives.SparseCategoricalCrossEntropy$;
import com.intel.analytics.bigdl.dllib.keras.objectives.SquaredHinge$;
import com.intel.analytics.bigdl.dllib.nn.Container;
import com.intel.analytics.bigdl.dllib.nn.HardSigmoid;
import com.intel.analytics.bigdl.dllib.nn.HardSigmoid$;
import com.intel.analytics.bigdl.dllib.nn.Identity;
import com.intel.analytics.bigdl.dllib.nn.Identity$;
import com.intel.analytics.bigdl.dllib.nn.InitializationMethod;
import com.intel.analytics.bigdl.dllib.nn.LogSigmoid;
import com.intel.analytics.bigdl.dllib.nn.LogSigmoid$;
import com.intel.analytics.bigdl.dllib.nn.LogSoftMax;
import com.intel.analytics.bigdl.dllib.nn.LogSoftMax$;
import com.intel.analytics.bigdl.dllib.nn.Ones$;
import com.intel.analytics.bigdl.dllib.nn.RandomNormal;
import com.intel.analytics.bigdl.dllib.nn.RandomUniform;
import com.intel.analytics.bigdl.dllib.nn.ReLU;
import com.intel.analytics.bigdl.dllib.nn.ReLU$;
import com.intel.analytics.bigdl.dllib.nn.ReLU6;
import com.intel.analytics.bigdl.dllib.nn.ReLU6$;
import com.intel.analytics.bigdl.dllib.nn.Sigmoid;
import com.intel.analytics.bigdl.dllib.nn.Sigmoid$;
import com.intel.analytics.bigdl.dllib.nn.SoftMax;
import com.intel.analytics.bigdl.dllib.nn.SoftMin;
import com.intel.analytics.bigdl.dllib.nn.SoftMin$;
import com.intel.analytics.bigdl.dllib.nn.SoftPlus;
import com.intel.analytics.bigdl.dllib.nn.SoftPlus$;
import com.intel.analytics.bigdl.dllib.nn.SoftSign;
import com.intel.analytics.bigdl.dllib.nn.SoftSign$;
import com.intel.analytics.bigdl.dllib.nn.Tanh;
import com.intel.analytics.bigdl.dllib.nn.Tanh$;
import com.intel.analytics.bigdl.dllib.nn.TanhShrink;
import com.intel.analytics.bigdl.dllib.nn.TanhShrink$;
import com.intel.analytics.bigdl.dllib.nn.Xavier$;
import com.intel.analytics.bigdl.dllib.nn.Zeros$;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractCriterion;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.dllib.nn.internal.KerasIdentityWrapper;
import com.intel.analytics.bigdl.dllib.nn.internal.KerasLayer;
import com.intel.analytics.bigdl.dllib.nn.internal.KerasLayerWrapper;
import com.intel.analytics.bigdl.dllib.nn.internal.Sequential;
import com.intel.analytics.bigdl.dllib.nn.internal.Sequential$;
import com.intel.analytics.bigdl.dllib.optim.Adadelta;
import com.intel.analytics.bigdl.dllib.optim.Adagrad;
import com.intel.analytics.bigdl.dllib.optim.Adagrad$;
import com.intel.analytics.bigdl.dllib.optim.Adam;
import com.intel.analytics.bigdl.dllib.optim.Adam$;
import com.intel.analytics.bigdl.dllib.optim.Adamax;
import com.intel.analytics.bigdl.dllib.optim.Adamax$;
import com.intel.analytics.bigdl.dllib.optim.OptimMethod;
import com.intel.analytics.bigdl.dllib.optim.RMSprop;
import com.intel.analytics.bigdl.dllib.optim.RMSprop$;
import com.intel.analytics.bigdl.dllib.optim.SGD;
import com.intel.analytics.bigdl.dllib.optim.SGD$;
import com.intel.analytics.bigdl.dllib.optim.ValidationMethod;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.dllib.utils.Log4Error$;
import com.intel.analytics.bigdl.dllib.utils.MultiShape;
import com.intel.analytics.bigdl.dllib.utils.Node;
import com.intel.analytics.bigdl.dllib.utils.Shape;
import com.intel.analytics.bigdl.dllib.utils.Shape$;
import com.intel.analytics.bigdl.dllib.utils.SingleShape;
import com.intel.analytics.bigdl.package$;
import java.lang.reflect.Method;
import org.apache.spark.rdd.RDD;
import scala.Array$;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableOnce;
import scala.collection.immutable.List;
import scala.collection.immutable.List$;
import scala.collection.immutable.Nil$;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.ArrayBuffer;
import scala.collection.mutable.ArrayBuffer$;
import scala.collection.mutable.StringBuilder;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;
import scala.runtime.ObjectRef;
import scala.runtime.RichInt$;

/* compiled from: KerasUtils.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/dllib/keras/layers/utils/KerasUtils$.class */
public final class KerasUtils$ {
    public static final KerasUtils$ MODULE$ = null;

    static {
        new KerasUtils$();
    }

    public Tuple2<Object, Object> getPadsFromBorderMode(String str, int[] iArr) {
        if (iArr == null || Predef$.MODULE$.intArrayOps(iArr).isEmpty()) {
            return (str != null ? !str.equals("same") : "same" != 0) ? new Tuple2.mcII.sp(0, 0) : new Tuple2.mcII.sp(-1, -1);
        }
        Log4Error$.MODULE$.invalidOperationError(iArr.length == 2, new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"expect paddings length is 2, but got ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(iArr.length)})), Log4Error$.MODULE$.invalidOperationError$default$3(), Log4Error$.MODULE$.invalidOperationError$default$4());
        return new Tuple2.mcII.sp(iArr[0], iArr[1]);
    }

    public String getPadsFromBorderMode$default$1() {
        return "valid";
    }

    public int[] getPadsFromBorderMode$default$2() {
        return null;
    }

    public InitializationMethod getInitMethod(String str, double[] dArr) {
        Serializable serializable;
        String lowerCase = str.toLowerCase();
        if ("glorot_uniform".equals(lowerCase)) {
            serializable = Xavier$.MODULE$;
        } else if ("one".equals(lowerCase)) {
            serializable = Ones$.MODULE$;
        } else if ("zero".equals(lowerCase)) {
            serializable = Zeros$.MODULE$;
        } else if ("uniform".equals(lowerCase)) {
            serializable = dArr == null ? new RandomUniform(-0.05d, 0.05d) : new RandomUniform(BoxesRunTime.unboxToDouble(Predef$.MODULE$.doubleArrayOps(dArr).head()), dArr[1]);
        } else if ("normal".equals(lowerCase)) {
            serializable = dArr == null ? new RandomNormal(0.0d, 0.05d) : new RandomNormal(BoxesRunTime.unboxToDouble(Predef$.MODULE$.doubleArrayOps(dArr).head()), dArr[1]);
        } else {
            Log4Error$.MODULE$.invalidInputError(false, new StringBuilder().append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Unsupported initialization method: "})).s(Nil$.MODULE$)).append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{str.toLowerCase()}))).toString(), "only support glorot_uniform, one, zero, uniform, normal");
            serializable = null;
        }
        return serializable;
    }

    public double[] getInitMethod$default$2() {
        return null;
    }

    public <T> KerasLayer<Tensor<T>, Tensor<T>, T> getKerasActivation(String str, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        if (str == null) {
            return null;
        }
        String lowerCase = str.toLowerCase();
        if (lowerCase != null ? lowerCase.equals("softmax") : "softmax" == 0) {
            return SoftMax$.MODULE$.apply(SoftMax$.MODULE$.apply$default$1(), classTag, tensorNumeric);
        }
        return new KerasIdentityWrapper(package$.MODULE$.convModule(getTorchActivation(str, classTag, tensorNumeric)), classTag, tensorNumeric);
    }

    public <T> String getActivationName(AbstractModule<?, ?, T> abstractModule, ClassTag<T> classTag) {
        String str;
        if (abstractModule == null) {
            Log4Error$.MODULE$.invalidInputError(false, "activation is null", Log4Error$.MODULE$.invalidInputError$default$3());
            return "";
        }
        if (abstractModule instanceof Tanh) {
            str = "tanh";
        } else if (abstractModule instanceof Sigmoid) {
            str = "sigmoid";
        } else if (abstractModule instanceof ReLU) {
            str = "relu";
        } else if (abstractModule instanceof SoftMax) {
            str = "softmax";
        } else if (abstractModule instanceof SoftPlus) {
            str = "softplus";
        } else if (abstractModule instanceof SoftSign) {
            str = "softsign";
        } else if (abstractModule instanceof HardSigmoid) {
            str = "hard_sigmoid";
        } else if (abstractModule instanceof ReLU6) {
            str = "relu6";
        } else if (abstractModule instanceof TanhShrink) {
            str = "tanh_shrink";
        } else if (abstractModule instanceof SoftMin) {
            str = "softmin";
        } else if (abstractModule instanceof LogSigmoid) {
            str = "log_sigmoid";
        } else if (abstractModule instanceof LogSoftMax) {
            str = "log_softmax";
        } else if (abstractModule instanceof Identity) {
            str = "linear";
        } else if (abstractModule instanceof com.intel.analytics.bigdl.dllib.keras.layers.SoftMax) {
            str = "softmax";
        } else {
            Log4Error$.MODULE$.invalidInputError(false, new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"unkown activation ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{abstractModule.getClass().getName()})), Log4Error$.MODULE$.invalidInputError$default$3());
            str = null;
        }
        return str;
    }

    public <T> AbstractModule<Tensor<T>, Tensor<T>, T> getTorchActivation(String str, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        AbstractModule abstractModule;
        if (str == null) {
            return null;
        }
        String lowerCase = str.toLowerCase();
        if ("tanh".equals(lowerCase)) {
            abstractModule = Tanh$.MODULE$.apply(classTag, tensorNumeric);
        } else if ("sigmoid".equals(lowerCase)) {
            abstractModule = Sigmoid$.MODULE$.apply(classTag, tensorNumeric);
        } else if ("relu".equals(lowerCase)) {
            abstractModule = ReLU$.MODULE$.apply(ReLU$.MODULE$.apply$default$1(), classTag, tensorNumeric);
        } else if ("softmax".equals(lowerCase)) {
            abstractModule = com.intel.analytics.bigdl.dllib.nn.SoftMax$.MODULE$.apply(classTag, tensorNumeric);
        } else if ("softplus".equals(lowerCase)) {
            abstractModule = SoftPlus$.MODULE$.apply(SoftPlus$.MODULE$.apply$default$1(), classTag, tensorNumeric);
        } else if ("softsign".equals(lowerCase)) {
            abstractModule = SoftSign$.MODULE$.apply(classTag, tensorNumeric);
        } else if ("hard_sigmoid".equals(lowerCase)) {
            abstractModule = HardSigmoid$.MODULE$.apply(classTag, tensorNumeric);
        } else if ("relu6".equals(lowerCase)) {
            abstractModule = ReLU6$.MODULE$.apply(ReLU6$.MODULE$.apply$default$1(), classTag, tensorNumeric);
        } else if ("tanh_shrink".equals(lowerCase)) {
            abstractModule = TanhShrink$.MODULE$.apply(classTag, tensorNumeric);
        } else if ("softmin".equals(lowerCase)) {
            abstractModule = SoftMin$.MODULE$.apply(classTag, tensorNumeric);
        } else if ("log_sigmoid".equals(lowerCase)) {
            abstractModule = LogSigmoid$.MODULE$.apply(classTag, tensorNumeric);
        } else if ("log_softmax".equals(lowerCase)) {
            abstractModule = LogSoftMax$.MODULE$.apply(classTag, tensorNumeric);
        } else if ("linear".equals(lowerCase)) {
            abstractModule = Identity$.MODULE$.apply(classTag, tensorNumeric);
        } else {
            Log4Error$.MODULE$.invalidInputError(false, new StringBuilder().append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Invalid activation: "})).s(Nil$.MODULE$)).append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", ". Only simple activations can be constructed using string"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{str.toLowerCase()}))).toString(), Log4Error$.MODULE$.invalidInputError$default$3());
            abstractModule = null;
        }
        return abstractModule;
    }

    public int computeConvOutputLength(int i, int i2, String str, int i3, int i4) {
        int i5;
        int i6 = i2 + ((i2 - 1) * (i4 - 1));
        if ("valid".equals(str)) {
            i5 = (i - i6) + 1;
        } else {
            if (!"same".equals(str)) {
                throw new MatchError(str);
            }
            i5 = i;
        }
        return ((i5 + i3) - 1) / i3;
    }

    public int computeConvOutputLength$default$5() {
        return 1;
    }

    public Tuple3<Object, Object, Object> getPadsFromBorderMode3D(String str) {
        return (str != null ? !str.equals("same") : "same" != 0) ? new Tuple3<>(BoxesRunTime.boxToInteger(0), BoxesRunTime.boxToInteger(0), BoxesRunTime.boxToInteger(0)) : new Tuple3<>(BoxesRunTime.boxToInteger(-1), BoxesRunTime.boxToInteger(-1), BoxesRunTime.boxToInteger(-1));
    }

    public String getPadsFromBorderMode3D$default$1() {
        return "valid";
    }

    /* JADX WARN: Removed duplicated region for block: B:12:0x0090  */
    /* JADX WARN: Removed duplicated region for block: B:8:0x0088  */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public com.intel.analytics.bigdl.dllib.nn.abstractnn.DataFormat toBigDLFormat(java.lang.String r11) {
        /*
            r10 = this;
            com.intel.analytics.bigdl.dllib.utils.Log4Error$ r0 = com.intel.analytics.bigdl.dllib.utils.Log4Error$.MODULE$
            r1 = r11
            java.lang.String r1 = r1.toLowerCase()
            java.lang.String r2 = "tf"
            r12 = r2
            r2 = r1
            if (r2 != 0) goto L17
        L10:
            r1 = r12
            if (r1 == 0) goto L39
            goto L1e
        L17:
            r2 = r12
            boolean r1 = r1.equals(r2)
            if (r1 != 0) goto L39
        L1e:
            r1 = r11
            java.lang.String r1 = r1.toLowerCase()
            java.lang.String r2 = "th"
            r13 = r2
            r2 = r1
            if (r2 != 0) goto L32
        L2b:
            r1 = r13
            if (r1 == 0) goto L39
            goto L3d
        L32:
            r2 = r13
            boolean r1 = r1.equals(r2)
            if (r1 == 0) goto L3d
        L39:
            r1 = 1
            goto L3e
        L3d:
            r1 = 0
        L3e:
            scala.StringContext r2 = new scala.StringContext
            r3 = r2
            scala.Predef$ r4 = scala.Predef$.MODULE$
            r5 = 2
            java.lang.String[] r5 = new java.lang.String[r5]
            r6 = r5
            r7 = 0
            java.lang.String r8 = "Dim ordering must be either tf or th, but got "
            r6[r7] = r8
            r6 = r5
            r7 = 1
            java.lang.String r8 = ""
            r6[r7] = r8
            java.lang.Object[] r5 = (java.lang.Object[]) r5
            scala.collection.mutable.WrappedArray r4 = r4.wrapRefArray(r5)
            r3.<init>(r4)
            scala.Predef$ r3 = scala.Predef$.MODULE$
            r4 = 1
            java.lang.Object[] r4 = new java.lang.Object[r4]
            r5 = r4
            r6 = 0
            r7 = r11
            java.lang.String r7 = r7.toLowerCase()
            r5[r6] = r7
            scala.collection.mutable.WrappedArray r3 = r3.genericWrapArray(r4)
            java.lang.String r2 = r2.s(r3)
            java.lang.String r3 = "Please set dimOrdering=tf or dimOrdering=tf"
            r0.invalidInputError(r1, r2, r3)
            r0 = r11
            java.lang.String r0 = r0.toLowerCase()
            r14 = r0
            java.lang.String r0 = "tf"
            r1 = r14
            boolean r0 = r0.equals(r1)
            if (r0 == 0) goto L90
            com.intel.analytics.bigdl.dllib.nn.abstractnn.DataFormat$NHWC$ r0 = com.intel.analytics.bigdl.dllib.nn.abstractnn.DataFormat$NHWC$.MODULE$
            r15 = r0
            goto La0
        L90:
            java.lang.String r0 = "th"
            r1 = r14
            boolean r0 = r0.equals(r1)
            if (r0 == 0) goto La3
            com.intel.analytics.bigdl.dllib.nn.abstractnn.DataFormat$NCHW$ r0 = com.intel.analytics.bigdl.dllib.nn.abstractnn.DataFormat$NCHW$.MODULE$
            r15 = r0
        La0:
            r0 = r15
            return r0
        La3:
            scala.MatchError r0 = new scala.MatchError
            r1 = r0
            r2 = r14
            r1.<init>(r2)
            throw r0
        */
        throw new UnsupportedOperationException("Method not decompiled: com.intel.analytics.bigdl.dllib.keras.layers.utils.KerasUtils$.toBigDLFormat(java.lang.String):com.intel.analytics.bigdl.dllib.nn.abstractnn.DataFormat");
    }

    /* JADX WARN: Removed duplicated region for block: B:12:0x0093  */
    /* JADX WARN: Removed duplicated region for block: B:8:0x008b  */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public java.lang.String toBigDLFormat5D(java.lang.String r11) {
        /*
            r10 = this;
            com.intel.analytics.bigdl.dllib.utils.Log4Error$ r0 = com.intel.analytics.bigdl.dllib.utils.Log4Error$.MODULE$
            r1 = r11
            java.lang.String r1 = r1.toLowerCase()
            java.lang.String r2 = "tf"
            r12 = r2
            r2 = r1
            if (r2 != 0) goto L17
        L10:
            r1 = r12
            if (r1 == 0) goto L39
            goto L1e
        L17:
            r2 = r12
            boolean r1 = r1.equals(r2)
            if (r1 != 0) goto L39
        L1e:
            r1 = r11
            java.lang.String r1 = r1.toLowerCase()
            java.lang.String r2 = "th"
            r13 = r2
            r2 = r1
            if (r2 != 0) goto L32
        L2b:
            r1 = r13
            if (r1 == 0) goto L39
            goto L3d
        L32:
            r2 = r13
            boolean r1 = r1.equals(r2)
            if (r1 == 0) goto L3d
        L39:
            r1 = 1
            goto L3e
        L3d:
            r1 = 0
        L3e:
            scala.StringContext r2 = new scala.StringContext
            r3 = r2
            scala.Predef$ r4 = scala.Predef$.MODULE$
            r5 = 2
            java.lang.String[] r5 = new java.lang.String[r5]
            r6 = r5
            r7 = 0
            java.lang.String r8 = "Dim ordering must be either tf or th, but got "
            r6[r7] = r8
            r6 = r5
            r7 = 1
            java.lang.String r8 = ""
            r6[r7] = r8
            java.lang.Object[] r5 = (java.lang.Object[]) r5
            scala.collection.mutable.WrappedArray r4 = r4.wrapRefArray(r5)
            r3.<init>(r4)
            scala.Predef$ r3 = scala.Predef$.MODULE$
            r4 = 1
            java.lang.Object[] r4 = new java.lang.Object[r4]
            r5 = r4
            r6 = 0
            r7 = r11
            java.lang.String r7 = r7.toLowerCase()
            r5[r6] = r7
            scala.collection.mutable.WrappedArray r3 = r3.genericWrapArray(r4)
            java.lang.String r2 = r2.s(r3)
            com.intel.analytics.bigdl.dllib.utils.Log4Error$ r3 = com.intel.analytics.bigdl.dllib.utils.Log4Error$.MODULE$
            java.lang.String r3 = r3.invalidInputError$default$3()
            r0.invalidInputError(r1, r2, r3)
            r0 = r11
            java.lang.String r0 = r0.toLowerCase()
            r14 = r0
            java.lang.String r0 = "tf"
            r1 = r14
            boolean r0 = r0.equals(r1)
            if (r0 == 0) goto L93
            java.lang.String r0 = "CHANNEL_LAST"
            r15 = r0
            goto La3
        L93:
            java.lang.String r0 = "th"
            r1 = r14
            boolean r0 = r0.equals(r1)
            if (r0 == 0) goto La6
            java.lang.String r0 = "CHANNEL_FIRST"
            r15 = r0
        La3:
            r0 = r15
            return r0
        La6:
            scala.MatchError r0 = new scala.MatchError
            r1 = r0
            r2 = r14
            r1.<init>(r2)
            throw r0
        */
        throw new UnsupportedOperationException("Method not decompiled: com.intel.analytics.bigdl.dllib.keras.layers.utils.KerasUtils$.toBigDLFormat5D(java.lang.String):java.lang.String");
    }

    public <T> AbstractCriterion<Activity, Activity, T> toBigDLCriterion(String str, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        AbstractCriterion<Activity, Activity, T> abstractCriterion;
        String lowerCase = str.toLowerCase();
        if ("binary_crossentropy".equals(lowerCase)) {
            package$ package_ = package$.MODULE$;
            BinaryCrossEntropy$ binaryCrossEntropy$ = BinaryCrossEntropy$.MODULE$;
            BinaryCrossEntropy$.MODULE$.apply$default$1();
            abstractCriterion = package_.convCriterion(binaryCrossEntropy$.apply(null, BinaryCrossEntropy$.MODULE$.apply$default$2(), classTag, tensorNumeric));
        } else if ("categorical_crossentropy".equals(lowerCase)) {
            abstractCriterion = package$.MODULE$.convCriterion(CategoricalCrossEntropy$.MODULE$.apply(classTag, tensorNumeric));
        } else if ("mse".equals(lowerCase)) {
            abstractCriterion = package$.MODULE$.convCriterion(MeanSquaredError$.MODULE$.apply(MeanSquaredError$.MODULE$.apply$default$1(), classTag, tensorNumeric));
        } else if ("mean_squared_error".equals(lowerCase)) {
            abstractCriterion = package$.MODULE$.convCriterion(MeanSquaredError$.MODULE$.apply(MeanSquaredError$.MODULE$.apply$default$1(), classTag, tensorNumeric));
        } else if ("mae".equals(lowerCase)) {
            abstractCriterion = package$.MODULE$.convCriterion(MeanAbsoluteError$.MODULE$.apply(MeanAbsoluteError$.MODULE$.apply$default$1(), classTag, tensorNumeric));
        } else if ("mean_absolute_error".equals(lowerCase)) {
            abstractCriterion = package$.MODULE$.convCriterion(MeanAbsoluteError$.MODULE$.apply(MeanAbsoluteError$.MODULE$.apply$default$1(), classTag, tensorNumeric));
        } else if ("hinge".equals(lowerCase)) {
            abstractCriterion = package$.MODULE$.convCriterion(Hinge$.MODULE$.apply(Hinge$.MODULE$.apply$default$1(), Hinge$.MODULE$.apply$default$2(), classTag, tensorNumeric));
        } else if ("mape".equals(lowerCase)) {
            abstractCriterion = package$.MODULE$.convCriterion(MeanAbsolutePercentageError$.MODULE$.apply(classTag, tensorNumeric));
        } else if ("mean_absolute_percentage_error".equals(lowerCase)) {
            abstractCriterion = package$.MODULE$.convCriterion(MeanAbsolutePercentageError$.MODULE$.apply(classTag, tensorNumeric));
        } else if ("msle".equals(lowerCase)) {
            abstractCriterion = package$.MODULE$.convCriterion(MeanSquaredLogarithmicError$.MODULE$.apply(classTag, tensorNumeric));
        } else if ("mean_squared_logarithmic_error".equals(lowerCase)) {
            abstractCriterion = package$.MODULE$.convCriterion(MeanSquaredLogarithmicError$.MODULE$.apply(classTag, tensorNumeric));
        } else if ("squared_hinge".equals(lowerCase)) {
            abstractCriterion = package$.MODULE$.convCriterion(SquaredHinge$.MODULE$.apply(SquaredHinge$.MODULE$.apply$default$1(), SquaredHinge$.MODULE$.apply$default$2(), classTag, tensorNumeric));
        } else if ("sparse_categorical_crossentropy".equals(lowerCase)) {
            package$ package_2 = package$.MODULE$;
            SparseCategoricalCrossEntropy$ sparseCategoricalCrossEntropy$ = SparseCategoricalCrossEntropy$.MODULE$;
            boolean apply$default$1 = SparseCategoricalCrossEntropy$.MODULE$.apply$default$1();
            boolean apply$default$2 = SparseCategoricalCrossEntropy$.MODULE$.apply$default$2();
            SparseCategoricalCrossEntropy$.MODULE$.apply$default$3();
            abstractCriterion = package_2.convCriterion(sparseCategoricalCrossEntropy$.apply(apply$default$1, apply$default$2, null, SparseCategoricalCrossEntropy$.MODULE$.apply$default$4(), SparseCategoricalCrossEntropy$.MODULE$.apply$default$5(), classTag, tensorNumeric));
        } else if ("kld".equals(lowerCase)) {
            abstractCriterion = package$.MODULE$.convCriterion(KullbackLeiblerDivergence$.MODULE$.apply(classTag, tensorNumeric));
        } else if ("kullback_leibler_divergence".equals(lowerCase)) {
            abstractCriterion = package$.MODULE$.convCriterion(KullbackLeiblerDivergence$.MODULE$.apply(classTag, tensorNumeric));
        } else if ("cosine_proximity".equals(lowerCase)) {
            abstractCriterion = package$.MODULE$.convCriterion(CosineProximity$.MODULE$.apply(classTag, tensorNumeric));
        } else if ("poisson".equals(lowerCase)) {
            abstractCriterion = package$.MODULE$.convCriterion(Poisson$.MODULE$.apply(classTag, tensorNumeric));
        } else if ("rank_hinge".equals(lowerCase)) {
            abstractCriterion = package$.MODULE$.convCriterion(RankHinge$.MODULE$.apply(RankHinge$.MODULE$.apply$default$1(), classTag, tensorNumeric));
        } else {
            Log4Error$.MODULE$.invalidInputError(false, new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Unsupported loss: ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{str})), Log4Error$.MODULE$.invalidInputError$default$3());
            abstractCriterion = null;
        }
        return abstractCriterion;
    }

    public <T> OptimMethod<T> toBigDLOptimMethod(String str, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        OptimMethod adam;
        String lowerCase = str.toLowerCase();
        if ("sgd".equals(lowerCase)) {
            double $lessinit$greater$default$2 = SGD$.MODULE$.$lessinit$greater$default$2();
            double $lessinit$greater$default$3 = SGD$.MODULE$.$lessinit$greater$default$3();
            double $lessinit$greater$default$4 = SGD$.MODULE$.$lessinit$greater$default$4();
            double $lessinit$greater$default$5 = SGD$.MODULE$.$lessinit$greater$default$5();
            boolean $lessinit$greater$default$6 = SGD$.MODULE$.$lessinit$greater$default$6();
            SGD.LearningRateSchedule $lessinit$greater$default$7 = SGD$.MODULE$.$lessinit$greater$default$7();
            SGD$.MODULE$.$lessinit$greater$default$8();
            SGD$.MODULE$.$lessinit$greater$default$9();
            adam = new SGD(0.01d, $lessinit$greater$default$2, $lessinit$greater$default$3, $lessinit$greater$default$4, $lessinit$greater$default$5, $lessinit$greater$default$6, $lessinit$greater$default$7, null, null, classTag, tensorNumeric);
        } else if ("rmsprop".equals(lowerCase)) {
            adam = new RMSprop(0.001d, RMSprop$.MODULE$.$lessinit$greater$default$2(), 0.9d, RMSprop$.MODULE$.$lessinit$greater$default$4(), classTag, tensorNumeric);
        } else if ("adamax".equals(lowerCase)) {
            adam = new Adamax(Adamax$.MODULE$.$lessinit$greater$default$1(), Adamax$.MODULE$.$lessinit$greater$default$2(), Adamax$.MODULE$.$lessinit$greater$default$3(), 1.0E-8d, classTag, tensorNumeric);
        } else if ("adagrad".equals(lowerCase)) {
            adam = new Adagrad(0.01d, Adagrad$.MODULE$.$lessinit$greater$default$2(), Adagrad$.MODULE$.$lessinit$greater$default$3(), classTag, tensorNumeric);
        } else if ("adadelta".equals(lowerCase)) {
            adam = new Adadelta(0.95d, 1.0E-8d, classTag, tensorNumeric);
        } else {
            if (!"adam".equals(lowerCase)) {
                throw new MatchError(lowerCase);
            }
            adam = new Adam(Adam$.MODULE$.$lessinit$greater$default$1(), Adam$.MODULE$.$lessinit$greater$default$2(), Adam$.MODULE$.$lessinit$greater$default$3(), Adam$.MODULE$.$lessinit$greater$default$4(), Adam$.MODULE$.$lessinit$greater$default$5(), classTag, tensorNumeric);
        }
        return adam;
    }

    public <T> ValidationMethod<T> com$intel$analytics$bigdl$dllib$keras$layers$utils$KerasUtils$$mappingForAcc(String str, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        ValidationMethod validationMethod;
        String lowerCase = str.toLowerCase();
        if ("sparse_categorical_crossentropy".equals(lowerCase)) {
            validationMethod = new SparseCategoricalAccuracy(classTag, tensorNumeric);
        } else if ("categorical_crossentropy".equals(lowerCase)) {
            validationMethod = new CategoricalAccuracy(classTag, tensorNumeric);
        } else if ("binary_crossentropy".equals(lowerCase)) {
            validationMethod = new BinaryAccuracy(classTag, tensorNumeric);
        } else {
            Log4Error$.MODULE$.invalidInputError(false, new StringBuilder().append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Unsupported metric: accuracy and "})).s(Nil$.MODULE$)).append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"loss: ", " combination"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{str}))).toString(), Log4Error$.MODULE$.invalidInputError$default$3());
            validationMethod = null;
        }
        return validationMethod;
    }

    public <T> List<ValidationMethod<T>> toBigDLMetrics(List<String> list, String str, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        if (list == null) {
            return null;
        }
        return (List) list.map(new KerasUtils$$anonfun$toBigDLMetrics$1(str, classTag, tensorNumeric), List$.MODULE$.canBuildFrom());
    }

    public Shape addBatch(Shape shape) {
        if (shape == null) {
            return null;
        }
        return shape instanceof SingleShape ? Shape$.MODULE$.apply((int[]) ((TraversableOnce) List$.MODULE$.apply(Predef$.MODULE$.wrapIntArray(new int[]{-1})).$plus$plus(shape.toSingle(), List$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.Int())) : new MultiShape((List) shape.toMulti().map(new KerasUtils$$anonfun$addBatch$1(), List$.MODULE$.canBuildFrom()));
    }

    public Shape removeBatch(Shape shape) {
        if (shape == null) {
            return null;
        }
        return shape instanceof SingleShape ? Shape$.MODULE$.apply((int[]) shape.toSingle().slice(1, shape.toSingle().length()).toArray(ClassTag$.MODULE$.Int())) : new MultiShape((List) shape.toMulti().map(new KerasUtils$$anonfun$removeBatch$1(), List$.MODULE$.canBuildFrom()));
    }

    public <T> AbstractModule<Activity, Activity, T> fuse(AbstractModule<Activity, Activity, T> abstractModule, KerasLayer<Tensor<T>, Tensor<T>, T> kerasLayer, Shape shape, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        if (kerasLayer == null) {
            return abstractModule;
        }
        Sequential<T> apply = Sequential$.MODULE$.apply(classTag, tensorNumeric);
        apply.add(new KerasLayerWrapper(abstractModule, removeBatch(shape), classTag, tensorNumeric));
        apply.add(kerasLayer);
        apply.setName(abstractModule.getName());
        apply.build(shape);
        return apply;
    }

    public Object invokeMethod(Object obj, String str, Seq<Object> seq) {
        Method method;
        Class<?> cls = obj.getClass();
        try {
            method = cls.getMethod(str, (Class[]) ((TraversableOnce) seq.map(new KerasUtils$$anonfun$1(), Seq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.apply(Class.class)));
        } catch (Throwable th) {
            Method[] methodArr = (Method[]) Predef$.MODULE$.refArrayOps(cls.getMethods()).filter(new KerasUtils$$anonfun$2(str));
            Log4Error$.MODULE$.invalidOperationError(methodArr.length == 1, new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"We should only found one result, but got ", ": ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{str, BoxesRunTime.boxToInteger(methodArr.length)})), Log4Error$.MODULE$.invalidOperationError$default$3(), Log4Error$.MODULE$.invalidOperationError$default$4());
            method = methodArr[0];
        }
        return method.invoke(obj, (Object[]) seq.toArray(ClassTag$.MODULE$.Object()));
    }

    public <T> Object invokeMethodWithEv(String str, String str2, Seq<Object> seq, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Method method;
        Class<?> cls = Class.forName(str);
        try {
            method = cls.getMethod(str2, (Class[]) ((TraversableOnce) seq.map(new KerasUtils$$anonfun$3(), Seq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.apply(Class.class)));
        } catch (Throwable th) {
            Method[] methodArr = (Method[]) Predef$.MODULE$.refArrayOps(cls.getMethods()).filter(new KerasUtils$$anonfun$4(str2));
            Log4Error$.MODULE$.invalidOperationError(methodArr.length == 1, new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"We should only found one result, but got ", ": ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{str2, BoxesRunTime.boxToInteger(methodArr.length)})), Log4Error$.MODULE$.invalidOperationError$default$3(), Log4Error$.MODULE$.invalidOperationError$default$4());
            method = methodArr[0];
        }
        return method.invoke(str, (Object[]) ((Seq) seq.$plus$plus(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Object[]{Predef$.MODULE$.implicitly(classTag), tensorNumeric})), Seq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.Object()));
    }

    public <T> Object invokeMethodWithEv(Object obj, String str, Seq<Object> seq, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        return invokeMethod(obj, str, (Seq) seq.$plus$plus(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Object[]{Predef$.MODULE$.implicitly(classTag), tensorNumeric})), Seq$.MODULE$.canBuildFrom()));
    }

    /* JADX WARN: Multi-variable type inference failed */
    public <T> Tuple2<Object, Object> countParams(KerasLayer<Activity, Activity, T> kerasLayer, ClassTag<T> classTag) {
        Tuple2 parameters = kerasLayer.parameters();
        if (parameters == null) {
            throw new MatchError(parameters);
        }
        Tuple2 tuple2 = new Tuple2((Tensor[]) parameters._1(), (Tensor[]) parameters._2());
        Tensor[] tensorArr = (Tensor[]) tuple2._1();
        IntRef create = IntRef.create(0);
        Predef$.MODULE$.refArrayOps(tensorArr).foreach(new KerasUtils$$anonfun$countParams$1(create));
        if (!(kerasLayer instanceof KerasNet)) {
            return ((Net) kerasLayer).isFrozen(classTag) ? new Tuple2.mcII.sp(create.elem, 0) : new Tuple2.mcII.sp(create.elem, create.elem);
        }
        ArrayBuffer<AbstractModule<Activity, Activity, T>> modules = ((Container) kerasLayer.labor()).modules();
        IntRef create2 = IntRef.create(0);
        modules.foreach(new KerasUtils$$anonfun$countParams$2(classTag, create2));
        return new Tuple2.mcII.sp(create.elem, create2.elem);
    }

    public <T> String[] getLayerSummary(KerasLayer<Activity, Activity, T> kerasLayer, ClassTag<T> classTag) {
        String strShape = strShape(kerasLayer.getOutputShape());
        return new String[]{new StringBuilder().append(kerasLayer.getName()).append(" (").append(kerasLayer.getClass().getSimpleName()).append(")").toString(), strShape.toString(), BoxesRunTime.boxToInteger(countParams(kerasLayer, classTag)._1$mcI$sp()).toString()};
    }

    public <T> String[] getNodeSummary(Node<AbstractModule<Activity, Activity, T>> node, ClassTag<T> classTag) {
        String[] layerSummary = getLayerSummary((KerasLayer) node.element(), classTag);
        ObjectRef create = ObjectRef.create("");
        Seq<Node<AbstractModule<Activity, Activity, T>>> prevNodes = node.prevNodes();
        prevNodes.indices().foreach$mVc$sp(new KerasUtils$$anonfun$getNodeSummary$1(create, prevNodes));
        return (String[]) Predef$.MODULE$.refArrayOps(layerSummary).$plus$plus(Predef$.MODULE$.refArrayOps(new String[]{(String) create.elem}), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(String.class)));
    }

    public <T> Tuple2<Object, Object> printNodeSummary(Node<AbstractModule<Activity, Activity, T>> node, int i, double[] dArr, ArrayBuffer<String> arrayBuffer, ClassTag<T> classTag) {
        printRow(getNodeSummary(node, classTag), i, dArr, printRow$default$4(), printRow$default$5(), arrayBuffer);
        return countParams((KerasLayer) node.element(), classTag);
    }

    public <T> int printNodeSummary$default$2() {
        return Caffe.LayerParameter.MVN_PARAM_FIELD_NUMBER;
    }

    public <T> double[] printNodeSummary$default$3() {
        return new double[]{0.33d, 0.55d, 0.67d, 1.0d};
    }

    public <T> ArrayBuffer<String> printNodeSummary$default$4() {
        return null;
    }

    public void printRow(String[] strArr, int i, double[] dArr, boolean z, char c, ArrayBuffer<String> arrayBuffer) {
        ArrayBuffer apply = ArrayBuffer$.MODULE$.apply(Nil$.MODULE$);
        Predef$.MODULE$.doubleArrayOps(dArr).indices().foreach$mVc$sp(new KerasUtils$$anonfun$printRow$1(i, dArr, apply));
        ObjectRef create = ObjectRef.create("");
        ObjectRef create2 = ObjectRef.create((String[]) Array$.MODULE$.apply(Nil$.MODULE$, ClassTag$.MODULE$.apply(String.class)));
        Predef$.MODULE$.refArrayOps(strArr).indices().foreach$mVc$sp(new KerasUtils$$anonfun$printRow$2(strArr, apply, create, create2));
        arrayBuffer.append(Predef$.MODULE$.wrapRefArray(new String[]{(String) create.elem}));
        Predef$.MODULE$.refArrayOps((Object[]) Predef$.MODULE$.refArrayOps((String[]) create2.elem).slice(1, ((String[]) create2.elem).length)).foreach(new KerasUtils$$anonfun$printRow$3(i, dArr, arrayBuffer));
        if (z) {
            printSplitLine(c, i, arrayBuffer);
        }
    }

    public int printRow$default$2() {
        return Caffe.LayerParameter.MVN_PARAM_FIELD_NUMBER;
    }

    public double[] printRow$default$3() {
        return new double[]{0.33d, 0.55d, 0.67d, 1.0d};
    }

    public boolean printRow$default$4() {
        return true;
    }

    public char printRow$default$5() {
        return '_';
    }

    public ArrayBuffer<String> printRow$default$6() {
        return null;
    }

    public void printSplitLine(char c, int i, ArrayBuffer<String> arrayBuffer) {
        arrayBuffer.append(Predef$.MODULE$.wrapRefArray(new String[]{new StringOps(Predef$.MODULE$.augmentString(BoxesRunTime.boxToCharacter(c).toString())).$times(i)}));
    }

    public int printSplitLine$default$2() {
        return Caffe.LayerParameter.MVN_PARAM_FIELD_NUMBER;
    }

    public String strShape(Shape shape) {
        String str;
        if (shape instanceof SingleShape) {
            str = new StringBuilder().append("(").append(((SingleShape) shape).toSingle().mkString(", ")).append(")").toString().replaceFirst("-1", "None");
        } else {
            if (!(shape instanceof MultiShape)) {
                throw new MatchError(shape);
            }
            List<Shape> multi = ((MultiShape) shape).toMulti();
            ObjectRef create = ObjectRef.create("");
            multi.foreach(new KerasUtils$$anonfun$strShape$1(create));
            str = (String) create.elem;
        }
        return str;
    }

    public RDD<Object> toZeroBasedLabel(boolean z, RDD<Object> rdd) {
        return z ? rdd.map(new KerasUtils$$anonfun$toZeroBasedLabel$1(), ClassTag$.MODULE$.Int()) : rdd;
    }

    public boolean toZeroBasedLabel$default$1() {
        return true;
    }

    public void validateBatchSize(int i) {
        int coreNumber = EngineRef$.MODULE$.getCoreNumber() * EngineRef$.MODULE$.getNodeNumber();
        Log4Error$.MODULE$.invalidInputError(i % coreNumber == 0, new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"BatchSize: ", " cannot be divided by ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(i), BoxesRunTime.boxToInteger(coreNumber)})), Log4Error$.MODULE$.invalidInputError$default$3());
    }

    public <T> Tensor<T> tril(Tensor<T> tensor, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Log4Error$.MODULE$.invalidInputError(tensor.dim() == 2, "tril expects a matrix!", Log4Error$.MODULE$.invalidInputError$default$3());
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), tensor.size(1)).foreach$mVc$sp(new KerasUtils$$anonfun$tril$1(tensor, tensorNumeric, tensor.stride(1), tensor.stride(2)));
        return tensor;
    }

    private KerasUtils$() {
        MODULE$ = this;
    }
}
