package com.intel.analytics.bigdl.orca.net;

import com.intel.analytics.bigdl.dllib.keras.models.InternalOptimizerUtil$;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.orca.net.TorchOptim;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.StringBuilder;
import scala.reflect.ClassTag;
import scala.runtime.BoxesRunTime;

/* compiled from: TorchOptim.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/orca/net/TorchOptim$.class */
public final class TorchOptim$ implements Serializable {
    public static final TorchOptim$ MODULE$ = null;

    static {
        new TorchOptim$();
    }

    public TorchOptim.DecayType getDecayType(String str) {
        TorchOptim.DecayType decayType;
        String lowerCase = str.toLowerCase();
        if ("epochdecay".equals(lowerCase)) {
            decayType = TorchOptim$EpochDecay$.MODULE$;
        } else if ("iterationdecay".equals(lowerCase)) {
            decayType = TorchOptim$IterationDecay$.MODULE$;
        } else {
            if (!"epochdecaybyscore".equals(lowerCase)) {
                throw new IllegalArgumentException(new StringBuilder().append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"unknow decay type: ", ", expected:"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{str}))).append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"EpochDecay, IterationDecay, EpochDecayByScore"})).s(Nil$.MODULE$)).toString());
            }
            decayType = TorchOptim$EpochDecayByScore$.MODULE$;
        }
        return decayType;
    }

    public <T> TorchOptim<T> apply(byte[] bArr, String str, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        return apply(bArr, getDecayType(str), classTag, tensorNumeric);
    }

    public <T> TorchOptim<T> apply(byte[] bArr, TorchOptim.DecayType decayType, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        return new TorchOptim<>(bArr, decayType, classTag, tensorNumeric);
    }

    public <T> int getEpoch(TorchOptim<T> torchOptim, ClassTag<T> classTag) {
        return BoxesRunTime.unboxToInt(InternalOptimizerUtil$.MODULE$.getStateFromOptiMethod(torchOptim).apply("epoch")) - 1;
    }

    public <T> float getScore(TorchOptim<T> torchOptim, ClassTag<T> classTag) {
        return BoxesRunTime.unboxToFloat(InternalOptimizerUtil$.MODULE$.getStateFromOptiMethod(torchOptim).apply("score"));
    }

    private Object readResolve() {
        return MODULE$;
    }

    private TorchOptim$() {
        MODULE$ = this;
    }
}
