package com.intel.analytics.bigdl.dllib.optim;

import com.intel.analytics.bigdl.dllib.optim.OptimMethod$mcD$sp;
import com.intel.analytics.bigdl.dllib.tensor.ConvertableFrom$ConvertableFromDouble$;
import com.intel.analytics.bigdl.dllib.tensor.ConvertableFrom$ConvertableFromInt$;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.Tensor$;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.dllib.utils.Table;
import scala.Array$;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.reflect.ClassTag;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: Adagrad.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/dllib/optim/Adagrad$mcD$sp.class */
public class Adagrad$mcD$sp extends Adagrad<Object> implements OptimMethod$mcD$sp {
    public final TensorNumericMath.TensorNumeric<Object> ev$mcD$sp;
    private final ClassTag<Object> evidence$1;

    @Override // com.intel.analytics.bigdl.dllib.optim.Adagrad
    /* renamed from: clone */
    public OptimMethod<Object> mo1605clone() {
        return OptimMethod$mcD$sp.Cclass.clone(this);
    }

    @Override // com.intel.analytics.bigdl.dllib.optim.Adagrad, com.intel.analytics.bigdl.dllib.optim.OptimMethod
    public OptimMethod<Object> clone$mcD$sp() {
        return OptimMethod$mcD$sp.Cclass.clone$mcD$sp(this);
    }

    @Override // com.intel.analytics.bigdl.dllib.optim.Adagrad, com.intel.analytics.bigdl.dllib.optim.OptimMethod
    public Tuple2<Tensor<Object>, Object> optimize(Function1<Tensor<Object>, Tuple2<Object, Tensor<Object>>> function1, Tensor<Object> tensor, Table table, Table table2) {
        return OptimMethod$mcD$sp.Cclass.optimize(this, function1, tensor, table, table2);
    }

    @Override // com.intel.analytics.bigdl.dllib.optim.Adagrad, com.intel.analytics.bigdl.dllib.optim.OptimMethod
    public Tuple2<Tensor<Object>, double[]> optimize$mcD$sp(Function1<Tensor<Object>, Tuple2<Object, Tensor<Object>>> function1, Tensor<Object> tensor, Table table, Table table2) {
        return OptimMethod$mcD$sp.Cclass.optimize$mcD$sp(this, function1, tensor, table, table2);
    }

    @Override // com.intel.analytics.bigdl.dllib.optim.Adagrad, com.intel.analytics.bigdl.dllib.optim.OptimMethod
    public Tuple2<Tensor<Object>, Object> optimize(Function1<Tensor<Object>, Tuple2<Object, Tensor<Object>>> function1, Tensor<Object> tensor) {
        return optimize$mcD$sp(function1, tensor);
    }

    @Override // com.intel.analytics.bigdl.dllib.optim.Adagrad, com.intel.analytics.bigdl.dllib.optim.OptimMethod
    public Tuple2<Tensor<Object>, double[]> optimize$mcD$sp(Function1<Tensor<Object>, Tuple2<Object, Tensor<Object>>> function1, Tensor<Object> tensor) {
        double learningRate = learningRate();
        double learningRateDecay = learningRateDecay();
        int unboxToInt = BoxesRunTime.unboxToInt(state().get("evalCounter").getOrElse(new Adagrad$mcD$sp$$anonfun$2(this)));
        double weightDecay = weightDecay();
        Tuple2 tuple2 = (Tuple2) function1.apply(tensor);
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        Tuple2 tuple22 = new Tuple2(BoxesRunTime.boxToDouble(tuple2._1$mcD$sp()), (Tensor) tuple2._2());
        double _1$mcD$sp = tuple22._1$mcD$sp();
        Tensor<?> tensor2 = (Tensor) tuple22._2();
        if (weightDecay != 0) {
            tensor2.add((Tensor<?>) BoxesRunTime.boxToDouble(this.ev$mcD$sp.fromType$mcD$sp(BoxesRunTime.boxToDouble(weightDecay), ConvertableFrom$ConvertableFromDouble$.MODULE$)), (Tensor<Tensor<?>>) tensor);
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        double d = learningRate / (1 + (unboxToInt * learningRateDecay));
        Tuple2 tuple23 = state().get("paramVariance").isDefined() ? new Tuple2(state().get("paramVariance").get(), state().get("paramStd").get()) : new Tuple2(Tensor$.MODULE$.apply$mDc$sp(this.com$intel$analytics$bigdl$dllib$optim$Adagrad$$evidence$1, this.ev$mcD$sp).resizeAs(tensor2).zero(), Tensor$.MODULE$.apply$mDc$sp(this.com$intel$analytics$bigdl$dllib$optim$Adagrad$$evidence$1, this.ev$mcD$sp).resizeAs(tensor2));
        if (tuple23 == null) {
            throw new MatchError(tuple23);
        }
        Tuple2 tuple24 = new Tuple2((Tensor) tuple23._1(), (Tensor) tuple23._2());
        Tensor<?> tensor3 = (Tensor) tuple24._1();
        Tensor tensor4 = (Tensor) tuple24._2();
        tensor3.addcmul(BoxesRunTime.boxToDouble(this.ev$mcD$sp.fromType$mcD$sp(BoxesRunTime.boxToInteger(1), ConvertableFrom$ConvertableFromInt$.MODULE$)), tensor2, tensor2);
        tensor4.resizeAs(tensor3).copy(tensor3).sqrt();
        tensor.addcdiv(BoxesRunTime.boxToDouble(this.ev$mcD$sp.fromType$mcD$sp(BoxesRunTime.boxToDouble(-d), ConvertableFrom$ConvertableFromDouble$.MODULE$)), tensor2, tensor4.add((Tensor) BoxesRunTime.boxToDouble(this.ev$mcD$sp.fromType$mcD$sp(BoxesRunTime.boxToDouble(1.0E-10d), ConvertableFrom$ConvertableFromDouble$.MODULE$))));
        state().update("evalCounter", BoxesRunTime.boxToInteger(unboxToInt + 1));
        state().update("paramVariance", tensor3);
        state().update("paramStd", tensor4);
        return new Tuple2<>(tensor, Array$.MODULE$.apply(Predef$.MODULE$.genericWrapArray(new double[]{_1$mcD$sp}), this.com$intel$analytics$bigdl$dllib$optim$Adagrad$$evidence$1));
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public Adagrad$mcD$sp(double d, double d2, double d3, ClassTag<Object> classTag, TensorNumericMath.TensorNumeric<Object> tensorNumeric) {
        super(d, d2, d3, classTag, tensorNumeric);
        this.ev$mcD$sp = tensorNumeric;
        this.evidence$1 = classTag;
        OptimMethod$mcD$sp.Cclass.$init$(this);
    }
}
