package com.intel.analytics.bigdl.dllib.utils.tf;

import com.intel.analytics.bigdl.dllib.nn.Contiguous$;
import com.intel.analytics.bigdl.dllib.nn.SelectTable$;
import com.intel.analytics.bigdl.dllib.nn.Sequential;
import com.intel.analytics.bigdl.dllib.nn.Sequential$;
import com.intel.analytics.bigdl.dllib.nn.SpatialBatchNormalization;
import com.intel.analytics.bigdl.dllib.nn.SpatialBatchNormalization$;
import com.intel.analytics.bigdl.dllib.nn.Transpose$;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.dllib.utils.DirectedGraph;
import com.intel.analytics.bigdl.dllib.utils.Node;
import com.intel.analytics.bigdl.dllib.utils.Node$;
import com.intel.analytics.bigdl.dllib.utils.tf.TensorflowToBigDL;
import java.nio.ByteOrder;
import java.util.Map;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.DataType;
import org.tensorflow.framework.NodeDef;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Tuple2;
import scala.collection.Seq;
import scala.reflect.ClassTag;

/* compiled from: TensorflowToBigDL.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/dllib/utils/tf/BatchNormV2NHWCTF$.class */
public final class BatchNormV2NHWCTF$ implements TensorflowToBigDL {
    public static final BatchNormV2NHWCTF$ MODULE$ = null;
    private final DirectedGraph<String> graph;

    static {
        new BatchNormV2NHWCTF$();
    }

    @Override // com.intel.analytics.bigdl.dllib.utils.tf.TensorflowToBigDL
    public <T> Tuple2<Tensor<T>, Tensor<T>> getOrSetTensor(NodeDef nodeDef, Context<T> context, ByteOrder byteOrder, Option<Seq<Tuple2<Object, Object>>> option, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        return TensorflowToBigDL.Cclass.getOrSetTensor(this, nodeDef, context, byteOrder, option, classTag, tensorNumeric);
    }

    @Override // com.intel.analytics.bigdl.dllib.utils.tf.TensorflowToBigDL
    public String getString(Map<String, AttrValue> map, String str) {
        return TensorflowToBigDL.Cclass.getString(this, map, str);
    }

    @Override // com.intel.analytics.bigdl.dllib.utils.tf.TensorflowToBigDL
    public int getInt(Map<String, AttrValue> map, String str) {
        return TensorflowToBigDL.Cclass.getInt(this, map, str);
    }

    @Override // com.intel.analytics.bigdl.dllib.utils.tf.TensorflowToBigDL
    public Seq<Object> getIntList(Map<String, AttrValue> map, String str) {
        return TensorflowToBigDL.Cclass.getIntList(this, map, str);
    }

    @Override // com.intel.analytics.bigdl.dllib.utils.tf.TensorflowToBigDL
    public boolean getBoolean(Map<String, AttrValue> map, String str) {
        return TensorflowToBigDL.Cclass.getBoolean(this, map, str);
    }

    @Override // com.intel.analytics.bigdl.dllib.utils.tf.TensorflowToBigDL
    public DataType getType(Map<String, AttrValue> map, String str) {
        return TensorflowToBigDL.Cclass.getType(this, map, str);
    }

    @Override // com.intel.analytics.bigdl.dllib.utils.tf.TensorflowToBigDL
    public <T> Option<Seq<Tuple2<Object, Object>>> getOrSetTensor$default$4() {
        Option<Seq<Tuple2<Object, Object>>> option;
        option = None$.MODULE$;
        return option;
    }

    private DirectedGraph<String> graph() {
        return this.graph;
    }

    @Override // com.intel.analytics.bigdl.dllib.utils.tf.TensorflowToBigDL
    public DirectedGraph<String> topology() {
        return graph();
    }

