package com.intel.analytics.bigdl.dllib.utils.serializer;

import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.dllib.utils.File$;
import com.intel.analytics.bigdl.dllib.utils.FileWriter;
import com.intel.analytics.bigdl.dllib.utils.FileWriter$;
import com.intel.analytics.bigdl.serialization.Bigdl;
import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.security.DigestOutputStream;
import java.security.MessageDigest;
import java.util.List;
import scala.collection.IterableLike;
import scala.collection.JavaConverters$;
import scala.collection.mutable.ArrayBuffer;
import scala.collection.mutable.HashMap;
import scala.collection.mutable.HashSet;
import scala.reflect.ClassTag;
import scala.runtime.ObjectRef;

/* compiled from: ModuleLoader.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/dllib/utils/serializer/ModulePersister$.class */
public final class ModulePersister$ {
    public static final ModulePersister$ MODULE$ = null;

    static {
        new ModulePersister$();
    }

    public <T> void saveToFile(String str, String str2, AbstractModule<Activity, Activity, T> abstractModule, boolean z, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        if (str2 == null) {
            SerializeResult serializeModule = serializeModule(abstractModule, ProtoStorageType$.MODULE$, classTag, tensorNumeric);
            setTensorStorage(serializeModule.bigDLModule(), serializeModule.storages());
            File$.MODULE$.saveBytes(serializeModule.bigDLModule().build().toByteArray(), str, z);
        } else {
            SerializeResult serializeModule2 = serializeModule(abstractModule, BigDLStorage$.MODULE$, classTag, tensorNumeric);
            HashMap<Object, Object> hashMap = (HashMap) serializeModule2.storages().filter(new ModulePersister$$anonfun$1());
            File$.MODULE$.saveBytes(serializeModule2.bigDLModule().build().toByteArray(), str, z);
            saveWeightsToFile(str2, hashMap, z);
        }
    }

    public <T> String saveToFile$default$2() {
        return null;
    }

    public <T> boolean saveToFile$default$4() {
        return false;
    }

    private <T> SerializeResult serializeModule(AbstractModule<Activity, Activity, T> abstractModule, StorageType storageType, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        return ModuleSerializer$.MODULE$.serialize(new SerializeContext<>(new ModuleData(abstractModule, new ArrayBuffer(), new ArrayBuffer(), classTag), new HashMap(), storageType, SerializeContext$.MODULE$.apply$default$4(), SerializeContext$.MODULE$.apply$default$5(), classTag), classTag, tensorNumeric);
    }

    private void saveWeightsToFile(String str, HashMap<Object, Object> hashMap, boolean z) {
        int MAGIC_NO = SerConst$.MODULE$.MAGIC_NO();
        int size = hashMap.size();
        FileWriter fileWriter = null;
        OutputStream outputStream = null;
        ObjectOutputStream objectOutputStream = null;
        DigestOutputStream digestOutputStream = null;
        ObjectRef create = ObjectRef.create((Object) null);
        try {
            fileWriter = FileWriter$.MODULE$.apply(str);
            outputStream = fileWriter.create(z);
            digestOutputStream = new DigestOutputStream(outputStream, MessageDigest.getInstance(SerConst$.MODULE$.DIGEST_TYPE()));
            create.elem = new DataOutputStream(digestOutputStream);
            digestOutputStream.on(true);
            ((DataOutputStream) create.elem).writeInt(MAGIC_NO);
            ((DataOutputStream) create.elem).writeInt(size);
            hashMap.foreach(new ModulePersister$$anonfun$saveWeightsToFile$1(create));
            digestOutputStream.on(false);
            byte[] digest = digestOutputStream.getMessageDigest().digest();
            ((DataOutputStream) create.elem).writeInt(digest.length);
            ((DataOutputStream) create.elem).write(digest);
            if (0 != 0) {
                objectOutputStream.close();
            }
            if (outputStream != null) {
                outputStream.close();
            }
            if (fileWriter != null) {
                fileWriter.close();
            }
            if (digestOutputStream != null) {
                digestOutputStream.flush();
                digestOutputStream.close();
            }
            if (((DataOutputStream) create.elem) != null) {
                ((DataOutputStream) create.elem).close();
            }
        } catch (Throwable th) {
            if (0 != 0) {
                objectOutputStream.close();
            }
            if (outputStream != null) {
                outputStream.close();
            }
            if (fileWriter != null) {
                fileWriter.close();
            }
            if (digestOutputStream != null) {
                digestOutputStream.flush();
                digestOutputStream.close();
            }
            if (((DataOutputStream) create.elem) != null) {
                ((DataOutputStream) create.elem).close();
            }
            throw th;
        }
    }

    private boolean saveWeightsToFile$default$3() {
        return false;
    }

    public void setTensorStorage(Bigdl.BigDLModule.Builder builder, HashMap<Object, Object> hashMap) {
        HashSet hashSet = new HashSet();
        HashMap hashMap2 = (HashMap) hashMap.filter(new ModulePersister$$anonfun$2());
        ObjectRef create = ObjectRef.create(Bigdl.NameAttrList.newBuilder().setName(SerConst$.MODULE$.GLOBAL_STORAGE()));
        ((IterableLike) hashMap.values().filter(new ModulePersister$$anonfun$setTensorStorage$1())).foreach(new ModulePersister$$anonfun$setTensorStorage$2(hashSet, hashMap2, create));
        Bigdl.AttrValue.Builder newBuilder = Bigdl.AttrValue.newBuilder();
        newBuilder.setNameAttrListValue((Bigdl.NameAttrList.Builder) create.elem);
        builder.putAttr(SerConst$.MODULE$.GLOBAL_STORAGE(), newBuilder.build());
    }

    public <T> void saveModelDefinitionToFile(String str, AbstractModule<Activity, Activity, T> abstractModule, boolean z, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Bigdl.BigDLModule.Builder bigDLModule = ModuleSerializer$.MODULE$.serialize(new SerializeContext<>(new ModuleData(abstractModule, new ArrayBuffer(), new ArrayBuffer(), classTag), new HashMap(), ProtoStorageType$.MODULE$, SerializeContext$.MODULE$.apply$default$4(), SerializeContext$.MODULE$.apply$default$5(), classTag), classTag, tensorNumeric).bigDLModule();
        com$intel$analytics$bigdl$dllib$utils$serializer$ModulePersister$$cleantWeightAndBias(bigDLModule);
        Bigdl.BigDLModule build = bigDLModule.build();
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        byteArrayOutputStream.write(build.toString().getBytes());
        File$.MODULE$.saveBytes(byteArrayOutputStream.toByteArray(), str, z);
    }

    public <T> boolean saveModelDefinitionToFile$default$3() {
        return false;
    }

    public void com$intel$analytics$bigdl$dllib$utils$serializer$ModulePersister$$cleantWeightAndBias(Bigdl.BigDLModule.Builder builder) {
        builder.clearWeight();
        builder.clearBias();
        if (builder.getSubModulesCount() > 0) {
            List<Bigdl.BigDLModule> subModulesList = builder.getSubModulesList();
            builder.clearSubModules();
            ((IterableLike) JavaConverters$.MODULE$.asScalaBufferConverter(subModulesList).asScala()).foreach(new ModulePersister$$anonfun$com$intel$analytics$bigdl$dllib$utils$serializer$ModulePersister$$cleantWeightAndBias$1(builder));
        }
    }

    private ModulePersister$() {
        MODULE$ = this;
    }
}
