package com.intel.analytics.bigdl.dllib.optim;

import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.optim.DistriOptimizer;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.Tensor$;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.dllib.utils.BlockManagerParameterSynchronizer;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Serializable;
import scala.Tuple2;
import scala.collection.Iterator;
import scala.collection.immutable.Map;
import scala.math.package$;
import scala.reflect.ClassTag;
import scala.runtime.AbstractFunction1;
import scala.runtime.BoxesRunTime;

/* JADX INFO: Add missing generic type declarations: [T] */
/* compiled from: ParallelOptimizer.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/dllib/optim/ParallelOptimizer$$anonfun$14.class */
public final class ParallelOptimizer$$anonfun$14<T> extends AbstractFunction1<Iterator<DistriOptimizer.Cache<T>>, Iterator<Map<Object, Tensor<T>>>> implements Serializable {
    public static final long serialVersionUID = 0;
    private final ClassTag evidence$7$1;
    private final TensorNumericMath.TensorNumeric ev$3;
    private final int taskSize$1;
    private final int extraSize$1;

    public final Iterator<Map<Object, Tensor<T>>> apply(Iterator<DistriOptimizer.Cache<T>> iterator) {
        DistriOptimizer.Cache cache = (DistriOptimizer.Cache) iterator.next();
        Tensor tensor = (Tensor) ((AbstractModule) Predef$.MODULE$.refArrayOps(cache.localModels()).head()).getParameters()._1();
        int partitionID = ((BlockManagerParameterSynchronizer) cache.parameterSynchronizer()).partitionID();
        int min = (partitionID * this.taskSize$1) + package$.MODULE$.min(partitionID, this.extraSize$1);
        int i = this.taskSize$1 + (partitionID < this.extraSize$1 ? 1 : 0);
        Tensor<T> apply = Tensor$.MODULE$.apply(i, this.evidence$7$1, this.ev$3);
        apply.copy(tensor.narrow(1, min + 1, i));
        return scala.package$.MODULE$.Iterator().single(Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(BoxesRunTime.boxToInteger(partitionID)), apply)})));
    }

    public ParallelOptimizer$$anonfun$14(ClassTag classTag, TensorNumericMath.TensorNumeric tensorNumeric, int i, int i2) {
        this.evidence$7$1 = classTag;
        this.ev$3 = tensorNumeric;
        this.taskSize$1 = i;
        this.extraSize$1 = i2;
    }
}
