package com.intel.analytics.bigdl.dllib.utils.tf;

import com.intel.analytics.bigdl.dllib.nn.SpatialBatchNormalization;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.Tensor$;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath$TensorNumeric$NumericInt$;
import com.intel.analytics.bigdl.dllib.utils.Log4Error$;
import java.nio.ByteOrder;
import org.tensorflow.framework.NodeDef;
import scala.Predef$;
import scala.StringContext;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.mutable.StringBuilder;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;

/* compiled from: BigDLToTensorflow.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/dllib/utils/tf/BatchNorm2DToTF$.class */
public final class BatchNorm2DToTF$ implements BigDLToTensorflow {
    public static final BatchNorm2DToTF$ MODULE$ = null;

    static {
        new BatchNorm2DToTF$();
    }

    @Override // com.intel.analytics.bigdl.dllib.utils.tf.BigDLToTensorflow
    public Seq<NodeDef> toTFDef(AbstractModule<?, ?, ?> abstractModule, Seq<NodeDef> seq, ByteOrder byteOrder) {
        Log4Error$.MODULE$.invalidInputError(seq.length() == 1, "BatchNorm only accept one input", Log4Error$.MODULE$.invalidInputError$default$3());
        SpatialBatchNormalization spatialBatchNormalization = (SpatialBatchNormalization) abstractModule;
        Log4Error$.MODULE$.invalidInputError(!spatialBatchNormalization.isTraining(), "Only support evaluate mode batch norm", Log4Error$.MODULE$.invalidInputError$default$3());
        Tensor<?> apply = Tensor$.MODULE$.apply(spatialBatchNormalization.nDim(), ClassTag$.MODULE$.Int(), TensorNumericMath$TensorNumeric$NumericInt$.MODULE$);
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), spatialBatchNormalization.nDim()).foreach(new BatchNorm2DToTF$$anonfun$toTFDef$2(apply));
        apply.update(2, (int) BoxesRunTime.boxToInteger(spatialBatchNormalization.runningVar().size(1)));
        if (spatialBatchNormalization.weight() == null) {
            NodeDef m2302const = Tensorflow$.MODULE$.m2302const(apply, new StringBuilder().append(spatialBatchNormalization.getName()).append("/reshape_1/shape").toString(), byteOrder);
            NodeDef m2302const2 = Tensorflow$.MODULE$.m2302const(apply, new StringBuilder().append(spatialBatchNormalization.getName()).append("/reshape_2/shape").toString(), byteOrder);
            NodeDef m2302const3 = Tensorflow$.MODULE$.m2302const(spatialBatchNormalization.runningVar(), new StringBuilder().append(spatialBatchNormalization.getName()).append("/var").toString(), byteOrder);
            NodeDef m2302const4 = Tensorflow$.MODULE$.m2302const(spatialBatchNormalization.runningMean(), new StringBuilder().append(spatialBatchNormalization.getName()).append("/mean").toString(), byteOrder);
            NodeDef reshape = Tensorflow$.MODULE$.reshape(m2302const3, m2302const, new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", "/reshape_1"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{spatialBatchNormalization.getName()})));
            NodeDef reshape2 = Tensorflow$.MODULE$.reshape(m2302const4, m2302const2, new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", "/reshape_2"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{spatialBatchNormalization.getName()})));
            NodeDef rsqrt = Tensorflow$.MODULE$.rsqrt(reshape, new StringBuilder().append(spatialBatchNormalization.getName()).append("/sqrtvar").toString());
            NodeDef multiply = Tensorflow$.MODULE$.multiply((NodeDef) seq.apply(0), rsqrt, new StringBuilder().append(spatialBatchNormalization.getName()).append("/mul1").toString());
            NodeDef multiply2 = Tensorflow$.MODULE$.multiply(reshape2, rsqrt, new StringBuilder().append(spatialBatchNormalization.getName()).append("/mul2").toString());
            return Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new NodeDef[]{Tensorflow$.MODULE$.subtract(multiply, multiply2, new StringBuilder().append(spatialBatchNormalization.getName()).append("/output").toString()), multiply2, multiply, reshape2, m2302const2, m2302const4, rsqrt, reshape, m2302const, m2302const3}));
        }
        NodeDef m2302const5 = Tensorflow$.MODULE$.m2302const(apply, new StringBuilder().append(spatialBatchNormalization.getName()).append("/reshape_1/shape").toString(), byteOrder);
        NodeDef m2302const6 = Tensorflow$.MODULE$.m2302const(apply, new StringBuilder().append(spatialBatchNormalization.getName()).append("/reshape_2/shape").toString(), byteOrder);
        NodeDef m2302const7 = Tensorflow$.MODULE$.m2302const(apply, new StringBuilder().append(spatialBatchNormalization.getName()).append("/reshape_3/shape").toString(), byteOrder);
        NodeDef m2302const8 = Tensorflow$.MODULE$.m2302const(apply, new StringBuilder().append(spatialBatchNormalization.getName()).append("/reshape_4/shape").toString(), byteOrder);
        NodeDef m2302const9 = Tensorflow$.MODULE$.m2302const(spatialBatchNormalization.runningVar(), new StringBuilder().append(spatialBatchNormalization.getName()).append("/var").toString(), byteOrder);
        NodeDef m2302const10 = Tensorflow$.MODULE$.m2302const(spatialBatchNormalization.runningMean(), new StringBuilder().append(spatialBatchNormalization.getName()).append("/mean").toString(), byteOrder);
        NodeDef m2302const11 = Tensorflow$.MODULE$.m2302const(spatialBatchNormalization.weight(), new StringBuilder().append(spatialBatchNormalization.getName()).append("/scale").toString(), byteOrder);
        NodeDef m2302const12 = Tensorflow$.MODULE$.m2302const(spatialBatchNormalization.bias(), new StringBuilder().append(spatialBatchNormalization.getName()).append("/offset").toString(), byteOrder);
        NodeDef reshape3 = Tensorflow$.MODULE$.reshape(m2302const9, m2302const5, new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", "/reshape_1"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{spatialBatchNormalization.getName()})));
        NodeDef reshape4 = Tensorflow$.MODULE$.reshape(m2302const10, m2302const6, new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", "/reshape_2"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{spatialBatchNormalization.getName()})));
        NodeDef reshape5 = Tensorflow$.MODULE$.reshape(m2302const11, m2302const7, new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", "/reshape_3"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{spatialBatchNormalization.getName()})));
        NodeDef reshape6 = Tensorflow$.MODULE$.reshape(m2302const12, m2302const8, new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", "/reshape_4"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{spatialBatchNormalization.getName()})));
        NodeDef rsqrt2 = Tensorflow$.MODULE$.rsqrt(reshape3, new StringBuilder().append(spatialBatchNormalization.getName()).append("/sqrtvar").toString());
        NodeDef multiply3 = Tensorflow$.MODULE$.multiply(reshape5, rsqrt2, new StringBuilder().append(spatialBatchNormalization.getName()).append("/mul0").toString());
        NodeDef multiply4 = Tensorflow$.MODULE$.multiply((NodeDef) seq.apply(0), multiply3, new StringBuilder().append(spatialBatchNormalization.getName()).append("/mul1").toString());
        NodeDef multiply5 = Tensorflow$.MODULE$.multiply(reshape4, multiply3, new StringBuilder().append(spatialBatchNormalization.getName()).append("/mul2").toString());
        NodeDef subtract = Tensorflow$.MODULE$.subtract(reshape6, multiply5, new StringBuilder().append(spatialBatchNormalization.getName()).append("/sub").toString());
        return Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new NodeDef[]{Tensorflow$.MODULE$.add(multiply4, subtract, new StringBuilder().append(spatialBatchNormalization.getName()).append("/output").toString()), subtract, multiply5, multiply4, multiply3, reshape6, reshape4, reshape5, m2302const8, m2302const6, m2302const7, m2302const12, m2302const11, m2302const10, rsqrt2, reshape3, m2302const5, m2302const9}));
    }

    private BatchNorm2DToTF$() {
        MODULE$ = this;
    }
}
