package com.intel.analytics.bigdl.ppml.utils;

import com.intel.analytics.bigdl.dllib.keras.models.InternalOptimizerUtil$;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.Tensor$;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath$TensorNumeric$NumericFloat$;
import com.intel.analytics.bigdl.dllib.utils.T$;
import com.intel.analytics.bigdl.dllib.utils.Table;
import com.intel.analytics.bigdl.ppml.common.Storage;
import com.intel.analytics.bigdl.ppml.generated.FGBoostServiceProto;
import com.intel.analytics.bigdl.ppml.generated.FlBaseProto;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import scala.Array$;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.Iterable;
import scala.collection.Iterable$;
import scala.collection.IterableLike;
import scala.collection.JavaConversions$;
import scala.collection.JavaConverters$;
import scala.collection.MapLike;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.immutable.List;
import scala.collection.immutable.List$;
import scala.collection.immutable.Map;
import scala.collection.immutable.Map$;
import scala.collection.mutable.ArrayOps;
import scala.collection.mutable.Buffer$;
import scala.math.Numeric$IntIsIntegral$;
import scala.math.Ordering$String$;
import scala.math.package$;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.ScalaRunTime$;
import scala.util.Random;

/* compiled from: ProtoUtils.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/ppml/utils/ProtoUtils$.class */
public final class ProtoUtils$ {
    public static ProtoUtils$ MODULE$;
    private final Logger logger;

    static {
        new ProtoUtils$();
    }

    private Logger logger() {
        return this.logger;
    }

    public FlBaseProto.TensorMap outputTargetToTableProto(Activity activity, Activity activity2, FlBaseProto.MetaData metaData) {
        FlBaseProto.TensorMap.Builder putTensors = FlBaseProto.TensorMap.newBuilder().putTensors("output", toFloatTensor(activity.toTensor(TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$)));
        if (metaData != null) {
            putTensors.setMetaData(metaData);
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        if (activity2 != null) {
            putTensors.putTensors("target", toFloatTensor(activity2.toTensor(TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$)));
        } else {
            BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
        }
        return putTensors.m1062build();
    }

    public FlBaseProto.MetaData outputTargetToTableProto$default$3() {
        return null;
    }

    public Tuple2<Table, Tensor<Object>> tableProtoToOutputTarget(Storage<FlBaseProto.TensorMap> storage) {
        Map<String, Iterable<Tensor<Object>>> protoTableMapToTensorIterableMap = protoTableMapToTensorIterableMap(storage.clientData);
        Tensor apply = Tensor$.MODULE$.apply(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
        if (protoTableMapToTensorIterableMap.contains("target")) {
            Tensor tensor = (Tensor) ((IterableLike) protoTableMapToTensorIterableMap.apply("target")).head();
            apply.resizeAs(tensor).copy(tensor);
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        Map map = (Map) protoTableMapToTensorIterableMap.filter(tuple2 -> {
            return BoxesRunTime.boxToBoolean($anonfun$tableProtoToOutputTarget$1(tuple2));
        });
        Predef$.MODULE$.require(map.size() == 1);
        return new Tuple2<>(T$.MODULE$.seq(((TraversableOnce) map.values().head()).toSeq()), apply);
    }

    public Map<String, Iterable<Tensor<Object>>> protoTableMapToTensorIterableMap(java.util.Map<String, FlBaseProto.TensorMap> map) {
        return (Map) ((TraversableLike) ((MapLike) JavaConverters$.MODULE$.mapAsScalaMapConverter(map).asScala()).mapValues(tensorMap -> {
            return tensorMap.getTensorsMap();
        }).values().flatMap(map2 -> {
            return (scala.collection.mutable.Map) JavaConverters$.MODULE$.mapAsScalaMapConverter(map2).asScala();
        }, Iterable$.MODULE$.canBuildFrom())).groupBy(tuple2 -> {
            return (String) tuple2._1();
        }).map(tuple22 -> {
            return new Tuple2(tuple22._1(), ((TraversableLike) tuple22._2()).map(tuple22 -> {
                return Tensor$.MODULE$.apply((float[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) ((TraversableOnce) JavaConverters$.MODULE$.asScalaBufferConverter(((FlBaseProto.FloatTensor) tuple22._2()).getTensorList()).asScala()).toArray(ClassTag$.MODULE$.apply(Float.class)))).map(f -> {
                    return BoxesRunTime.boxToFloat($anonfun$protoTableMapToTensorIterableMap$6(f));
                }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Float())), (int[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) ((TraversableOnce) JavaConverters$.MODULE$.asScalaBufferConverter(((FlBaseProto.FloatTensor) tuple22._2()).getShapeList()).asScala()).toArray(ClassTag$.MODULE$.apply(Integer.class)))).map(num -> {
                    return BoxesRunTime.boxToInteger($anonfun$protoTableMapToTensorIterableMap$7(num));
                }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int())), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
            }, Iterable$.MODULE$.canBuildFrom()));
        }, Map$.MODULE$.canBuildFrom());
    }

