package com.intel.analytics.bigdl.ppml.nn;

import com.intel.analytics.bigdl.ppml.common.FLPhase;
import com.intel.analytics.bigdl.ppml.common.Storage;
import com.intel.analytics.bigdl.ppml.generated.FlBaseProto;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import scala.Predef$;
import scala.collection.IterableLike;
import scala.collection.JavaConversions$;
import scala.collection.JavaConverters$;
import scala.collection.MapLike;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.RichInt$;

/* compiled from: HflNNAggregator.scala */
@ScalaSignature(bytes = "\u0006\u0001\t3Aa\u0002\u0005\u0001+!)!\u0004\u0001C\u00017!9Q\u0004\u0001a\u0001\n#q\u0002bB\u0014\u0001\u0001\u0004%\t\u0002\u000b\u0005\u0007c\u0001\u0001\u000b\u0015B\u0010\t\u000bI\u0002A\u0011I\u001a\t\u0015q\u0002\u0001\u0013!A\u0001\u0002\u0013\u0005QHA\bIM2te*Q4he\u0016<\u0017\r^8s\u0015\tI!\"\u0001\u0002o]*\u00111\u0002D\u0001\u0005aBlGN\u0003\u0002\u000e\u001d\u0005)!-[4eY*\u0011q\u0002E\u0001\nC:\fG.\u001f;jGNT!!\u0005\n\u0002\u000b%tG/\u001a7\u000b\u0003M\t1aY8n\u0007\u0001\u0019\"\u0001\u0001\f\u0011\u0005]AR\"\u0001\u0005\n\u0005eA!\u0001\u0004(O\u0003\u001e<'/Z4bi>\u0014\u0018A\u0002\u001fj]&$h\bF\u0001\u001d!\t9\u0002!A\u0005n_\u0012,GNT1nKV\tq\u0004\u0005\u0002!K5\t\u0011E\u0003\u0002#G\u0005!A.\u00198h\u0015\u0005!\u0013\u0001\u00026bm\u0006L!AJ\u0011\u0003\rM#(/\u001b8h\u00035iw\u000eZ3m\u001d\u0006lWm\u0018\u0013fcR\u0011\u0011f\f\t\u0003U5j\u0011a\u000b\u0006\u0002Y\u0005)1oY1mC&\u0011af\u000b\u0002\u0005+:LG\u000fC\u00041\u0007\u0005\u0005\t\u0019A\u0010\u0002\u0007a$\u0013'\u0001\u0006n_\u0012,GNT1nK\u0002\n\u0011\"Y4he\u0016<\u0017\r^3\u0015\u0005%\"\u0004\"B\u001b\u0006\u0001\u00041\u0014a\u00024m!\"\f7/\u001a\t\u0003oij\u0011\u0001\u000f\u0006\u0003s)\taaY8n[>t\u0017BA\u001e9\u0005\u001d1E\n\u00155bg\u0016\f1\u0003\u001d:pi\u0016\u001cG/\u001a3%G2LWM\u001c;Ok6$\"AP!\u0011\u0005\u0001z\u0014B\u0001!\"\u0005\u001dIe\u000e^3hKJDq\u0001\r\u0004\u0002\u0002\u0003\u0007A\u0004")
/* loaded from: input_file:com/intel/analytics/bigdl/ppml/nn/HflNNAggregator.class */
public class HflNNAggregator extends NNAggregator {
    private String modelName = "averaged";

    public /* synthetic */ Integer protected$clientNum(HflNNAggregator hflNNAggregator) {
        return hflNNAggregator.clientNum;
    }

    public String modelName() {
        return this.modelName;
    }

    public void modelName_$eq(String str) {
        this.modelName = str;
    }

    @Override // com.intel.analytics.bigdl.ppml.common.Aggregator
    public void aggregate(FLPhase fLPhase) {
        HashMap hashMap = new HashMap();
        Storage<FlBaseProto.TensorMap> tensorMapStorage = this.aggregateTypeMap.get(fLPhase).getTensorMapStorage();
        ((MapLike) JavaConverters$.MODULE$.mapAsScalaMapConverter(tensorMapStorage.clientData).asScala()).values().foreach(tensorMap -> {
            $anonfun$aggregate$1(hashMap, tensorMap);
            return BoxedUnit.UNIT;
        });
        HashMap hashMap2 = new HashMap();
        JavaConversions$.MODULE$.deprecated$u0020asScalaSet(hashMap.keySet()).foreach(str -> {
            List<Integer> shapeList = ((FlBaseProto.FloatTensor) hashMap.get(str)).getShapeList();
            List<Float> tensorList = ((FlBaseProto.FloatTensor) hashMap.get(str)).getTensorList();
            ArrayList arrayList = new ArrayList();
            RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), tensorList.size()).foreach(i -> {
                return arrayList.add(Predef$.MODULE$.float2Float(Predef$.MODULE$.Float2float((Float) tensorList.get(i)) / Predef$.MODULE$.Integer2int(this.protected$clientNum(this))));
            });
            return (FlBaseProto.FloatTensor) hashMap2.put(str, FlBaseProto.FloatTensor.newBuilder().addAllTensor(arrayList).addAllShape(shapeList).m966build());
        });
        tensorMapStorage.clearClientAndUpdateServer(FlBaseProto.TensorMap.newBuilder().setMetaData(FlBaseProto.MetaData.newBuilder().setName(modelName()).setVersion(tensorMapStorage.version + 1).m1013build()).putAllTensors(hashMap2).m1062build());
    }

    public static final /* synthetic */ void $anonfun$aggregate$1(HashMap hashMap, FlBaseProto.TensorMap tensorMap) {
        Map<String, FlBaseProto.FloatTensor> tensorsMap = tensorMap.getTensorsMap();
        ((IterableLike) JavaConverters$.MODULE$.asScalaSetConverter(tensorsMap.keySet()).asScala()).foreach(str -> {
            List<Integer> shapeList = ((FlBaseProto.FloatTensor) tensorsMap.get(str)).getShapeList();
            List<Float> tensorList = ((FlBaseProto.FloatTensor) tensorsMap.get(str)).getTensorList();
            if (hashMap.get(str) == null) {
                return (FlBaseProto.FloatTensor) hashMap.put(str, FlBaseProto.FloatTensor.newBuilder().addAllTensor(tensorList).addAllShape(shapeList).m966build());
            }
            List<Integer> shapeList2 = ((FlBaseProto.FloatTensor) hashMap.get(str)).getShapeList();
            List<Float> tensorList2 = ((FlBaseProto.FloatTensor) hashMap.get(str)).getTensorList();
            ArrayList arrayList = new ArrayList();
            RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), tensorList2.size()).foreach(i -> {
                return arrayList.add(Predef$.MODULE$.float2Float(Predef$.MODULE$.Float2float((Float) tensorList.get(i)) + Predef$.MODULE$.Float2float((Float) tensorList2.get(i))));
            });
            return (FlBaseProto.FloatTensor) hashMap.put(str, FlBaseProto.FloatTensor.newBuilder().addAllTensor(arrayList).addAllShape(shapeList2).m966build());
        });
    }
}
