package com.intel.analytics.bigdl.dllib.nn.mkldnn;

import com.intel.analytics.bigdl.dllib.nn.MklInt8Convertible;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.Tensor$;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath$TensorNumeric$NumericFloat$;
import com.intel.analytics.bigdl.dllib.utils.Log4Error$;
import com.intel.analytics.bigdl.dllib.utils.Node;
import scala.Predef$;
import scala.StringContext;
import scala.collection.IterableLike;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.generic.GenericTraversableTemplate;
import scala.collection.immutable.Nil$;
import scala.collection.immutable.StringOps;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.RichInt$;

/* compiled from: Fusion.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/dllib/nn/mkldnn/Fusion$.class */
public final class Fusion$ {
    public static final Fusion$ MODULE$ = null;

    static {
        new Fusion$();
    }

    private boolean fuse() {
        return new StringOps(Predef$.MODULE$.augmentString(System.getProperty("bigdl.mkldnn.fusion", "true"))).toBoolean();
    }

    public void fuseModule(Node<AbstractModule<Activity, Activity, Object>> node) {
        if (fuse()) {
            AbstractModule<Activity, Activity, Object> element = node.element();
            if (element instanceof ReLU) {
                fusionRelu(node);
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else if (!(element instanceof SpatialBatchNormalization)) {
                BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
            } else {
                fusionBN(node);
                BoxedUnit boxedUnit3 = BoxedUnit.UNIT;
            }
        }
    }

    public void fuseCAdd(Node<AbstractModule<Activity, Activity, Object>> node) {
        if (fuse()) {
            if (!(node.element() instanceof CAddTable)) {
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else {
                fusionCAddTable(node);
                BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
            }
        }
    }

    private void fusionBN(Node<AbstractModule<Activity, Activity, Object>> node) {
        node.prevNodes().foreach(new Fusion$$anonfun$fusionBN$1(node, (SpatialBatchNormalization) node.element()));
    }

    private void fusionRelu(Node<AbstractModule<Activity, Activity, Object>> node) {
        node.prevNodes().foreach(new Fusion$$anonfun$fusionRelu$1(node));
    }

    public Node<AbstractModule<Activity, Activity, Object>> com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$findPrevious(Node<AbstractModule<Activity, Activity, Object>> node) {
        while ((node.element() instanceof Identity) && node.prevNodes().length() == 1) {
            node = (Node) node.prevNodes().apply(0);
        }
        return node;
    }

    public Seq<Node<AbstractModule<Activity, Activity, Object>>> com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$findNext(Node<AbstractModule<Activity, Activity, Object>> node) {
        return node.element() instanceof Identity ? (Seq) node.nextNodes().flatMap(new Fusion$$anonfun$com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$findNext$1(), Seq$.MODULE$.canBuildFrom()) : Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Node[]{node}));
    }

    private void fusionCAddTable(Node<AbstractModule<Activity, Activity, Object>> node) {
        if ((node.element() instanceof CAddTable) && node.prevNodes().length() == 2) {
            Node<AbstractModule<Activity, Activity, Object>>[] nodeArr = (Node[]) node.prevNodes().toArray(ClassTag$.MODULE$.apply(Node.class));
            Node<AbstractModule<Activity, Activity, Object>> com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$findPrevious = com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$findPrevious(nodeArr[0]);
            Node<AbstractModule<Activity, Activity, Object>> com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$findPrevious2 = com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$findPrevious(nodeArr[1]);
            Node<AbstractModule<Activity, Activity, Object>> node2 = null;
            int i = 0;
            if (com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$findPrevious.element() instanceof SpatialConvolution) {
                if (requirements(com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$findPrevious)) {
                    node2 = com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$findPrevious;
                }
                i = 1;
            } else if (com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$findPrevious2.element() instanceof SpatialConvolution) {
                if (requirements(com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$findPrevious2)) {
                    node2 = com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$findPrevious2;
                }
                i = 0;
            }
            if (node2 != null) {
                node.element_$eq(node2.element());
                SpatialConvolution spatialConvolution = (SpatialConvolution) node.element();
                spatialConvolution.setSumOp(nodeArr[i].element(), i + 1);
                node2.element_$eq(Identity$.MODULE$.apply$mFc$sp(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$));
                Node node3 = (Node) node.nextNodes().apply(0);
                if ((node3.element() instanceof ReLU) && !spatialConvolution.relu()) {
                    ((SpatialConvolution) node.element()).setReLU(true);
                    ((SpatialConvolution) node.element()).setOutputScales(((ReLU) node3.element()).getOutputScales());
                    node3.element_$eq(new Identity());
                }
                Node<AbstractModule<Activity, Activity, Object>> com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$findPrevious3 = com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$findPrevious(nodeArr[i]);
                AbstractModule<Activity, Activity, Object> element = com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$findPrevious3.element();
                if (element instanceof SpatialConvolution) {
                    ((SpatialConvolution) element).setOutputScales(((SpatialConvolution) node.element()).getOutputScales());
                    BoxedUnit boxedUnit = BoxedUnit.UNIT;
                } else {
                    if (!(element instanceof ReLU)) {
                        BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
                        return;
                    }
                    ((ReLU) element).setOutputScales(((SpatialConvolution) node.element()).getOutputScales());
                    ((IterableLike) ((TraversableLike) com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$findPrevious3.nextNodes().flatMap(new Fusion$$anonfun$fusionCAddTable$1(), Seq$.MODULE$.canBuildFrom())).filter(new Fusion$$anonfun$fusionCAddTable$2(node))).foreach(new Fusion$$anonfun$fusionCAddTable$3(node));
                    BoxedUnit boxedUnit3 = BoxedUnit.UNIT;
                }
            }
        }
    }

