package smile.classification;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import smile.math.DifferentiableMultivariateFunction;
import smile.math.Math;
import smile.util.MulticoreExecutor;

/* loaded from: input_file:smile/classification/LogisticRegression.class */
public class LogisticRegression implements Classifier<double[]> {
    private int p;
    private int k;
    private double L;
    private double[] w;
    private double[][] W;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:smile/classification/LogisticRegression$BinaryObjectiveFunction.class */
    public static class BinaryObjectiveFunction implements DifferentiableMultivariateFunction {
        double[][] x;
        int[] y;
        double lambda;
        List<FTask> ftasks;
        List<GTask> gtasks;

        /* loaded from: input_file:smile/classification/LogisticRegression$BinaryObjectiveFunction$FTask.class */
        class FTask implements Callable<Double> {
            double[] w;
            int start;
            int end;

            FTask(int i, int i2) {
                this.start = i;
                this.end = i2;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.concurrent.Callable
            public Double call() {
                double d = 0.0d;
                for (int i = this.start; i < this.end; i++) {
                    double dot = LogisticRegression.dot(BinaryObjectiveFunction.this.x[i], this.w);
                    d += LogisticRegression.log1pe(dot) - (BinaryObjectiveFunction.this.y[i] * dot);
                }
                return Double.valueOf(d);
            }
        }

        /* loaded from: input_file:smile/classification/LogisticRegression$BinaryObjectiveFunction$GTask.class */
        class GTask implements Callable<double[]> {
            double[] w;
            int start;
            int end;

            GTask(int i, int i2) {
                this.start = i;
                this.end = i2;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.concurrent.Callable
            public double[] call() {
                double d = 0.0d;
                int length = this.w.length - 1;
                double[] dArr = new double[this.w.length + 1];
                for (int i = this.start; i < this.end; i++) {
                    double dot = LogisticRegression.dot(BinaryObjectiveFunction.this.x[i], this.w);
                    d += LogisticRegression.log1pe(dot) - (BinaryObjectiveFunction.this.y[i] * dot);
                    double logistic = BinaryObjectiveFunction.this.y[i] - Math.logistic(dot);
                    for (int i2 = 0; i2 < length; i2++) {
                        int i3 = i2;
                        dArr[i3] = dArr[i3] - (logistic * BinaryObjectiveFunction.this.x[i][i2]);
                    }
                    dArr[length] = dArr[length] - logistic;
                }
                dArr[this.w.length] = d;
                return dArr;
            }
        }

        BinaryObjectiveFunction(double[][] dArr, int[] iArr, double d) {
            this.ftasks = null;
            this.gtasks = null;
            this.x = dArr;
            this.y = iArr;
            this.lambda = d;
            int length = dArr.length;
            int threadPoolSize = MulticoreExecutor.getThreadPoolSize();
            if (length < 1000 || threadPoolSize < 2) {
                return;
            }
            this.ftasks = new ArrayList(threadPoolSize + 1);
            this.gtasks = new ArrayList(threadPoolSize + 1);
            int i = length / threadPoolSize;
            i = i < 100 ? 100 : i;
            int i2 = 0;
            int i3 = i;
            for (int i4 = 0; i4 < threadPoolSize - 1; i4++) {
                this.ftasks.add(new FTask(i2, i3));
                this.gtasks.add(new GTask(i2, i3));
                i2 += i;
                i3 += i;
            }
            this.ftasks.add(new FTask(i2, length));
            this.gtasks.add(new GTask(i2, length));
        }

        @Override // smile.math.MultivariateFunction
        public double f(double[] dArr) {
            double d = Double.NaN;
            int length = dArr.length - 1;
            if (this.ftasks != null) {
                Iterator<FTask> it2 = this.ftasks.iterator();
                while (it2.hasNext()) {
                    it2.next().w = dArr;
                }
                try {
                    d = 0.0d;
                    Iterator it3 = MulticoreExecutor.run(this.ftasks).iterator();
                    while (it3.hasNext()) {
                        d += ((Double) it3.next()).doubleValue();
                    }
                } catch (Exception e) {
                    System.err.println(e);
                    d = Double.NaN;
                }
            }
            if (Double.isNaN(d)) {
                d = 0.0d;
                int length2 = this.x.length;
                for (int i = 0; i < length2; i++) {
                    double dot = LogisticRegression.dot(this.x[i], dArr);
                    d += LogisticRegression.log1pe(dot) - (this.y[i] * dot);
                }
            }
            if (this.lambda != CMAESOptimizer.DEFAULT_STOPFITNESS) {
                double d2 = 0.0d;
                for (int i2 = 0; i2 < length; i2++) {
                    d2 += dArr[i2] * dArr[i2];
                }
                d += 0.5d * this.lambda * d2;
            }
            return d;
        }

