package hivemall.optimizer;

import hivemall.model.IWeightValue;
import hivemall.model.WeightValue;
import hivemall.utils.lang.Primitives;
import hivemall.utils.math.MathUtils;
import java.util.HashMap;
import java.util.Map;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.concurrent.NotThreadSafe;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.lucene.analysis.wikipedia.WikipediaTokenizer;
import org.apache.lucene.util.packed.PackedInts;

/* loaded from: input_file:hivemall/optimizer/Optimizer.class */
public interface Optimizer {

    /* loaded from: input_file:hivemall/optimizer/Optimizer$AdaDelta.class */
    public static abstract class AdaDelta extends OptimizerBase {
        private final float decay;
        private final float eps;
        private final float scale;

        public AdaDelta(@Nonnull Map<String, String> map) {
            super(map);
            this.decay = Primitives.parseFloat(map.get("decay"), 0.95f);
            this.eps = Primitives.parseFloat(map.get("eps"), 1.0E-6f);
            this.scale = Primitives.parseFloat(map.get("scale"), 100.0f);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        public WeightValue.WeightValueParamsF2 newWeightValue(float f) {
            return new WeightValue.WeightValueParamsF2(f, PackedInts.COMPACT, PackedInts.COMPACT);
        }

        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        protected final EtaEstimator getEtaEstimator(@Nonnull Map<String, String> map) {
            if (!map.containsKey("eta")) {
                map.put("eta", "fixed");
            }
            if (!map.containsKey("eta0")) {
                map.put("eta0", "1.0");
            }
            return super.getEtaEstimator(map);
        }

        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        protected float computeDelta(@Nonnull IWeightValue iWeightValue, float f) {
            float sumOfSquaredGradients = iWeightValue.getSumOfSquaredGradients();
            float sumOfSquaredDeltaX = iWeightValue.getSumOfSquaredDeltaX();
            float f2 = (this.decay * sumOfSquaredGradients) + ((1.0f - this.decay) * f * (f / this.scale));
            float sqrt = ((float) Math.sqrt((sumOfSquaredDeltaX + this.eps) / ((f2 * this.scale) + this.eps))) * f;
            float f3 = (this.decay * sumOfSquaredDeltaX) + ((1.0f - this.decay) * sqrt * sqrt);
            iWeightValue.setSumOfSquaredGradients(f2);
            iWeightValue.setSumOfSquaredDeltaX(f3);
            return sqrt;
        }

        @Override // hivemall.optimizer.Optimizer
        public String getOptimizerName() {
            return "adadelta";
        }

        @Override // hivemall.optimizer.Optimizer.OptimizerBase, hivemall.optimizer.Optimizer
        public Map<String, Object> getHyperParameters() {
            Map<String, Object> hyperParameters = super.getHyperParameters();
            hyperParameters.put("decay", Float.valueOf(this.decay));
            hyperParameters.put("eps", Float.valueOf(this.eps));
            hyperParameters.put("scale", Float.valueOf(this.scale));
            return hyperParameters;
        }
    }

    /* loaded from: input_file:hivemall/optimizer/Optimizer$AdaGrad.class */
    public static abstract class AdaGrad extends OptimizerBase {
        private final float eps;
        private final float scale;

        public AdaGrad(@Nonnull Map<String, String> map) {
            super(map);
            this.eps = Primitives.parseFloat(map.get("eps"), 1.0f);
            this.scale = Primitives.parseFloat(map.get("scale"), 100.0f);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        public WeightValue.WeightValueParamsF1 newWeightValue(float f) {
            return new WeightValue.WeightValueParamsF1(f, PackedInts.COMPACT);
        }

        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        protected float computeDelta(@Nonnull IWeightValue iWeightValue, float f) {
            float sumOfSquaredGradients = iWeightValue.getSumOfSquaredGradients();
            iWeightValue.setSumOfSquaredGradients(sumOfSquaredGradients + (f * (f / this.scale)));
            return (float) (f / Math.sqrt(this.eps + (sumOfSquaredGradients * this.scale)));
        }

        @Override // hivemall.optimizer.Optimizer
        public String getOptimizerName() {
            return "adagrad";
        }

        @Override // hivemall.optimizer.Optimizer.OptimizerBase, hivemall.optimizer.Optimizer
        public Map<String, Object> getHyperParameters() {
            Map<String, Object> hyperParameters = super.getHyperParameters();
            hyperParameters.put("eps", Float.valueOf(this.eps));
            hyperParameters.put("scale", Float.valueOf(this.scale));
            return hyperParameters;
        }
    }

