package com.intel.analytics.bigdl.orca.tfpark;

import java.io.File;
import java.io.FileInputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import org.json4s.DefaultFormats$;
import org.json4s.jackson.JsonMethods$;
import org.json4s.package$;
import org.tensorflow.DataType;
import org.tensorflow.Tensor;
import org.tensorflow.framework.GraphDef;
import scala.Array$;
import scala.MatchError;
import scala.Predef$;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.immutable.Set;
import scala.io.Codec$;
import scala.io.Source$;
import scala.math.Numeric$IntIsIntegral$;
import scala.reflect.ClassTag$;
import scala.reflect.ManifestFactory$;
import scala.reflect.io.Path;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: TFUtils.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/orca/tfpark/TFUtils$.class */
public final class TFUtils$ {
    public static final TFUtils$ MODULE$ = null;
    private final SessionConfig defaultSessionConfig;

    static {
        new TFUtils$();
    }

    public SessionConfig defaultSessionConfig() {
        return this.defaultSessionConfig;
    }

    public TrainMeta getTrainMeta(Path path) {
        String mkString = Source$.MODULE$.fromFile(path.jfile(), Codec$.MODULE$.fallbackSystemCodec()).getLines().mkString();
        return (TrainMeta) package$.MODULE$.jvalue2extractable(package$.MODULE$.jvalue2monadic(JsonMethods$.MODULE$.parse(package$.MODULE$.string2JsonInput(mkString), JsonMethods$.MODULE$.parse$default$2(), JsonMethods$.MODULE$.parse$default$3())).camelizeKeys()).extract(DefaultFormats$.MODULE$, ManifestFactory$.MODULE$.classType(TrainMeta.class));
    }

    public GraphDef parseGraph(String str) {
        FileInputStream fileInputStream = null;
        try {
            fileInputStream = new FileInputStream(new File(str));
            GraphDef parseFrom = GraphDef.parseFrom(fileInputStream);
            if (fileInputStream != null) {
                fileInputStream.close();
            }
            return parseFrom;
        } catch (Throwable th) {
            if (fileInputStream != null) {
                fileInputStream.close();
            }
            throw th;
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v18, types: [int] */
    private Tuple2<Object, Object> decodeUVarInt64(byte[] bArr, int i) {
        int i2 = i;
        long j = 0;
        for (byte b = 0; b <= 63 && i2 < bArr.length; b += 7) {
            byte b2 = bArr[i2];
            i2++;
            if ((b2 & 128) == 0) {
                return new Tuple2.mcJI.sp(j | (b2 << b), i2);
            }
            j |= (b2 & Byte.MAX_VALUE) << b;
        }
        return new Tuple2.mcJI.sp(j, i2);
    }

    private int[] getOffsets(ByteBuffer byteBuffer, int i) {
        long[] jArr = new long[i];
        ByteBuffer.wrap((byte[]) Predef$.MODULE$.byteArrayOps(byteBuffer.array()).slice(byteBuffer.arrayOffset(), i * 8)).order(ByteOrder.nativeOrder()).asLongBuffer().get(jArr);
        return (int[]) Predef$.MODULE$.longArrayOps(jArr).map(new TFUtils$$anonfun$getOffsets$1(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int()));
    }

    public void tf2bigdl(Tensor<?> tensor, com.intel.analytics.bigdl.dllib.tensor.Tensor<?> tensor2) {
        int[] iArr = (int[]) Predef$.MODULE$.longArrayOps(tensor.shape()).map(new TFUtils$$anonfun$1(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int()));
        tensor2.resize(iArr, tensor2.resize$default$2());
        DataType dataType = tensor.dataType();
        Set apply = Predef$.MODULE$.Set().apply(Predef$.MODULE$.wrapRefArray(new DataType[]{DataType.FLOAT, DataType.UINT8, DataType.INT32, DataType.INT64, DataType.DOUBLE, DataType.BOOL}));
        DataType dataType2 = DataType.STRING;
        if (dataType != null ? !dataType.equals(dataType2) : dataType2 != null) {
            if (!apply.apply(dataType)) {
                throw new Exception(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"data type ", " are not supported"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{dataType})));
            }
            if (DataType.FLOAT.equals(dataType)) {
                tensor.writeTo(FloatBuffer.wrap((float[]) tensor2.storage().array(), tensor2.storageOffset() - 1, BoxesRunTime.unboxToInt(Predef$.MODULE$.intArrayOps(iArr).product(Numeric$IntIsIntegral$.MODULE$))));
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
                return;
            }
            if (DataType.UINT8.equals(dataType)) {
                byte[] bArr = new byte[BoxesRunTime.unboxToInt(Predef$.MODULE$.intArrayOps(iArr).product(Numeric$IntIsIntegral$.MODULE$))];
                tensor.writeTo(ByteBuffer.wrap(bArr));
                byte2float(bArr, (float[]) tensor2.storage().array(), tensor2.storageOffset() - 1);
                BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
                return;
            }
            if (DataType.INT32.equals(dataType)) {
                int[] iArr2 = new int[BoxesRunTime.unboxToInt(Predef$.MODULE$.intArrayOps(iArr).product(Numeric$IntIsIntegral$.MODULE$))];
                tensor.writeTo(IntBuffer.wrap(iArr2));
                int2float(iArr2, (float[]) tensor2.storage().array(), tensor2.storageOffset() - 1);
                BoxedUnit boxedUnit3 = BoxedUnit.UNIT;
                return;
            }
            if (DataType.INT64.equals(dataType)) {
                long[] jArr = new long[BoxesRunTime.unboxToInt(Predef$.MODULE$.intArrayOps(iArr).product(Numeric$IntIsIntegral$.MODULE$))];
                tensor.writeTo(LongBuffer.wrap(jArr));
                long2float(jArr, (float[]) tensor2.storage().array(), tensor2.storageOffset() - 1);
                BoxedUnit boxedUnit4 = BoxedUnit.UNIT;
                return;
            }
            if (DataType.DOUBLE.equals(dataType)) {
                double[] dArr = new double[BoxesRunTime.unboxToInt(Predef$.MODULE$.intArrayOps(iArr).product(Numeric$IntIsIntegral$.MODULE$))];
                tensor.writeTo(DoubleBuffer.wrap(dArr));
                double2float(dArr, (float[]) tensor2.storage().array(), tensor2.storageOffset() - 1);
                BoxedUnit boxedUnit5 = BoxedUnit.UNIT;
                return;
            }
            if (!DataType.BOOL.equals(dataType)) {
                throw new MatchError(dataType);
            }
            byte[] bArr2 = new byte[tensor.numBytes()];
            Predef$.MODULE$.assert(tensor.numBytes() == BoxesRunTime.unboxToInt(Predef$.MODULE$.intArrayOps(iArr).product(Numeric$IntIsIntegral$.MODULE$)), new TFUtils$$anonfun$tf2bigdl$2());
            tensor.writeTo(ByteBuffer.wrap(bArr2));
            byte2float(bArr2, (float[]) tensor2.storage().array(), tensor2.storageOffset() - 1);
            BoxedUnit boxedUnit6 = BoxedUnit.UNIT;
            return;
        }
        Predef$.MODULE$.require(tensor.numDimensions() <= 1, new TFUtils$$anonfun$tf2bigdl$1());
        int numElements = tensor.numElements();
        ByteBuffer allocate = ByteBuffer.allocate(tensor.numBytes());
        tensor.writeTo(allocate);
        int[] offsets = getOffsets(allocate, numElements);
        byte[][] bArr3 = (byte[][]) tensor2.storage().array();
        int i = numElements * 8;
        int i2 = 0;
        while (true) {
            int i3 = i2;
            if (i3 >= numElements) {
                return;
            }
            Tuple2<Object, Object> decodeUVarInt64 = decodeUVarInt64(allocate.array(), allocate.arrayOffset() + offsets[i3] + i);
            if (decodeUVarInt64 == null) {
                throw new MatchError(decodeUVarInt64);
            }
            Tuple2.mcJI.sp spVar = new Tuple2.mcJI.sp(decodeUVarInt64._1$mcJ$sp(), decodeUVarInt64._2$mcI$sp());
            long _1$mcJ$sp = spVar._1$mcJ$sp();
            int _2$mcI$sp = spVar._2$mcI$sp();
            bArr3[(tensor2.storageOffset() - 1) + i3] = (byte[]) Predef$.MODULE$.byteArrayOps(allocate.array()).slice(_2$mcI$sp, _2$mcI$sp + ((int) _1$mcJ$sp));
            i2 = i3 + 1;
        }
    }

