package com.intel.analytics.bigdl.dllib.optim;

import com.intel.analytics.bigdl.dllib.feature.dataset.AbstractDataSet;
import com.intel.analytics.bigdl.dllib.feature.dataset.DistributedDataSet;
import com.intel.analytics.bigdl.dllib.feature.dataset.MiniBatch;
import com.intel.analytics.bigdl.dllib.models.utils.ModelBroadcast;
import com.intel.analytics.bigdl.dllib.models.utils.ModelBroadcast$;
import com.intel.analytics.bigdl.dllib.nn.Container;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.dllib.optim.DistriOptimizer;
import com.intel.analytics.bigdl.dllib.optim.DistriOptimizerV2;
import com.intel.analytics.bigdl.dllib.optim.parameters.AllReduceParameter;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.dllib.utils.Engine$;
import com.intel.analytics.bigdl.dllib.utils.Log4Error$;
import com.intel.analytics.bigdl.dllib.utils.Table;
import com.intel.analytics.bigdl.dllib.utils.intermediate.ConversionUtils$;
import com.intel.analytics.bigdl.dllib.visualization.TrainSummary;
import com.intel.analytics.bigdl.dllib.visualization.ValidationSummary;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.spark.SparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.rdd.RDD;
import org.apache.spark.util.DoubleAccumulator;
import scala.MatchError;
import scala.Option;
import scala.Predef$;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.immutable.IndexedSeq$;
import scala.collection.immutable.Iterable$;
import scala.collection.immutable.Map;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.ArrayBuffer;
import scala.collection.mutable.ArrayBuffer$;
import scala.collection.mutable.StringBuilder;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;
import scala.runtime.ScalaRunTime$;
import scala.runtime.VolatileObjectRef;

/* compiled from: DistriOptimizerV2.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/dllib/optim/DistriOptimizerV2$.class */
public final class DistriOptimizerV2$ extends AbstractOptimizer {
    public static final DistriOptimizerV2$ MODULE$ = null;
    private final Logger logger;

    static {
        new DistriOptimizerV2$();
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v0 */
    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Throwable] */
    /* JADX WARN: Type inference failed for: r0v5 */
    private DistriOptimizerV2$TrainingConfig$4$ TrainingConfig$2$lzycompute(VolatileObjectRef volatileObjectRef) {
        ?? r0 = this;
        synchronized (r0) {
            if (volatileObjectRef.elem == null) {
                volatileObjectRef.elem = new DistriOptimizerV2$TrainingConfig$4$();
            }
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
            r0 = r0;
            return (DistriOptimizerV2$TrainingConfig$4$) volatileObjectRef.elem;
        }
    }

    public Logger logger() {
        return this.logger;
    }

