package com.intel.analytics.bigdl.dllib.nn.ops;

import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.dllib.utils.Log4Error$;
import com.intel.analytics.bigdl.dllib.utils.Table;
import scala.Predef$;
import scala.StringContext;
import scala.collection.mutable.StringBuilder;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;

/* compiled from: SegmentSum.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005\u0015a\u0001B\u0001\u0003\u0001E\u0011!bU3h[\u0016tGoU;n\u0015\t\u0019A!A\u0002paNT!!\u0002\u0004\u0002\u00059t'BA\u0004\t\u0003\u0015!G\u000e\\5c\u0015\tI!\"A\u0003cS\u001e$GN\u0003\u0002\f\u0019\u0005I\u0011M\\1msRL7m\u001d\u0006\u0003\u001b9\tQ!\u001b8uK2T\u0011aD\u0001\u0004G>l7\u0001A\u000b\u0003%\u0015\u001a\"\u0001A\n\u0011\u000bQ)r#H\u0012\u000e\u0003\tI!A\u0006\u0002\u0003\u0013=\u0003XM]1uS>t\u0007C\u0001\r\u001c\u001b\u0005I\"B\u0001\u000e\u0007\u0003\u0015)H/\u001b7t\u0013\ta\u0012DA\u0003UC\ndW\rE\u0002\u001fC\rj\u0011a\b\u0006\u0003A\u0019\ta\u0001^3og>\u0014\u0018B\u0001\u0012 \u0005\u0019!VM\\:peB\u0011A%\n\u0007\u0001\t\u00151\u0003A1\u0001(\u0005\u0005!\u0016C\u0001\u0015/!\tIC&D\u0001+\u0015\u0005Y\u0013!B:dC2\f\u0017BA\u0017+\u0005\u001dqu\u000e\u001e5j]\u001e\u0004\"!K\u0018\n\u0005AR#aA!os\"A!\u0007\u0001B\u0002B\u0003-1'\u0001\u0006fm&$WM\\2fIE\u00022\u0001N\u001c$\u001b\u0005)$B\u0001\u001c+\u0003\u001d\u0011XM\u001a7fGRL!\u0001O\u001b\u0003\u0011\rc\u0017m]:UC\u001eD\u0001B\u000f\u0001\u0003\u0002\u0003\u0006YaO\u0001\u0003KZ\u00042\u0001\u0010($\u001d\tiDJ\u0004\u0002?\u0017:\u0011qH\u0013\b\u0003\u0001&s!!\u0011%\u000f\u0005\t;eBA\"G\u001b\u0005!%BA#\u0011\u0003\u0019a$o\\8u}%\tq\"\u0003\u0002\u000e\u001d%\u00111\u0002D\u0005\u0003\u0013)I!a\u0002\u0005\n\u0005\u00012\u0011BA' \u0003E!VM\\:pe:+X.\u001a:jG6\u000bG\u000f[\u0005\u0003\u001fB\u0013Q\u0002V3og>\u0014h*^7fe&\u001c'BA' \u0011\u0015\u0011\u0006\u0001\"\u0001T\u0003\u0019a\u0014N\\5u}Q\tA\u000bF\u0002V-^\u00032\u0001\u0006\u0001$\u0011\u0015\u0011\u0014\u000bq\u00014\u0011\u0015Q\u0014\u000bq\u0001<\u0011\u0015I\u0006\u0001\"\u0001[\u00031)\b\u000fZ1uK>+H\u000f];u)\ti2\fC\u0003]1\u0002\u0007q#\u0001\u0004j]B,Ho]\u0004\u0006=\nA\taX\u0001\u000b'\u0016<W.\u001a8u'Vl\u0007C\u0001\u000ba\r\u0015\t!\u0001#\u0001b'\r\u0001'-\u001a\t\u0003S\rL!\u0001\u001a\u0016\u0003\r\u0005s\u0017PU3g!\tIc-\u0003\u0002hU\ta1+\u001a:jC2L'0\u00192mK\")!\u000b\u0019C\u0001SR\tq\fC\u0003lA\u0012\u0005A.A\u0003baBd\u00170\u0006\u0002ncR\ta\u000eF\u0002peV\u00042\u0001\u0006\u0001q!\t!\u0013\u000fB\u0003'U\n\u0007q\u0005C\u0004tU\u0006\u0005\t9\u0001;\u0002\u0015\u00154\u0018\u000eZ3oG\u0016$#\u0007E\u00025oADQA\u000f6A\u0004Y\u00042\u0001\u0010(q\u0011\u001dA\b-!A\u0005\ne\f1B]3bIJ+7o\u001c7wKR\t!\u0010E\u0002|\u0003\u0003i\u0011\u0001 \u0006\u0003{z\fA\u0001\\1oO*\tq0\u0001\u0003kCZ\f\u0017bAA\u0002y\n1qJ\u00196fGR\u0004")
/* loaded from: input_file:com/intel/analytics/bigdl/dllib/nn/ops/SegmentSum.class */
public class SegmentSum<T> extends Operation<Table, Tensor<T>, T> {
    @Override // com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule
    public Tensor<T> updateOutput(Table table) {
        Tensor tensor = (Tensor) table.apply(BoxesRunTime.boxToInteger(1));
        Tensor tensor2 = (Tensor) table.apply(BoxesRunTime.boxToInteger(2));
        Log4Error$.MODULE$.invalidInputError(tensor2.nDimension() == 1, "segment ids should be 1D tensor", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(tensor2.size(1) == tensor.size(1), new StringBuilder().append("segment ids should be the same size as").append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{" first dimension of input, excepted ", ", but got ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(tensor.size(1)), BoxesRunTime.boxToInteger(tensor2.size(1))}))).toString(), Log4Error$.MODULE$.invalidInputError$default$3());
        int[] size = tensor.size();
        size[0] = BoxesRunTime.unboxToInt(tensor2.mo1972valueAt(tensor2.nElement())) + 1;
        Tensor tensor3 = (Tensor) output();
        tensor3.resize(size, tensor3.resize$default$2()).zero();
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 >= tensor2.nElement()) {
                return (Tensor) output();
            }
            ((Tensor) output()).select(1, BoxesRunTime.unboxToInt(tensor2.mo1972valueAt(i2 + 1)) + 1).add((Tensor) tensor.select(1, i2 + 1));
            i = i2 + 1;
        }
    }

    public SegmentSum(ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        super(ClassTag$.MODULE$.apply(Table.class), ClassTag$.MODULE$.apply(Tensor.class), classTag, tensorNumeric);
    }
}
