package com.intel.analytics.bigdl.dllib.optim;

import com.intel.analytics.bigdl.dllib.feature.dataset.AbstractDataSet;
import com.intel.analytics.bigdl.dllib.feature.dataset.DataSet$;
import com.intel.analytics.bigdl.dllib.feature.dataset.DistributedDataSet;
import com.intel.analytics.bigdl.dllib.feature.dataset.LocalDataSet;
import com.intel.analytics.bigdl.dllib.feature.dataset.MiniBatch;
import com.intel.analytics.bigdl.dllib.feature.dataset.PaddingParam;
import com.intel.analytics.bigdl.dllib.feature.dataset.Sample;
import com.intel.analytics.bigdl.dllib.feature.dataset.SampleToMiniBatch$;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractCriterion;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.dllib.optim.parameters.ParameterProcessor;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.dllib.utils.Engine$;
import com.intel.analytics.bigdl.dllib.utils.Log4Error$;
import com.intel.analytics.bigdl.dllib.utils.OptimizerV1$;
import com.intel.analytics.bigdl.dllib.utils.OptimizerV2$;
import com.intel.analytics.bigdl.dllib.utils.OptimizerVersion;
import com.intel.analytics.bigdl.dllib.utils.Table;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.spark.rdd.RDD;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Some;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableOnce;
import scala.collection.immutable.Iterable$;
import scala.collection.immutable.Map;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.ArrayBuffer;
import scala.collection.mutable.StringBuilder;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.reflect.package$;
import scala.runtime.BoxesRunTime;
import scala.runtime.Null$;

/* compiled from: Optimizer.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/dllib/optim/Optimizer$.class */
public final class Optimizer$ {
    public static final Optimizer$ MODULE$ = null;
    private final Logger com$intel$analytics$bigdl$dllib$optim$Optimizer$$logger;

    static {
        new Optimizer$();
    }

    public Logger com$intel$analytics$bigdl$dllib$optim$Optimizer$$logger() {
        return this.com$intel$analytics$bigdl$dllib$optim$Optimizer$$logger;
    }