    public <T> void optimize(MasterCache<T> masterCache, RDD<DistriOptimizerV2.Cache<T>> rdd, DistributedDataSet<MiniBatch<T>> distributedDataSet, Trigger trigger, Option<Trigger> option, Option<AbstractDataSet<MiniBatch<T>, ?>> option2, Option<ValidationMethod<T>[]> option3, Option<Trigger> option4, Option<String> option5, Option<TrainSummary> option6, Option<ValidationSummary> option7, boolean z, TrainingContext<T> trainingContext, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        OptimMethod optimMethod = (OptimMethod) masterCache.optimMethods().values().head();
        trainingContext.loadState(optimMethod.state());
        logger().info(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"config ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{trainingContext.state()})));
        if (BoxesRunTime.unboxToInt(optimMethod.state().apply(StateEntry$.MODULE$.RECORDS_PROCESSED())) == 0) {
            long nanoTime = System.nanoTime();
            logger().info("Shuffle data");
            distributedDataSet.shuffle();
            logger().info(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Shuffle data complete. Takes ", "s"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToDouble((System.nanoTime() - nanoTime) / 1.0E9d)})));
        }
        SparkContext sparkContext = distributedDataSet.originRDD().sparkContext();
        RDD<MiniBatch<T>> data = distributedDataSet.data(true);
        TrainingTrace apply = TrainingTrace$.MODULE$.apply(optimMethod.state());
        while (!trigger.apply(trainingContext.state())) {
            iteration(sparkContext, data, rdd, masterCache, trainingContext, apply, classTag, tensorNumeric);
            if (trainingContext.hasCompleteAllSamples(apply.recordsOfEpoch(), masterCache.model())) {
                distributedDataSet.shuffle();
                data = distributedDataSet.data(true);
            }
            validate(option, option2, option3, trainingContext.subModelNumber(), rdd, trainingContext.state(), option7, Optimizer$.MODULE$.header(apply.epochs() - 1, apply.recordsOfEpoch(), trainingContext.numSamples(), apply.iterations(), apply.trainingTakes()), masterCache.parameter());
            checkpoint(option4, option5, z, apply.trainingTakes(), rdd, trainingContext.state(), masterCache.parameter(), masterCache.optimMethods(), masterCache.model(), classTag, tensorNumeric);
            option6.foreach(new DistriOptimizerV2$$anonfun$optimize$1(masterCache, rdd, trainingContext, classTag, tensorNumeric));
        }
    }

    private void initMetrics(SparkContext sparkContext, Metrics metrics, int i) {
        metrics.set(COMPUTING_TIME_EACH_NODE$.MODULE$.value(), (ArrayBuffer<Object>) ArrayBuffer$.MODULE$.apply(Nil$.MODULE$), sparkContext);
        metrics.set(GET_WEIGHTS_EACH_NODE$.MODULE$.value(), (ArrayBuffer<Object>) ArrayBuffer$.MODULE$.apply(Nil$.MODULE$), sparkContext);
        metrics.set(COMPUTING_TIME_AVERAGE$.MODULE$.value(), 0.0d, sparkContext, i);
        metrics.set(AGGREGATE_GRADIENT_TIME$.MODULE$.value(), 0.0d, sparkContext, i);
        metrics.set(GET_WEIGHTS_AVERAGE$.MODULE$.value(), 0.0d, sparkContext, i);
        metrics.set(PUT_GRADIENT$.MODULE$.value(), 0.0d, sparkContext, Engine$.MODULE$.nodeNumber());
        metrics.set(AGGREGATE_PARTITION_GRADIENT$.MODULE$.value(), 0.0d, sparkContext, Engine$.MODULE$.nodeNumber());
        metrics.set(COMPUTE_WEIGHT_AVERAGE$.MODULE$.value(), 0.0d, sparkContext, Engine$.MODULE$.nodeNumber());
        metrics.set(SEND_WEIGHTS_AVERAGE$.MODULE$.value(), 0.0d, sparkContext, Engine$.MODULE$.nodeNumber());
    }

    private <T> void iteration(SparkContext sparkContext, RDD<MiniBatch<T>> rdd, RDD<DistriOptimizerV2.Cache<T>> rdd2, MasterCache<T> masterCache, TrainingContext<T> trainingContext, TrainingTrace trainingTrace, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        DoubleAccumulator doubleAccumulator = sparkContext.doubleAccumulator("loss sum");
        DoubleAccumulator doubleAccumulator2 = sparkContext.doubleAccumulator("record number");
        Metrics metrics = masterCache.metrics();
        initMetrics(sparkContext, metrics, masterCache.partitionNum());
        trainingTrace.traceIteration(new DistriOptimizerV2$$anonfun$iteration$1(rdd, rdd2, masterCache, trainingContext, classTag, tensorNumeric, doubleAccumulator, doubleAccumulator2, metrics));
        driverStatesUpdate(masterCache, (int) Predef$.MODULE$.Double2double(doubleAccumulator2.value()), trainingContext, trainingTrace, metrics, classTag, tensorNumeric);
    }

    public <T> Tuple2<RDD<DistriOptimizerV2.Cache<T>>, ModelBroadcast<T>> com$intel$analytics$bigdl$dllib$optim$DistriOptimizerV2$$initCacheOfSlave(MasterCache<T> masterCache, DistributedDataSet<MiniBatch<T>> distributedDataSet, TrainingContext<T> trainingContext, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        DistriOptimizerV2$TrainingConfig$3<T> apply = TrainingConfig$2(VolatileObjectRef.zero()).apply(masterCache.criterion(), masterCache.validationMethods(), masterCache.optimMethods(), masterCache.parameterSplits(), masterCache.parameterProcessers());
        SparkContext sparkContext = distributedDataSet.originRDD().sparkContext();
        Broadcast broadcast = sparkContext.broadcast(apply, ClassTag$.MODULE$.apply(DistriOptimizerV2$TrainingConfig$3.class));
        AbstractModule<Activity, Activity, T> convert = ConversionUtils$.MODULE$.convert(masterCache.model(), classTag);
        convert.getParameters();
        ModelBroadcast<T> broadcast2 = ModelBroadcast$.MODULE$.apply(classTag, tensorNumeric).broadcast(sparkContext, convert);
        Engine$.MODULE$.nodeNumber();
        Engine$.MODULE$.coreNumber();
        AllReduceParameter<T> parameter = masterCache.parameter();
        int subModelNumber = trainingContext.subModelNumber();
        Table state = trainingContext.state();
        RDD<?> originRDD = distributedDataSet.originRDD();
        RDD persist = originRDD.mapPartitions(new DistriOptimizerV2$$anonfun$5(classTag, tensorNumeric, broadcast, broadcast2, parameter, subModelNumber, state), originRDD.mapPartitions$default$2(), ClassTag$.MODULE$.apply(DistriOptimizerV2.Cache.class)).persist();
        persist.setName("Thread Model RDD");
        logger().info("Cache thread models...");
        persist.count();
        logger().info("Cache thread models... done");
        return new Tuple2<>(persist, broadcast2);
    }

    public <T> void com$intel$analytics$bigdl$dllib$optim$DistriOptimizerV2$$setModelId(AbstractModule<Activity, Activity, T> abstractModule, int i, ClassTag<T> classTag) {
        abstractModule.setId(i);
        if (abstractModule instanceof Container) {
            ((Container) abstractModule).modules().foreach(new DistriOptimizerV2$$anonfun$com$intel$analytics$bigdl$dllib$optim$DistriOptimizerV2$$setModelId$1(i, classTag));
        }
    }

    @Override // com.intel.analytics.bigdl.dllib.optim.AbstractOptimizer
    public <T> AbstractModule<Activity, Activity, T> getModel(RDD<DistriOptimizer.Cache<T>> rdd, AllReduceParameter<T> allReduceParameter, AbstractModule<Activity, Activity, T> abstractModule, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        int length = rdd.partitions().length;
        abstractModule.setExtraParameter((Tensor[]) rdd.map(new DistriOptimizerV2$$anonfun$8(), ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Tensor.class))).first());
        Tuple2<Tensor<T>[], Tensor<T>[]> parameters = abstractModule.parameters();
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), ((Tensor[]) parameters._2()).length).foreach(new DistriOptimizerV2$$anonfun$getModel$1(parameters));
        Tuple2<Tensor<T>, Tensor<T>> parameters2 = abstractModule.getParameters();
        if (parameters2 == null) {
            throw new MatchError(parameters2);
        }
        Tuple2 tuple2 = new Tuple2((Tensor) parameters2._1(), (Tensor) parameters2._2());
        Tensor tensor = (Tensor) tuple2._1();
        Tensor tensor2 = (Tensor) tuple2._2();
        Tuple2 tuple22 = (Tuple2) rdd.mapPartitions(new DistriOptimizerV2$$anonfun$9(allReduceParameter), rdd.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Tuple2.class)).reduce(new DistriOptimizerV2$$anonfun$10());
        if (tuple22 == null) {
            throw new MatchError(tuple22);
        }
        Tuple2 tuple23 = new Tuple2((Map) tuple22._1(), (Map) tuple22._2());
        Map map = (Map) tuple23._1();
        Map map2 = (Map) tuple23._2();
        int size = allReduceParameter.size() / length;
        Log4Error$.MODULE$.invalidOperationError(size != 0, "parameter length should not less than partition number", Log4Error$.MODULE$.invalidOperationError$default$3(), Log4Error$.MODULE$.invalidOperationError$default$4());
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), length).map(new DistriOptimizerV2$$anonfun$getModel$2(allReduceParameter, tensor, tensor2, map, map2, size, allReduceParameter.size() % length), IndexedSeq$.MODULE$.canBuildFrom());
        return abstractModule;
    }

    public <T> DistriOptimizerV2.TrainingResults com$intel$analytics$bigdl$dllib$optim$DistriOptimizerV2$$train(DistriOptimizerV2.Cache<T> cache, MiniBatch<T>[] miniBatchArr, TrainingContext<T> trainingContext, Metrics metrics, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        int size = ((MiniBatch) Predef$.MODULE$.refArrayOps(miniBatchArr).head()).size();
        ArrayBuffer arrayBuffer = new ArrayBuffer();
        Seq seq = (Seq) TrainingTrace$.MODULE$.time(new DistriOptimizerV2$$anonfun$11(cache, miniBatchArr, trainingContext, classTag, tensorNumeric), metrics, new MetricEntry[]{COMPUTING_TIME_EACH_NODE$.MODULE$, COMPUTING_TIME_AVERAGE$.MODULE$});
        double d = 0.0d;
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 >= seq.size()) {
                TrainingTrace$.MODULE$.time(new DistriOptimizerV2$$anonfun$com$intel$analytics$bigdl$dllib$optim$DistriOptimizerV2$$train$1(cache, (Tensor) TrainingTrace$.MODULE$.time(new DistriOptimizerV2$$anonfun$12(cache, trainingContext, classTag, seq), metrics, new MetricEntry[]{AGGREGATE_GRADIENT_TIME$.MODULE$})), metrics, new MetricEntry[]{PUT_GRADIENT$.MODULE$});
                arrayBuffer.$plus$plus$eq(Engine$.MODULE$.m2082default().invoke((Seq) RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), trainingContext.subModelNumber()).map(new DistriOptimizerV2$$anonfun$com$intel$analytics$bigdl$dllib$optim$DistriOptimizerV2$$train$2(cache), IndexedSeq$.MODULE$.canBuildFrom())));
                return new DistriOptimizerV2.TrainingResults(seq.size(), d, seq.size() * size);
            }
            d += ((LossWithElapsedTime) seq.apply(i2)).loss();
            cache.moduleTimeList()[i2] = ((LossWithElapsedTime) seq.apply(i2)).elapsed();
            i = i2 + 1;
        }
    }

    public <T> void com$intel$analytics$bigdl$dllib$optim$DistriOptimizerV2$$updateStates(Map<String, OptimMethod<T>> map, Table table, boolean z) {
        map.map(new DistriOptimizerV2$$anonfun$com$intel$analytics$bigdl$dllib$optim$DistriOptimizerV2$$updateStates$1(table, z), Iterable$.MODULE$.canBuildFrom());
    }

    private <T> void driverStatesUpdate(MasterCache<T> masterCache, int i, TrainingContext<T> trainingContext, TrainingTrace trainingTrace, Metrics metrics, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Map<String, OptimMethod<T>> optimMethods = masterCache.optimMethods();
        boolean isDefined = masterCache.validationMethods().isDefined();
        optimMethods.foreach(new DistriOptimizerV2$$anonfun$driverStatesUpdate$1());
        long trainingTakes = trainingTrace.trainingTakes();
        long iterationTakes = trainingTrace.iterationTakes();
        float f = i / ((float) (iterationTakes / 1.0E9d));
        String header = Optimizer$.MODULE$.header(trainingTrace.epochs(), trainingTrace.updateRecords(i).recordsOfEpoch(), trainingContext.numSamples(), trainingTrace.iterations(), trainingTakes);
        logger().info(new StringBuilder().append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", " Trained ", " records in ", " seconds. "})).s(Predef$.MODULE$.genericWrapArray(new Object[]{header, BoxesRunTime.boxToInteger(i), BoxesRunTime.boxToFloat((float) (iterationTakes / 1.0E9d))}))).append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Throughput is ", " records/second. "})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToFloat(f)}))).append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Loss is ", ". "})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToFloat(BoxesRunTime.unboxToFloat(trainingContext.state().apply(StateEntry$.MODULE$.LOSS())))}))).append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{Optimizer$.MODULE$.getHyperParameterLog(optimMethods)}))).toString());
        logger().debug(new StringBuilder().append("\n").append(metrics.summary(metrics.summary$default$1(), metrics.summary$default$2())).toString());
        trainingContext.state().update(StateEntry$.MODULE$.THROUGHPUT(), BoxesRunTime.boxToFloat(i / ((float) (iterationTakes / 1.0E9d))));
        trainingContext.state().update(StateEntry$.MODULE$.NEVAL(), BoxesRunTime.boxToInteger(trainingTrace.iterations() + 1));
        trainingContext.state().update(StateEntry$.MODULE$.LEARNING_RATE(), BoxesRunTime.boxToFloat((float) ((OptimMethod) ((Tuple2) optimMethods.head())._2()).getLearningRate()));
        if (trainingContext.hasCompleteAllSamples(trainingTrace.recordsOfEpoch(), masterCache.model())) {
            trainingTrace.startNewEpoch();
            logger().info(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", " Epoch finished. Wall clock time is ", " ms"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{header, BoxesRunTime.boxToDouble(trainingTakes / 1000000.0d)})));
        }
        trainingContext.state().update(StateEntry$.MODULE$.EPOCH(), BoxesRunTime.boxToInteger(trainingTrace.epochs()));
        trainingContext.state().update(StateEntry$.MODULE$.RECORDS_PROCESSED(), BoxesRunTime.boxToInteger(trainingTrace.recordsOfEpoch()));
        com$intel$analytics$bigdl$dllib$optim$DistriOptimizerV2$$updateStates(optimMethods, trainingContext.state(), isDefined);
    }

    public <T> void com$intel$analytics$bigdl$dllib$optim$DistriOptimizerV2$$parameterSync(double d, int i, MasterCache<T> masterCache, RDD<DistriOptimizerV2.Cache<T>> rdd, TrainingContext<T> trainingContext, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Metrics metrics = masterCache.metrics();
        AllReduceParameter<T> parameter = masterCache.parameter();
        boolean isDefined = masterCache.validationMethods().isDefined();
        trainingContext.state().update(StateEntry$.MODULE$.NUM_FINISHED_MODELS(), BoxesRunTime.boxToInteger(i));
        trainingContext.state().update(StateEntry$.MODULE$.IS_GRADIENT_UPDATED(), BoxesRunTime.boxToBoolean(false));
        Predef$.MODULE$.refArrayOps(masterCache.parameterProcessers()).foreach(new DistriOptimizerV2$$anonfun$com$intel$analytics$bigdl$dllib$optim$DistriOptimizerV2$$parameterSync$1(rdd, trainingContext, tensorNumeric, metrics, parameter));
        rdd.mapPartitions(new DistriOptimizerV2$$anonfun$com$intel$analytics$bigdl$dllib$optim$DistriOptimizerV2$$parameterSync$2(d, i, trainingContext, classTag, tensorNumeric, metrics, parameter, isDefined, BoxesRunTime.unboxToBoolean(trainingContext.state().apply(StateEntry$.MODULE$.IS_GRADIENT_UPDATED()))), rdd.mapPartitions$default$2(), classTag).count();
        trainingContext.state().update(StateEntry$.MODULE$.IS_GRADIENT_UPDATED(), BoxesRunTime.boxToBoolean(true));
        trainingContext.state().update(StateEntry$.MODULE$.LOSS(), BoxesRunTime.boxToFloat(((float) d) / i));
    }

    private final DistriOptimizerV2$TrainingConfig$4$ TrainingConfig$2(VolatileObjectRef volatileObjectRef) {
        return volatileObjectRef.elem == null ? TrainingConfig$2$lzycompute(volatileObjectRef) : (DistriOptimizerV2$TrainingConfig$4$) volatileObjectRef.elem;
    }

    private DistriOptimizerV2$() {
        MODULE$ = this;
        this.logger = LogManager.getLogger(getClass());
    }
}
