package com.intel.analytics.bigdl.dllib.optim;

import com.intel.analytics.bigdl.dllib.feature.dataset.MiniBatch;
import com.intel.analytics.bigdl.dllib.optim.DistriOptimizer;
import com.intel.analytics.bigdl.dllib.optim.parameters.AllReduceParameter;
import com.intel.analytics.bigdl.dllib.optim.parameters.FutureResult;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.dllib.utils.Engine$;
import com.intel.analytics.bigdl.dllib.utils.ThreadPool;
import java.util.concurrent.Future;
import org.apache.spark.util.DoubleAccumulator;
import scala.Function0;
import scala.Predef$;
import scala.Serializable;
import scala.collection.Iterator;
import scala.collection.Seq;
import scala.collection.TraversableLike;
import scala.collection.immutable.IndexedSeq;
import scala.collection.immutable.IndexedSeq$;
import scala.collection.mutable.ArrayBuffer;
import scala.collection.mutable.Buffer;
import scala.collection.mutable.Buffer$;
import scala.package$;
import scala.runtime.AbstractFunction2;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;
import scala.runtime.LongRef;
import scala.runtime.ObjectRef;
import scala.runtime.RichInt$;

/* JADX INFO: Add missing generic type declarations: [T] */
/* compiled from: DistriOptimizer.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/dllib/optim/DistriOptimizer$$anonfun$4.class */
public final class DistriOptimizer$$anonfun$4<T> extends AbstractFunction2<Iterator<MiniBatch<T>>, Iterator<DistriOptimizer.CacheV1<T>>, Iterator<Object>> implements Serializable {
    public static final long serialVersionUID = 0;
    private final AllReduceParameter parameters$1;
    public final TensorNumericMath.TensorNumeric ev$1;
    public final int _subModelNumber$1;
    private final ObjectRef tasks$1;
    private final LongRef threshold$1;
    private final LongRef timeout$1;
    private final IntRef iteration$1;
    private final double dropPercentage$1;
    private final int warmupIterationNum$1;
    private final int computeThresholdbatchSize$1;
    public final ObjectRef lossArray$1;
    private final DoubleAccumulator lossSum$1;
    private final DoubleAccumulator recordsNum$1;
    private final Metrics driverMetrics$1;