    public FlBaseProto.FloatTensor toFloatTensor(float[] fArr, int[] iArr) {
        return FlBaseProto.FloatTensor.newBuilder().addAllTensor((Iterable) JavaConverters$.MODULE$.asJavaIterableConverter(new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofFloat(Predef$.MODULE$.floatArrayOps(fArr)).map(obj -> {
            return $anonfun$toFloatTensor$1(BoxesRunTime.unboxToFloat(obj));
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Float.class))))).toIterable()).asJava()).addAllShape((Iterable) JavaConverters$.MODULE$.asJavaIterableConverter(new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps(iArr)).map(obj2 -> {
            return $anonfun$toFloatTensor$2(BoxesRunTime.unboxToInt(obj2));
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Integer.class))))).toIterable()).asJava()).m966build();
    }

    public FlBaseProto.FloatTensor toFloatTensor(Tensor<Object> tensor) {
        return FlBaseProto.FloatTensor.newBuilder().addAllTensor((Iterable) JavaConverters$.MODULE$.asJavaIterableConverter(new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofFloat(Predef$.MODULE$.floatArrayOps((float[]) new ArrayOps.ofFloat(Predef$.MODULE$.floatArrayOps((float[]) tensor.toTensor(TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$).contiguous().storage().array())).slice(tensor.storageOffset() - 1, (tensor.storageOffset() - 1) + tensor.nElement()))).map(obj -> {
            return $anonfun$toFloatTensor$3(BoxesRunTime.unboxToFloat(obj));
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Float.class))))).toIterable()).asJava()).addAllShape((Iterable) JavaConverters$.MODULE$.asJavaIterableConverter(new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps(tensor.toTensor(TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$).size())).map(obj2 -> {
            return $anonfun$toFloatTensor$4(BoxesRunTime.unboxToInt(obj2));
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Integer.class))))).toIterable()).asJava()).m966build();
    }

    public FlBaseProto.FloatTensor toFloatTensor(float[] fArr) {
        return toFloatTensor(fArr, new int[]{fArr.length});
    }

    public FlBaseProto.TensorMap getModelWeightTable(AbstractModule<Activity, Activity, Object> abstractModule, int i, String str) {
        Tensor tensor = (Tensor) InternalOptimizerUtil$.MODULE$.getParametersFromModel(abstractModule, ClassTag$.MODULE$.Float())._1();
        return FlBaseProto.TensorMap.newBuilder().putTensors("weights", FlBaseProto.FloatTensor.newBuilder().addAllTensor(JavaConversions$.MODULE$.deprecated$u0020seqAsJavaList((Seq) tensor.storage().toList().map(obj -> {
            return $anonfun$getModelWeightTable$1(BoxesRunTime.unboxToFloat(obj));
        }, List$.MODULE$.canBuildFrom()))).addAllShape(JavaConversions$.MODULE$.deprecated$u0020seqAsJavaList((Seq) new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps(tensor.size())).toList().map(obj2 -> {
            return $anonfun$getModelWeightTable$2(BoxesRunTime.unboxToInt(obj2));
        }, List$.MODULE$.canBuildFrom()))).m966build()).setMetaData(FlBaseProto.MetaData.newBuilder().setName(str).setVersion(i).m1013build()).m1062build();
    }

    public String getModelWeightTable$default$3() {
        return "test";
    }

    public void updateModel(AbstractModule<Activity, Activity, Object> abstractModule, FlBaseProto.TensorMap tensorMap) {
        FlBaseProto.FloatTensor floatTensor = tensorMap.getTensorsMap().get("weights");
        ((Tensor) InternalOptimizerUtil$.MODULE$.getParametersFromModel(abstractModule, ClassTag$.MODULE$.Float())._1()).copy(Tensor$.MODULE$.apply$mFc$sp((float[]) ((TraversableOnce) ((TraversableLike) JavaConverters$.MODULE$.asScalaBufferConverter(floatTensor.getTensorList()).asScala()).map(f -> {
            return BoxesRunTime.boxToFloat($anonfun$updateModel$1(f));
        }, Buffer$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.Float()), (int[]) ((TraversableOnce) ((TraversableLike) JavaConverters$.MODULE$.asScalaBufferConverter(floatTensor.getShapeList()).asScala()).map(num -> {
            return BoxesRunTime.boxToInteger($anonfun$updateModel$2(num));
        }, Buffer$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.Int()), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$));
    }

    public Tensor<Object> getTensor(String str, FlBaseProto.TensorMap tensorMap) {
        FlBaseProto.FloatTensor floatTensor = tensorMap.getTensorsMap().get(str);
        return Tensor$.MODULE$.apply((float[]) ((TraversableOnce) ((TraversableLike) JavaConverters$.MODULE$.asScalaBufferConverter(floatTensor.getTensorList()).asScala()).map(f -> {
            return BoxesRunTime.boxToFloat($anonfun$getTensor$1(f));
        }, Buffer$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.Float()), (int[]) ((TraversableOnce) ((TraversableLike) JavaConverters$.MODULE$.asScalaBufferConverter(floatTensor.getShapeList()).asScala()).map(num -> {
            return BoxesRunTime.boxToInteger($anonfun$getTensor$2(num));
        }, Buffer$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.Int()), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
    }

    public <T> Object[] randomSplit(float[] fArr, Object obj, int i, ClassTag<T> classTag) {
        Random random = new Random(i);
        int[] iArr = (int[]) new ArrayOps.ofFloat(Predef$.MODULE$.floatArrayOps(fArr)).map(f -> {
            return (int) (f * ScalaRunTime$.MODULE$.array_length(obj));
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int()));
        iArr[iArr.length - 1] = ScalaRunTime$.MODULE$.array_length(obj) - BoxesRunTime.unboxToInt(new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps((int[]) new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps(iArr)).slice(0, iArr.length - 1))).sum(Numeric$IntIsIntegral$.MODULE$));
        Object[] objArr = (Object[]) new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps(iArr)).map(obj2 -> {
            return classTag.newArray(BoxesRunTime.unboxToInt(obj2));
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(classTag.runtimeClass()))));
        int[] iArr2 = (int[]) new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps(iArr)).map(i2 -> {
            return 0;
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int()));
        Predef$.MODULE$.genericArrayOps(obj).foreach(obj3 -> {
            $anonfun$randomSplit$4(random, fArr, iArr2, iArr, objArr, obj3);
            return BoxedUnit.UNIT;
        });
        return objArr;
    }

    public <T> int randomSplit$default$3() {
        return 1;
    }

    public List<FGBoostServiceProto.BoostEval> toBoostEvals(Map<String, boolean[]>[] mapArr) {
        return new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(mapArr)).map(map -> {
            return FGBoostServiceProto.BoostEval.newBuilder().addAllEvaluates((Iterable) JavaConverters$.MODULE$.seqAsJavaListConverter(((TraversableOnce) ((TraversableLike) map.toSeq().sortBy(tuple2 -> {
                return (String) tuple2._1();
            }, Ordering$String$.MODULE$)).map(tuple22 -> {
                return FGBoostServiceProto.TreePredict.newBuilder().setTreeID((String) tuple22._1()).addAllPredicts((Iterable) JavaConverters$.MODULE$.seqAsJavaListConverter(new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofBoolean(Predef$.MODULE$.booleanArrayOps((boolean[]) tuple22._2())).map(obj -> {
                    return $anonfun$toBoostEvals$4(BoxesRunTime.unboxToBoolean(obj));
                }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Boolean.class))))).toList()).asJava()).m730build();
            }, Seq$.MODULE$.canBuildFrom())).toList()).asJava()).m72build();
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(FGBoostServiceProto.BoostEval.class))))).toList();
    }

    public float[] toArrayFloat(FGBoostServiceProto.PredictResponse predictResponse) {
        return (float[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) ((TraversableOnce) JavaConverters$.MODULE$.asScalaBufferConverter(predictResponse.getData().getTensorsMap().get("predictResult").getTensorList()).asScala()).toArray(ClassTag$.MODULE$.apply(Float.class)))).map(f -> {
            return BoxesRunTime.boxToFloat($anonfun$toArrayFloat$1(f));
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Float()));
    }

    public boolean almostEqual(float f, float f2) {
        return package$.MODULE$.abs(f - f2) <= 0.1f;
    }

    public static final /* synthetic */ boolean $anonfun$tableProtoToOutputTarget$1(Tuple2 tuple2) {
        Object _1 = tuple2._1();
        return _1 != null ? !_1.equals("target") : "target" != 0;
    }

    public static final /* synthetic */ float $anonfun$protoTableMapToTensorIterableMap$6(Float f) {
        return Predef$.MODULE$.Float2float(f);
    }

    public static final /* synthetic */ int $anonfun$protoTableMapToTensorIterableMap$7(Integer num) {
        return Predef$.MODULE$.Integer2int(num);
    }

    public static final /* synthetic */ Float $anonfun$toFloatTensor$1(float f) {
        return Predef$.MODULE$.float2Float(f);
    }

    public static final /* synthetic */ Integer $anonfun$toFloatTensor$2(int i) {
        return Predef$.MODULE$.int2Integer(i);
    }

    public static final /* synthetic */ Float $anonfun$toFloatTensor$3(float f) {
        return Predef$.MODULE$.float2Float(f);
    }

    public static final /* synthetic */ Integer $anonfun$toFloatTensor$4(int i) {
        return Predef$.MODULE$.int2Integer(i);
    }

    public static final /* synthetic */ Float $anonfun$getModelWeightTable$1(float f) {
        return Predef$.MODULE$.float2Float(f);
    }

    public static final /* synthetic */ Integer $anonfun$getModelWeightTable$2(int i) {
        return Predef$.MODULE$.int2Integer(i);
    }

    public static final /* synthetic */ float $anonfun$updateModel$1(Float f) {
        return Predef$.MODULE$.Float2float(f);
    }

    public static final /* synthetic */ int $anonfun$updateModel$2(Integer num) {
        return Predef$.MODULE$.Integer2int(num);
    }

    public static final /* synthetic */ float $anonfun$getTensor$1(Float f) {
        return Predef$.MODULE$.Float2float(f);
    }

    public static final /* synthetic */ int $anonfun$getTensor$2(Integer num) {
        return Predef$.MODULE$.Integer2int(num);
    }

    public static final /* synthetic */ void $anonfun$randomSplit$4(Random random, float[] fArr, int[] iArr, int[] iArr2, Object[] objArr, Object obj) {
        int nextInt = random.nextInt(fArr.length);
        while (true) {
            int i = nextInt;
            if (iArr[i] != iArr2[i]) {
                ScalaRunTime$.MODULE$.array_update(objArr[i], iArr[i], obj);
                iArr[i] = iArr[i] + 1;
                return;
            }
            nextInt = (i + 1) % fArr.length;
        }
    }

    public static final /* synthetic */ Boolean $anonfun$toBoostEvals$4(boolean z) {
        return Predef$.MODULE$.boolean2Boolean(z);
    }

    public static final /* synthetic */ float $anonfun$toArrayFloat$1(Float f) {
        return Predef$.MODULE$.Float2float(f);
    }

    private ProtoUtils$() {
        MODULE$ = this;
        this.logger = LogManager.getLogger(getClass());
    }
}
