package com.intel.analytics.bigdl.dllib.optim;

import com.intel.analytics.bigdl.dllib.optim.OptimMethod$mcF$sp;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.Tensor$;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.dllib.utils.Engine$;
import com.intel.analytics.bigdl.dllib.utils.Table;
import com.intel.analytics.bigdl.dllib.utils.ThreadPool;
import scala.Array$;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.immutable.IndexedSeq;
import scala.collection.immutable.IndexedSeq$;
import scala.reflect.ClassTag;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;
import scala.runtime.RichInt$;

/* compiled from: ParallelAdam.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/dllib/optim/ParallelAdam$mcF$sp.class */
public class ParallelAdam$mcF$sp extends ParallelAdam<Object> implements OptimMethod$mcF$sp {
    public final TensorNumericMath.TensorNumeric<Object> ev$mcF$sp;
    private final ClassTag<Object> evidence$1;

    @Override // com.intel.analytics.bigdl.dllib.optim.ParallelAdam
    /* renamed from: clone */
    public OptimMethod<Object> mo1766clone() {
        return OptimMethod$mcF$sp.Cclass.clone(this);
    }

    @Override // com.intel.analytics.bigdl.dllib.optim.ParallelAdam, com.intel.analytics.bigdl.dllib.optim.OptimMethod
    public OptimMethod<Object> clone$mcF$sp() {
        return OptimMethod$mcF$sp.Cclass.clone$mcF$sp(this);
    }

    @Override // com.intel.analytics.bigdl.dllib.optim.ParallelAdam, com.intel.analytics.bigdl.dllib.optim.OptimMethod
    public Tuple2<Tensor<Object>, Object> optimize(Function1<Tensor<Object>, Tuple2<Object, Tensor<Object>>> function1, Tensor<Object> tensor, Table table, Table table2) {
        return OptimMethod$mcF$sp.Cclass.optimize(this, function1, tensor, table, table2);
    }

    @Override // com.intel.analytics.bigdl.dllib.optim.ParallelAdam, com.intel.analytics.bigdl.dllib.optim.OptimMethod
    public Tuple2<Tensor<Object>, float[]> optimize$mcF$sp(Function1<Tensor<Object>, Tuple2<Object, Tensor<Object>>> function1, Tensor<Object> tensor, Table table, Table table2) {
        return OptimMethod$mcF$sp.Cclass.optimize$mcF$sp(this, function1, tensor, table, table2);
    }

    @Override // com.intel.analytics.bigdl.dllib.optim.ParallelAdam, com.intel.analytics.bigdl.dllib.optim.OptimMethod
    public Tuple2<Tensor<Object>, Object> optimize(Function1<Tensor<Object>, Tuple2<Object, Tensor<Object>>> function1, Tensor<Object> tensor) {
        return optimize$mcF$sp(function1, tensor);
    }

    @Override // com.intel.analytics.bigdl.dllib.optim.ParallelAdam, com.intel.analytics.bigdl.dllib.optim.OptimMethod
    public Tuple2<Tensor<Object>, float[]> optimize$mcF$sp(Function1<Tensor<Object>, Tuple2<Object, Tensor<Object>>> function1, Tensor<Object> tensor) {
        double learningRate = learningRate();
        double learningRateDecay = learningRateDecay();
        double beta1 = beta1();
        double beta2 = beta2();
        double Epsilon = Epsilon();
        Tuple2 tuple2 = (Tuple2) function1.apply(tensor);
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        Tuple2 tuple22 = new Tuple2(BoxesRunTime.boxToFloat(BoxesRunTime.unboxToFloat(tuple2._1())), (Tensor) tuple2._2());
        float unboxToFloat = BoxesRunTime.unboxToFloat(tuple22._1());
        Tensor tensor2 = (Tensor) tuple22._2();
        IntRef create = IntRef.create(BoxesRunTime.unboxToInt(state().getOrElse("evalCounter", BoxesRunTime.boxToInteger(0))));
        double d = learningRate / (1 + (create.elem * learningRateDecay));
        create.elem++;
        int nElement = tensor.nElement();
        int parallelNum = nElement / parallelNum();
        int parallelNum2 = nElement % parallelNum();
        if (com$intel$analytics$bigdl$dllib$optim$ParallelAdam$$ones() == null || com$intel$analytics$bigdl$dllib$optim$ParallelAdam$$ones().nElement() < parallelNum + 1) {
            com$intel$analytics$bigdl$dllib$optim$ParallelAdam$$ones_$eq(Tensor$.MODULE$.apply$mFc$sp(this.com$intel$analytics$bigdl$dllib$optim$ParallelAdam$$evidence$1, this.ev$mcF$sp).resize(parallelNum + 1).fill(BoxesRunTime.boxToFloat(this.ev$mcF$sp.one$mcF$sp())));
        }
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), parallelNum()).foreach(new ParallelAdam$mcF$sp$$anonfun$optimize$mcF$sp$1(this));
        ThreadPool m2082default = Engine$.MODULE$.m2082default();
        m2082default.invokeAndWait((IndexedSeq) RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), parallelNum()).map(new ParallelAdam$mcF$sp$$anonfun$3(this, tensor, beta1, beta2, Epsilon, tensor2, create, d, parallelNum, parallelNum2), IndexedSeq$.MODULE$.canBuildFrom()), m2082default.invokeAndWait$default$2());
        state().update("evalCounter", BoxesRunTime.boxToInteger(create.elem));
        return new Tuple2<>(tensor, Array$.MODULE$.apply(Predef$.MODULE$.genericWrapArray(new float[]{unboxToFloat}), this.com$intel$analytics$bigdl$dllib$optim$ParallelAdam$$evidence$1));
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public ParallelAdam$mcF$sp(double d, double d2, double d3, double d4, double d5, int i, ClassTag<Object> classTag, TensorNumericMath.TensorNumeric<Object> tensorNumeric) {
        super(d, d2, d3, d4, d5, i, classTag, tensorNumeric);
        this.ev$mcF$sp = tensorNumeric;
        this.evidence$1 = classTag;
        OptimMethod$mcF$sp.Cclass.$init$(this);
    }
}
