package com.intel.analytics.bigdl.dllib.nn.mkldnn;

import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.Tensor$;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath$TensorNumeric$NumericFloat$;
import com.intel.analytics.bigdl.dllib.utils.Node;
import scala.Serializable;
import scala.reflect.ClassTag$;
import scala.runtime.AbstractFunction1;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: Perf.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/dllib/nn/mkldnn/ResNet$$anonfun$modelInit$1$1.class */
public final class ResNet$$anonfun$modelInit$1$1 extends AbstractFunction1<Node<AbstractModule<Activity, Activity, Object>>, BoxedUnit> implements Serializable {
    public static final long serialVersionUID = 0;

    public final void apply(Node<AbstractModule<Activity, Activity, Object>> node) {
        AbstractModule<Activity, Activity, Object> element = node.element();
        if (element instanceof SpatialConvolution) {
            SpatialConvolution spatialConvolution = (SpatialConvolution) element;
            float kernelW = spatialConvolution.kernelW() * spatialConvolution.kernelW() * spatialConvolution.nOutputPlane();
            Tensor apply = Tensor$.MODULE$.apply(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
            Tensor<Object> apply1 = apply.resize(spatialConvolution.weight().size(), apply.resize$default$2()).apply1(new ResNet$$anonfun$modelInit$1$1$$anonfun$7(this, kernelW));
            Tensor apply2 = Tensor$.MODULE$.apply(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
            Tensor<Object> apply12 = apply2.resize(spatialConvolution.bias().size(), apply2.resize$default$2()).apply1(new ResNet$$anonfun$modelInit$1$1$$anonfun$8(this));
            spatialConvolution.weight().copy(apply1);
            spatialConvolution.bias().copy(apply12);
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
            return;
        }
        if (!(element instanceof SpatialBatchNormalization)) {
            if (!(element instanceof Linear)) {
                BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
                return;
            }
            Linear linear = (Linear) element;
            linear.bias().copy(Tensor$.MODULE$.apply(linear.bias().size(), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$).apply1(new ResNet$$anonfun$modelInit$1$1$$anonfun$9(this)));
            BoxedUnit boxedUnit3 = BoxedUnit.UNIT;
            return;
        }
        SpatialBatchNormalization spatialBatchNormalization = (SpatialBatchNormalization) element;
        Tensor apply3 = Tensor$.MODULE$.apply(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
        Tensor resize = apply3.resize(new int[]{2, spatialBatchNormalization.nOutput()}, apply3.resize$default$2());
        resize.select(1, 1).fill(BoxesRunTime.boxToFloat(1.0f));
        resize.select(1, 2).fill(BoxesRunTime.boxToFloat(0.0f));
        spatialBatchNormalization.weightAndBias().copy(resize.view(new int[]{spatialBatchNormalization.nOutput() * 2}));
        BoxedUnit boxedUnit4 = BoxedUnit.UNIT;
    }

    public final /* bridge */ /* synthetic */ Object apply(Object obj) {
        apply((Node<AbstractModule<Activity, Activity, Object>>) obj);
        return BoxedUnit.UNIT;
    }
}