    public final Iterator<Object> apply(Iterator<MiniBatch<T>> iterator, Iterator<DistriOptimizer.CacheV1<T>> iterator2) {
        DistriOptimizer.CacheV1 cacheV1 = (DistriOptimizer.CacheV1) iterator2.next();
        long nanoTime = System.nanoTime();
        FutureResult<Object> weights = this.parameters$1.getWeights(((Tensor) Predef$.MODULE$.refArrayOps(cacheV1.modelWeights()).head()).narrow(1, this.parameters$1.paramOffset(), this.parameters$1.size()));
        MiniBatch[] miniBatchArr = new MiniBatch[this._subModelNumber$1];
        MiniBatch miniBatch = (MiniBatch) iterator.next();
        ((ArrayBuffer) this.tasks$1.elem).$plus$eq(Engine$.MODULE$.m2086default().invoke((Function0) new DistriOptimizer$$anonfun$4$$anonfun$apply$1(this, miniBatchArr, miniBatch, miniBatch.size() / this._subModelNumber$1)));
        ThreadPool m2086default = Engine$.MODULE$.m2086default();
        m2086default.sync((ArrayBuffer) this.tasks$1.elem, m2086default.sync$default$2());
        weights.waitResult();
        long nanoTime2 = System.nanoTime() - nanoTime;
        this.driverMetrics$1.add("get weights average", nanoTime2);
        this.driverMetrics$1.add("get weights for each node", nanoTime2);
        ((ArrayBuffer) this.tasks$1.elem).clear();
        long nanoTime3 = System.nanoTime();
        if (this.dropPercentage$1 > 0.0d && this.iteration$1.elem > (this.warmupIterationNum$1 + this.computeThresholdbatchSize$1) - 1) {
            this.timeout$1.elem = this.threshold$1.elem - nanoTime2;
        }
        int i = (this.iteration$1.elem % this.computeThresholdbatchSize$1) * this._subModelNumber$1;
        ThreadPool m2086default2 = Engine$.MODULE$.m2086default();
        Buffer<Future<T>> invokeAndWait2 = m2086default2.invokeAndWait2((IndexedSeq) RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), this._subModelNumber$1).map(new DistriOptimizer$$anonfun$4$$anonfun$5(this, cacheV1, miniBatchArr, nanoTime2, i), IndexedSeq$.MODULE$.canBuildFrom()), this.timeout$1.elem, m2086default2.invokeAndWait2$default$3());
        long nanoTime4 = System.nanoTime() - nanoTime3;
        this.driverMetrics$1.add("computing time average", nanoTime4);
        this.driverMetrics$1.add("computing time for each node", nanoTime4);
        Buffer buffer = (Buffer) ((TraversableLike) invokeAndWait2.filter(new DistriOptimizer$$anonfun$4$$anonfun$6(this))).map(new DistriOptimizer$$anonfun$4$$anonfun$7(this), Buffer$.MODULE$.canBuildFrom());
        this.recordsNum$1.add(buffer.size() * r0);
        int i2 = 0;
        while (true) {
            int i3 = i2;
            if (i3 >= buffer.size()) {
                break;
            }
            this.lossSum$1.add(((double[]) this.lossArray$1.elem)[BoxesRunTime.unboxToInt(buffer.apply(i3))]);
            i2 = i3 + 1;
        }
        if (buffer.nonEmpty()) {
            Buffer buffer2 = (Buffer) buffer.map(new DistriOptimizer$$anonfun$4$$anonfun$8(this, cacheV1), Buffer$.MODULE$.canBuildFrom());
            long nanoTime5 = System.nanoTime();
            int paramOffset = this.parameters$1.paramOffset();
            int size = this.parameters$1.size();
            int i4 = size / this._subModelNumber$1;
            int i5 = size % this._subModelNumber$1;
            int i6 = i4 == 0 ? i5 : this._subModelNumber$1;
            if (i6 != 1) {
                ThreadPool m2086default3 = Engine$.MODULE$.m2086default();
                m2086default3.invokeAndWait((IndexedSeq) RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), i6).map(new DistriOptimizer$$anonfun$4$$anonfun$9(this, buffer2, paramOffset, i4, i5), IndexedSeq$.MODULE$.canBuildFrom()), m2086default3.invokeAndWait$default$2());
                this.driverMetrics$1.add("aggregate gradient time", System.nanoTime() - nanoTime5);
            } else {
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            }
            long nanoTime6 = System.nanoTime();
            this.parameters$1.putGradients(((Tensor) buffer2.apply(0)).narrow(1, paramOffset, size));
            this.driverMetrics$1.add("put gradient", System.nanoTime() - nanoTime6);
        } else {
            long nanoTime7 = System.nanoTime();
            cacheV1.modelGradients()[0].zero();
            this.parameters$1.putGradients(cacheV1.modelGradients()[0].narrow(1, this.parameters$1.paramOffset(), this.parameters$1.size()));
            this.driverMetrics$1.add("put gradient", System.nanoTime() - nanoTime7);
        }
        ((ArrayBuffer) this.tasks$1.elem).$plus$plus$eq(Engine$.MODULE$.m2086default().invoke((Seq) RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), this._subModelNumber$1).map(new DistriOptimizer$$anonfun$4$$anonfun$apply$8(this, cacheV1), IndexedSeq$.MODULE$.canBuildFrom())));
        return package$.MODULE$.Iterator().single(BoxesRunTime.boxToInteger(buffer.size()));
    }

    public DistriOptimizer$$anonfun$4(AllReduceParameter allReduceParameter, TensorNumericMath.TensorNumeric tensorNumeric, int i, ObjectRef objectRef, LongRef longRef, LongRef longRef2, IntRef intRef, double d, int i2, int i3, ObjectRef objectRef2, DoubleAccumulator doubleAccumulator, DoubleAccumulator doubleAccumulator2, Metrics metrics) {
        this.parameters$1 = allReduceParameter;
        this.ev$1 = tensorNumeric;
        this._subModelNumber$1 = i;
        this.tasks$1 = objectRef;
        this.threshold$1 = longRef;
        this.timeout$1 = longRef2;
        this.iteration$1 = intRef;
        this.dropPercentage$1 = d;
        this.warmupIterationNum$1 = i2;
        this.computeThresholdbatchSize$1 = i3;
        this.lossArray$1 = objectRef2;
        this.lossSum$1 = doubleAccumulator;
        this.recordsNum$1 = doubleAccumulator2;
        this.driverMetrics$1 = metrics;
    }
}
