package com.intel.analytics.bigdl.ppml.nn;

import com.intel.analytics.bigdl.dllib.nn.CAddTable$;
import com.intel.analytics.bigdl.dllib.nn.Sequential;
import com.intel.analytics.bigdl.dllib.nn.Sequential$;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractCriterion;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.dllib.optim.OptimMethod;
import com.intel.analytics.bigdl.dllib.optim.ValidationMethod;
import com.intel.analytics.bigdl.dllib.optim.ValidationResult;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.Tensor$;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath$TensorNumeric$NumericFloat$;
import com.intel.analytics.bigdl.dllib.utils.T$;
import com.intel.analytics.bigdl.dllib.utils.Table;
import com.intel.analytics.bigdl.ppml.common.FLPhase;
import com.intel.analytics.bigdl.ppml.common.Storage;
import com.intel.analytics.bigdl.ppml.generated.FlBaseProto;
import com.intel.analytics.bigdl.ppml.utils.ProtoUtils$;
import org.apache.logging.log4j.Logger;
import scala.Array$;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.immutable.List;
import scala.collection.immutable.List$;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: VflNNAggregator.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005Mc\u0001B\n\u0015\u0001\u0005B\u0001B\n\u0001\u0003\u0002\u0003\u0006Ia\n\u0005\t\u007f\u0001\u0011\t\u0011)A\u0005\u0001\"A\u0001\n\u0001B\u0001B\u0003%\u0011\n\u0003\u0005M\u0001\t\u0005\t\u0015!\u0003N\u0011\u0015\u0019\u0006\u0001\"\u0001U\u0011\u001dQ\u0006A1A\u0005\u0002mCa!\u0019\u0001!\u0002\u0013a\u0006b\u00022\u0001\u0001\u0004%\ta\u0019\u0005\ba\u0002\u0001\r\u0011\"\u0001r\u0011\u00199\b\u0001)Q\u0005I\")\u0001\u0010\u0001C!s\u001e9\u0011Q\u0001\u000b\t\u0002\u0005\u001daAB\n\u0015\u0011\u0003\tI\u0001\u0003\u0004T\u001b\u0011\u0005\u0011\u0011\u0003\u0005\n\u0003'i!\u0019!C\u0001\u0003+A\u0001\"a\f\u000eA\u0003%\u0011q\u0003\u0005\b\u0003ciA\u0011AA\u001a\u0011\u001d\t\t$\u0004C\u0001\u0003\u000f\u0012qB\u00164m\u001d:\u000bum\u001a:fO\u0006$xN\u001d\u0006\u0003+Y\t!A\u001c8\u000b\u0005]A\u0012\u0001\u00029q[2T!!\u0007\u000e\u0002\u000b\tLw\r\u001a7\u000b\u0005ma\u0012!C1oC2LH/[2t\u0015\tib$A\u0003j]R,GNC\u0001 \u0003\r\u0019w.\\\u0002\u0001'\t\u0001!\u0005\u0005\u0002$I5\tA#\u0003\u0002&)\taaJT!hOJ,w-\u0019;pe\u0006)Qn\u001c3fYB\u0019\u0001FN\u001d\u000f\u0005%\"dB\u0001\u00164\u001d\tY#G\u0004\u0002-c9\u0011Q\u0006M\u0007\u0002])\u0011q\u0006I\u0001\u0007yI|w\u000e\u001e \n\u0003}I!!\b\u0010\n\u0005ma\u0012BA\r\u001b\u0013\t)\u0004$A\u0004qC\u000e\\\u0017mZ3\n\u0005]B$AB'pIVdWM\u0003\u000261A\u0011!(P\u0007\u0002w)\tA(A\u0003tG\u0006d\u0017-\u0003\u0002?w\t)a\t\\8bi\u0006Yq\u000e\u001d;j[6+G\u000f[8e!\r\te)O\u0007\u0002\u0005*\u00111\tR\u0001\u0006_B$\u0018.\u001c\u0006\u0003\u000bb\tQ\u0001\u001a7mS\nL!a\u0012\"\u0003\u0017=\u0003H/[7NKRDw\u000eZ\u0001\nGJLG/\u001a:j_:\u00042\u0001\u000b&:\u0013\tY\u0005HA\u0005De&$XM]5p]\u0006\tb/\u00197jI\u0006$\u0018n\u001c8NKRDw\u000eZ:\u0011\u0007ir\u0005+\u0003\u0002Pw\t)\u0011I\u001d:bsB\u0019\u0011)U\u001d\n\u0005I\u0013%\u0001\u0005,bY&$\u0017\r^5p]6+G\u000f[8e\u0003\u0019a\u0014N\\5u}Q)QKV,Y3B\u00111\u0005\u0001\u0005\u0006M\u0015\u0001\ra\n\u0005\u0006\u007f\u0015\u0001\r\u0001\u0011\u0005\u0006\u0011\u0016\u0001\r!\u0013\u0005\u0006\u0019\u0016\u0001\r!T\u0001\u0007[>$W\u000f\\3\u0016\u0003q\u00032!X0:\u001b\u0005q&BA\u000bE\u0013\t\u0001gL\u0001\u0006TKF,XM\u001c;jC2\fq!\\8ek2,\u0007%\u0001\twC2LG-\u0019;j_:\u0014Vm];miV\tA\rE\u0002fU2l\u0011A\u001a\u0006\u0003O\"\f\u0011\"[7nkR\f'\r\\3\u000b\u0005%\\\u0014AC2pY2,7\r^5p]&\u00111N\u001a\u0002\u0005\u0019&\u001cH\u000fE\u0002;\u001d6\u0004\"!\u00118\n\u0005=\u0014%\u0001\u0005,bY&$\u0017\r^5p]J+7/\u001e7u\u0003Q1\u0018\r\\5eCRLwN\u001c*fgVdGo\u0018\u0013fcR\u0011!/\u001e\t\u0003uML!\u0001^\u001e\u0003\tUs\u0017\u000e\u001e\u0005\bm&\t\t\u00111\u0001e\u0003\rAH%M\u0001\u0012m\u0006d\u0017\u000eZ1uS>t'+Z:vYR\u0004\u0013!C1hOJ,w-\u0019;f)\t\u0011(\u0010C\u0003|\u0017\u0001\u0007A0A\u0004gYBC\u0017m]3\u0011\u0007u\f\t!D\u0001\u007f\u0015\tyh#\u0001\u0004d_6lwN\\\u0005\u0004\u0003\u0007q(a\u0002$M!\"\f7/Z\u0001\u0010-\u001adgJT!hOJ,w-\u0019;peB\u00111%D\n\u0004\u001b\u0005-\u0001c\u0001\u001e\u0002\u000e%\u0019\u0011qB\u001e\u0003\r\u0005s\u0017PU3g)\t\t9!\u0001\u0004m_\u001e<WM]\u000b\u0003\u0003/\u0001B!!\u0007\u0002,5\u0011\u00111\u0004\u0006\u0005\u0003;\ty\"A\u0003m_\u001e$$N\u0003\u0003\u0002\"\u0005\r\u0012a\u00027pO\u001eLgn\u001a\u0006\u0005\u0003K\t9#\u0001\u0004ba\u0006\u001c\u0007.\u001a\u0006\u0003\u0003S\t1a\u001c:h\u0013\u0011\ti#a\u0007\u0003\r1{wmZ3s\u0003\u001dawnZ4fe\u0002\nQ!\u00199qYf$\u0012\"VA\u001b\u0003\u007f\t\u0019%!\u0012\t\u000f\u0005]\u0012\u00031\u0001\u0002:\u0005I1\r\\5f]RtU/\u001c\t\u0004u\u0005m\u0012bAA\u001fw\t\u0019\u0011J\u001c;\t\r\u0005\u0005\u0013\u00031\u0001(\u0003)\u0019G.Y:tS\u001aLWM\u001d\u0005\u0006\u007fE\u0001\r\u0001\u0011\u0005\u0006\u0011F\u0001\r!\u0013\u000b\f+\u0006%\u00131JA'\u0003\u001f\n\t\u0006C\u0004\u00028I\u0001\r!!\u000f\t\r\u0005\u0005#\u00031\u0001(\u0011\u0015y$\u00031\u0001A\u0011\u0015A%\u00031\u0001J\u0011\u0015a%\u00031\u0001N\u0001")
/* loaded from: input_file:com/intel/analytics/bigdl/ppml/nn/VflNNAggregator.class */
public class VflNNAggregator extends NNAggregator {
    private final AbstractCriterion<Activity, Activity, Object> criterion;
    private final ValidationMethod<Object>[] validationMethods;
    private final Sequential<Object> module = Sequential$.MODULE$.apply$mFc$sp(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$).add(CAddTable$.MODULE$.apply(CAddTable$.MODULE$.apply$default$1(), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$));
    private List<ValidationResult[]> validationResult;

