package com.intel.analytics.bigdl.dllib.keras.models;

import com.intel.analytics.bigdl.dllib.feature.dataset.MiniBatch;
import com.intel.analytics.bigdl.dllib.keras.layers.utils.KerasUtils$;
import com.intel.analytics.bigdl.dllib.models.utils.ModelBroadcast;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.dllib.optim.DistriOptimizer;
import com.intel.analytics.bigdl.dllib.optim.DistriOptimizer$;
import com.intel.analytics.bigdl.dllib.optim.DistriOptimizerV2;
import com.intel.analytics.bigdl.dllib.optim.DistriOptimizerV2$;
import com.intel.analytics.bigdl.dllib.optim.OptimMethod;
import com.intel.analytics.bigdl.dllib.optim.Optimizer;
import com.intel.analytics.bigdl.dllib.optim.parameters.AllReduceParameter;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.dllib.utils.Table;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import org.apache.spark.rdd.RDD;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.Seq;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;

/* compiled from: Topology.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/dllib/keras/models/InternalOptimizerUtil$.class */
public final class InternalOptimizerUtil$ {
    public static final InternalOptimizerUtil$ MODULE$ = null;

    static {
        new InternalOptimizerUtil$();
    }

    public void setExecutorMklThread(RDD<?> rdd) {
        rdd.mapPartitions(new InternalOptimizerUtil$$anonfun$setExecutorMklThread$1(), rdd.mapPartitions$default$2(), ClassTag$.MODULE$.Int()).count();
    }

    public <T> RDD<DistriOptimizer.Cache<T>> getModelCacheFromOptimizer(Optimizer<T, MiniBatch<T>> optimizer, ClassTag<T> classTag) {
        Field declaredField = DistriOptimizer.class.getDeclaredField("models");
        declaredField.setAccessible(true);
        return (RDD) declaredField.get(optimizer);
    }

    public <T> Table getStateFromOptiMethod(OptimMethod<T> optimMethod) {
        Method declaredMethod = OptimMethod.class.getDeclaredMethod("state", new Class[0]);
        declaredMethod.setAccessible(true);
        return (Table) declaredMethod.invoke(optimMethod, new Object[0]);
    }

    public <T> Table getStateFromOptimizer(Optimizer<T, MiniBatch<T>> optimizer, ClassTag<T> classTag) {
        Method declaredMethod = Optimizer.class.getDeclaredMethod("state", new Class[0]);
        declaredMethod.setAccessible(true);
        return (Table) declaredMethod.invoke(optimizer, new Object[0]);
    }

    public <T> void endEpoch(DistriOptimizer<T> distriOptimizer, ClassTag<T> classTag) {
        Method declaredMethod = DistriOptimizer.class.getDeclaredMethod("endEpoch", new Class[0]);
        declaredMethod.setAccessible(true);
        declaredMethod.invoke(distriOptimizer, new Object[0]);
    }

    public <T> void endEpochV2(DistriOptimizerV2<T> distriOptimizerV2, ClassTag<T> classTag) {
        Method declaredMethod = DistriOptimizerV2.class.getDeclaredMethod("endEpoch", new Class[0]);
        declaredMethod.setAccessible(true);
        declaredMethod.invoke(distriOptimizerV2, new Object[0]);
    }

    public <T> Tuple2<Tensor<T>, Tensor<T>> getParametersFromModel(AbstractModule<Activity, Activity, T> abstractModule, ClassTag<T> classTag) {
        Method declaredMethod = AbstractModule.class.getDeclaredMethod("getParameters", new Class[0]);
        declaredMethod.setAccessible(true);
        return (Tuple2) declaredMethod.invoke(abstractModule, new Object[0]);
    }

    public <T> Tuple2<RDD<DistriOptimizer.CacheV1<T>>, ModelBroadcast<T>> initThreadModels(Seq<Object> seq, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        return (Tuple2) KerasUtils$.MODULE$.invokeMethodWithEv(DistriOptimizer$.MODULE$, "com$intel$analytics$bigdl$dllib$optim$DistriOptimizer$$initThreadModels", seq, classTag, tensorNumeric);
    }

    public <T> void clearState(RDD<DistriOptimizer.CacheV1<T>> rdd, ClassTag<T> classTag) {
        KerasUtils$.MODULE$.invokeMethod(DistriOptimizer$.MODULE$, "clearState", Predef$.MODULE$.wrapRefArray(new Object[]{rdd, Predef$.MODULE$.implicitly(classTag)}));
    }

    public <T> void clearStateV2(RDD<DistriOptimizerV2.Cache<T>> rdd, ClassTag<T> classTag) {
        KerasUtils$.MODULE$.invokeMethod(DistriOptimizerV2$.MODULE$, "clearState", Predef$.MODULE$.wrapRefArray(new Object[]{rdd, Predef$.MODULE$.implicitly(classTag)}));
    }

    public <T> void optimizeModels(Seq<Object> seq, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        KerasUtils$.MODULE$.invokeMethodWithEv(DistriOptimizer$.MODULE$, "optimize", seq, classTag, tensorNumeric);
    }

    public <T> void getModel(Seq<Object> seq, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        KerasUtils$.MODULE$.invokeMethodWithEv(DistriOptimizer$.MODULE$, "getModel", seq, classTag, tensorNumeric);
    }

    public <T> void releaseBroadcast(String str, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        KerasUtils$.MODULE$.invokeMethodWithEv("com.intel.analytics.bigdl.dllib.models.utils.CachedModels", "deleteKey", (Seq<Object>) Predef$.MODULE$.wrapRefArray(new Object[]{str}), (ClassTag) classTag, (TensorNumericMath.TensorNumeric) tensorNumeric);
    }

    public <T> Tuple2<Object, Object> getLocalPartitionRangeFromParameters(AllReduceParameter<T> allReduceParameter, ClassTag<T> classTag) {
        return (Tuple2) KerasUtils$.MODULE$.invokeMethod(allReduceParameter, "localPartitionRange", Predef$.MODULE$.wrapRefArray(new Object[0]));
    }

    private InternalOptimizerUtil$() {
        MODULE$ = this;
    }
}
