package com.intel.analytics.bigdl.ppml.nn;

import com.intel.analytics.bigdl.dllib.feature.dataset.LocalDataSet;
import com.intel.analytics.bigdl.dllib.feature.dataset.MiniBatch;
import com.intel.analytics.bigdl.dllib.keras.models.InternalOptimizerUtil$;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.ppml.generated.FlBaseProto;
import com.intel.analytics.bigdl.ppml.utils.ProtoUtils$;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.collection.Iterator;
import scala.runtime.AbstractFunction1;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;

/* compiled from: VflNNEstimator.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/ppml/nn/VflNNEstimator$$anonfun$train$1.class */
public final class VflNNEstimator$$anonfun$train$1 extends AbstractFunction1.mcVI.sp implements Serializable {
    public static final long serialVersionUID = 0;
    private final /* synthetic */ VflNNEstimator $outer;
    private final int endEpoch$1;
    private final LocalDataSet trainDataSet$1;
    private final LocalDataSet valDataSet$1;
    private final long size$1;
    private final IntRef iteration$1;

    public final void apply(int i) {
        apply$mcVI$sp(i);
    }

    public void apply$mcVI$sp(int i) {
        Iterator iterator = (Iterator) this.trainDataSet$1.data(true);
        int i2 = 0;
        while (i2 < this.size$1) {
            this.$outer.logger().debug(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"training next batch, progress: ", "/", ", epoch: ", "/", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(i2), BoxesRunTime.boxToLong(this.size$1), BoxesRunTime.boxToInteger(i), BoxesRunTime.boxToInteger(this.endEpoch$1)})));
            MiniBatch miniBatch = (MiniBatch) iterator.next();
            miniBatch.size();
            InternalOptimizerUtil$.MODULE$.getStateFromOptiMethod(this.$outer.com$intel$analytics$bigdl$ppml$nn$VflNNEstimator$$optimMethod).update("epoch", BoxesRunTime.boxToInteger(i + 1));
            InternalOptimizerUtil$.MODULE$.getStateFromOptiMethod(this.$outer.com$intel$analytics$bigdl$ppml$nn$VflNNEstimator$$optimMethod).update("neval", BoxesRunTime.boxToInteger(this.iteration$1.elem + 1));
            Activity input = miniBatch.getInput();
            Activity target = miniBatch.getTarget();
            if (target == null) {
            }
            this.$outer.com$intel$analytics$bigdl$ppml$nn$VflNNEstimator$$model.training();
            this.$outer.com$intel$analytics$bigdl$ppml$nn$VflNNEstimator$$model.forward(input);
            FlBaseProto.TensorMap outputTargetToTableProto = ProtoUtils$.MODULE$.outputTargetToTableProto(this.$outer.com$intel$analytics$bigdl$ppml$nn$VflNNEstimator$$model.output(), target, FlBaseProto.MetaData.newBuilder().setName(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", "_output"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{this.$outer.com$intel$analytics$bigdl$ppml$nn$VflNNEstimator$$model.getName()}))).setVersion(this.iteration$1.elem).m1027build());
            this.$outer.com$intel$analytics$bigdl$ppml$nn$VflNNEstimator$$model.zeroGradParameters();
            FlBaseProto.TensorMap data = this.$outer.flClient().nnStub().train(outputTargetToTableProto, this.$outer.com$intel$analytics$bigdl$ppml$nn$VflNNEstimator$$algorithm).getData();
            Tensor<Object> tensor = ProtoUtils$.MODULE$.getTensor("gradInput", data);
            float unboxToFloat = BoxesRunTime.unboxToFloat(ProtoUtils$.MODULE$.getTensor("loss", data).value());
            this.$outer.com$intel$analytics$bigdl$ppml$nn$VflNNEstimator$$model.backward(input, tensor);
            this.$outer.logger().debug(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Model doing backward, version: ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(this.iteration$1.elem)})));
            this.$outer.com$intel$analytics$bigdl$ppml$nn$VflNNEstimator$$optimMethod.optimize$mcF$sp(new VflNNEstimator$$anonfun$train$1$$anonfun$apply$mcVI$sp$1(this, unboxToFloat), this.$outer.weight());
            this.iteration$1.elem++;
            i2 += miniBatch.size();
        }
        if (this.valDataSet$1 != null) {
            this.$outer.com$intel$analytics$bigdl$ppml$nn$VflNNEstimator$$model.evaluate();
            this.$outer.evaluate(this.valDataSet$1);
        }
    }

    public /* synthetic */ VflNNEstimator com$intel$analytics$bigdl$ppml$nn$VflNNEstimator$$anonfun$$$outer() {
        return this.$outer;
    }

    public final /* bridge */ /* synthetic */ Object apply(Object obj) {
        apply(BoxesRunTime.unboxToInt(obj));
        return BoxedUnit.UNIT;
    }

    public VflNNEstimator$$anonfun$train$1(VflNNEstimator vflNNEstimator, int i, LocalDataSet localDataSet, LocalDataSet localDataSet2, long j, IntRef intRef) {
        if (vflNNEstimator == null) {
            throw null;
        }
        this.$outer = vflNNEstimator;
        this.endEpoch$1 = i;
        this.trainDataSet$1 = localDataSet;
        this.valDataSet$1 = localDataSet2;
        this.size$1 = j;
        this.iteration$1 = intRef;
    }
}
