package com.intel.analytics.bigdl.dllib.keras.autograd;

import com.intel.analytics.bigdl.dllib.keras.Model;
import com.intel.analytics.bigdl.dllib.keras.objectives.TensorLossFunction;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.Tensor$;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.dllib.utils.Log4Error$;
import com.intel.analytics.bigdl.dllib.utils.Shape;
import com.intel.analytics.bigdl.dllib.utils.Shape$;
import com.intel.analytics.bigdl.dllib.utils.T$;
import scala.Function2;
import scala.Predef$;
import scala.StringContext;
import scala.collection.Seq;
import scala.reflect.ClassTag;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;

/* compiled from: CustomLoss.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005\u0015w!B\u0001\u0003\u0011\u0003\t\u0012AC\"vgR|W\u000eT8tg*\u00111\u0001B\u0001\tCV$xn\u001a:bI*\u0011QAB\u0001\u0006W\u0016\u0014\u0018m\u001d\u0006\u0003\u000f!\tQ\u0001\u001a7mS\nT!!\u0003\u0006\u0002\u000b\tLw\r\u001a7\u000b\u0005-a\u0011!C1oC2LH/[2t\u0015\tia\"A\u0003j]R,GNC\u0001\u0010\u0003\r\u0019w.\\\u0002\u0001!\t\u00112#D\u0001\u0003\r\u0015!\"\u0001#\u0001\u0016\u0005)\u0019Uo\u001d;p[2{7o]\n\u0004'Ya\u0002CA\f\u001b\u001b\u0005A\"\"A\r\u0002\u000bM\u001c\u0017\r\\1\n\u0005mA\"AB!osJ+g\r\u0005\u0002\u0018;%\u0011a\u0004\u0007\u0002\r'\u0016\u0014\u0018.\u00197ju\u0006\u0014G.\u001a\u0005\u0006AM!\t!I\u0001\u0007y%t\u0017\u000e\u001e \u0015\u0003EAQaI\n\u0005\u0002\u0011\nQ!\u00199qYf,\"!J\u0018\u0015\u000b\u0019R&M\u001b7\u0015\u0007\u001dB\u0004\tE\u0002)W5j\u0011!\u000b\u0006\u0003U\u0011\t!b\u001c2kK\u000e$\u0018N^3t\u0013\ta\u0013F\u0001\nUK:\u001cxN\u001d'pgN4UO\\2uS>t\u0007C\u0001\u00180\u0019\u0001!Q\u0001\r\u0012C\u0002E\u0012\u0011\u0001V\t\u0003eU\u0002\"aF\u001a\n\u0005QB\"a\u0002(pi\"Lgn\u001a\t\u0003/YJ!a\u000e\r\u0003\u0007\u0005s\u0017\u0010C\u0004:E\u0005\u0005\t9\u0001\u001e\u0002\u0015\u00154\u0018\u000eZ3oG\u0016$\u0013\u0007E\u0002<}5j\u0011\u0001\u0010\u0006\u0003{a\tqA]3gY\u0016\u001cG/\u0003\u0002@y\tA1\t\\1tgR\u000bw\rC\u0003BE\u0001\u000f!)\u0001\u0002fmB\u00191iV\u0017\u000f\u0005\u0011#fBA#S\u001d\t1\u0015K\u0004\u0002H!:\u0011\u0001j\u0014\b\u0003\u0013:s!AS'\u000e\u0003-S!\u0001\u0014\t\u0002\rq\u0012xn\u001c;?\u0013\u0005y\u0011BA\u0007\u000f\u0013\tYA\"\u0003\u0002\n\u0015%\u0011q\u0001C\u0005\u0003'\u001a\ta\u0001^3og>\u0014\u0018BA+W\u0003E!VM\\:pe:+X.\u001a:jG6\u000bG\u000f\u001b\u0006\u0003'\u001aI!\u0001W-\u0003\u001bQ+gn]8s\u001dVlWM]5d\u0015\t)f\u000bC\u0003\\E\u0001\u0007A,\u0001\u0005m_N\u001ch)\u001e8d!\u00159RlX0`\u0013\tq\u0006DA\u0005Gk:\u001cG/[8oeA\u0019!\u0003Y\u0017\n\u0005\u0005\u0014!\u0001\u0003,be&\f'\r\\3\t\u000b\r\u0014\u0003\u0019\u00013\u0002\u0015e\u0004&/\u001a3TQ\u0006\u0004X\r\u0005\u0002fQ6\taM\u0003\u0002h\r\u0005)Q\u000f^5mg&\u0011\u0011N\u001a\u0002\u0006'\"\f\u0007/\u001a\u0005\bW\n\u0002\n\u00111\u0001e\u0003)IHK];f'\"\f\u0007/\u001a\u0005\b[\n\u0002\n\u00111\u0001o\u0003-\u0019\u0018N_3Bm\u0016\u0014\u0018mZ3\u0011\u0005]y\u0017B\u00019\u0019\u0005\u001d\u0011un\u001c7fC:DqA]\n\u0012\u0002\u0013\u00051/A\bbaBd\u0017\u0010\n3fM\u0006,H\u000e\u001e\u00134+\t!x0F\u0001vU\t!goK\u0001x!\tAX0D\u0001z\u0015\tQ80A\u0005v]\u000eDWmY6fI*\u0011A\u0010G\u0001\u000bC:tw\u000e^1uS>t\u0017B\u0001@z\u0005E)hn\u00195fG.,GMV1sS\u0006t7-\u001a\u0003\u0006aE\u0014\r!\r\u0005\n\u0003\u0007\u0019\u0012\u0013!C\u0001\u0003\u000b\tq\"\u00199qYf$C-\u001a4bk2$H\u0005N\u000b\u0005\u0003\u000f\tY!\u0006\u0002\u0002\n)\u0012aN\u001e\u0003\u0007a\u0005\u0005!\u0019A\u0019\t\u0013\u0005=1#!A\u0005\n\u0005E\u0011a\u0003:fC\u0012\u0014Vm]8mm\u0016$\"!a\u0005\u0011\t\u0005U\u0011qD\u0007\u0003\u0003/QA!!\u0007\u0002\u001c\u0005!A.\u00198h\u0015\t\ti\"\u0001\u0003kCZ\f\u0017\u0002BA\u0011\u0003/\u0011aa\u00142kK\u000e$hA\u0002\u000b\u0003\u0003\u0003\t)#\u0006\u0003\u0002(\u000552\u0003BA\u0012\u0003S\u0001B\u0001K\u0016\u0002,A\u0019a&!\f\u0005\rA\n\u0019C1\u00012\u0011%i\u00171\u0005B\u0001B\u0003%a\u000eC\u0006\u00024\u0005\r\"1!Q\u0001\f\u0005U\u0012AC3wS\u0012,gnY3%gA!1HPA\u0016\u0011)\t\u00151\u0005B\u0001B\u0003-\u0011\u0011\b\t\u0005\u0007^\u000bY\u0003C\u0004!\u0003G!\t!!\u0010\u0015\t\u0005}\u0012q\t\u000b\u0007\u0003\u0003\n\u0019%!\u0012\u0011\u000bI\t\u0019#a\u000b\t\u0011\u0005M\u00121\ba\u0002\u0003kAq!QA\u001e\u0001\b\tI\u0004\u0003\u0004n\u0003w\u0001\rA\u001c\u0005\t\u0003\u0017\n\u0019C\"\u0005\u0002N\u0005IAm\\$fi2{7o\u001d\u000b\u0005\u0003\u001f\n)\u0007\u0005\u0006\u0002R\u0005m\u0013qLA0\u0003Wi!!a\u0015\u000b\t\u0005U\u0013qK\u0001\u000bC\n\u001cHO]1di:t'bAA-\r\u0005\u0011aN\\\u0005\u0005\u0003;\n\u0019F\u0001\bBEN$(/Y2u\u001b>$W\u000f\\3\u0011\t\u0005E\u0013\u0011M\u0005\u0005\u0003G\n\u0019F\u0001\u0005BGRLg/\u001b;z\u0011!\t9'!\u0013A\u0002\u0005%\u0014AB5oaV$8\u000fE\u0003\u0018\u0003W\ny'C\u0002\u0002na\u0011Q!\u0011:sCf\u0004BA\u00051\u0002,!A\u00111OA\u0012\r#\t)(\u0001\u0007hKRLe\u000e];u-\u0006\u00148\u000f\u0006\u0003\u0002j\u0005]\u0004\u0002CA=\u0003c\u0002\r!a\u001f\u0002\u0017%t\u0007/\u001e;TQ\u0006\u0004Xm\u001d\t\u0005/\u0005-D\r\u0003\u0005\u0002��\u0005\rBQAAA\u0003\u001d9W\r\u001e'pgN$B!a\u0014\u0002\u0004\"A\u0011\u0011PA?\u0001\u0004\tY\b\u0003\u0005\u0002\b\u0006\rBQAAE\u0003Q9WM\\3sCR,Gj\\:t\rJ|WNV1sgR1\u00111RAJ\u0003/\u0003b!!$\u0002\u0010\u0006-R\"\u0001\u0003\n\u0007\u0005EEAA\u0003N_\u0012,G\u000e\u0003\u0005\u0002\u0016\u0006\u0015\u0005\u0019AA5\u0003\u0019IgNV1sg\"A\u0011\u0011TAC\u0001\u0004\ty'\u0001\u0004pkR4\u0016M\u001d\u0005\t\u0003;\u000b\u0019\u0003\"\u0003\u0002 \u0006)B/\u001a8t_J$vNT8o\u0005\u0006$8\r[*iCB,Gc\u00013\u0002\"\"91+a'A\u0002\u0005\r\u0006CBAS\u0003O\u000bY#D\u0001W\u0013\r\tIK\u0016\u0002\u0007)\u0016t7o\u001c:\t\u0011\u00055\u00161\u0005C!\u0003_\u000bA\"\u001e9eCR,w*\u001e;qkR$b!a\u000b\u00022\u0006U\u0006\u0002CAZ\u0003W\u0003\r!a)\u0002\u000be\u0004&/\u001a3\t\u0011\u0005]\u00161\u0016a\u0001\u0003G\u000ba\u0001^1sO\u0016$\b\u0002CA^\u0003G!\t%!0\u0002\u001fU\u0004H-\u0019;f\u000fJ\fG-\u00138qkR$b!a)\u0002@\u0006\u0005\u0007\u0002CAZ\u0003s\u0003\r!a)\t\u0011\u0005\r\u0017\u0011\u0018a\u0001\u0003G\u000bQ!\u001f+sk\u0016\u0004")
/* loaded from: input_file:com/intel/analytics/bigdl/dllib/keras/autograd/CustomLoss.class */
public abstract class CustomLoss<T> extends TensorLossFunction<T> {
    private final boolean sizeAverage;
    private final ClassTag<T> evidence$3;
    private final TensorNumericMath.TensorNumeric<T> ev;

