package com.intel.analytics.bigdl.dllib.keras.layers.internal;

import com.intel.analytics.bigdl.dllib.nn.JoinTable;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.dllib.utils.T$;
import com.intel.analytics.bigdl.dllib.utils.Table;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;

/* compiled from: InternalSplitTensor.scala */
@ScalaSignature(bytes = "\u0006\u0001q4A!\u0001\u0002\u0001'\t\u0019\u0012J\u001c;fe:\fGn\u00159mSR$VM\\:pe*\u00111\u0001B\u0001\tS:$XM\u001d8bY*\u0011QAB\u0001\u0007Y\u0006LXM]:\u000b\u0005\u001dA\u0011!B6fe\u0006\u001c(BA\u0005\u000b\u0003\u0015!G\u000e\\5c\u0015\tYA\"A\u0003cS\u001e$GN\u0003\u0002\u000e\u001d\u0005I\u0011M\\1msRL7m\u001d\u0006\u0003\u001fA\tQ!\u001b8uK2T\u0011!E\u0001\u0004G>l7\u0001A\u000b\u0003)\u0015\u001a\"\u0001A\u000b\u0011\u000bYYR$M\u0012\u000e\u0003]Q!\u0001G\r\u0002\u0015\u0005\u00147\u000f\u001e:bGRtgN\u0003\u0002\u001b\u0011\u0005\u0011aN\\\u0005\u00039]\u0011a\"\u00112tiJ\f7\r^'pIVdW\rE\u0002\u001fC\rj\u0011a\b\u0006\u0003A!\ta\u0001^3og>\u0014\u0018B\u0001\u0012 \u0005\u0019!VM\\:peB\u0011A%\n\u0007\u0001\t\u00151\u0003A1\u0001(\u0005\u0005!\u0016C\u0001\u0015/!\tIC&D\u0001+\u0015\u0005Y\u0013!B:dC2\f\u0017BA\u0017+\u0005\u001dqu\u000e\u001e5j]\u001e\u0004\"!K\u0018\n\u0005AR#aA!osB\u0011!'N\u0007\u0002g)\u0011A\u0007C\u0001\u0006kRLGn]\u0005\u0003mM\u0012Q\u0001V1cY\u0016D\u0001\u0002\u000f\u0001\u0003\u0002\u0003\u0006I!O\u0001\nI&lWM\\:j_:\u0004\"!\u000b\u001e\n\u0005mR#aA%oi\"AQ\b\u0001B\u0001B\u0003%\u0011(A\u0002ok6D\u0001b\u0010\u0001\u0003\u0004\u0003\u0006Y\u0001Q\u0001\u000bKZLG-\u001a8dK\u0012\n\u0004cA!EG5\t!I\u0003\u0002DU\u00059!/\u001a4mK\u000e$\u0018BA#C\u0005!\u0019E.Y:t)\u0006<\u0007\u0002C$\u0001\u0005\u0003\u0005\u000b1\u0002%\u0002\u0005\u00154\bcA%\\G9\u0011!*\u0017\b\u0003\u0017bs!\u0001T,\u000f\u000553fB\u0001(V\u001d\tyEK\u0004\u0002Q'6\t\u0011K\u0003\u0002S%\u00051AH]8pizJ\u0011!E\u0005\u0003\u001fAI!!\u0004\b\n\u0005-a\u0011BA\u0005\u000b\u0013\t\u0001\u0003\"\u0003\u0002[?\u0005\tB+\u001a8t_JtU/\\3sS\u000el\u0015\r\u001e5\n\u0005qk&!\u0004+f]N|'OT;nKJL7M\u0003\u0002[?!)q\f\u0001C\u0001A\u00061A(\u001b8jiz\"2!\u00194h)\r\u0011G-\u001a\t\u0004G\u0002\u0019S\"\u0001\u0002\t\u000b}r\u00069\u0001!\t\u000b\u001ds\u00069\u0001%\t\u000bar\u0006\u0019A\u001d\t\u000bur\u0006\u0019A\u001d\t\u000b%\u0004A\u0011\t6\u0002\u0019U\u0004H-\u0019;f\u001fV$\b/\u001e;\u0015\u0005EZ\u0007\"\u00027i\u0001\u0004i\u0012!B5oaV$\bb\u00028\u0001\u0005\u0004%Ia\\\u0001\u000bS:tWM\u001d'bs\u0016\u0014X#\u00019\u0011\u0007E\u00148%D\u0001\u001a\u0013\t\u0019\u0018DA\u0005K_&tG+\u00192mK\"1Q\u000f\u0001Q\u0001\nA\f1\"\u001b8oKJd\u0015-_3sA!)q\u000f\u0001C!q\u0006yQ\u000f\u001d3bi\u0016<%/\u00193J]B,H\u000fF\u0002\u001esjDQ\u0001\u001c<A\u0002uAQa\u001f<A\u0002E\n!b\u001a:bI>+H\u000f];u\u0001")
/* loaded from: input_file:com/intel/analytics/bigdl/dllib/keras/layers/internal/InternalSplitTensor.class */
public class InternalSplitTensor<T> extends AbstractModule<Tensor<T>, Table, T> {
    private final int dimension;
    private final int num;
    private final TensorNumericMath.TensorNumeric<T> ev;
    private final JoinTable<T> innerLayer;

    @Override // com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule
    public Table updateOutput(Tensor<T> tensor) {
        output_$eq(T$.MODULE$.array(tensor.split(tensor.size(this.dimension) / this.num, this.dimension)));
        return output();
    }

    private JoinTable<T> innerLayer() {
        return this.innerLayer;
    }

    @Override // com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule
    public Tensor<T> updateGradInput2(Tensor<T> tensor, Table table) {
        gradInput_$eq(innerLayer().forward(table).toTensor(this.ev));
        return gradInput();
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public InternalSplitTensor(int i, int i2, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        super(ClassTag$.MODULE$.apply(Tensor.class), ClassTag$.MODULE$.apply(Table.class), classTag, tensorNumeric);
        this.dimension = i;
        this.num = i2;
        this.ev = tensorNumeric;
        this.innerLayer = new JoinTable<>(i, -1, classTag, tensorNumeric);
    }
}
