package hivemall.optimizer;

import hivemall.model.IWeightValue;
import hivemall.model.WeightValue;
import hivemall.optimizer.Optimizer;
import it.unimi.dsi.fastutil.objects.Object2ObjectMap;
import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap;
import java.util.Map;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.concurrent.NotThreadSafe;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.lucene.util.packed.PackedInts;

/* loaded from: input_file:hivemall/optimizer/SparseOptimizerFactory.class */
public final class SparseOptimizerFactory {
    private static final Log LOG = LogFactory.getLog(SparseOptimizerFactory.class);

    /* JADX INFO: Access modifiers changed from: package-private */
    @NotThreadSafe
    /* loaded from: input_file:hivemall/optimizer/SparseOptimizerFactory$AdaDelta.class */
    public static final class AdaDelta extends Optimizer.AdaDelta {

        @Nonnull
        private final Object2ObjectMap<Object, IWeightValue> auxWeights;

        public AdaDelta(@Nonnegative int i, @Nonnull Map<String, String> map) {
            super(map);
            this.auxWeights = new Object2ObjectOpenHashMap(i);
        }

        /* JADX WARN: Multi-variable type inference failed */
        /* JADX WARN: Type inference failed for: r0v3, types: [hivemall.model.IWeightValue] */
        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        protected float update(@Nonnull Object obj, float f, float f2) {
            WeightValue.WeightValueParamsF2 weightValueParamsF2 = this.auxWeights.get(obj);
            if (weightValueParamsF2 == null) {
                weightValueParamsF2 = newWeightValue(f);
                this.auxWeights.put(obj, weightValueParamsF2);
            } else {
                weightValueParamsF2.set(f);
            }
            return update(weightValueParamsF2, f2);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @NotThreadSafe
    /* loaded from: input_file:hivemall/optimizer/SparseOptimizerFactory$AdaGrad.class */
    public static final class AdaGrad extends Optimizer.AdaGrad {

        @Nonnull
        private final Object2ObjectMap<Object, IWeightValue> auxWeights;

        public AdaGrad(@Nonnegative int i, @Nonnull Map<String, String> map) {
            super(map);
            this.auxWeights = new Object2ObjectOpenHashMap(i);
        }

        /* JADX WARN: Multi-variable type inference failed */
        /* JADX WARN: Type inference failed for: r0v3, types: [hivemall.model.IWeightValue] */
        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        protected float update(@Nonnull Object obj, float f, float f2) {
            WeightValue.WeightValueParamsF1 weightValueParamsF1 = this.auxWeights.get(obj);
            if (weightValueParamsF1 == null) {
                weightValueParamsF1 = newWeightValue(f);
                this.auxWeights.put(obj, weightValueParamsF1);
            } else {
                weightValueParamsF1.set(f);
            }
            return update(weightValueParamsF1, f2);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @NotThreadSafe
    /* loaded from: input_file:hivemall/optimizer/SparseOptimizerFactory$AdagradRDA.class */
    public static final class AdagradRDA extends Optimizer.AdagradRDA {

        @Nonnull
        private final Object2ObjectMap<Object, IWeightValue> auxWeights;

        public AdagradRDA(@Nonnegative int i, @Nonnull Optimizer.AdaGrad adaGrad, @Nonnull Map<String, String> map) {
            super(adaGrad, map);
            this.auxWeights = new Object2ObjectOpenHashMap(i);
        }

        /* JADX WARN: Multi-variable type inference failed */
        /* JADX WARN: Type inference failed for: r0v3, types: [hivemall.model.IWeightValue] */
        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        protected float update(@Nonnull Object obj, float f, float f2) {
            WeightValue.WeightValueParamsF2 weightValueParamsF2 = this.auxWeights.get(obj);
            if (weightValueParamsF2 == null) {
                weightValueParamsF2 = newWeightValue(f);
                this.auxWeights.put(obj, weightValueParamsF2);
            } else {
                weightValueParamsF2.set(f);
            }
            float update = update(weightValueParamsF2, f2);
            if (update == PackedInts.COMPACT) {
                this.auxWeights.remove(obj);
            }
            return update;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @NotThreadSafe
    /* loaded from: input_file:hivemall/optimizer/SparseOptimizerFactory$Adam.class */
    public static final class Adam extends Optimizer.Adam {

        @Nonnull
        private final Object2ObjectMap<Object, IWeightValue> auxWeights;

        public Adam(@Nonnegative int i, @Nonnull Map<String, String> map) {
            super(map);
            this.auxWeights = new Object2ObjectOpenHashMap(i);
        }

        /* JADX WARN: Multi-variable type inference failed */
        /* JADX WARN: Type inference failed for: r0v3, types: [hivemall.model.IWeightValue] */
        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        protected float update(@Nonnull Object obj, float f, float f2) {
            WeightValue.WeightValueParamsF2 weightValueParamsF2 = this.auxWeights.get(obj);
            if (weightValueParamsF2 == null) {
                weightValueParamsF2 = newWeightValue(f);
                this.auxWeights.put(obj, weightValueParamsF2);
            } else {
                weightValueParamsF2.set(f);
            }
            return update(weightValueParamsF2, f2);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @NotThreadSafe
    /* loaded from: input_file:hivemall/optimizer/SparseOptimizerFactory$AdamHD.class */
    public static final class AdamHD extends Optimizer.AdamHD {

        @Nonnull
        private final Object2ObjectMap<Object, IWeightValue> auxWeights;

        public AdamHD(@Nonnegative int i, @Nonnull Map<String, String> map) {
            super(map);
            this.auxWeights = new Object2ObjectOpenHashMap(i);
        }

        /* JADX WARN: Multi-variable type inference failed */
        /* JADX WARN: Type inference failed for: r0v3, types: [hivemall.model.IWeightValue] */
        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        protected float update(@Nonnull Object obj, float f, float f2) {
            WeightValue.WeightValueParamsF2 weightValueParamsF2 = this.auxWeights.get(obj);
            if (weightValueParamsF2 == null) {
                weightValueParamsF2 = newWeightValue(f);
                this.auxWeights.put(obj, weightValueParamsF2);
            } else {
                weightValueParamsF2.set(f);
            }
            return update(weightValueParamsF2, f2);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @NotThreadSafe
    /* loaded from: input_file:hivemall/optimizer/SparseOptimizerFactory$Eve.class */
    public static final class Eve extends Optimizer.Eve {

        @Nonnull
        private final Object2ObjectMap<Object, IWeightValue> auxWeights;

        public Eve(@Nonnegative int i, @Nonnull Map<String, String> map) {
            super(map);
            this.auxWeights = new Object2ObjectOpenHashMap(i);
        }

        /* JADX WARN: Multi-variable type inference failed */
        /* JADX WARN: Type inference failed for: r0v3, types: [hivemall.model.IWeightValue] */
        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        protected float update(@Nonnull Object obj, float f, float f2) {
            WeightValue.WeightValueParamsF2 weightValueParamsF2 = this.auxWeights.get(obj);
            if (weightValueParamsF2 == null) {
                weightValueParamsF2 = newWeightValue(f);
                this.auxWeights.put(obj, weightValueParamsF2);
            } else {
                weightValueParamsF2.set(f);
            }
            return update(weightValueParamsF2, f2);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @NotThreadSafe
    /* loaded from: input_file:hivemall/optimizer/SparseOptimizerFactory$Momentum.class */
    public static final class Momentum extends Optimizer.Momentum {

        @Nonnull
        private final Object2ObjectMap<Object, IWeightValue> auxWeights;

        public Momentum(@Nonnegative int i, @Nonnull Map<String, String> map) {
            super(map);
            this.auxWeights = new Object2ObjectOpenHashMap(i);
        }

        /* JADX WARN: Multi-variable type inference failed */
        /* JADX WARN: Type inference failed for: r0v3, types: [hivemall.model.IWeightValue] */
        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        protected float update(@Nonnull Object obj, float f, float f2) {
            WeightValue.WeightValueParamsF1 weightValueParamsF1 = this.auxWeights.get(obj);
            if (weightValueParamsF1 == null) {
                weightValueParamsF1 = newWeightValue(f);
                this.auxWeights.put(obj, weightValueParamsF1);
            } else {
                weightValueParamsF1.set(f);
            }
            return update(weightValueParamsF1, f2);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @NotThreadSafe
    /* loaded from: input_file:hivemall/optimizer/SparseOptimizerFactory$Nadam.class */
    public static final class Nadam extends Optimizer.Nadam {

        @Nonnull
        private final Object2ObjectMap<Object, IWeightValue> auxWeights;

        public Nadam(@Nonnegative int i, @Nonnull Map<String, String> map) {
            super(map);
            this.auxWeights = new Object2ObjectOpenHashMap(i);
        }

        /* JADX WARN: Multi-variable type inference failed */
        /* JADX WARN: Type inference failed for: r0v3, types: [hivemall.model.IWeightValue] */
        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        protected float update(@Nonnull Object obj, float f, float f2) {
            WeightValue.WeightValueParamsF2 weightValueParamsF2 = this.auxWeights.get(obj);
            if (weightValueParamsF2 == null) {
                weightValueParamsF2 = newWeightValue(f);
                this.auxWeights.put(obj, weightValueParamsF2);
            } else {
                weightValueParamsF2.set(f);
            }
            return update(weightValueParamsF2, f2);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @NotThreadSafe
    /* loaded from: input_file:hivemall/optimizer/SparseOptimizerFactory$RMSprop.class */
    public static final class RMSprop extends Optimizer.RMSprop {

        @Nonnull
        private final Object2ObjectMap<Object, IWeightValue> auxWeights;

        public RMSprop(@Nonnegative int i, @Nonnull Map<String, String> map) {
            super(map);
            this.auxWeights = new Object2ObjectOpenHashMap(i);
        }

        /* JADX WARN: Multi-variable type inference failed */
        /* JADX WARN: Type inference failed for: r0v3, types: [hivemall.model.IWeightValue] */
        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        protected float update(@Nonnull Object obj, float f, float f2) {
            WeightValue.WeightValueParamsF1 weightValueParamsF1 = this.auxWeights.get(obj);
            if (weightValueParamsF1 == null) {
                weightValueParamsF1 = newWeightValue(f);
                this.auxWeights.put(obj, weightValueParamsF1);
            } else {
                weightValueParamsF1.set(f);
            }
            return update(weightValueParamsF1, f2);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @NotThreadSafe
    /* loaded from: input_file:hivemall/optimizer/SparseOptimizerFactory$RMSpropGraves.class */
    public static final class RMSpropGraves extends Optimizer.RMSpropGraves {

        @Nonnull
        private final Object2ObjectMap<Object, IWeightValue> auxWeights;

        public RMSpropGraves(@Nonnegative int i, @Nonnull Map<String, String> map) {
            super(map);
            this.auxWeights = new Object2ObjectOpenHashMap(i);
        }

        /* JADX WARN: Multi-variable type inference failed */
        /* JADX WARN: Type inference failed for: r0v3, types: [hivemall.model.IWeightValue] */
        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        protected float update(@Nonnull Object obj, float f, float f2) {
            WeightValue.WeightValueParamsF3 weightValueParamsF3 = this.auxWeights.get(obj);
            if (weightValueParamsF3 == null) {
                weightValueParamsF3 = newWeightValue(f);
                this.auxWeights.put(obj, weightValueParamsF3);
            } else {
                weightValueParamsF3.set(f);
            }
            return update(weightValueParamsF3, f2);
        }
    }

    @Nonnull
    public static Optimizer create(@Nonnull int i, @Nonnull Map<String, String> map) {
        Optimizer.OptimizerBase rMSpropGraves;
        String str = map.get("optimizer");
        if (str == null) {
            throw new IllegalArgumentException("`optimizer` not defined");
        }
        String lowerCase = str.toLowerCase();
        if ("rda".equalsIgnoreCase(map.get("regularization")) && !"adagrad".equals(lowerCase)) {
            throw new IllegalArgumentException("`-regularization rda` is only supported for AdaGrad but `-optimizer " + str + "`. Please specify `-regularization l1` and so on.");
        }
        if ("sgd".equals(lowerCase)) {
            rMSpropGraves = new Optimizer.SGD(map);
        } else if ("momentum".equals(lowerCase)) {
            rMSpropGraves = new Momentum(i, map);
        } else if ("nesterov".equals(lowerCase)) {
            map.put("nesterov", "");
            rMSpropGraves = new Momentum(i, map);
        } else if ("adagrad".equals(lowerCase)) {
            rMSpropGraves = "rda".equalsIgnoreCase(map.get("regularization")) ? new AdagradRDA(i, new AdaGrad(i, map), map) : new AdaGrad(i, map);
        } else if ("rmsprop".equals(lowerCase)) {
            rMSpropGraves = new RMSprop(i, map);
        } else if ("rmspropgraves".equals(lowerCase) || "rmsprop_graves".equals(lowerCase)) {
            rMSpropGraves = new RMSpropGraves(i, map);
        } else if ("adadelta".equals(lowerCase)) {
            rMSpropGraves = new AdaDelta(i, map);
        } else if ("adam".equals(lowerCase)) {
            rMSpropGraves = new Adam(i, map);
        } else if ("nadam".equals(lowerCase)) {
            rMSpropGraves = new Nadam(i, map);
        } else if ("eve".equals(lowerCase)) {
            rMSpropGraves = new Eve(i, map);
        } else {
            if (!"adam_hd".equals(lowerCase) && !"adamhd".equals(lowerCase)) {
                throw new IllegalArgumentException("Unsupported optimizer name: " + str);
            }
            rMSpropGraves = new AdamHD(i, map);
        }
        if (LOG.isInfoEnabled()) {
            LOG.info("Configured " + rMSpropGraves.getOptimizerName() + " as the optimizer: " + map);
            LOG.info("ETA estimator: " + rMSpropGraves._eta);
        }
        return rMSpropGraves;
    }
}