    public void byte2float(byte[] bArr, float[] fArr, int i) {
        int length = bArr.length;
        int i2 = 0;
        while (true) {
            int i3 = i2;
            if (i3 >= length) {
                return;
            }
            fArr[i + i3] = bArr[i3];
            i2 = i3 + 1;
        }
    }

    public void int2float(int[] iArr, float[] fArr, int i) {
        int length = iArr.length;
        int i2 = 0;
        while (true) {
            int i3 = i2;
            if (i3 >= length) {
                return;
            }
            fArr[i + i3] = iArr[i3];
            i2 = i3 + 1;
        }
    }

    public void long2float(long[] jArr, float[] fArr, int i) {
        int length = jArr.length;
        int i2 = 0;
        while (true) {
            int i3 = i2;
            if (i3 >= length) {
                return;
            }
            fArr[i + i3] = (float) jArr[i3];
            i2 = i3 + 1;
        }
    }

    public void double2float(double[] dArr, float[] fArr, int i) {
        int length = dArr.length;
        int i2 = 0;
        while (true) {
            int i3 = i2;
            if (i3 >= length) {
                return;
            }
            fArr[i + i3] = (float) dArr[i3];
            i2 = i3 + 1;
        }
    }

    public DataType tfenum2datatype(int i) {
        switch (i) {
            case 1:
                return DataType.FLOAT;
            case 2:
                return DataType.DOUBLE;
            case 3:
                return DataType.INT32;
            case 4:
                return DataType.UINT8;
            case 5:
            case 6:
            case 8:
            default:
                throw new IllegalArgumentException(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"unsupported tensorflow datatype ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(i)})));
            case 7:
                return DataType.STRING;
            case 9:
                return DataType.INT64;
            case 10:
                return DataType.BOOL;
        }
    }

    public int tfdatatype2enum(DataType dataType) {
        int i;
        if (DataType.FLOAT.equals(dataType)) {
            i = 1;
        } else if (DataType.DOUBLE.equals(dataType)) {
            i = 2;
        } else if (DataType.INT32.equals(dataType)) {
            i = 3;
        } else if (DataType.UINT8.equals(dataType)) {
            i = 4;
        } else if (DataType.STRING.equals(dataType)) {
            i = 7;
        } else if (DataType.INT64.equals(dataType)) {
            i = 9;
        } else {
            if (!DataType.BOOL.equals(dataType)) {
                throw new IllegalArgumentException(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"unsupported tensorflow datatype ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{dataType})));
            }
            i = 10;
        }
        return i;
    }

    private TFUtils$() {
        MODULE$ = this;
        this.defaultSessionConfig = new SessionConfig(SessionConfig$.MODULE$.apply$default$1(), SessionConfig$.MODULE$.apply$default$2(), SessionConfig$.MODULE$.apply$default$3());
    }
}
