package com.intel.analytics.bigdl.dllib.nn.mkldnn;

import com.intel.analytics.bigdl.dllib.utils.Shape;
import scala.Serializable;
import scala.reflect.ClassTag$;
import scala.runtime.AbstractFunction1;

/* compiled from: BlasWrapper.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/dllib/nn/mkldnn/BlasWrapper$$anonfun$inferOutputFormats$1.class */
public final class BlasWrapper$$anonfun$inferOutputFormats$1 extends AbstractFunction1<Shape, HeapData> implements Serializable {
    public static final long serialVersionUID = 0;
    private final /* synthetic */ BlasWrapper $outer;
    private final MemoryData[] inputs$1;

    public final HeapData apply(Shape shape) {
        int[] iArr = (int[]) shape.toSingle().toArray(ClassTag$.MODULE$.Int());
        int com$intel$analytics$bigdl$dllib$nn$mkldnn$BlasWrapper$$getFormats = (iArr.length == 4 && this.inputs$1[0].heapFormat() == 8) ? 8 : this.$outer.com$intel$analytics$bigdl$dllib$nn$mkldnn$BlasWrapper$$getFormats(iArr.length);
        return (HeapData) new HeapData(com$intel$analytics$bigdl$dllib$nn$mkldnn$BlasWrapper$$getFormats == 8 ? new int[]{iArr[0], iArr[3], iArr[1], iArr[2]} : iArr, com$intel$analytics$bigdl$dllib$nn$mkldnn$BlasWrapper$$getFormats, HeapData$.MODULE$.apply$default$3()).setHeapFormat(com$intel$analytics$bigdl$dllib$nn$mkldnn$BlasWrapper$$getFormats);
    }

    public BlasWrapper$$anonfun$inferOutputFormats$1(BlasWrapper blasWrapper, MemoryData[] memoryDataArr) {
        if (blasWrapper == null) {
            throw null;
        }
        this.$outer = blasWrapper;
        this.inputs$1 = memoryDataArr;
    }
}