    public String header(int i, int i2, long j, int i3, long j2) {
        return new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"[Epoch ", " ", "/", "][Iteration ", "][Wall Clock ", "s]"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(i), BoxesRunTime.boxToInteger(i2), BoxesRunTime.boxToLong(j), BoxesRunTime.boxToInteger(i3), BoxesRunTime.boxToDouble(j2 / 1.0E9d)}));
    }

    public <T> void checkSubModules(AbstractModule<Activity, Activity, T> abstractModule, Seq<String> seq, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Tuple2[] tuple2Arr = (Tuple2[]) ((TraversableOnce) seq.map(new Optimizer$$anonfun$1(abstractModule, abstractModule.getParameters()), Seq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.apply(Tuple2.class));
        if (tuple2Arr.length == 1) {
            return;
        }
        Tuple2[] tuple2Arr2 = (Tuple2[]) Predef$.MODULE$.refArrayOps(tuple2Arr).sortWith(new Optimizer$$anonfun$2());
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 >= tuple2Arr2.length - 1) {
                return;
            }
            Tuple2 tuple2 = tuple2Arr2[i2];
            Tuple2 tuple22 = tuple2Arr2[i2 + 1];
            Log4Error$.MODULE$.invalidOperationError(((Tensor) tuple2._2()).storageOffset() + ((Tensor) tuple2._2()).nElement() <= ((Tensor) tuple22._2()).storageOffset(), new StringBuilder().append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Optimizer: ", " and ", "'s parameters are duplicated."})).s(Predef$.MODULE$.genericWrapArray(new Object[]{tuple2._1(), tuple22._1()}))).append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{" Please check your model and optimMethods."})).s(Nil$.MODULE$)).toString(), Log4Error$.MODULE$.invalidOperationError$default$3(), Log4Error$.MODULE$.invalidOperationError$default$4());
            i = i2 + 1;
        }
    }

    public String getHyperParameterLog(Map<String, OptimMethod<?>> map) {
        return (String) ((TraversableOnce) map.map(new Optimizer$$anonfun$getHyperParameterLog$1(), Iterable$.MODULE$.canBuildFrom())).reduce(new Optimizer$$anonfun$getHyperParameterLog$2());
    }

    public <T> void saveModel(AbstractModule<Activity, Activity, T> abstractModule, Option<String> option, boolean z, String str) {
        if (option.isDefined()) {
            abstractModule.save(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", "/model", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{option.get(), str})), z);
        }
    }

    public <T> String saveModel$default$4() {
        return "";
    }

    public void saveState(Table table, Option<String> option, boolean z, String str) {
        if (option.isDefined()) {
            table.save(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", "/state", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{option.get(), str})), z);
        }
    }

    public String saveState$default$4() {
        return "";
    }

    public <T> void saveOptimMethod(OptimMethod<T> optimMethod, Option<String> option, boolean z, String str, ClassTag<T> classTag) {
        if (option.isDefined()) {
            optimMethod.save(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", "/optimMethod", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{option.get(), str})), z);
        }
    }

    public <T> String saveOptimMethod$default$4() {
        return "";
    }

    public <T> Optimizer<T, MiniBatch<T>> apply(AbstractModule<Activity, Activity, T> abstractModule, RDD<Sample<T>> rdd, AbstractCriterion<Activity, Activity, T> abstractCriterion, int i, PaddingParam<T> paddingParam, PaddingParam<T> paddingParam2, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Optimizer distriOptimizerV2;
        None$ some = paddingParam == null ? None$.MODULE$ : new Some(paddingParam);
        None$ some2 = paddingParam2 == null ? None$.MODULE$ : new Some(paddingParam2);
        OptimizerVersion optimizerVersion = Engine$.MODULE$.getOptimizerVersion();
        if (OptimizerV1$.MODULE$.equals(optimizerVersion)) {
            distriOptimizerV2 = new DistriOptimizer(abstractModule, DataSet$.MODULE$.rdd(rdd, DataSet$.MODULE$.rdd$default$2(), DataSet$.MODULE$.rdd$default$3(), DataSet$.MODULE$.rdd$default$4(), ClassTag$.MODULE$.apply(Sample.class)).$minus$greater(SampleToMiniBatch$.MODULE$.apply(i, some, some2, SampleToMiniBatch$.MODULE$.apply$default$4(), SampleToMiniBatch$.MODULE$.apply$default$5(), classTag, tensorNumeric), ClassTag$.MODULE$.apply(MiniBatch.class)).toDistributed(), abstractCriterion, classTag, tensorNumeric);
        } else {
            if (!OptimizerV2$.MODULE$.equals(optimizerVersion)) {
                throw new MatchError(optimizerVersion);
            }
            distriOptimizerV2 = new DistriOptimizerV2(abstractModule, DataSet$.MODULE$.rdd(rdd, DataSet$.MODULE$.rdd$default$2(), DataSet$.MODULE$.rdd$default$3(), DataSet$.MODULE$.rdd$default$4(), ClassTag$.MODULE$.apply(Sample.class)).$minus$greater(SampleToMiniBatch$.MODULE$.apply(i, some, some2, SampleToMiniBatch$.MODULE$.apply$default$4(), SampleToMiniBatch$.MODULE$.apply$default$5(), classTag, tensorNumeric), ClassTag$.MODULE$.apply(MiniBatch.class)).toDistributed(), abstractCriterion, classTag, tensorNumeric);
        }
        return distriOptimizerV2;
    }

    public <T> Optimizer<T, MiniBatch<T>> apply(AbstractModule<Activity, Activity, T> abstractModule, RDD<Sample<T>> rdd, AbstractCriterion<Activity, Activity, T> abstractCriterion, int i, MiniBatch<T> miniBatch, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Optimizer distriOptimizerV2;
        OptimizerVersion optimizerVersion = Engine$.MODULE$.getOptimizerVersion();
        if (OptimizerV1$.MODULE$.equals(optimizerVersion)) {
            distriOptimizerV2 = new DistriOptimizer(abstractModule, DataSet$.MODULE$.rdd(rdd, DataSet$.MODULE$.rdd$default$2(), DataSet$.MODULE$.rdd$default$3(), DataSet$.MODULE$.rdd$default$4(), ClassTag$.MODULE$.apply(Sample.class)).$minus$greater(SampleToMiniBatch$.MODULE$.apply(miniBatch, i, None$.MODULE$, classTag, tensorNumeric), ClassTag$.MODULE$.apply(MiniBatch.class)).toDistributed(), abstractCriterion, classTag, tensorNumeric);
        } else {
            if (!OptimizerV2$.MODULE$.equals(optimizerVersion)) {
                throw new MatchError(optimizerVersion);
            }
            distriOptimizerV2 = new DistriOptimizerV2(abstractModule, DataSet$.MODULE$.rdd(rdd, DataSet$.MODULE$.rdd$default$2(), DataSet$.MODULE$.rdd$default$3(), DataSet$.MODULE$.rdd$default$4(), ClassTag$.MODULE$.apply(Sample.class)).$minus$greater(SampleToMiniBatch$.MODULE$.apply(miniBatch, i, None$.MODULE$, classTag, tensorNumeric), ClassTag$.MODULE$.apply(MiniBatch.class)).toDistributed(), abstractCriterion, classTag, tensorNumeric);
        }
        return distriOptimizerV2;
    }

    public <T, D> Optimizer<T, D> apply(AbstractModule<Activity, Activity, T> abstractModule, AbstractDataSet<D, ?> abstractDataSet, AbstractCriterion<Activity, Activity, T> abstractCriterion, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Optimizer optimizer;
        Optimizer distriOptimizerV2;
        if (abstractDataSet instanceof DistributedDataSet) {
            DistributedDataSet distributedDataSet = (DistributedDataSet) abstractDataSet;
            OptimizerVersion optimizerVersion = Engine$.MODULE$.getOptimizerVersion();
            if (OptimizerV1$.MODULE$.equals(optimizerVersion)) {
                distriOptimizerV2 = new DistriOptimizer(abstractModule, distributedDataSet.toDistributed(), abstractCriterion, classTag, tensorNumeric);
            } else {
                if (!OptimizerV2$.MODULE$.equals(optimizerVersion)) {
                    throw new MatchError(optimizerVersion);
                }
                distriOptimizerV2 = new DistriOptimizerV2(abstractModule, distributedDataSet.toDistributed(), abstractCriterion, classTag, tensorNumeric);
            }
            optimizer = distriOptimizerV2;
        } else if (abstractDataSet instanceof LocalDataSet) {
            optimizer = new LocalOptimizer(abstractModule, ((LocalDataSet) abstractDataSet).toLocal(), abstractCriterion, classTag, tensorNumeric);
        } else {
            Log4Error$.MODULE$.invalidOperationError(false, new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"unexpected type ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{abstractDataSet})), "only support DistributedDataSet and  LocalDataSet", Log4Error$.MODULE$.invalidOperationError$default$4());
            optimizer = null;
        }
        return optimizer;
    }

    public <T> Null$ apply$default$5() {
        return null;
    }

    public <T> Null$ apply$default$6() {
        return null;
    }

    public <T extends ParameterProcessor> int findIndex(ArrayBuffer<ParameterProcessor> arrayBuffer, ClassTag<T> classTag) {
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 >= arrayBuffer.size()) {
                return -1;
            }
            if (package$.MODULE$.classTag(classTag).runtimeClass().isInstance(arrayBuffer.apply(i2))) {
                return i2;
            }
            i = i2 + 1;
        }
    }

    private Optimizer$() {
        MODULE$ = this;
        this.com$intel$analytics$bigdl$dllib$optim$Optimizer$$logger = LogManager.getLogger(getClass());
    }
}
