package com.intel.analytics.bigdl.orca.net;

import com.intel.analytics.bigdl.dllib.optim.OptimMethod;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath$TensorNumeric$NumericFloat$;
import com.intel.analytics.bigdl.dllib.utils.Table;
import com.intel.analytics.bigdl.orca.net.TorchOptim;
import com.intel.analytics.bigdl.orca.utils.PythonInterpreter$;
import jep.NDArray;
import scala.Array$;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.StringContext;
import scala.Tuple2;
import scala.reflect.ClassTag;
import scala.runtime.BoxesRunTime;

/* compiled from: TorchOptim.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/orca/net/TorchOptim$mcD$sp.class */
public class TorchOptim$mcD$sp extends TorchOptim<Object> implements OptimMethod.mcD.sp {
    public final TensorNumericMath.TensorNumeric<Object> ev$mcD$sp;
    private final byte[] torchOptim;
    private final TorchOptim.DecayType decayType;
    private final ClassTag<Object> evidence$1;

    @Override // com.intel.analytics.bigdl.orca.net.TorchOptim
    /* renamed from: clone */
    public OptimMethod<Object> mo120clone() {
        return OptimMethod.mcD.sp.class.clone(this);
    }

    @Override // com.intel.analytics.bigdl.orca.net.TorchOptim
    public OptimMethod<Object> clone$mcD$sp() {
        return OptimMethod.mcD.sp.class.clone$mcD$sp(this);
    }

    @Override // com.intel.analytics.bigdl.orca.net.TorchOptim
    public Tuple2<Tensor<Object>, Object> optimize(Function1<Tensor<Object>, Tuple2<Object, Tensor<Object>>> function1, Tensor<Object> tensor, Table table, Table table2) {
        return OptimMethod.mcD.sp.class.optimize(this, function1, tensor, table, table2);
    }

    @Override // com.intel.analytics.bigdl.orca.net.TorchOptim
    public Tuple2<Tensor<Object>, double[]> optimize$mcD$sp(Function1<Tensor<Object>, Tuple2<Object, Tensor<Object>>> function1, Tensor<Object> tensor, Table table, Table table2) {
        return OptimMethod.mcD.sp.class.optimize$mcD$sp(this, function1, tensor, table, table2);
    }

    @Override // com.intel.analytics.bigdl.orca.net.TorchOptim
    public Tuple2<Tensor<Object>, Object> optimize(Function1<Tensor<Object>, Tuple2<Object, Tensor<Object>>> function1, Tensor<Object> tensor) {
        return optimize$mcD$sp(function1, tensor);
    }

    @Override // com.intel.analytics.bigdl.orca.net.TorchOptim
    public Tuple2<Tensor<Object>, double[]> optimize$mcD$sp(Function1<Tensor<Object>, Tuple2<Object, Tensor<Object>>> function1, Tensor<Object> tensor) {
        optimType();
        int epoch = TorchOptim$.MODULE$.getEpoch(this, this.com$intel$analytics$bigdl$orca$net$TorchOptim$$evidence$1);
        Tuple2 tuple2 = (Tuple2) function1.apply(tensor);
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        Tuple2 tuple22 = new Tuple2(BoxesRunTime.boxToDouble(tuple2._1$mcD$sp()), (Tensor) tuple2._2());
        double _1$mcD$sp = tuple22._1$mcD$sp();
        Tensor tensor2 = (Tensor) tuple22._2();
        if (init()) {
            updateHyperParameter();
        } else {
            lastEpoch_$eq(epoch);
            PythonInterpreter$.MODULE$.set(weightName(), new NDArray(tensor.toTensor(TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$).storage().array()));
            PythonInterpreter$.MODULE$.exec(initCode());
            init_$eq(true);
        }
        PythonInterpreter$.MODULE$.set(gradientName(), new NDArray(tensor2.toTensor(TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$).storage().array()));
        PythonInterpreter$.MODULE$.exec(stepCode());
        tensor.copy(PythonFeatureSet$.MODULE$.ndArrayToTensor((NDArray) PythonInterpreter$.MODULE$.getValue(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", ".data.numpy()"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{weightName()})))).toTensor(this.ev$mcD$sp));
        return new Tuple2<>(tensor, Array$.MODULE$.apply(Predef$.MODULE$.genericWrapArray(new double[]{_1$mcD$sp}), this.com$intel$analytics$bigdl$orca$net$TorchOptim$$evidence$1));
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public TorchOptim$mcD$sp(byte[] bArr, TorchOptim.DecayType decayType, ClassTag<Object> classTag, TensorNumericMath.TensorNumeric<Object> tensorNumeric) {
        super(bArr, decayType, classTag, tensorNumeric);
        this.ev$mcD$sp = tensorNumeric;
        this.torchOptim = bArr;
        this.decayType = decayType;
        this.evidence$1 = classTag;
        OptimMethod.mcD.sp.class.$init$(this);
    }
}