    private boolean requirements(Node<AbstractModule<Activity, Activity, Object>> node) {
        return !((SpatialConvolution) node.element()).sum();
    }

    public void com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$fusionConvBn(SpatialConvolution spatialConvolution, SpatialBatchNormalization spatialBatchNormalization) {
        spatialConvolution.setBatchNorm(true);
        Tensor apply = Tensor$.MODULE$.apply(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
        Tensor copy = apply.resize(spatialBatchNormalization.runningVariance().size(), apply.resize$default$2()).copy(spatialBatchNormalization.runningVariance().dense());
        Tensor apply2 = Tensor$.MODULE$.apply(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
        Tensor copy2 = apply2.resize(spatialBatchNormalization.runningMean().size(), apply2.resize$default$2()).copy(spatialBatchNormalization.runningMean().dense());
        Tensor apply3 = Tensor$.MODULE$.apply(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
        Tensor<Object> copy3 = apply3.resize(spatialConvolution.weight().size(), apply3.resize$default$2()).copy(spatialConvolution.weight().dense());
        Tensor apply4 = Tensor$.MODULE$.apply(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
        Tensor<Object> copy4 = apply4.resize(spatialConvolution.bias().size(), apply4.resize$default$2()).copy(spatialConvolution.bias().dense());
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), spatialBatchNormalization.nOutput()).foreach$mVc$sp(new Fusion$$anonfun$com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$fusionConvBn$1(spatialConvolution, spatialBatchNormalization, copy, copy2, copy3, copy4, Tensor$.MODULE$.apply(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$).resizeAs(spatialBatchNormalization.weightAndBias().dense()).copy(spatialBatchNormalization.weightAndBias().dense())));
        spatialConvolution.weight().dense().set(copy3);
        spatialConvolution.bias().dense().set(copy4);
        spatialConvolution.flushWeightScales(spatialConvolution.weight().dense());
        spatialConvolution.setOutputScales(spatialBatchNormalization.getOutputScales());
    }

    public void setNegativeInputOfConv(Node<AbstractModule<Activity, Activity, Object>> node) {
        if (fuse() && (node.element() instanceof SpatialConvolution) && ((IterableLike) ((TraversableLike) node.prevNodes().flatMap(new Fusion$$anonfun$1(), Seq$.MODULE$.canBuildFrom())).map(new Fusion$$anonfun$2(), Seq$.MODULE$.canBuildFrom())).forall(new Fusion$$anonfun$3())) {
            ((SpatialConvolution) node.element()).negativeInput_$eq(false);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void setScalesPrevousJoinTable(Node<AbstractModule<Activity, Activity, Object>> node) {
        if (fuse() && (node.element() instanceof JoinTable)) {
            Seq seq = (Seq) ((TraversableLike) ((TraversableLike) node.prevNodes().flatMap(new Fusion$$anonfun$4(), Seq$.MODULE$.canBuildFrom())).filter(new Fusion$$anonfun$5())).map(new Fusion$$anonfun$6(), Seq$.MODULE$.canBuildFrom());
            if (seq.exists(new Fusion$$anonfun$setScalesPrevousJoinTable$1())) {
                Log4Error$.MODULE$.invalidInputError(((TraversableOnce) seq.map(new Fusion$$anonfun$7(), Seq$.MODULE$.canBuildFrom())).toSet().size() == 1, new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"all preceding convolutions must have the same mask"})).s(Nil$.MODULE$), Log4Error$.MODULE$.invalidInputError$default$3());
                Seq seq2 = (Seq) ((TraversableLike) node.nextNodes().flatMap(new Fusion$$anonfun$8(), Seq$.MODULE$.canBuildFrom())).filter(new Fusion$$anonfun$9());
                seq.foreach(new Fusion$$anonfun$setScalesPrevousJoinTable$2(seq2.isEmpty() ? (float[][]) ((Object[]) new float[]{(float[]) ((TraversableOnce) ((GenericTraversableTemplate) seq.map(new Fusion$$anonfun$10(), Seq$.MODULE$.canBuildFrom())).transpose(new Fusion$$anonfun$11()).map(new Fusion$$anonfun$12(), Seq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.Float())}) : ((MklInt8Convertible) ((IterableLike) seq2.map(new Fusion$$anonfun$13(), Seq$.MODULE$.canBuildFrom())).head()).getInputScales()));
            }
        }
    }

    public void fuseScale(Node<AbstractModule<Activity, Activity, Object>> node) {
        AbstractModule<Activity, Activity, Object> element = node.element();
        if (!(element instanceof BlasWrapper) || !(((BlasWrapper) element).module() instanceof com.intel.analytics.bigdl.dllib.nn.Scale)) {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        } else if (node.prevNodes().forall(new Fusion$$anonfun$14())) {
            node.prevNodes().foreach(new Fusion$$anonfun$fuseScale$1(node));
            node.element_$eq(Identity$.MODULE$.apply$mFc$sp(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$));
            BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
        }
    }

    public Seq<Node<AbstractModule<Activity, Activity, Object>>> com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$findAllNonIdentityPrevs(Node<AbstractModule<Activity, Activity, Object>> node) {
        return ((node.element() instanceof Identity) || (node.element() instanceof MaxPooling) || (node.element() instanceof AvgPooling) || (node.element() instanceof JoinTable)) ? (Seq) node.prevNodes().flatMap(new Fusion$$anonfun$com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$findAllNonIdentityPrevs$1(), Seq$.MODULE$.canBuildFrom()) : Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Node[]{node}));
    }

    private Fusion$() {
        MODULE$ = this;
    }
}
