package com.intel.analytics.bigdl.dllib.optim;

import com.intel.analytics.bigdl.dllib.feature.dataset.MiniBatch;
import com.intel.analytics.bigdl.dllib.optim.DistriOptimizer;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.dllib.utils.Engine$;
import com.intel.analytics.bigdl.dllib.utils.ThreadPool;
import java.util.concurrent.Future;
import scala.Function0;
import scala.Predef$;
import scala.Serializable;
import scala.Tuple3;
import scala.collection.Iterator;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.mutable.Buffer;
import scala.collection.mutable.Buffer$;
import scala.package$;
import scala.runtime.AbstractFunction2;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;
import scala.runtime.IntRef;
import scala.runtime.LongRef;
import scala.runtime.ObjectRef;

/* JADX INFO: Add missing generic type declarations: [T] */
/* compiled from: ParallelOptimizer.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/dllib/optim/ParallelOptimizer$$anonfun$4.class */
public final class ParallelOptimizer$$anonfun$4<T> extends AbstractFunction2<Iterator<MiniBatch<T>>, Iterator<DistriOptimizer.Cache<T>>, Iterator<Tuple3<Object, Object, Object>>> implements Serializable {
    public static final long serialVersionUID = 0;
    public final TensorNumericMath.TensorNumeric ev$1;
    private final LongRef wallClockTime$1;
    private final int _subModelNumber$1;
    private final LongRef threshold$1;
    private final LongRef timeout$1;
    private final IntRef iteration$1;
    private final double dropPercentage$1;
    private final int warmupIterationNum$1;
    private final int computeThresholdbatchSize$1;
    private final int iterationPerTime$1;
    public final ObjectRef lossArray$1;
    private final DoubleRef lossSum$1;
    private final IntRef recordsNum$1;
    private final Metrics driverMetrics$1;
    private final long start$1;

    public final Iterator<Tuple3<Object, Object, Object>> apply(Iterator<MiniBatch<T>> iterator, Iterator<DistriOptimizer.Cache<T>> iterator2) {
        int i = 0;
        DistriOptimizer.Cache cache = (DistriOptimizer.Cache) iterator2.next();
        ObjectRef create = ObjectRef.create((Object) null);
        for (int i2 = 0; i2 < this.iterationPerTime$1; i2++) {
            System.nanoTime();
            create.elem = (MiniBatch) iterator.next();
            long nanoTime = System.nanoTime();
            if (this.dropPercentage$1 > 0.0d && this.iteration$1.elem > (this.warmupIterationNum$1 + this.computeThresholdbatchSize$1) - 1) {
                this.timeout$1.elem = this.threshold$1.elem;
            }
            int i3 = (this.iteration$1.elem % this.computeThresholdbatchSize$1) * this._subModelNumber$1;
            ThreadPool m2086default = Engine$.MODULE$.m2086default();
            Buffer<Future<T>> invokeAndWait2 = m2086default.invokeAndWait2((Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Function0[]{new ParallelOptimizer$$anonfun$4$$anonfun$2(this, cache, create, i3)})), this.timeout$1.elem, m2086default.invokeAndWait2$default$3());
            long nanoTime2 = System.nanoTime() - nanoTime;
            this.driverMetrics$1.add("computing time average", nanoTime2);
            this.driverMetrics$1.add("computing time for each node", nanoTime2);
            Buffer buffer = (Buffer) ((TraversableLike) invokeAndWait2.filter(new ParallelOptimizer$$anonfun$4$$anonfun$5(this))).map(new ParallelOptimizer$$anonfun$4$$anonfun$6(this), Buffer$.MODULE$.canBuildFrom());
            int size = buffer.size();
            i += size;
            this.recordsNum$1.elem += size * ((MiniBatch) create.elem).size();
            int i4 = 0;
            while (true) {
                int i5 = i4;
                if (i5 < size) {
                    this.lossSum$1.elem += ((double[]) this.lossArray$1.elem)[BoxesRunTime.unboxToInt(buffer.apply(i5))];
                    i4 = i5 + 1;
                }
            }
        }
        this.wallClockTime$1.elem += System.nanoTime() - this.start$1;
        return package$.MODULE$.Iterator().single(new Tuple3(BoxesRunTime.boxToInteger(i), BoxesRunTime.boxToDouble(this.lossSum$1.elem), BoxesRunTime.boxToInteger(this.recordsNum$1.elem)));
    }

    public ParallelOptimizer$$anonfun$4(TensorNumericMath.TensorNumeric tensorNumeric, LongRef longRef, int i, LongRef longRef2, LongRef longRef3, IntRef intRef, double d, int i2, int i3, int i4, ObjectRef objectRef, DoubleRef doubleRef, IntRef intRef2, Metrics metrics, long j) {
        this.ev$1 = tensorNumeric;
        this.wallClockTime$1 = longRef;
        this._subModelNumber$1 = i;
        this.threshold$1 = longRef2;
        this.timeout$1 = longRef3;
        this.iteration$1 = intRef;
        this.dropPercentage$1 = d;
        this.warmupIterationNum$1 = i2;
        this.computeThresholdbatchSize$1 = i3;
        this.iterationPerTime$1 = i4;
        this.lossArray$1 = objectRef;
        this.lossSum$1 = doubleRef;
        this.recordsNum$1 = intRef2;
        this.driverMetrics$1 = metrics;
        this.start$1 = j;
    }
}
