package com.intel.analytics.bigdl.dllib.nn.mkldnn;

import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.utils.Log4Error$;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.runtime.AbstractFunction1;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: Fusion.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/dllib/nn/mkldnn/Fusion$$anonfun$com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$fusionConvBn$1.class */
public final class Fusion$$anonfun$com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$fusionConvBn$1 extends AbstractFunction1.mcVI.sp implements Serializable {
    public static final long serialVersionUID = 0;
    private final SpatialConvolution conv$1;
    private final SpatialBatchNormalization bn$2;
    private final Tensor originVar$1;
    private final Tensor originMean$1;
    private final Tensor convWeight$1;
    private final Tensor convBias$1;
    private final Tensor bnWeight$1;

    public final void apply(int i) {
        apply$mcVI$sp(i);
    }

    public void apply$mcVI$sp(int i) {
        Tensor select;
        float sqrt = (float) Math.sqrt(((float[]) this.originVar$1.storage().array())[(i + this.originVar$1.storageOffset()) - 1] + this.bn$2.eps());
        Log4Error$.MODULE$.invalidInputError(((double) sqrt) != 0.0d, new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"the eps of ", " should be more than 0"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{this.bn$2.getName()})), Log4Error$.MODULE$.invalidInputError$default$3());
        float f = ((float[]) this.bnWeight$1.storage().array())[(this.bnWeight$1.storageOffset() - 1) + i];
        float f2 = ((float[]) this.bnWeight$1.storage().array())[(this.bnWeight$1.storageOffset() - 1) + this.bn$2.nOutput() + i];
        if (this.conv$1.nGroup() == 1) {
            select = this.convWeight$1.select(1, i + 1);
        } else {
            int nOutputPlane = this.conv$1.nOutputPlane() / this.conv$1.nGroup();
            select = this.convWeight$1.select(1, (i / nOutputPlane) + 1).select(2, (i % nOutputPlane) + 1);
        }
        Tensor tensor = select;
        tensor.div((Tensor) BoxesRunTime.boxToFloat(sqrt));
        tensor.mul(BoxesRunTime.boxToFloat(f));
        ((float[]) this.convBias$1.storage().array())[i] = (((f / sqrt) * ((float[]) this.convBias$1.storage().array())[i]) + f2) - ((f * ((float[]) this.originMean$1.storage().array())[i]) / sqrt);
    }

    public final /* bridge */ /* synthetic */ Object apply(Object obj) {
        apply(BoxesRunTime.unboxToInt(obj));
        return BoxedUnit.UNIT;
    }

    public Fusion$$anonfun$com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$fusionConvBn$1(SpatialConvolution spatialConvolution, SpatialBatchNormalization spatialBatchNormalization, Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4, Tensor tensor5) {
        this.conv$1 = spatialConvolution;
        this.bn$2 = spatialBatchNormalization;
        this.originVar$1 = tensor;
        this.originMean$1 = tensor2;
        this.convWeight$1 = tensor3;
        this.convBias$1 = tensor4;
        this.bnWeight$1 = tensor5;
    }
}