    /* loaded from: input_file:hivemall/optimizer/Optimizer$AdagradRDA.class */
    public static abstract class AdagradRDA extends OptimizerBase {

        @Nonnull
        private final AdaGrad optimizerImpl;
        private final float lambda;

        public AdagradRDA(@Nonnull AdaGrad adaGrad, @Nonnull Map<String, String> map) {
            super(map);
            this.optimizerImpl = adaGrad;
            this.lambda = Primitives.parseFloat(map.get("lambda"), 1.0E-6f);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        public WeightValue.WeightValueParamsF2 newWeightValue(float f) {
            return new WeightValue.WeightValueParamsF2(f, PackedInts.COMPACT, PackedInts.COMPACT);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        public float update(@Nonnull IWeightValue iWeightValue, float f) {
            float sumOfGradients = iWeightValue.getSumOfGradients() + f;
            float f2 = sumOfGradients > PackedInts.COMPACT ? 1.0f : -1.0f;
            float f3 = ((f2 * sumOfGradients) / ((float) this._numStep)) - this.lambda;
            if (f3 < PackedInts.COMPACT) {
                iWeightValue.set(PackedInts.COMPACT);
                iWeightValue.setSumOfSquaredGradients(PackedInts.COMPACT);
                iWeightValue.setSumOfGradients(PackedInts.COMPACT);
                return PackedInts.COMPACT;
            }
            float eta = (-1.0f) * f2 * this._eta.eta(this._numStep) * ((float) this._numStep) * this.optimizerImpl.computeDelta(iWeightValue, f3);
            iWeightValue.set(eta);
            iWeightValue.setSumOfGradients(sumOfGradients);
            return eta;
        }

        @Override // hivemall.optimizer.Optimizer
        public String getOptimizerName() {
            return "adagrad_rda";
        }

        @Override // hivemall.optimizer.Optimizer.OptimizerBase, hivemall.optimizer.Optimizer
        public Map<String, Object> getHyperParameters() {
            Map<String, Object> hyperParameters = this.optimizerImpl.getHyperParameters();
            hyperParameters.put("optimizer", getOptimizerName());
            hyperParameters.put("lambda", Float.valueOf(this.lambda));
            return hyperParameters;
        }
    }

    /* loaded from: input_file:hivemall/optimizer/Optimizer$Adam.class */
    public static abstract class Adam extends OptimizerBase {
        protected float alpha;
        protected final float beta1;
        protected final float beta2;
        protected final float eps;
        protected final float decay;
        protected final boolean amsgrad;
        protected float max_vhat;

        public Adam(@Nonnull Map<String, String> map) {
            super(map);
            this.max_vhat = Float.MIN_VALUE;
            this.alpha = Primitives.parseFloat(map.get("alpha"), 1.0f);
            this.beta1 = Primitives.parseFloat(map.get("beta1"), 0.9f);
            this.beta2 = Primitives.parseFloat(map.get("beta2"), 0.999f);
            this.eps = Primitives.parseFloat(map.get("eps"), 1.0E-8f);
            this.decay = Primitives.parseFloat(map.get("decay"), PackedInts.COMPACT);
            this.amsgrad = map.containsKey("amsgrad");
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        public WeightValue.WeightValueParamsF2 newWeightValue(float f) {
            return new WeightValue.WeightValueParamsF2(f, PackedInts.COMPACT, PackedInts.COMPACT);
        }

        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        protected float eta(long j) {
            double pow = 1.0d - Math.pow(this.beta1, j);
            double pow2 = 1.0d - Math.pow(this.beta2, j);
            return (float) (this._eta.eta(j) * (Math.sqrt(pow2) / pow));
        }

        protected double alpha() {
            return this.alpha * (Math.sqrt(1.0d - Math.pow(this.beta2, this._numStep)) / (1.0d - Math.pow(this.beta1, this._numStep)));
        }

        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        protected float computeDelta(@Nonnull IWeightValue iWeightValue, float f) {
            if (this.decay != PackedInts.COMPACT) {
                f += this.decay * iWeightValue.get();
            }
            float m = (this.beta1 * iWeightValue.getM()) + ((1.0f - this.beta1) * f);
            float v = (this.beta2 * iWeightValue.getV()) + ((float) ((1.0f - this.beta2) * MathUtils.square(f)));
            float f2 = v;
            if (this.amsgrad) {
                if (f2 > this.max_vhat) {
                    this.max_vhat = f2;
                } else {
                    f2 = this.max_vhat;
                }
            }
            float alpha = (float) (alpha() * (m / (Math.sqrt(f2) + this.eps)));
            if (this.decay != PackedInts.COMPACT) {
                alpha += this.decay * iWeightValue.get();
            }
            iWeightValue.setM(m);
            iWeightValue.setV(v);
            return alpha;
        }

        @Override // hivemall.optimizer.Optimizer
        public String getOptimizerName() {
            return this.amsgrad ? "adam-amsgrad" : "adam";
        }

        @Override // hivemall.optimizer.Optimizer.OptimizerBase, hivemall.optimizer.Optimizer
        public Map<String, Object> getHyperParameters() {
            Map<String, Object> hyperParameters = super.getHyperParameters();
            hyperParameters.put("alpha", Float.valueOf(this.alpha));
            hyperParameters.put("beta1", Float.valueOf(this.beta1));
            hyperParameters.put("beta2", Float.valueOf(this.beta2));
            hyperParameters.put("eps", Float.valueOf(this.eps));
            hyperParameters.put("decay", Float.valueOf(this.decay));
            return hyperParameters;
        }
    }

    /* loaded from: input_file:hivemall/optimizer/Optimizer$AdamHD.class */
    public static abstract class AdamHD extends Adam {
        private final float beta;
        protected double deltaU;

        public AdamHD(@Nonnull Map<String, String> map) {
            super(map);
            this.deltaU = CMAESOptimizer.DEFAULT_STOPFITNESS;
            this.alpha = Primitives.parseFloat(map.get("alpha"), 0.02f);
            this.beta = Primitives.parseFloat(map.get("beta"), 1.0E-6f);
        }

        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        protected final EtaEstimator getEtaEstimator(@Nonnull Map<String, String> map) {
            if (!map.containsKey("eta")) {
                map.put("eta", "fixed");
            }
            if (!map.containsKey("eta0")) {
                map.put("eta0", "1.0");
            }
            return super.getEtaEstimator(map);
        }

        private float alpha(float f, double d) {
            double d2 = f * d;
            if (d2 > CMAESOptimizer.DEFAULT_STOPFITNESS) {
                this.alpha *= 1.0f - this.beta;
            } else if (d2 < CMAESOptimizer.DEFAULT_STOPFITNESS) {
                this.alpha *= 1.0f + this.beta;
            }
            return this.alpha;
        }

        @Override // hivemall.optimizer.Optimizer.Adam, hivemall.optimizer.Optimizer.OptimizerBase
        protected float computeDelta(@Nonnull IWeightValue iWeightValue, float f) {
            if (this.decay != PackedInts.COMPACT) {
                f += this.decay * iWeightValue.get();
            }
            float m = (this.beta1 * iWeightValue.getM()) + ((1.0f - this.beta1) * f);
            float v = (this.beta2 * iWeightValue.getV()) + ((float) ((1.0f - this.beta2) * MathUtils.square(f)));
            double pow = m / (1.0d - Math.pow(this.beta1, this._numStep));
            double pow2 = v / (1.0d - Math.pow(this.beta2, this._numStep));
            float alpha = alpha(f, this.deltaU);
            double sqrt = pow / (Math.sqrt(pow2) + this.eps);
            float f2 = (float) (alpha * sqrt);
            this.deltaU = sqrt;
            if (this.decay != PackedInts.COMPACT) {
                f2 += this.decay * iWeightValue.get();
            }
            iWeightValue.setM(m);
            iWeightValue.setV(v);
            return f2;
        }

        @Override // hivemall.optimizer.Optimizer.Adam, hivemall.optimizer.Optimizer
        public String getOptimizerName() {
            return "adam_hd";
        }

        @Override // hivemall.optimizer.Optimizer.Adam, hivemall.optimizer.Optimizer.OptimizerBase, hivemall.optimizer.Optimizer
        public Map<String, Object> getHyperParameters() {
            Map<String, Object> hyperParameters = super.getHyperParameters();
            hyperParameters.put("beta", Float.valueOf(this.beta));
            return hyperParameters;
        }
    }

    /* loaded from: input_file:hivemall/optimizer/Optimizer$Eve.class */
    public static abstract class Eve extends Adam {
        protected final float beta3;
        private float c;
        private float inv_c;
        private float currLoss;
        private float prevLoss;
        private double prevDt;

        public Eve(@Nonnull Map<String, String> map) {
            super(map);
            this.c = 10.0f;
            this.inv_c = 0.1f;
            this.prevLoss = PackedInts.COMPACT;
            this.prevDt = 1.0d;
            this.beta3 = Primitives.parseFloat(map.get("beta3"), 0.999f);
            this.c = Primitives.parseFloat(map.get(WikipediaTokenizer.CATEGORY), 10.0f);
            this.inv_c = 1.0f / this.c;
        }

        @Override // hivemall.optimizer.Optimizer.Adam
        protected double alpha() {
            double sqrt = this.alpha * (Math.sqrt(1.0d - Math.pow(this.beta2, this._numStep)) / (1.0d - Math.pow(this.beta1, this._numStep)));
            if (this._numStep > 1 && this.currLoss != this.prevLoss) {
                double clip = (this.beta3 * this.prevDt) + ((1.0d - this.beta3) * MathUtils.clip(Math.abs(this.currLoss - this.prevLoss) / Math.min(this.currLoss, this.prevLoss), this.inv_c, this.c));
                this.prevDt = clip;
                sqrt /= clip;
            }
            return sqrt;
        }

        @Override // hivemall.optimizer.Optimizer.OptimizerBase, hivemall.optimizer.Optimizer
        public float update(Object obj, float f, float f2, float f3) {
            this.currLoss = f2;
            float update = update(obj, f, f3);
            this.prevLoss = f2;
            return update;
        }

        @Override // hivemall.optimizer.Optimizer.Adam, hivemall.optimizer.Optimizer
        public String getOptimizerName() {
            return "eve";
        }

        @Override // hivemall.optimizer.Optimizer.Adam, hivemall.optimizer.Optimizer.OptimizerBase, hivemall.optimizer.Optimizer
        public Map<String, Object> getHyperParameters() {
            Map<String, Object> hyperParameters = super.getHyperParameters();
            hyperParameters.put("beta3", Float.valueOf(this.beta3));
            hyperParameters.put(WikipediaTokenizer.CATEGORY, Float.valueOf(this.c));
            return hyperParameters;
        }
    }

    /* loaded from: input_file:hivemall/optimizer/Optimizer$Momentum.class */
    public static abstract class Momentum extends OptimizerBase {

        @Nonnull
        private final WeightValue.WeightValueParamsF1 weightValueReused;
        private final boolean nesterov;
        private final float alpha;
        private final float momentum;

        public Momentum(@Nonnull Map<String, String> map) {
            super(map);
            this.weightValueReused = newWeightValue(PackedInts.COMPACT);
            this.nesterov = map.containsKey("nesterov");
            this.alpha = Primitives.parseFloat(map.get("alpha"), 1.0f);
            this.momentum = Primitives.parseFloat(map.get("momentum"), 0.9f);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        public WeightValue.WeightValueParamsF1 newWeightValue(float f) {
            return new WeightValue.WeightValueParamsF1(f, PackedInts.COMPACT);
        }

        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        protected float computeDelta(@Nonnull IWeightValue iWeightValue, float f) {
            float delta = (this.momentum * iWeightValue.getDelta()) + (this.alpha * f);
            iWeightValue.setDelta(delta);
            return this.nesterov ? (this.momentum * this.momentum * delta) + ((1.0f + this.momentum) * this.alpha * f) : delta;
        }

        @Override // hivemall.optimizer.Optimizer
        public String getOptimizerName() {
            return this.nesterov ? "nesterov" : "momentum";
        }

        @Override // hivemall.optimizer.Optimizer.OptimizerBase, hivemall.optimizer.Optimizer
        public Map<String, Object> getHyperParameters() {
            Map<String, Object> hyperParameters = super.getHyperParameters();
            hyperParameters.put("nesterov", Boolean.valueOf(this.nesterov));
            hyperParameters.put("alpha", Float.valueOf(this.alpha));
            hyperParameters.put("momentum", Float.valueOf(this.momentum));
            return hyperParameters;
        }
    }

    /* loaded from: input_file:hivemall/optimizer/Optimizer$Nadam.class */
    public static abstract class Nadam extends OptimizerBase {
        protected float alpha;
        protected final float beta1;
        protected final float beta2;
        protected final float eps;
        protected final float decay;
        protected final float scheduleDecay;
        protected double mu_t;
        protected double mu_t_1;
        protected double mu_product;
        protected double mu_product_next;

        public Nadam(@Nonnull Map<String, String> map) {
            super(map);
            this.mu_product = 1.0d;
            this.mu_product_next = 1.0d;
            this.alpha = Primitives.parseFloat(map.get("alpha"), 1.0f);
            this.beta1 = Primitives.parseFloat(map.get("beta1"), 0.9f);
            this.beta2 = Primitives.parseFloat(map.get("beta2"), 0.999f);
            this.eps = Primitives.parseFloat(map.get("eps"), 1.0E-8f);
            this.decay = Primitives.parseFloat(map.get("decay"), PackedInts.COMPACT);
            this.scheduleDecay = Primitives.parseFloat(map.get("scheduleDecay"), 0.004f);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        public WeightValue.WeightValueParamsF2 newWeightValue(float f) {
            return new WeightValue.WeightValueParamsF2(f, PackedInts.COMPACT, PackedInts.COMPACT);
        }

        @Override // hivemall.optimizer.Optimizer.OptimizerBase, hivemall.optimizer.Optimizer
        public void proceedStep() {
            long j = this._numStep + 1;
            this._numStep = j;
            double d = this.mu_product;
            double pow = this.beta1 * (1.0d - (0.5d * Math.pow(0.96d, Math.floor(((float) j) * this.scheduleDecay) + 1.0d)));
            double pow2 = this.beta1 * (1.0d - (0.5d * Math.pow(0.96d, Math.floor((j + 1.0d) * this.scheduleDecay) + 1.0d)));
            this.mu_t = pow;
            this.mu_t_1 = pow2;
            this.mu_product = d * pow;
            this.mu_product_next = d * pow * pow2;
        }

        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        protected float eta(long j) {
            double pow = 1.0d - Math.pow(this.beta1, j);
            double pow2 = 1.0d - Math.pow(this.beta2, j);
            return (float) (this._eta.eta(j) * (Math.sqrt(pow2) / pow));
        }

        protected double alpha() {
            return this.alpha * (Math.sqrt(1.0d - Math.pow(this.beta2, this._numStep)) / (1.0d - Math.pow(this.beta1, this._numStep)));
        }

        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        protected float computeDelta(@Nonnull IWeightValue iWeightValue, float f) {
            if (this.decay != PackedInts.COMPACT) {
                f += this.decay * iWeightValue.get();
            }
            float m = (this.beta1 * iWeightValue.getM()) + ((1.0f - this.beta1) * f);
            double d = m / (1.0d - this.mu_product_next);
            float v = (this.beta2 * iWeightValue.getV()) + ((float) ((1.0d - this.beta2) * MathUtils.square(f)));
            float alpha = (float) (alpha() * ((((1.0d - this.mu_t) * (f / (1.0d - this.mu_product))) + (this.mu_t_1 * d)) / (Math.sqrt(v / (1.0d - Math.pow(this.beta2, this._numStep))) + this.eps)));
            if (this.decay != CMAESOptimizer.DEFAULT_STOPFITNESS) {
                alpha += this.decay * iWeightValue.get();
            }
            iWeightValue.setM(m);
            iWeightValue.setV(v);
            return alpha;
        }

        @Override // hivemall.optimizer.Optimizer
        public String getOptimizerName() {
            return "nadam";
        }

        @Override // hivemall.optimizer.Optimizer.OptimizerBase, hivemall.optimizer.Optimizer
        public Map<String, Object> getHyperParameters() {
            Map<String, Object> hyperParameters = super.getHyperParameters();
            hyperParameters.put("alpha", Float.valueOf(this.alpha));
            hyperParameters.put("beta1", Float.valueOf(this.beta1));
            hyperParameters.put("beta2", Float.valueOf(this.beta2));
            hyperParameters.put("eps", Float.valueOf(this.eps));
            hyperParameters.put("decay", Float.valueOf(this.decay));
            hyperParameters.put("scheduleDecay", Float.valueOf(this.scheduleDecay));
            return hyperParameters;
        }
    }

    @NotThreadSafe
    /* loaded from: input_file:hivemall/optimizer/Optimizer$OptimizerBase.class */
    public static abstract class OptimizerBase implements Optimizer {

        @Nonnull
        protected final EtaEstimator _eta;

        @Nonnull
        protected final Regularization _reg;

        @Nonnegative
        protected long _numStep = 0;

        public OptimizerBase(@Nonnull Map<String, String> map) {
            this._eta = getEtaEstimator(map);
            this._reg = Regularization.get(map);
        }

        @Nonnull
        protected abstract IWeightValue newWeightValue(float f);

        @Nonnull
        protected EtaEstimator getEtaEstimator(@Nonnull Map<String, String> map) {
            return EtaEstimator.get(map);
        }

        @Override // hivemall.optimizer.Optimizer
        public void proceedStep() {
            this._numStep++;
        }

        @Override // hivemall.optimizer.Optimizer
        public float update(@Nonnull Object obj, float f, float f2, float f3) {
            return update(obj, f, f3);
        }

        protected abstract float update(@Nonnull Object obj, float f, float f2);

        /* JADX INFO: Access modifiers changed from: protected */
        public float update(@Nonnull IWeightValue iWeightValue, float f) {
            float f2 = iWeightValue.get();
            float eta = f2 - (eta(this._numStep) * this._reg.regularize(f2, computeDelta(iWeightValue, f)));
            iWeightValue.set(eta);
            return eta;
        }

        protected float eta(long j) {
            return this._eta.eta(this._numStep);
        }

        protected float computeDelta(@Nonnull IWeightValue iWeightValue, float f) {
            return f;
        }

        @Override // hivemall.optimizer.Optimizer
        public Map<String, Object> getHyperParameters() {
            HashMap hashMap = new HashMap();
            hashMap.put("optimizer", getOptimizerName());
            this._eta.getHyperParameters(hashMap);
            this._reg.getHyperParameters(hashMap);
            return hashMap;
        }
    }

    /* loaded from: input_file:hivemall/optimizer/Optimizer$RMSprop.class */
    public static abstract class RMSprop extends OptimizerBase {
        private final float decay;
        private final float eps;
        private final float scale;

        public RMSprop(@Nonnull Map<String, String> map) {
            super(map);
            this.decay = Primitives.parseFloat(map.get("decay"), 0.95f);
            this.eps = Primitives.parseFloat(map.get("eps"), 1.0f);
            this.scale = Primitives.parseFloat(map.get("scale"), 100.0f);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        public WeightValue.WeightValueParamsF1 newWeightValue(float f) {
            return new WeightValue.WeightValueParamsF1(f, PackedInts.COMPACT);
        }

        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        protected float computeDelta(@Nonnull IWeightValue iWeightValue, float f) {
            float sumOfSquaredGradients = iWeightValue.getSumOfSquaredGradients();
            iWeightValue.setSumOfSquaredGradients((this.decay * sumOfSquaredGradients) + ((1.0f - this.decay) * f * (f / this.scale)));
            return (float) (f / Math.sqrt(this.eps + (sumOfSquaredGradients * this.scale)));
        }

        @Override // hivemall.optimizer.Optimizer
        public String getOptimizerName() {
            return "rmsprop";
        }

        @Override // hivemall.optimizer.Optimizer.OptimizerBase, hivemall.optimizer.Optimizer
        public Map<String, Object> getHyperParameters() {
            Map<String, Object> hyperParameters = super.getHyperParameters();
            hyperParameters.put("decay", Float.valueOf(this.decay));
            hyperParameters.put("eps", Float.valueOf(this.eps));
            hyperParameters.put("scale", Float.valueOf(this.scale));
            return hyperParameters;
        }
    }

    /* loaded from: input_file:hivemall/optimizer/Optimizer$RMSpropGraves.class */
    public static abstract class RMSpropGraves extends OptimizerBase {
        private final float decay;
        private final float alpha;
        private final float momentum;
        private final float eps;
        private final float scale;

        public RMSpropGraves(@Nonnull Map<String, String> map) {
            super(map);
            this.decay = Primitives.parseFloat(map.get("decay"), 0.95f);
            this.alpha = Primitives.parseFloat(map.get("alpha"), 1.0f);
            this.momentum = Primitives.parseFloat(map.get("momentum"), 0.9f);
            this.eps = Primitives.parseFloat(map.get("eps"), 1.0f);
            this.scale = Primitives.parseFloat(map.get("scale"), 100.0f);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        public WeightValue.WeightValueParamsF3 newWeightValue(float f) {
            return new WeightValue.WeightValueParamsF3(f, PackedInts.COMPACT, PackedInts.COMPACT, PackedInts.COMPACT);
        }

        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        protected float computeDelta(@Nonnull IWeightValue iWeightValue, float f) {
            float sumOfSquaredGradients = iWeightValue.getSumOfSquaredGradients();
            iWeightValue.setSumOfSquaredGradients((this.decay * sumOfSquaredGradients) + ((1.0f - this.decay) * f * (f / this.scale)));
            float sumOfGradients = (this.decay * iWeightValue.getSumOfGradients()) + (((1.0f - this.decay) * f) / this.scale);
            iWeightValue.setSumOfGradients(sumOfGradients);
            double d = sumOfSquaredGradients * this.scale;
            double d2 = sumOfGradients * this.scale;
            float delta = (this.momentum * iWeightValue.getDelta()) + (this.alpha * ((float) (f / Math.sqrt((d - (d2 * d2)) + this.eps))));
            iWeightValue.setDelta(delta);
            return delta;
        }

        @Override // hivemall.optimizer.Optimizer
        public String getOptimizerName() {
            return "rmsprop_graves";
        }

        @Override // hivemall.optimizer.Optimizer.OptimizerBase, hivemall.optimizer.Optimizer
        public Map<String, Object> getHyperParameters() {
            Map<String, Object> hyperParameters = super.getHyperParameters();
            hyperParameters.put("decay", Float.valueOf(this.decay));
            hyperParameters.put("alpha", Float.valueOf(this.alpha));
            hyperParameters.put("momentum", Float.valueOf(this.momentum));
            hyperParameters.put("eps", Float.valueOf(this.eps));
            return hyperParameters;
        }
    }

    /* loaded from: input_file:hivemall/optimizer/Optimizer$SGD.class */
    public static final class SGD extends OptimizerBase {
        private final IWeightValue weightValueReused;

        public SGD(@Nonnull Map<String, String> map) {
            super(map);
            this.weightValueReused = newWeightValue(PackedInts.COMPACT);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        public WeightValue newWeightValue(float f) {
            return new WeightValue(f);
        }

        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        protected float update(@Nonnull Object obj, float f, float f2) {
            this.weightValueReused.set(f);
            update(this.weightValueReused, f2);
            return this.weightValueReused.get();
        }

        @Override // hivemall.optimizer.Optimizer
        public String getOptimizerName() {
            return "sgd";
        }
    }

    float update(@Nonnull Object obj, float f, float f2, float f3);

    void proceedStep();

    @Nonnull
    String getOptimizerName();

    @Nonnull
    Map<String, Object> getHyperParameters();
}
