package com.intel.analytics.bigdl.dllib.optim;

import com.intel.analytics.bigdl.dllib.optim.OptimMethod$mcF$sp;
import com.intel.analytics.bigdl.dllib.tensor.ConvertableFrom$ConvertableFromDouble$;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.Tensor$;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.dllib.utils.Table;
import scala.Array$;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.Tuple3;
import scala.math.package$;
import scala.reflect.ClassTag;
import scala.runtime.BoxesRunTime;

/* compiled from: Adam.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/dllib/optim/Adam$mcF$sp.class */
public class Adam$mcF$sp extends Adam<Object> implements OptimMethod$mcF$sp {
    public final TensorNumericMath.TensorNumeric<Object> ev$mcF$sp;
    private final ClassTag<Object> evidence$1;

    @Override // com.intel.analytics.bigdl.dllib.optim.Adam
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public OptimMethod<Object> mo1609clone() {
        return OptimMethod$mcF$sp.Cclass.clone(this);
    }

    @Override // com.intel.analytics.bigdl.dllib.optim.Adam, com.intel.analytics.bigdl.dllib.optim.OptimMethod
    public OptimMethod<Object> clone$mcF$sp() {
        return OptimMethod$mcF$sp.Cclass.clone$mcF$sp(this);
    }

    @Override // com.intel.analytics.bigdl.dllib.optim.Adam, com.intel.analytics.bigdl.dllib.optim.OptimMethod
    public Tuple2<Tensor<Object>, Object> optimize(Function1<Tensor<Object>, Tuple2<Object, Tensor<Object>>> function1, Tensor<Object> tensor, Table table, Table table2) {
        return OptimMethod$mcF$sp.Cclass.optimize(this, function1, tensor, table, table2);
    }

    @Override // com.intel.analytics.bigdl.dllib.optim.Adam, com.intel.analytics.bigdl.dllib.optim.OptimMethod
    public Tuple2<Tensor<Object>, float[]> optimize$mcF$sp(Function1<Tensor<Object>, Tuple2<Object, Tensor<Object>>> function1, Tensor<Object> tensor, Table table, Table table2) {
        return OptimMethod$mcF$sp.Cclass.optimize$mcF$sp(this, function1, tensor, table, table2);
    }

    @Override // com.intel.analytics.bigdl.dllib.optim.Adam, com.intel.analytics.bigdl.dllib.optim.OptimMethod
    public Tuple2<Tensor<Object>, Object> optimize(Function1<Tensor<Object>, Tuple2<Object, Tensor<Object>>> function1, Tensor<Object> tensor) {
        return optimize$mcF$sp(function1, tensor);
    }

