package com.intel.analytics.bigdl.dllib.nn.mkldnn;

import com.intel.analytics.bigdl.dllib.nn.Container;
import com.intel.analytics.bigdl.dllib.nn.CrossEntropyCriterion;
import com.intel.analytics.bigdl.dllib.nn.CrossEntropyCriterion$;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.models.Vgg_16$;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.Tensor$;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath$TensorNumeric$NumericFloat$;
import com.intel.analytics.bigdl.dllib.utils.Engine$;
import com.intel.analytics.bigdl.dllib.utils.Log4Error$;
import com.intel.analytics.bigdl.dllib.utils.T$;
import scala.Array$;
import scala.Function0;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.immutable.StringOps;
import scala.reflect.ClassTag$;
import scala.runtime.AbstractFunction1;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: Perf.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/dllib/nn/mkldnn/Perf$$anonfun$main$1.class */
public final class Perf$$anonfun$main$1 extends AbstractFunction1<ResNet50PerfParams, BoxedUnit> implements Serializable {
    public static final long serialVersionUID = 0;

    public final void apply(ResNet50PerfParams resNet50PerfParams) {
        Container container;
        int batchSize = resNet50PerfParams.batchSize();
        boolean training = resNet50PerfParams.training();
        int iteration = resNet50PerfParams.iteration();
        int[] iArr = {batchSize, 3, 224, 224};
        Tensor<Object> rand = Tensor$.MODULE$.apply$mFc$sp(iArr, ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$).rand();
        Tensor<Object> apply1 = Tensor$.MODULE$.apply$mFc$sp(batchSize, ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$).apply1(new Perf$$anonfun$main$1$$anonfun$1(this));
        String model = resNet50PerfParams.model();
        if ("vgg16".equals(model)) {
            container = Vgg_16$.MODULE$.apply(batchSize, 1000, true);
        } else if ("resnet50".equals(model)) {
            container = ResNet$.MODULE$.apply(batchSize, 1000, T$.MODULE$.apply(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("depth"), BoxesRunTime.boxToInteger(50)), (Seq<Tuple2<Object, Object>>) Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("dataSet"), ResNet$DatasetType$ImageNet$.MODULE$)})));
        } else if ("vgg16_graph".equals(model)) {
            container = Vgg_16$.MODULE$.graph(batchSize, 1000, true);
        } else if ("resnet50_graph".equals(model)) {
            container = ResNet$.MODULE$.graph(batchSize, 1000, T$.MODULE$.apply(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("depth"), BoxesRunTime.boxToInteger(50)), (Seq<Tuple2<Object, Object>>) Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("dataSet"), ResNet$DatasetType$ImageNet$.MODULE$)})));
        } else {
            Log4Error$.MODULE$.invalidInputError(false, new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Unkown model ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{resNet50PerfParams.model()})), "only support vgg16, resnet50, vgg16_graph, resnet50_graph");
            container = null;
        }
        Container container2 = container;
        CrossEntropyCriterion$ crossEntropyCriterion$ = CrossEntropyCriterion$.MODULE$;
        CrossEntropyCriterion$.MODULE$.apply$default$1();
        CrossEntropyCriterion<Object> apply$mFc$sp = crossEntropyCriterion$.apply$mFc$sp(null, CrossEntropyCriterion$.MODULE$.apply$default$2(), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
        Engine$.MODULE$.dnnComputing().invokeAndWait2(Predef$.MODULE$.wrapRefArray((Object[]) Predef$.MODULE$.intArrayOps(new int[]{1}).map(new Perf$$anonfun$main$1$$anonfun$apply$2(this, training, 7, iArr, container2), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Function0.class)))), Engine$.MODULE$.dnnComputing().invokeAndWait2$default$2(), Engine$.MODULE$.dnnComputing().invokeAndWait2$default$3());
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 >= iteration) {
                return;
            }
            long nanoTime = System.nanoTime();
            Engine$.MODULE$.dnnComputing().invokeAndWait2(Predef$.MODULE$.wrapRefArray((Object[]) Predef$.MODULE$.intArrayOps(new int[]{1}).map(new Perf$$anonfun$main$1$$anonfun$apply$3(this, training, rand, apply1, container2, apply$mFc$sp), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Function0.class)))), Engine$.MODULE$.dnnComputing().invokeAndWait2$default$2(), Engine$.MODULE$.dnnComputing().invokeAndWait2$default$3());
            long nanoTime2 = System.nanoTime() - nanoTime;
            Perf$.MODULE$.logger().info(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Iteration ", ", takes ", " s, throughput is ", " imgs/sec"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(i2), BoxesRunTime.boxToLong(nanoTime2), new StringOps(Predef$.MODULE$.augmentString("%.2f")).format(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToDouble(batchSize / (nanoTime2 / 1.0E9d))}))})));
            i = i2 + 1;
        }
    }

    public final /* bridge */ /* synthetic */ Object apply(Object obj) {
        apply((ResNet50PerfParams) obj);
        return BoxedUnit.UNIT;
    }
}