    @Override // com.intel.analytics.bigdl.dllib.utils.tf.TensorflowToBigDL
    public <T> AbstractModule<Activity, Activity, T> layer(DirectedGraph<NodeDef> directedGraph, Context<T> context, ByteOrder byteOrder, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        NodeDef nodeDef = (NodeDef) ((Node) ((Node) ((Node) directedGraph.source().prevNodes().apply(1)).prevNodes().head()).prevNodes().head()).element();
        Tuple2<Tensor<T>, Tensor<T>> orSetTensor = getOrSetTensor((NodeDef) ((Node) ((Node) ((Node) ((Node) ((Node) directedGraph.source().prevNodes().apply(1)).prevNodes().apply(1)).prevNodes().apply(1)).prevNodes().apply(1)).prevNodes().head()).element(), context, byteOrder, getOrSetTensor$default$4(), classTag, tensorNumeric);
        if (orSetTensor == null) {
            throw new MatchError(orSetTensor);
        }
        Tuple2 tuple2 = new Tuple2((Tensor) orSetTensor._1(), (Tensor) orSetTensor._2());
        Tensor<T> tensor = (Tensor) tuple2._1();
        Tensor<T> tensor2 = (Tensor) tuple2._2();
        Tuple2<Tensor<T>, Tensor<T>> orSetTensor2 = getOrSetTensor(nodeDef, context, byteOrder, getOrSetTensor$default$4(), classTag, tensorNumeric);
        if (orSetTensor2 == null) {
            throw new MatchError(orSetTensor2);
        }
        Tuple2 tuple22 = new Tuple2((Tensor) orSetTensor2._1(), (Tensor) orSetTensor2._2());
        SpatialBatchNormalization<T> apply = SpatialBatchNormalization$.MODULE$.apply(tensor.size(1), SpatialBatchNormalization$.MODULE$.apply$default$2(), SpatialBatchNormalization$.MODULE$.apply$default$3(), SpatialBatchNormalization$.MODULE$.apply$default$4(), tensor, (Tensor) tuple22._1(), tensor2, (Tensor) tuple22._2(), SpatialBatchNormalization$.MODULE$.apply$default$9(), classTag, tensorNumeric);
        Sequential<T> apply2 = Sequential$.MODULE$.apply(classTag, tensorNumeric);
        apply2.mo1321add(SelectTable$.MODULE$.apply(1, classTag, tensorNumeric));
        apply2.mo1321add(Transpose$.MODULE$.apply(new Tuple2[]{new Tuple2.mcII.sp(2, 4)}, classTag, tensorNumeric));
        apply2.mo1321add(Contiguous$.MODULE$.apply(classTag, tensorNumeric));
        apply2.mo1321add(apply);
        apply2.mo1321add(Transpose$.MODULE$.apply(new Tuple2[]{new Tuple2.mcII.sp(2, 4)}, classTag, tensorNumeric));
        apply2.mo1321add(Contiguous$.MODULE$.apply(classTag, tensorNumeric));
        return apply2;
    }

    private BatchNormV2NHWCTF$() {
        MODULE$ = this;
        TensorflowToBigDL.Cclass.$init$(this);
        Node apply = Node$.MODULE$.apply("*");
        Node apply2 = Node$.MODULE$.apply("Mean");
        Node apply3 = Node$.MODULE$.apply("StopGradient");
        Node apply4 = Node$.MODULE$.apply("Sub");
        Node apply5 = Node$.MODULE$.apply("SquaredDifference");
        Node apply6 = Node$.MODULE$.apply("Mean");
        Node apply7 = Node$.MODULE$.apply("Add");
        Node apply8 = Node$.MODULE$.apply("Mean");
        Node apply9 = Node$.MODULE$.apply("Sub");
        Node apply10 = Node$.MODULE$.apply("Add");
        Node apply11 = Node$.MODULE$.apply("Mul");
        Node apply12 = Node$.MODULE$.apply("Mul");
        Node apply13 = Node$.MODULE$.apply("Mul");
        Node apply14 = Node$.MODULE$.apply("Sub");
        Node apply15 = Node$.MODULE$.apply("Add");
        Node apply16 = Node$.MODULE$.apply("Squeeze");
        Node apply17 = Node$.MODULE$.apply("Squeeze");
        apply.$minus$greater(apply12).$minus$greater(apply15);
        Node$.MODULE$.apply("Const").$minus$greater(Node$.MODULE$.apply("Identity")).$minus$greater(apply14);
        apply.$minus$greater(apply4).$minus$greater(apply6).$minus$greater(apply7).$minus$greater(apply17).$minus$greater(apply13);
        apply.$minus$greater(apply2).$minus$greater(apply3).$minus$greater(apply7);
        Node$.MODULE$.apply("Const").$minus$greater(apply2);
        apply3.$minus$greater(apply4);
        apply.$minus$greater(apply5).$minus$greater(apply8).$minus$greater(apply9);
        Node$.MODULE$.apply("Const").$minus$greater(apply8);
        apply3.$minus$greater(apply5);
        Node$.MODULE$.apply("Const").$minus$greater(apply6).$minus$greater(Node$.MODULE$.apply("Square")).$minus$greater(apply9).$minus$greater(apply16).$minus$greater(apply10);
        Node$.MODULE$.apply("Const").$minus$greater(apply10).$minus$greater(Node$.MODULE$.apply("Rsqrt")).$minus$greater(apply11).$minus$greater(apply12);
        Node$.MODULE$.apply("Const").$minus$greater(Node$.MODULE$.apply("Identity")).$minus$greater(apply11).$minus$greater(apply13).$minus$greater(apply14).$minus$greater(apply15);
        this.graph = apply15.graph(true);
    }
}
