package com.intel.analytics.bigdl.dllib.keras;

import com.intel.analytics.bigdl.dllib.common.zooUtils$;
import com.intel.analytics.bigdl.dllib.nn.StaticGraph;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.dllib.utils.Log4Error$;
import com.intel.analytics.bigdl.dllib.utils.Node;
import com.intel.analytics.bigdl.dllib.utils.python.api.PythonBigDL;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.InputStreamReader;
import java.nio.file.Path;
import org.apache.commons.io.FileUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.spark.bigdl.api.python.BigDLSerDe$;
import scala.Array$;
import scala.Predef$;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableOnce;
import scala.collection.immutable.Nil$;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.StringBuilder;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;

/* compiled from: Net.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/dllib/keras/Net$NetSaver$.class */
public class Net$NetSaver$ {
    public static final Net$NetSaver$ MODULE$ = null;
    private final Logger logger;
    private final String header;
    private final String tfHeader;

    static {
        new Net$NetSaver$();
    }

    private Logger logger() {
        return this.logger;
    }

    public String header() {
        return this.header;
    }

    public String tfHeader() {
        return this.tfHeader;
    }

    public <T> void save(AbstractModule<Activity, Activity, T> abstractModule, String str, String str2, String str3, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Path createTmpDir = zooUtils$.MODULE$.createTmpDir("ZooKeras", zooUtils$.MODULE$.createTmpDir$default$2());
        logger().info(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Write model's temp file to ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{createTmpDir})));
        String stringBuilder = new StringBuilder().append(createTmpDir.toString()).append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"/", ".py"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{abstractModule.getName()}))).toString();
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(stringBuilder));
        bufferedWriter.write(header());
        if (abstractModule instanceof Sequential) {
            export((Sequential) abstractModule, bufferedWriter, classTag);
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        } else if (abstractModule instanceof Model) {
            export((Model) abstractModule, bufferedWriter, classTag);
            BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
        } else {
            Log4Error$.MODULE$.invalidOperationError(false, new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", " is not supported."})).s(Predef$.MODULE$.genericWrapArray(new Object[]{abstractModule.getClass().getName()})), "Only support Sequential and Model", Log4Error$.MODULE$.invalidOperationError$default$4());
            BoxedUnit boxedUnit3 = BoxedUnit.UNIT;
        }
        bufferedWriter.write(saveWeights(abstractModule, createTmpDir.toString(), classTag, tensorNumeric));
        bufferedWriter.write(str3);
        bufferedWriter.flush();
        bufferedWriter.close();
        execCommand(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", " ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{str2, stringBuilder})));
        FileUtils.deleteDirectory(createTmpDir.toFile());
    }

    public <T> void saveToTf(AbstractModule<Activity, Activity, T> abstractModule, String str, String str2, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        save(abstractModule, str, str2, new StringBuilder().append(tfHeader()).append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"export_tf(K.get_session(), '", "', model.inputs, model.outputs)\\n"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{str}))).toString(), classTag, tensorNumeric);
    }

    public <T> void saveToKeras2(AbstractModule<Activity, Activity, T> abstractModule, String str, String str2, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        save(abstractModule, str, str2, new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"model.save('", "')\\n"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{str})), classTag, tensorNumeric);
    }

    public void execCommand(String str) {
        Process exec = Runtime.getRuntime().exec(str);
        exec.waitFor();
        if (exec.exitValue() == 0) {
            return;
        }
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(exec.getErrorStream()));
        StringBuilder stringBuilder = new StringBuilder();
        String readLine = bufferedReader.readLine();
        while (true) {
            String str2 = readLine;
            if (str2 == null) {
                Log4Error$.MODULE$.unKnowExceptionError(false, new StringBuilder().append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Export Keras2 model failed:\\n"})).s(Nil$.MODULE$)).append(stringBuilder.toString()).toString(), Log4Error$.MODULE$.unKnowExceptionError$default$3(), Log4Error$.MODULE$.unKnowExceptionError$default$4());
                return;
            } else {
                stringBuilder.append(new StringBuilder().append(str2).append("\n").toString());
                readLine = bufferedReader.readLine();
            }
        }
    }

    public <T> void export(Model<T> model, BufferedWriter bufferedWriter, ClassTag<T> classTag) {
        Seq<Node<AbstractModule<Activity, Activity, T>>> inputs = model.getInputs();
        Seq<Node<AbstractModule<Activity, Activity, T>>> outputs = model.getOutputs();
        Predef$.MODULE$.refArrayOps(((StaticGraph) model.labor()).getSortedForwardExecutions()).foreach(new Net$NetSaver$$anonfun$export$1(bufferedWriter, classTag));
        bufferedWriter.write(new StringBuilder().append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", " = Model(inputs=[", "],"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{model.getName(), ((TraversableOnce) inputs.map(new Net$NetSaver$$anonfun$2(), Seq$.MODULE$.canBuildFrom())).mkString(", ")}))).append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{" outputs=[", "])\\n"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{((TraversableOnce) outputs.map(new Net$NetSaver$$anonfun$3(), Seq$.MODULE$.canBuildFrom())).mkString(", ")}))).toString());
    }

    /* JADX WARN: Multi-variable type inference failed */
    public <T> void export(Node<AbstractModule<Activity, Activity, T>> node, BufferedWriter bufferedWriter, ClassTag<T> classTag) {
        AbstractModule<Activity, Activity, T> element = node.element();
        if (!(element instanceof Net)) {
            Log4Error$.MODULE$.invalidOperationError(false, new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Unsupported layer ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{element.getName()})), "only support Net", Log4Error$.MODULE$.invalidOperationError$default$4());
        } else {
            bufferedWriter.write(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", " = ", "", "\\n"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{element.getName(), ((Net) element).toKeras2(), node.prevNodes().length() == 1 ? new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"(", ")"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{((AbstractModule) ((Node) node.prevNodes().apply(0)).element()).getName()})) : node.prevNodes().length() > 1 ? new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"([", "])"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{((TraversableOnce) node.prevNodes().map(new Net$NetSaver$$anonfun$4(), Seq$.MODULE$.canBuildFrom())).mkString(", ")})) : ""})));
            bufferedWriter.flush();
        }
    }

    public <T> void export(Sequential<T> sequential, BufferedWriter bufferedWriter, ClassTag<T> classTag) {
        bufferedWriter.write(new StringBuilder().append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", " = "})).s(Predef$.MODULE$.genericWrapArray(new Object[]{sequential.getName()}))).append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Sequential(name='", "')\\n"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{sequential.getName()}))).toString());
        ((com.intel.analytics.bigdl.dllib.nn.Sequential) sequential.modules().apply(0)).modules().foreach(new Net$NetSaver$$anonfun$export$2(sequential, bufferedWriter, classTag));
    }

    /* JADX WARN: Multi-variable type inference failed */
    public <T> String saveWeights(AbstractModule<?, ?, T> abstractModule, String str, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        String name = abstractModule.getName();
        Tuple2[] tuple2Arr = (Tuple2[]) Predef$.MODULE$.refArrayOps(((Net) abstractModule).getKerasWeights()).map(new Net$NetSaver$$anonfun$5(str, name, IntRef.create(0)), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)));
        return new StringBuilder().append(Predef$.MODULE$.refArrayOps((Object[]) Predef$.MODULE$.refArrayOps(tuple2Arr).map(new Net$NetSaver$$anonfun$6(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(String.class)))).mkString("\n")).append("\n").append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", ".set_weights([", "])\\n"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{name, Predef$.MODULE$.refArrayOps((Object[]) Predef$.MODULE$.refArrayOps(tuple2Arr).map(new Net$NetSaver$$anonfun$7(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(String.class)))).mkString(",")}))).toString();
    }

    public File com$intel$analytics$bigdl$dllib$keras$Net$NetSaver$$getUniqueFile(String str) {
        File file = new File(str);
        int i = 0;
        while (true) {
            int i2 = i;
            if (!file.exists()) {
                return file;
            }
            file = new File(new StringBuilder().append(str).append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{".", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(i2)}))).toString());
            i = i2 + 1;
        }
    }

    public <T> void com$intel$analytics$bigdl$dllib$keras$Net$NetSaver$$saveToJTensor(Tensor<T> tensor, File file, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        byte[] dumps = BigDLSerDe$.MODULE$.dumps(new PythonBigDL(classTag, tensorNumeric).toJTensor(tensor));
        FileOutputStream fileOutputStream = new FileOutputStream(file);
        fileOutputStream.write(dumps);
        fileOutputStream.flush();
        fileOutputStream.close();
    }

    public Net$NetSaver$() {
        MODULE$ = this;
        this.logger = LogManager.getLogger(getClass());
        this.header = new StringBuilder().append(new StringOps(Predef$.MODULE$.augmentString("\n        |from tensorflow.keras.models import Sequential, Model\n        |from tensorflow.keras.layers import *\n        |from pyspark.serializers import PickleSerializer\n        |\n        |def load_to_numpy(file):\n        |    in_file = open(file, \"rb\")\n        |    data = in_file.read()\n        |    in_file.close()\n        |    r=PickleSerializer().loads(data, encoding=\"bytes\")\n        |    return r.to_ndarray()\n      ")).stripMargin()).append("\n").toString();
        this.tfHeader = new StringBuilder().append(new StringOps(Predef$.MODULE$.augmentString("\n        |from zoo.util.tf import export_tf\n        |from tensorflow.keras import backend as K\n        |import tensorflow as tf\n      ")).stripMargin()).append("\n").toString();
    }
}
