package com.intel.analytics.bigdl.dllib.nn;

import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.Tensor$;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.dllib.utils.T$;
import com.intel.analytics.bigdl.dllib.utils.Table;
import scala.Predef$;
import scala.StringContext;
import scala.collection.Seq;
import scala.collection.immutable.Nil$;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: DotProduct.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005Ud\u0001B\u0001\u0003\u0001=\u0011!\u0002R8u!J|G-^2u\u0015\t\u0019A!\u0001\u0002o]*\u0011QAB\u0001\u0006I2d\u0017N\u0019\u0006\u0003\u000f!\tQAY5hI2T!!\u0003\u0006\u0002\u0013\u0005t\u0017\r\\=uS\u000e\u001c(BA\u0006\r\u0003\u0015Ig\u000e^3m\u0015\u0005i\u0011aA2p[\u000e\u0001QC\u0001\t&'\t\u0001\u0011\u0003E\u0003\u0013+]i2%D\u0001\u0014\u0015\t!\"!\u0001\u0006bEN$(/Y2u]:L!AF\n\u0003\u001d\u0005\u00137\u000f\u001e:bGRlu\u000eZ;mKB\u0011\u0001dG\u0007\u00023)\u0011!\u0004B\u0001\u0006kRLGn]\u0005\u00039e\u0011Q\u0001V1cY\u0016\u00042AH\u0011$\u001b\u0005y\"B\u0001\u0011\u0005\u0003\u0019!XM\\:pe&\u0011!e\b\u0002\u0007)\u0016t7o\u001c:\u0011\u0005\u0011*C\u0002\u0001\u0003\u0006M\u0001\u0011\ra\n\u0002\u0002)F\u0011\u0001F\f\t\u0003S1j\u0011A\u000b\u0006\u0002W\u0005)1oY1mC&\u0011QF\u000b\u0002\b\u001d>$\b.\u001b8h!\tIs&\u0003\u00021U\t\u0019\u0011I\\=\t\u0011I\u0002!1!Q\u0001\fM\n!\"\u001a<jI\u0016t7-\u001a\u00132!\r!tgI\u0007\u0002k)\u0011aGK\u0001\be\u00164G.Z2u\u0013\tATG\u0001\u0005DY\u0006\u001c8\u000fV1h\u0011!Q\u0004A!A!\u0002\u0017Y\u0014AA3w!\radj\t\b\u0003{1s!AP&\u000f\u0005}ReB\u0001!J\u001d\t\t\u0005J\u0004\u0002C\u000f:\u00111IR\u0007\u0002\t*\u0011QID\u0001\u0007yI|w\u000e\u001e \n\u00035I!a\u0003\u0007\n\u0005%Q\u0011BA\u0004\t\u0013\t)a!\u0003\u0002!\t%\u0011QjH\u0001\u0012)\u0016t7o\u001c:Ok6,'/[2NCRD\u0017BA(Q\u00055!VM\\:pe:+X.\u001a:jG*\u0011Qj\b\u0005\u0006%\u0002!\taU\u0001\u0007y%t\u0017\u000e\u001e \u0015\u0003Q#2!V,Y!\r1\u0006aI\u0007\u0002\u0005!)!'\u0015a\u0002g!)!(\u0015a\u0002w!)!\f\u0001C!7\u0006aQ\u000f\u001d3bi\u0016|U\u000f\u001e9viR\u0011Q\u0004\u0018\u0005\u0006;f\u0003\raF\u0001\u0006S:\u0004X\u000f\u001e\u0005\u0006?\u0002!\t\u0005Y\u0001\u0010kB$\u0017\r^3He\u0006$\u0017J\u001c9viR\u0019q#\u00192\t\u000bus\u0006\u0019A\f\t\u000b\rt\u0006\u0019A\u000f\u0002\u0015\u001d\u0014\u0018\rZ(viB,H\u000fC\u0003f\u0001\u0011\u0005c-\u0001\u0005u_N#(/\u001b8h)\u00059\u0007C\u00015l\u001d\tI\u0013.\u0003\u0002kU\u00051\u0001K]3eK\u001aL!\u0001\\7\u0003\rM#(/\u001b8h\u0015\tQ'\u0006C\u0004p\u0001\u0001\u0007I\u0011\u00029\u0002\r\t,hMZ3s+\u0005i\u0002b\u0002:\u0001\u0001\u0004%Ia]\u0001\u000bEV4g-\u001a:`I\u0015\fHC\u0001;x!\tIS/\u0003\u0002wU\t!QK\\5u\u0011\u001dA\u0018/!AA\u0002u\t1\u0001\u001f\u00132\u0011\u0019Q\b\u0001)Q\u0005;\u00059!-\u001e4gKJ\u0004\u0003FA=}!\tIS0\u0003\u0002\u007fU\tIAO]1og&,g\u000e\u001e\u0015\b\u0001\u0005\u0005\u0011qAA\u0005!\rI\u00131A\u0005\u0004\u0003\u000bQ#\u0001E*fe&\fGNV3sg&|g.V%E\u0003\u00151\u0018\r\\;f=!\u0011S#GX\u0004Io?|aBA\u0007\u0005!\u0005\u0011qB\u0001\u000b\t>$\bK]8ek\u000e$\bc\u0001,\u0002\u0012\u00191\u0011A\u0001E\u0001\u0003'\u0019b!!\u0005\u0002\u0016\u0005m\u0001cA\u0015\u0002\u0018%\u0019\u0011\u0011\u0004\u0016\u0003\r\u0005s\u0017PU3g!\rI\u0013QD\u0005\u0004\u0003?Q#\u0001D*fe&\fG.\u001b>bE2,\u0007b\u0002*\u0002\u0012\u0011\u0005\u00111\u0005\u000b\u0003\u0003\u001fA\u0001\"a\n\u0002\u0012\u0011\u0005\u0011\u0011F\u0001\u0006CB\u0004H._\u000b\u0005\u0003W\t\u0019\u0004\u0006\u0002\u0002.Q1\u0011qFA+\u00037\u0002BA\u0016\u0001\u00022A\u0019A%a\r\u0005\u0015\u0019\n)\u0003)A\u0001\u0002\u000b\u0007q\u0005\u000b\u0005\u00024\u0005]\u0012QHA&!\rI\u0013\u0011H\u0005\u0004\u0003wQ#aC:qK\u000eL\u0017\r\\5{K\u0012\f\u0014bIA \u0003\u0003\n)%a\u0011\u000f\u0007%\n\t%C\u0002\u0002D)\nQA\u00127pCR\fd\u0001JA$\u0003\u0013ZcbA\"\u0002J%\t1&M\u0005$\u0003\u001b\ny%a\u0015\u0002R9\u0019\u0011&a\u0014\n\u0007\u0005E#&\u0001\u0004E_V\u0014G.Z\u0019\u0007I\u0005\u001d\u0013\u0011J\u0016\t\u0015\u0005]\u0013QEA\u0001\u0002\b\tI&\u0001\u0006fm&$WM\\2fII\u0002B\u0001N\u001c\u00022!9!(!\nA\u0004\u0005u\u0003\u0003\u0002\u001fO\u0003cA!\"!\u0019\u0002\u0012\u0005\u0005I\u0011BA2\u0003-\u0011X-\u00193SKN|GN^3\u0015\u0005\u0005\u0015\u0004\u0003BA4\u0003cj!!!\u001b\u000b\t\u0005-\u0014QN\u0001\u0005Y\u0006twM\u0003\u0002\u0002p\u0005!!.\u0019<b\u0013\u0011\t\u0019(!\u001b\u0003\r=\u0013'.Z2u\u0001")
/* loaded from: input_file:com/intel/analytics/bigdl/dllib/nn/DotProduct.class */
public class DotProduct<T> extends AbstractModule<Table, Tensor<T>, T> {
    public static final long serialVersionUID = 2455897411271580599L;
    private final ClassTag<T> evidence$1;
    private final TensorNumericMath.TensorNumeric<T> ev;
    private transient Tensor<T> buffer;

    private Tensor<T> buffer() {
        return this.buffer;
    }

    private void buffer_$eq(Tensor<T> tensor) {
        this.buffer = tensor;
    }

    @Override // com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule
    public Tensor<T> updateOutput(Table table) {
        Tensor<T> tensor = (Tensor) table.apply(BoxesRunTime.boxToInteger(1));
        Tensor<T> tensor2 = (Tensor) table.apply(BoxesRunTime.boxToInteger(2));
        if (tensor.dim() == 1) {
            tensor = tensor.view((Seq<Object>) Predef$.MODULE$.wrapIntArray(new int[]{1, tensor.size(1)}));
            tensor2 = tensor2.view((Seq<Object>) Predef$.MODULE$.wrapIntArray(new int[]{1, tensor2.size(1)}));
        }
        if (buffer() == null) {
            buffer_$eq(Tensor$.MODULE$.apply(this.evidence$1, this.ev));
        }
        buffer().resizeAs(tensor).cmul(tensor, tensor2);
        output().sum(buffer(), 2);
        output().resize(tensor.size(1));
        return output();
    }

    @Override // com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule
    public Table updateGradInput(Table table, Tensor<T> tensor) {
        Tensor<T> tensor2 = (Tensor) table.apply(BoxesRunTime.boxToInteger(1));
        Tensor<T> tensor3 = (Tensor) table.apply(BoxesRunTime.boxToInteger(2));
        boolean z = false;
        if (gradInput().length() != 2) {
            if (gradInput().contains(BoxesRunTime.boxToInteger(1))) {
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else {
                gradInput().update(BoxesRunTime.boxToInteger(1), Tensor$.MODULE$.apply(this.evidence$1, this.ev));
            }
            if (gradInput().contains(BoxesRunTime.boxToInteger(2))) {
                BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
            } else {
                gradInput().update(BoxesRunTime.boxToInteger(2), Tensor$.MODULE$.apply(this.evidence$1, this.ev));
            }
        } else {
            BoxedUnit boxedUnit3 = BoxedUnit.UNIT;
        }
        if (tensor2.dim() == 1) {
            tensor2 = tensor2.view((Seq<Object>) Predef$.MODULE$.wrapIntArray(new int[]{1, tensor2.size(1)}));
            tensor3 = tensor3.view((Seq<Object>) Predef$.MODULE$.wrapIntArray(new int[]{1, tensor3.size(1)}));
            z = true;
        }
        Tensor tensor4 = (Tensor) gradInput().apply(BoxesRunTime.boxToInteger(1));
        Tensor tensor5 = (Tensor) gradInput().apply(BoxesRunTime.boxToInteger(2));
        tensor4.resizeAs(tensor2).copy(tensor3);
        tensor5.resizeAs(tensor3).copy(tensor2);
        Tensor<T> expandAs = tensor.view((Seq<Object>) Predef$.MODULE$.wrapIntArray(new int[]{tensor.size(1), 1})).expandAs(tensor2);
        tensor4.cmul(expandAs);
        tensor5.cmul(expandAs);
        if (z) {
            ((Tensor) gradInput().apply(BoxesRunTime.boxToInteger(1))).set(tensor4.select(1, 1));
            ((Tensor) gradInput().apply(BoxesRunTime.boxToInteger(2))).set(tensor5.select(1, 1));
        } else {
            BoxedUnit boxedUnit4 = BoxedUnit.UNIT;
        }
        return gradInput();
    }

    @Override // com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule
    public String toString() {
        return new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"nn.DotProduct"})).s(Nil$.MODULE$);
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public DotProduct(ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        super(ClassTag$.MODULE$.apply(Table.class), ClassTag$.MODULE$.apply(Tensor.class), classTag, tensorNumeric);
        this.evidence$1 = classTag;
        this.ev = tensorNumeric;
        gradInput_$eq(T$.MODULE$.apply(Tensor$.MODULE$.apply(classTag, tensorNumeric), (Seq<Object>) Predef$.MODULE$.genericWrapArray(new Object[]{Tensor$.MODULE$.apply(classTag, tensorNumeric)})));
        this.buffer = null;
    }
}
