package com.intel.analytics.bigdl.dllib.utils.intermediate;

import com.intel.analytics.bigdl.dllib.nn.Graph;
import com.intel.analytics.bigdl.dllib.nn.StaticGraph;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.DnnGraph;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.MklDnnContainer;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.MklDnnModule;
import com.intel.analytics.bigdl.dllib.nn.quantized.Quantization$;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.dllib.utils.Engine$;
import com.intel.analytics.bigdl.dllib.utils.EngineType;
import com.intel.analytics.bigdl.dllib.utils.MklDnn$;
import org.apache.spark.rdd.RDD;
import scala.Option;
import scala.collection.immutable.Nil$;
import scala.reflect.ClassTag;

/* compiled from: ConversionUtils.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/dllib/utils/intermediate/ConversionUtils$.class */
public final class ConversionUtils$ {
    public static final ConversionUtils$ MODULE$ = null;

    static {
        new ConversionUtils$();
    }

    public <T> AbstractModule<Activity, Activity, T> convert(AbstractModule<Activity, Activity, T> abstractModule, ClassTag<T> classTag) {
        if (abstractModule instanceof IRGraph) {
            IRGraph iRGraph = (IRGraph) abstractModule;
            return iRGraph.isBuild() ? iRGraph : iRGraph.build();
        }
        if (!(abstractModule instanceof MklDnnModule)) {
            EngineType engineType = Engine$.MODULE$.getEngineType();
            MklDnn$ mklDnn$ = MklDnn$.MODULE$;
            if (engineType != null ? engineType.equals(mklDnn$) : mklDnn$ == null) {
                AbstractModule<Activity, Activity, T> graph = abstractModule instanceof Graph ? abstractModule : abstractModule.toGraph(Nil$.MODULE$);
                if (!(graph instanceof StaticGraph)) {
                    return abstractModule;
                }
                IRGraph<T> iRgraph = ((StaticGraph) graph).toIRgraph();
                if (abstractModule.isTraining()) {
                    iRgraph.training2();
                } else {
                    iRgraph.evaluate2();
                }
                return iRgraph;
            }
        }
        return abstractModule;
    }

    public <T> AbstractModule<Activity, Activity, T> convert(AbstractModule<Activity, Activity, T> abstractModule, boolean z, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        return getInt8ModelIfNeeded(convert(abstractModule, classTag), z, classTag, tensorNumeric);
    }

    public <T> RDD<T> coalesce(RDD<T> rdd, ClassTag<T> classTag) {
        if (rdd.partitions().length == Engine$.MODULE$.nodeNumber() || Engine$.MODULE$.isMultiModels()) {
            return rdd;
        }
        int nodeNumber = Engine$.MODULE$.nodeNumber();
        Option coalesce$default$3 = rdd.coalesce$default$3();
        return rdd.coalesce(nodeNumber, false, coalesce$default$3, rdd.coalesce$default$4(nodeNumber, false, coalesce$default$3));
    }

    /* JADX WARN: Multi-variable type inference failed */
    private <T> AbstractModule<Activity, Activity, T> getInt8ModelIfNeeded(AbstractModule<Activity, Activity, T> abstractModule, boolean z, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        AbstractModule<Activity, Activity, T> quantize;
        if (abstractModule instanceof IRGraph) {
            IRGraph iRGraph = (IRGraph) abstractModule;
            quantize = z ? iRGraph.setQuantize(true) : iRGraph;
        } else if (abstractModule instanceof DnnGraph) {
            DnnGraph dnnGraph = (DnnGraph) abstractModule;
            quantize = z ? ((DnnGraph) dnnGraph.cloneModule()).setQuantize(true) : dnnGraph;
        } else if (abstractModule instanceof MklDnnContainer) {
            Object obj = (MklDnnContainer) abstractModule;
            quantize = (AbstractModule) (z ? ((MklDnnContainer) ((AbstractModule) obj).cloneModule()).setQuantize(true) : obj);
        } else {
            quantize = z ? Quantization$.MODULE$.quantize(abstractModule, classTag, tensorNumeric) : abstractModule;
        }
        return quantize;
    }

    private ConversionUtils$() {
        MODULE$ = this;
    }
}