    @Override // com.intel.analytics.bigdl.dllib.optim.Adam, com.intel.analytics.bigdl.dllib.optim.OptimMethod
    public Tuple2<Tensor<Object>, float[]> optimize$mcF$sp(Function1<Tensor<Object>, Tuple2<Object, Tensor<Object>>> function1, Tensor<Object> tensor) {
        if (com$intel$analytics$bigdl$dllib$optim$Adam$$buffer() == null) {
            com$intel$analytics$bigdl$dllib$optim$Adam$$buffer_$eq(Tensor$.MODULE$.apply$mFc$sp(this.com$intel$analytics$bigdl$dllib$optim$Adam$$evidence$1, this.ev$mcF$sp));
        }
        double learningRate = learningRate();
        double learningRateDecay = learningRateDecay();
        double beta1 = beta1();
        double beta2 = beta2();
        double Epsilon = Epsilon();
        Tuple2 tuple2 = (Tuple2) function1.apply(tensor);
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        Tuple2 tuple22 = new Tuple2(BoxesRunTime.boxToFloat(BoxesRunTime.unboxToFloat(tuple2._1())), (Tensor) tuple2._2());
        float unboxToFloat = BoxesRunTime.unboxToFloat(tuple22._1());
        Tensor<?> tensor2 = (Tensor) tuple22._2();
        int unboxToInt = BoxesRunTime.unboxToInt(state().getOrElse("evalCounter", BoxesRunTime.boxToInteger(0)));
        Tuple3 tuple3 = state().get("s").isDefined() ? new Tuple3(state().get("s").get(), state().get("r").get(), ((Tensor) state().get("denom").get()).resizeAs(tensor2)) : new Tuple3(Tensor$.MODULE$.apply$mFc$sp(this.com$intel$analytics$bigdl$dllib$optim$Adam$$evidence$1, this.ev$mcF$sp).resizeAs(tensor2).zero(), Tensor$.MODULE$.apply$mFc$sp(this.com$intel$analytics$bigdl$dllib$optim$Adam$$evidence$1, this.ev$mcF$sp).resizeAs(tensor2).zero(), Tensor$.MODULE$.apply$mFc$sp(this.com$intel$analytics$bigdl$dllib$optim$Adam$$evidence$1, this.ev$mcF$sp).resizeAs(tensor2).zero());
        if (tuple3 == null) {
            throw new MatchError(tuple3);
        }
        Tuple3 tuple32 = new Tuple3((Tensor) tuple3._1(), (Tensor) tuple3._2(), (Tensor) tuple3._3());
        Tensor<Object> tensor3 = (Tensor) tuple32._1();
        Tensor<Object> tensor4 = (Tensor) tuple32._2();
        Tensor<Object> tensor5 = (Tensor) tuple32._3();
        double d = learningRate / (1 + (unboxToInt * learningRateDecay));
        int i = unboxToInt + 1;
        tensor3.mul(BoxesRunTime.boxToFloat(this.ev$mcF$sp.fromType$mcF$sp(BoxesRunTime.boxToDouble(beta1), ConvertableFrom$ConvertableFromDouble$.MODULE$))).add((Tensor<Object>) BoxesRunTime.boxToFloat(this.ev$mcF$sp.fromType$mcF$sp(BoxesRunTime.boxToDouble(1 - beta1), ConvertableFrom$ConvertableFromDouble$.MODULE$)), (Tensor<Tensor<Object>>) tensor2);
        com$intel$analytics$bigdl$dllib$optim$Adam$$buffer().resizeAs(tensor2).cmul(tensor2, tensor2);
        tensor4.mul(BoxesRunTime.boxToFloat(this.ev$mcF$sp.fromType$mcF$sp(BoxesRunTime.boxToDouble(beta2), ConvertableFrom$ConvertableFromDouble$.MODULE$))).add((Tensor<Object>) BoxesRunTime.boxToFloat(this.ev$mcF$sp.fromType$mcF$sp(BoxesRunTime.boxToDouble(1 - beta2), ConvertableFrom$ConvertableFromDouble$.MODULE$)), (Tensor<Tensor<Object>>) com$intel$analytics$bigdl$dllib$optim$Adam$$buffer());
        tensor5.sqrt(tensor4);
        com$intel$analytics$bigdl$dllib$optim$Adam$$buffer().fill(BoxesRunTime.boxToFloat(this.ev$mcF$sp.one$mcF$sp()));
        tensor5.add((Tensor<Object>) BoxesRunTime.boxToFloat(this.ev$mcF$sp.fromType$mcF$sp(BoxesRunTime.boxToDouble(Epsilon), ConvertableFrom$ConvertableFromDouble$.MODULE$)), (Tensor<Tensor<Object>>) com$intel$analytics$bigdl$dllib$optim$Adam$$buffer());
        tensor.addcdiv(BoxesRunTime.boxToFloat(this.ev$mcF$sp.fromType$mcF$sp(BoxesRunTime.boxToDouble(-((d * package$.MODULE$.sqrt(1 - package$.MODULE$.pow(beta2, i))) / (1 - package$.MODULE$.pow(beta1, i)))), ConvertableFrom$ConvertableFromDouble$.MODULE$)), tensor3, tensor5);
        state().update("evalCounter", BoxesRunTime.boxToInteger(i));
        state().update("s", tensor3);
        state().update("r", tensor4);
        state().update("denom", tensor5);
        return new Tuple2<>(tensor, Array$.MODULE$.apply(Predef$.MODULE$.genericWrapArray(new float[]{unboxToFloat}), this.com$intel$analytics$bigdl$dllib$optim$Adam$$evidence$1));
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public Adam$mcF$sp(double d, double d2, double d3, double d4, double d5, ClassTag<Object> classTag, TensorNumericMath.TensorNumeric<Object> tensorNumeric) {
        super(d, d2, d3, d4, d5, classTag, tensorNumeric);
        this.ev$mcF$sp = tensorNumeric;
        this.evidence$1 = classTag;
        OptimMethod$mcF$sp.Cclass.$init$(this);
    }
}
