package com.intel.analytics.bigdl.dllib.nn;

import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.TensorModule;
import com.intel.analytics.bigdl.dllib.optim.Regularizer;
import com.intel.analytics.bigdl.dllib.tensor.ConvertableFrom$ConvertableFromDouble$;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.Tensor$;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.package$;
import scala.Predef$;
import scala.Tuple2;
import scala.reflect.ClassTag;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.Null$;
import scala.runtime.RichInt$;
import scala.runtime.ScalaRunTime$;

/* compiled from: TransformerOperation.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/dllib/nn/TransformerOperation$.class */
public final class TransformerOperation$ {
    public static final TransformerOperation$ MODULE$ = null;
    private final double com$intel$analytics$bigdl$dllib$nn$TransformerOperation$$maskValue;

    static {
        new TransformerOperation$();
    }

    public <T> AbstractModule<Activity, Activity, T> dense(int i, int i2, boolean z, TensorModule<T> tensorModule, Regularizer<T> regularizer, Regularizer<T> regularizer2, String str, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Sequential sequential = new Sequential(classTag, tensorNumeric);
        Linear$ linear$ = Linear$.MODULE$;
        Linear$.MODULE$.apply$default$6();
        Linear$.MODULE$.apply$default$7();
        Linear$.MODULE$.apply$default$8();
        Linear$.MODULE$.apply$default$9();
        Linear<T> apply = linear$.apply(i, i2, z, regularizer, regularizer2, null, null, null, null, classTag, tensorNumeric);
        apply.setInitMethod(Xavier$.MODULE$, Zeros$.MODULE$);
        if (str != null ? !str.equals("") : "" != 0) {
            apply.setName(str);
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        sequential.mo1324add(TimeDistributed$.MODULE$.apply(apply, TimeDistributed$.MODULE$.apply$default$2(), classTag, tensorNumeric));
        if (tensorModule == null) {
            BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
        } else {
            sequential.mo1324add(tensorModule);
        }
        return sequential;
    }

    public <T> boolean dense$default$3() {
        return true;
    }

    public <T> Null$ dense$default$4() {
        return null;
    }

    public <T> Null$ dense$default$5() {
        return null;
    }

    public <T> Null$ dense$default$6() {
        return null;
    }

    public <T> String dense$default$7() {
        return "";
    }

    public <T> AbstractModule<Activity, Activity, T> softMax(ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        SoftMax<T> apply = SoftMax$.MODULE$.apply(classTag, tensorNumeric);
        Sequential<T> apply2 = Sequential$.MODULE$.apply(classTag, tensorNumeric);
        apply2.mo1324add(Transpose$.MODULE$.apply(new Tuple2[]{new Tuple2.mcII.sp(2, 4)}, classTag, tensorNumeric));
        apply2.mo1324add(apply);
        apply2.mo1324add(Transpose$.MODULE$.apply(new Tuple2[]{new Tuple2.mcII.sp(2, 4)}, classTag, tensorNumeric));
        return package$.MODULE$.convModule(apply2);
    }

    public <T> Tensor<T> getPaddingBias(Tensor<T> tensor, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Tensor<T> mul = getPadding(tensor, getPadding$default$2(), classTag, tensorNumeric).mul(tensorNumeric.mo2049fromType(BoxesRunTime.boxToDouble(-1.0E9d), ConvertableFrom$ConvertableFromDouble$.MODULE$));
        mul.addSingletonDimension(mul, 2);
        return mul.addSingletonDimension(mul, 3);
    }

    public <T> Tensor<T> getPadding(Tensor<T> tensor, float f, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        return tensor.apply1(new TransformerOperation$$anonfun$getPadding$1(f, tensorNumeric));
    }

    public <T> float getPadding$default$2() {
        return 0.0f;
    }

    public <T> Tensor<T> shiftRight3D(Tensor<T> tensor, Tensor<T> tensor2, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        tensor2.resizeAs(tensor).zero();
        int size = tensor.size(2);
        tensor2.narrow(2, 2, size - 1).copy(tensor.narrow(2, 1, size - 1));
        return tensor2;
    }

    public <T> void initRangeTensor(int i, Tensor<T> tensor, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        tensor.resize(new int[]{i}, tensor.resize$default$2());
        RichInt$.MODULE$.to$extension0(Predef$.MODULE$.intWrapper(0), i - 1).foreach$mVc$sp(new TransformerOperation$$anonfun$initRangeTensor$1(tensorNumeric, tensor.storage().array()));
    }

    public <T> Tensor<T> getPositionEncode(int i, int i2, float f, float f2, Tensor<T> tensor, Tensor<T> tensor2, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        int i3 = i2 / 2;
        double log = scala.math.package$.MODULE$.log(f2 / f) / scala.math.package$.MODULE$.max(i3 - 1, 1);
        Tensor<T> apply = Tensor$.MODULE$.apply(1, i3, classTag, tensorNumeric);
        Object array = apply.storage().array();
        int storageOffset = apply.storageOffset() - 1;
        int i4 = 0;
        while (true) {
            int i5 = i4;
            if (i5 >= i3) {
                Tensor<T> narrow = tensor2.narrow(2, 1, i3);
                narrow.addmm(tensorNumeric.mo2060zero(), tensorNumeric.mo2061one(), tensor.resize(i, 1), apply);
                Tensor<T> copy = tensor2.narrow(2, i3 + 1, i3).copy(narrow);
                narrow.apply1(new TransformerOperation$$anonfun$getPositionEncode$1(tensorNumeric));
                copy.apply1(new TransformerOperation$$anonfun$getPositionEncode$2(tensorNumeric));
                return tensor2;
            }
            ScalaRunTime$.MODULE$.array_update(array, i5 + storageOffset, tensorNumeric.mo2049fromType(BoxesRunTime.boxToDouble(f * scala.math.package$.MODULE$.exp(i5 * (-log))), ConvertableFrom$ConvertableFromDouble$.MODULE$));
            i4 = i5 + 1;
        }
    }

    public <T> float getPositionEncode$default$3() {
        return 1.0f;
    }

    public <T> float getPositionEncode$default$4() {
        return 10000.0f;
    }

    public double com$intel$analytics$bigdl$dllib$nn$TransformerOperation$$maskValue() {
        return this.com$intel$analytics$bigdl$dllib$nn$TransformerOperation$$maskValue;
    }

    public <T> Tensor<T> attentionBiasLowerTriangle(int i, Tensor<T> tensor, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        RichInt$.MODULE$.to$extension0(Predef$.MODULE$.intWrapper(0), i - 1).foreach$mVc$sp(new TransformerOperation$$anonfun$attentionBiasLowerTriangle$1(i, tensorNumeric, tensor.storage().array()));
        return tensor.resize(new int[]{1, 1, i, i}, tensor.resize$default$2());
    }

    private TransformerOperation$() {
        MODULE$ = this;
        this.com$intel$analytics$bigdl$dllib$nn$TransformerOperation$$maskValue = -1.0E9d;
    }
}
