package com.intel.analytics.bigdl.dllib.keras.models;

import com.intel.analytics.bigdl.dllib.optim.DistriOptimizer;
import com.intel.analytics.bigdl.dllib.optim.parameters.AllReduceParameter;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.Tensor$;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import org.apache.spark.TaskContext$;
import scala.MatchError;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Serializable;
import scala.Tuple2;
import scala.collection.Iterator;
import scala.collection.immutable.Map;
import scala.package$;
import scala.reflect.ClassTag;
import scala.runtime.AbstractFunction1;
import scala.runtime.BoxesRunTime;

/* JADX INFO: Add missing generic type declarations: [T] */
/* compiled from: Topology.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/dllib/keras/models/InternalDistriOptimizer$$anonfun$21.class */
public final class InternalDistriOptimizer$$anonfun$21<T> extends AbstractFunction1<Iterator<DistriOptimizer.CacheV1<T>>, Iterator<Tuple2<Map<Object, Tensor<T>>, Map<Object, Tensor<T>>>>> implements Serializable {
    public static final long serialVersionUID = 0;
    private final AllReduceParameter parameters$1;
    private final ClassTag evidence$19$1;
    private final TensorNumericMath.TensorNumeric ev$1;

    public final Iterator<Tuple2<Map<Object, Tensor<T>>, Map<Object, Tensor<T>>>> apply(Iterator<DistriOptimizer.CacheV1<T>> iterator) {
        DistriOptimizer.CacheV1 cacheV1 = (DistriOptimizer.CacheV1) iterator.next();
        int partitionId = TaskContext$.MODULE$.getPartitionId();
        Tuple2<Object, Object> localPartitionRangeFromParameters = InternalOptimizerUtil$.MODULE$.getLocalPartitionRangeFromParameters(this.parameters$1, this.evidence$19$1);
        if (localPartitionRangeFromParameters == null) {
            throw new MatchError(localPartitionRangeFromParameters);
        }
        Tuple2.mcII.sp spVar = new Tuple2.mcII.sp(localPartitionRangeFromParameters._1$mcI$sp(), localPartitionRangeFromParameters._2$mcI$sp());
        int _1$mcI$sp = spVar._1$mcI$sp();
        int _2$mcI$sp = spVar._2$mcI$sp();
        Tensor<T> apply = Tensor$.MODULE$.apply(_2$mcI$sp, this.evidence$19$1, this.ev$1);
        apply.copy(((Tensor) Predef$.MODULE$.refArrayOps(cacheV1.modelWeights()).head()).narrow(1, _1$mcI$sp, _2$mcI$sp));
        return package$.MODULE$.Iterator().single(new Tuple2(Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(BoxesRunTime.boxToInteger(partitionId)), apply)})), Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(BoxesRunTime.boxToInteger(partitionId)), this.parameters$1.gradientPartition())}))));
    }

    public InternalDistriOptimizer$$anonfun$21(AllReduceParameter allReduceParameter, ClassTag classTag, TensorNumericMath.TensorNumeric tensorNumeric) {
        this.parameters$1 = allReduceParameter;
        this.evidence$19$1 = classTag;
        this.ev$1 = tensorNumeric;
    }
}