    public static VflNNAggregator apply(int i, AbstractModule<Activity, Activity, Object> abstractModule, OptimMethod<Object> optimMethod, AbstractCriterion<Activity, Activity, Object> abstractCriterion, ValidationMethod<Object>[] validationMethodArr) {
        return VflNNAggregator$.MODULE$.apply(i, abstractModule, optimMethod, abstractCriterion, validationMethodArr);
    }

    public static VflNNAggregator apply(int i, AbstractModule<Activity, Activity, Object> abstractModule, OptimMethod<Object> optimMethod, AbstractCriterion<Activity, Activity, Object> abstractCriterion) {
        return VflNNAggregator$.MODULE$.apply(i, abstractModule, optimMethod, abstractCriterion);
    }

    public static Logger logger() {
        return VflNNAggregator$.MODULE$.logger();
    }

    public Sequential<Object> module() {
        return this.module;
    }

    public List<ValidationResult[]> validationResult() {
        return this.validationResult;
    }

    public void validationResult_$eq(List<ValidationResult[]> list) {
        this.validationResult = list;
    }

    @Override // com.intel.analytics.bigdl.ppml.common.Aggregator
    public void aggregate(FLPhase fLPhase) {
        FlBaseProto.TensorMap m1062build;
        Storage<FlBaseProto.TensorMap> storage = getStorage(fLPhase);
        Tuple2<Table, Tensor<Object>> tableProtoToOutputTarget = ProtoUtils$.MODULE$.tableProtoToOutputTarget(storage);
        if (tableProtoToOutputTarget == null) {
            throw new MatchError(tableProtoToOutputTarget);
        }
        Tuple2 tuple2 = new Tuple2((Table) tableProtoToOutputTarget._1(), (Tensor) tableProtoToOutputTarget._2());
        Table table = (Table) tuple2._1();
        Tensor tensor = (Tensor) tuple2._2();
        Activity forward = module().forward(table);
        FlBaseProto.MetaData.Builder newBuilder = FlBaseProto.MetaData.newBuilder();
        if (FLPhase.TRAIN.equals(fLPhase)) {
            float unboxToFloat = BoxesRunTime.unboxToFloat(this.criterion.forward(forward, tensor));
            Activity backward = module().backward(table, this.criterion.backward(forward, tensor));
            m1062build = FlBaseProto.TensorMap.newBuilder().setMetaData(newBuilder.setName("gradInput").setVersion(storage.version).m1013build()).putTensors("gradInput", ProtoUtils$.MODULE$.toFloatTensor((Tensor<Object>) backward.toTable().apply(BoxesRunTime.boxToInteger(1)))).putTensors("loss", ProtoUtils$.MODULE$.toFloatTensor(Tensor$.MODULE$.apply(T$.MODULE$.apply(BoxesRunTime.boxToFloat(unboxToFloat), Predef$.MODULE$.genericWrapArray(new Object[0])), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$))).m1062build();
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        } else if (FLPhase.EVAL.equals(fLPhase)) {
            validationResult_$eq((List) validationResult().$colon$plus((ValidationResult[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(this.validationMethods)).map(validationMethod -> {
                return validationMethod.apply(forward, tensor);
            }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ValidationResult.class))), List$.MODULE$.canBuildFrom()));
            if (shouldReturn()) {
                setReturnMessage(((ValidationResult[]) validationResult().reduce((validationResultArr, validationResultArr2) -> {
                    return (ValidationResult[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(validationResultArr)).zip(Predef$.MODULE$.wrapRefArray(validationResultArr2), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class))))).map(tuple22 -> {
                        if (tuple22 != null) {
                            return ((ValidationResult) tuple22._1()).$plus((ValidationResult) tuple22._2());
                        }
                        throw new MatchError(tuple22);
                    }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ValidationResult.class)));
                })).toString());
            }
            m1062build = FlBaseProto.TensorMap.newBuilder().setMetaData(newBuilder.setName("evaluateResult").setVersion(storage.version).m1013build()).m1062build();
            BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
        } else {
            if (!FLPhase.PREDICT.equals(fLPhase)) {
                throw new MatchError(fLPhase);
            }
            m1062build = FlBaseProto.TensorMap.newBuilder().setMetaData(newBuilder.setName("predictResult").setVersion(storage.version).m1013build()).putTensors("predictOutput", ProtoUtils$.MODULE$.toFloatTensor(forward.toTensor(TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$))).m1062build();
            BoxedUnit boxedUnit3 = BoxedUnit.UNIT;
        }
        storage.clearClientAndUpdateServer(m1062build);
    }

    public VflNNAggregator(AbstractModule<Activity, Activity, Object> abstractModule, OptimMethod<Object> optimMethod, AbstractCriterion<Activity, Activity, Object> abstractCriterion, ValidationMethod<Object>[] validationMethodArr) {
        this.criterion = abstractCriterion;
        this.validationMethods = validationMethodArr;
        if (abstractModule != null) {
            module().add(abstractModule);
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        this.validationResult = Nil$.MODULE$;
    }
}