    public static <T> TensorLossFunction<T> apply(Function2<Variable<T>, Variable<T>, Variable<T>> function2, Shape shape, Shape shape2, boolean z, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        return CustomLoss$.MODULE$.apply(function2, shape, shape2, z, classTag, tensorNumeric);
    }

    public abstract AbstractModule<Activity, Activity, T> doGetLoss(Variable<T>[] variableArr);

    public abstract Variable<T>[] getInputVars(Shape[] shapeArr);

    public final AbstractModule<Activity, Activity, T> getLoss(Shape[] shapeArr) {
        return doGetLoss(getInputVars(shapeArr));
    }

    public final Model<T> generateLossFromVars(Variable<T>[] variableArr, Variable<T> variable) {
        return this.sizeAverage ? AutoGrad$.MODULE$.mean(variable, 0, AutoGrad$.MODULE$.mean$default$3(), this.evidence$3, this.ev).toGraph(variableArr) : variable.toGraph(variableArr);
    }

    private Shape tensorToNonBatchShape(Tensor<T> tensor) {
        int[] size = tensor.size();
        return Shape$.MODULE$.apply((int[]) Predef$.MODULE$.intArrayOps(size).slice(1, size.length));
    }

    @Override // com.intel.analytics.bigdl.dllib.keras.objectives.LossFunction, com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractCriterion
    /* renamed from: updateOutput, reason: merged with bridge method [inline-methods] */
    public T mo1037updateOutput(Tensor<T> tensor, Tensor<T> tensor2) {
        Object tensor3 = tensor.toTensor(this.ev);
        Object tensor4 = tensor2.toTensor(this.ev);
        Shape tensorToNonBatchShape = tensorToNonBatchShape(tensor);
        Tensor tensor5 = getLoss(new Shape[]{tensorToNonBatchShape, tensorToNonBatchShape}).forward(T$.MODULE$.apply(tensor4, (Seq<Object>) Predef$.MODULE$.genericWrapArray(new Object[]{tensor3}))).toTensor(this.ev);
        Log4Error$.MODULE$.invalidOperationError(tensor5.isScalar(), new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"The loss should be scalar, but got result with shape: [", "]"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{Predef$.MODULE$.intArrayOps(tensor5.size()).mkString(", ")})), Log4Error$.MODULE$.invalidOperationError$default$3(), Log4Error$.MODULE$.invalidOperationError$default$4());
        return (T) tensor5.mo1973value();
    }

    @Override // com.intel.analytics.bigdl.dllib.keras.objectives.LossFunction, com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractCriterion
    public Tensor<T> updateGradInput(Tensor<T> tensor, Tensor<T> tensor2) {
        Shape tensorToNonBatchShape = tensorToNonBatchShape(tensor);
        return (Tensor) getLoss(new Shape[]{tensorToNonBatchShape, tensorToNonBatchShape}).updateGradInput2(T$.MODULE$.apply(tensor2, (Seq<Object>) Predef$.MODULE$.genericWrapArray(new Object[]{tensor})), Tensor$.MODULE$.apply(1, this.evidence$3, this.ev).fill(this.ev.mo2057one())).toTable().get(BoxesRunTime.boxToInteger(2)).get();
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public CustomLoss(boolean z, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        super(classTag, tensorNumeric);
        this.sizeAverage = z;
        this.evidence$3 = classTag;
        this.ev = tensorNumeric;
    }
}