        @Override // smile.math.DifferentiableMultivariateFunction
        public double f(double[] dArr, double[] dArr2) {
            double d = Double.NaN;
            int length = dArr.length - 1;
            Arrays.fill(dArr2, CMAESOptimizer.DEFAULT_STOPFITNESS);
            if (this.gtasks != null) {
                Iterator<GTask> it2 = this.gtasks.iterator();
                while (it2.hasNext()) {
                    it2.next().w = dArr;
                }
                try {
                    d = 0.0d;
                    for (double[] dArr3 : MulticoreExecutor.run(this.gtasks)) {
                        d += dArr3[dArr.length];
                        for (int i = 0; i < dArr.length; i++) {
                            int i2 = i;
                            dArr2[i2] = dArr2[i2] + dArr3[i];
                        }
                    }
                } catch (Exception e) {
                    System.err.println(e);
                    d = Double.NaN;
                }
            }
            if (Double.isNaN(d)) {
                d = 0.0d;
                int length2 = this.x.length;
                for (int i3 = 0; i3 < length2; i3++) {
                    double dot = LogisticRegression.dot(this.x[i3], dArr);
                    d += LogisticRegression.log1pe(dot) - (this.y[i3] * dot);
                    double logistic = this.y[i3] - Math.logistic(dot);
                    for (int i4 = 0; i4 < length; i4++) {
                        int i5 = i4;
                        dArr2[i5] = dArr2[i5] - (logistic * this.x[i3][i4]);
                    }
                    dArr2[length] = dArr2[length] - logistic;
                }
            }
            if (this.lambda != CMAESOptimizer.DEFAULT_STOPFITNESS) {
                double d2 = 0.0d;
                for (int i6 = 0; i6 < length; i6++) {
                    d2 += dArr[i6] * dArr[i6];
                }
                d += 0.5d * this.lambda * d2;
                for (int i7 = 0; i7 < length; i7++) {
                    int i8 = i7;
                    dArr2[i8] = dArr2[i8] + (this.lambda * dArr[i7]);
                }
            }
            return d;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:smile/classification/LogisticRegression$MultiClassObjectiveFunction.class */
    public static class MultiClassObjectiveFunction implements DifferentiableMultivariateFunction {
        double[][] x;
        int[] y;
        int k;
        double lambda;
        List<FTask> ftasks;
        List<GTask> gtasks;

        /* loaded from: input_file:smile/classification/LogisticRegression$MultiClassObjectiveFunction$FTask.class */
        class FTask implements Callable<Double> {
            double[] w;
            int start;
            int end;

            FTask(int i, int i2) {
                this.start = i;
                this.end = i2;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.concurrent.Callable
            public Double call() {
                double d = 0.0d;
                int length = MultiClassObjectiveFunction.this.x[0].length;
                double[] dArr = new double[MultiClassObjectiveFunction.this.k];
                for (int i = this.start; i < this.end; i++) {
                    for (int i2 = 0; i2 < MultiClassObjectiveFunction.this.k; i2++) {
                        dArr[i2] = LogisticRegression.dot(MultiClassObjectiveFunction.this.x[i], this.w, i2 * (length + 1));
                    }
                    LogisticRegression.softmax(dArr);
                    d -= LogisticRegression.log(dArr[MultiClassObjectiveFunction.this.y[i]]);
                }
                return Double.valueOf(d);
            }
        }

        /* loaded from: input_file:smile/classification/LogisticRegression$MultiClassObjectiveFunction$GTask.class */
        class GTask implements Callable<double[]> {
            double[] w;
            int start;
            int end;

            GTask(int i, int i2) {
                this.start = i;
                this.end = i2;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.concurrent.Callable
            public double[] call() {
                double d = 0.0d;
                double[] dArr = new double[this.w.length + 1];
                int length = MultiClassObjectiveFunction.this.x[0].length;
                double[] dArr2 = new double[MultiClassObjectiveFunction.this.k];
                for (int i = this.start; i < this.end; i++) {
                    for (int i2 = 0; i2 < MultiClassObjectiveFunction.this.k; i2++) {
                        dArr2[i2] = LogisticRegression.dot(MultiClassObjectiveFunction.this.x[i], this.w, i2 * (length + 1));
                    }
                    LogisticRegression.softmax(dArr2);
                    d -= LogisticRegression.log(dArr2[MultiClassObjectiveFunction.this.y[i]]);
                    int i3 = 0;
                    while (i3 < MultiClassObjectiveFunction.this.k) {
                        double d2 = (MultiClassObjectiveFunction.this.y[i] == i3 ? 1.0d : CMAESOptimizer.DEFAULT_STOPFITNESS) - dArr2[i3];
                        int i4 = i3 * (length + 1);
                        for (int i5 = 0; i5 < length; i5++) {
                            int i6 = i4 + i5;
                            dArr[i6] = dArr[i6] - (d2 * MultiClassObjectiveFunction.this.x[i][i5]);
                        }
                        int i7 = (i3 * (length + 1)) + length;
                        dArr[i7] = dArr[i7] - d2;
                        i3++;
                    }
                }
                dArr[this.w.length] = d;
                return dArr;
            }
        }

        MultiClassObjectiveFunction(double[][] dArr, int[] iArr, int i, double d) {
            this.ftasks = null;
            this.gtasks = null;
            this.x = dArr;
            this.y = iArr;
            this.k = i;
            this.lambda = d;
            int length = dArr.length;
            int threadPoolSize = MulticoreExecutor.getThreadPoolSize();
            if (length < 1000 || threadPoolSize < 2) {
                return;
            }
            this.ftasks = new ArrayList(threadPoolSize + 1);
            this.gtasks = new ArrayList(threadPoolSize + 1);
            int i2 = length / threadPoolSize;
            i2 = i2 < 100 ? 100 : i2;
            int i3 = 0;
            int i4 = i2;
            for (int i5 = 0; i5 < threadPoolSize - 1; i5++) {
                this.ftasks.add(new FTask(i3, i4));
                this.gtasks.add(new GTask(i3, i4));
                i3 += i2;
                i4 += i2;
            }
            this.ftasks.add(new FTask(i3, length));
            this.gtasks.add(new GTask(i3, length));
        }

        @Override // smile.math.MultivariateFunction
        public double f(double[] dArr) {
            double d = Double.NaN;
            int length = this.x[0].length;
            double[] dArr2 = new double[this.k];
            if (this.ftasks != null) {
                Iterator<FTask> it2 = this.ftasks.iterator();
                while (it2.hasNext()) {
                    it2.next().w = dArr;
                }
                try {
                    d = 0.0d;
                    Iterator it3 = MulticoreExecutor.run(this.ftasks).iterator();
                    while (it3.hasNext()) {
                        d += ((Double) it3.next()).doubleValue();
                    }
                } catch (Exception e) {
                    System.err.println(e);
                    d = Double.NaN;
                }
            }
            if (Double.isNaN(d)) {
                d = 0.0d;
                int length2 = this.x.length;
                for (int i = 0; i < length2; i++) {
                    for (int i2 = 0; i2 < this.k; i2++) {
                        dArr2[i2] = LogisticRegression.dot(this.x[i], dArr, i2 * (length + 1));
                    }
                    LogisticRegression.softmax(dArr2);
                    d -= LogisticRegression.log(dArr2[this.y[i]]);
                }
            }
            if (this.lambda != CMAESOptimizer.DEFAULT_STOPFITNESS) {
                double d2 = 0.0d;
                for (int i3 = 0; i3 < this.k; i3++) {
                    for (int i4 = 0; i4 < length; i4++) {
                        d2 += Math.sqr(dArr[(i3 * (length + 1)) + i4]);
                    }
                }
                d += 0.5d * this.lambda * d2;
            }
            return d;
        }

        @Override // smile.math.DifferentiableMultivariateFunction
        public double f(double[] dArr, double[] dArr2) {
            double d = Double.NaN;
            int length = this.x[0].length;
            double[] dArr3 = new double[this.k];
            Arrays.fill(dArr2, CMAESOptimizer.DEFAULT_STOPFITNESS);
            if (this.gtasks != null) {
                Iterator<GTask> it2 = this.gtasks.iterator();
                while (it2.hasNext()) {
                    it2.next().w = dArr;
                }
                try {
                    d = 0.0d;
                    for (double[] dArr4 : MulticoreExecutor.run(this.gtasks)) {
                        d += dArr4[dArr.length];
                        for (int i = 0; i < dArr.length; i++) {
                            int i2 = i;
                            dArr2[i2] = dArr2[i2] + dArr4[i];
                        }
                    }
                } catch (Exception e) {
                    System.err.println(e);
                    d = Double.NaN;
                }
            }
            if (Double.isNaN(d)) {
                d = 0.0d;
                int length2 = this.x.length;
                for (int i3 = 0; i3 < length2; i3++) {
                    for (int i4 = 0; i4 < this.k; i4++) {
                        dArr3[i4] = LogisticRegression.dot(this.x[i3], dArr, i4 * (length + 1));
                    }
                    LogisticRegression.softmax(dArr3);
                    d -= LogisticRegression.log(dArr3[this.y[i3]]);
                    int i5 = 0;
                    while (i5 < this.k) {
                        double d2 = (this.y[i3] == i5 ? 1.0d : CMAESOptimizer.DEFAULT_STOPFITNESS) - dArr3[i5];
                        int i6 = i5 * (length + 1);
                        for (int i7 = 0; i7 < length; i7++) {
                            int i8 = i6 + i7;
                            dArr2[i8] = dArr2[i8] - (d2 * this.x[i3][i7]);
                        }
                        int i9 = (i5 * (length + 1)) + length;
                        dArr2[i9] = dArr2[i9] - d2;
                        i5++;
                    }
                }
            }
            if (this.lambda != CMAESOptimizer.DEFAULT_STOPFITNESS) {
                double d3 = 0.0d;
                for (int i10 = 0; i10 < this.k; i10++) {
                    for (int i11 = 0; i11 < length; i11++) {
                        int i12 = (i10 * (length + 1)) + i11;
                        d3 += dArr[i12] * dArr[i12];
                        dArr2[i12] = dArr2[i12] + (this.lambda * dArr[i12]);
                    }
                }
                d += 0.5d * this.lambda * d3;
            }
            return d;
        }
    }

    /* loaded from: input_file:smile/classification/LogisticRegression$Trainer.class */
    public static class Trainer extends ClassifierTrainer<double[]> {
        private double lambda = CMAESOptimizer.DEFAULT_STOPFITNESS;
        private double tol = 1.0E-5d;
        private int maxIter = 500;

        public Trainer setRegularizationFactor(double d) {
            this.lambda = d;
            return this;
        }

        public Trainer setTolerance(double d) {
            if (d <= CMAESOptimizer.DEFAULT_STOPFITNESS) {
                throw new IllegalArgumentException("Invalid tolerance: " + d);
            }
            this.tol = d;
            return this;
        }

        public Trainer setMaxNumIteration(int i) {
            if (i <= 0) {
                throw new IllegalArgumentException("Invalid maximum number of iterations: " + i);
            }
            this.maxIter = i;
            return this;
        }

        @Override // smile.classification.ClassifierTrainer
        public LogisticRegression train(double[][] dArr, int[] iArr) {
            return new LogisticRegression(dArr, iArr, this.lambda, this.tol, this.maxIter);
        }
    }

    public LogisticRegression(double[][] dArr, int[] iArr) {
        this(dArr, iArr, CMAESOptimizer.DEFAULT_STOPFITNESS);
    }

    public LogisticRegression(double[][] dArr, int[] iArr, double d) {
        this(dArr, iArr, d, 1.0E-5d, 500);
    }

    public LogisticRegression(double[][] dArr, int[] iArr, double d, double d2, int i) {
        if (dArr.length != iArr.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", Integer.valueOf(dArr.length), Integer.valueOf(iArr.length)));
        }
        if (d < CMAESOptimizer.DEFAULT_STOPFITNESS) {
            throw new IllegalArgumentException("Invalid regularization factor: " + d);
        }
        if (d2 <= CMAESOptimizer.DEFAULT_STOPFITNESS) {
            throw new IllegalArgumentException("Invalid tolerance: " + d2);
        }
        if (i <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + i);
        }
        int[] unique = Math.unique(iArr);
        Arrays.sort(unique);
        for (int i2 = 0; i2 < unique.length; i2++) {
            if (unique[i2] < 0) {
                throw new IllegalArgumentException("Negative class label: " + unique[i2]);
            }
            if (i2 > 0 && unique[i2] - unique[i2 - 1] > 1) {
                throw new IllegalArgumentException("Missing class: " + unique[i2] + 1);
            }
        }
        this.k = unique.length;
        if (this.k < 2) {
            throw new IllegalArgumentException("Only one class.");
        }
        this.p = dArr[0].length;
        if (this.k == 2) {
            BinaryObjectiveFunction binaryObjectiveFunction = new BinaryObjectiveFunction(dArr, iArr, d);
            this.w = new double[this.p + 1];
            this.L = CMAESOptimizer.DEFAULT_STOPFITNESS;
            try {
                this.L = -Math.min(binaryObjectiveFunction, 5, this.w, d2, i);
                return;
            } catch (Exception e) {
                this.L = -Math.min(binaryObjectiveFunction, this.w, d2, i);
                return;
            }
        }
        MultiClassObjectiveFunction multiClassObjectiveFunction = new MultiClassObjectiveFunction(dArr, iArr, this.k, d);
        this.w = new double[this.k * (this.p + 1)];
        this.L = CMAESOptimizer.DEFAULT_STOPFITNESS;
        try {
            this.L = -Math.min(multiClassObjectiveFunction, 5, this.w, d2, i);
        } catch (Exception e2) {
            this.L = -Math.min(multiClassObjectiveFunction, this.w, d2, i);
        }
        this.W = new double[this.k][this.p + 1];
        int i3 = 0;
        for (int i4 = 0; i4 < this.k; i4++) {
            int i5 = 0;
            while (i5 <= this.p) {
                this.W[i4][i5] = this.w[i3];
                i5++;
                i3++;
            }
        }
        this.w = null;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double log1pe(double d) {
        return d > 15.0d ? d : CMAESOptimizer.DEFAULT_STOPFITNESS + Math.log1p(Math.exp(d));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double log(double d) {
        return d < 1.0E-300d ? -690.7755d : Math.log(d);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void softmax(double[] dArr) {
        double d = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < dArr.length; i++) {
            if (dArr[i] > d) {
                d = dArr[i];
            }
        }
        double d2 = 0.0d;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            double exp = Math.exp(dArr[i2] - d);
            dArr[i2] = exp;
            d2 += exp;
        }
        for (int i3 = 0; i3 < dArr.length; i3++) {
            int i4 = i3;
            dArr[i4] = dArr[i4] / d2;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double dot(double[] dArr, double[] dArr2) {
        int i = 0;
        double d = 0.0d;
        while (i < dArr.length) {
            d += dArr[i] * dArr2[i];
            i++;
        }
        return d + dArr2[i];
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double dot(double[] dArr, double[] dArr2, int i) {
        int i2 = 0;
        double d = 0.0d;
        while (i2 < dArr.length) {
            d += dArr[i2] * dArr2[i + i2];
            i2++;
        }
        return d + dArr2[i + i2];
    }

    public double loglikelihood() {
        return this.L;
    }

    @Override // smile.classification.Classifier
    public int predict(double[] dArr) {
        return predict(dArr, (double[]) null);
    }

    @Override // smile.classification.Classifier
    public int predict(double[] dArr, double[] dArr2) {
        if (dArr.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(this.p)));
        }
        if (dArr2 != null && dArr2.length != this.k) {
            throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", Integer.valueOf(dArr2.length), Integer.valueOf(this.k)));
        }
        if (this.k == 2) {
            double exp = 1.0d / (1.0d + Math.exp(-dot(dArr, this.w)));
            if (dArr2 != null) {
                dArr2[0] = exp;
                dArr2[1] = 1.0d - exp;
            }
            return exp < 0.5d ? 0 : 1;
        }
        int i = -1;
        double d = Double.NEGATIVE_INFINITY;
        for (int i2 = 0; i2 < this.k; i2++) {
            double dot = dot(dArr, this.W[i2]);
            if (dot > d) {
                d = dot;
                i = i2;
            }
            if (dArr2 != null) {
                dArr2[i2] = dot;
            }
        }
        if (dArr2 != null) {
            double d2 = 0.0d;
            for (int i3 = 0; i3 < this.k; i3++) {
                dArr2[i3] = Math.exp(dArr2[i3] - d);
                d2 += dArr2[i3];
            }
            for (int i4 = 0; i4 < this.k; i4++) {
                int i5 = i4;
                dArr2[i5] = dArr2[i5] / d2;
            }
        }
        return i;
    }
}
