package weka.classifiers.functions;

import java.util.Enumeration;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Vector;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import weka.classifiers.RandomizableClassifier;
import weka.classifiers.functions.supportVector.CachedKernel;
import weka.classifiers.functions.supportVector.Kernel;
import weka.classifiers.functions.supportVector.PolyKernel;
import weka.core.Capabilities;
import weka.core.ConjugateGradientOptimization;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Optimization;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.NominalToBinary;
import weka.filters.unsupervised.attribute.RemoveUseless;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;
import weka.filters.unsupervised.attribute.Standardize;

/* loaded from: input_file:weka/classifiers/functions/KernelLogisticRegression.class */
public class KernelLogisticRegression extends RandomizableClassifier {
    static final long serialVersionUID = 6332117032546553533L;
    protected double[] m_weights;
    protected Instances m_data;
    protected Kernel m_kernel = new PolyKernel();
    protected double m_lambda = 0.01d;
    protected boolean m_useCGD = false;
    protected Standardize m_standardize = new Standardize();
    protected ReplaceMissingValues m_replaceMissing = new ReplaceMissingValues();
    protected NominalToBinary m_nominalToBinary = new NominalToBinary();
    protected RemoveUseless m_removeUseless = new RemoveUseless();
    protected int m_numThreads = 1;
    protected int m_poolSize = 1;
    protected transient ExecutorService m_Pool = null;
    protected double[][] m_kernelMatrix = (double[][]) null;
    protected double[] m_classValues = null;

    /* loaded from: input_file:weka/classifiers/functions/KernelLogisticRegression$OptEng.class */
    protected class OptEng extends Optimization {
        protected OptEng() {
        }

        protected double objectiveFunction(double[] dArr) throws Exception {
            KernelLogisticRegression.this.m_weights = dArr;
            return KernelLogisticRegression.this.calculateLoss();
        }

        protected double[] evaluateGradient(double[] dArr) throws Exception {
            KernelLogisticRegression.this.m_weights = dArr;
            return KernelLogisticRegression.this.calculateGradient();
        }

        public String getRevision() {
            return RevisionUtils.extract("$Revision: 1345 $");
        }
    }

    /* loaded from: input_file:weka/classifiers/functions/KernelLogisticRegression$OptEngCGD.class */
    protected class OptEngCGD extends ConjugateGradientOptimization {
        protected OptEngCGD() {
        }

        protected double objectiveFunction(double[] dArr) throws Exception {
            KernelLogisticRegression.this.m_weights = dArr;
            return KernelLogisticRegression.this.calculateLoss();
        }

        protected double[] evaluateGradient(double[] dArr) throws Exception {
            KernelLogisticRegression.this.m_weights = dArr;
            return KernelLogisticRegression.this.calculateGradient();
        }

        public String getRevision() {
            return RevisionUtils.extract("$Revision: 9345 $");
        }
    }

    public String globalInfo() {
        return "This classifier generates a two-class kernel logistic regression model. The model is fit by minimizing the negative log-likelihood with a quadratic penalty using BFGS optimization, as implemented in the Optimization class. Alternatively, conjugate gradient optimization can be applied. The user can specify the kernel function and the value of lambda, the multiplier for the quadractic penalty. Using a linear kernel (the default) this method should give the same result as ridge logistic regression implemented in Logistic, assuming the ridge parameter is set to the same value as lambda, and not too small. By replacing the kernel function, we can learn non-linear decision boundaries.\n\nNote that the data is filtered using ReplaceMissingValues, RemoveUseless, NominalToBinary, and Standardize (in that order).\n\nIf a CachedKernel is used, this class will overwrite the manually specified cache size and use a full cache instead.\n\nTo apply this classifier to multi-class problems, use the MultiClassClassifier.\n\nThis implementation stores the full kernel matrix at training time for speed reasons.";
    }

    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAll();
        capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.MISSING_VALUES);
        capabilities.enable(Capabilities.Capability.BINARY_CLASS);
        capabilities.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        return capabilities;
    }

    public Enumeration listOptions() {
        Vector vector = new Vector();
        Enumeration listOptions = super.listOptions();
        while (listOptions.hasMoreElements()) {
            vector.addElement((Option) listOptions.nextElement());
        }
        vector.addElement(new Option("\tThe Kernel to use.\n\t(default: weka.classifiers.functions.supportVector.PolyKernel)", "K", 1, "-K <classname and parameters>"));
        vector.addElement(new Option("\tThe lambda penalty parameter. (default 0.01)", "L", 1, "-L <double>"));
        vector.addElement(new Option("\tUse conjugate gradient descent instead of BFGS.\n", "G", 0, "-G"));
        vector.addElement(new Option("\t" + poolSizeTipText() + " (default 1)\n", "P", 1, "-P <int>"));
        vector.addElement(new Option("\t" + numThreadsTipText() + " (default 1)\n", "E", 1, "-E <int>"));
        vector.addElement(new Option("", "", 0, "\nOptions specific to kernel " + getKernel().getClass().getName() + ":"));
        Enumeration listOptions2 = getKernel().listOptions();
        while (listOptions2.hasMoreElements()) {
            vector.addElement((Option) listOptions2.nextElement());
        }
        return vector.elements();
    }

    public void setOptions(String[] strArr) throws Exception {
        String option = Utils.getOption('L', strArr);
        if (option.length() != 0) {
            setLambda(Double.parseDouble(option));
        } else {
            setLambda(0.01d);
        }
        this.m_useCGD = Utils.getFlag('G', strArr);
        String[] splitOptions = Utils.splitOptions(Utils.getOption('K', strArr));
        if (splitOptions.length != 0) {
            String str = splitOptions[0];
            splitOptions[0] = "";
            setKernel(Kernel.forName(str, splitOptions));
        }
        String option2 = Utils.getOption('P', strArr);
        if (option2.length() != 0) {
            setPoolSize(Integer.parseInt(option2));
        } else {
            setPoolSize(1);
        }
        String option3 = Utils.getOption('E', strArr);
        if (option3.length() != 0) {
            setNumThreads(Integer.parseInt(option3));
        } else {
            setNumThreads(1);
        }
        super.setOptions(strArr);
    }

    public String[] getOptions() {
        Vector vector = new Vector();
        for (String str : super.getOptions()) {
            vector.add(str);
        }
        vector.add("-K");
        vector.add("" + getKernel().getClass().getName() + " " + Utils.joinOptions(getKernel().getOptions()));
        vector.add("-L");
        vector.add("" + getLambda());
        if (this.m_useCGD) {
            vector.add("-G");
        }
        vector.add("-P");
        vector.add("" + getPoolSize());
        vector.add("-E");
        vector.add("" + getNumThreads());
        return (String[]) vector.toArray(new String[vector.size()]);
    }

    public String numThreadsTipText() {
        return "The number of threads to use, which should be >= size of thread pool.";
    }

    public int getNumThreads() {
        return this.m_numThreads;
    }

    public void setNumThreads(int i) {
        this.m_numThreads = i;
    }

    public String poolSizeTipText() {
        return "The size of the thread pool, for example, the number of cores in the CPU.";
    }

    public int getPoolSize() {
        return this.m_poolSize;
    }

    public void setPoolSize(int i) {
        this.m_poolSize = i;
    }

    public String lambdaTipText() {
        return "The penalty parameter lambda.";
    }

    public double getLambda() {
        return this.m_lambda;
    }

    public void setLambda(double d) {
        this.m_lambda = d;
    }

    public String kernelTipText() {
        return "The kernel to use.";
    }

    public void setKernel(Kernel kernel) {
        this.m_kernel = kernel;
    }

    public Kernel getKernel() {
        return this.m_kernel;
    }

    public String useCGDTipText() {
        return "Whether to use conjugate gradient descent (potentially useful for many parameters).";
    }

    public boolean getUseCGD() {
        return this.m_useCGD;
    }

    public void setUseCGD(boolean z) {
        this.m_useCGD = z;
    }

    protected double calculateLoss() throws Exception {
        final int length = this.m_classValues.length;
        int i = length / this.m_numThreads;
        HashSet hashSet = new HashSet();
        int i2 = 0;
        while (i2 < this.m_numThreads) {
            final int i3 = i2 * i;
            final int i4 = i2 < this.m_numThreads - 1 ? i3 + i : length;
            hashSet.add(this.m_Pool.submit(new Callable<Double>() { // from class: weka.classifiers.functions.KernelLogisticRegression.1
                /* JADX WARN: Can't rename method to resolve collision */
                @Override // java.util.concurrent.Callable
                public Double call() throws Exception {
                    double d = 0.0d;
                    for (int i5 = i3; i5 < i4; i5++) {
                        double d2 = 0.0d;
                        for (int i6 = 0; i6 < length; i6++) {
                            d2 += KernelLogisticRegression.this.m_weights[i6] * KernelLogisticRegression.this.m_kernelMatrix[i5][i6];
                        }
                        d = d + (KernelLogisticRegression.this.m_lambda * KernelLogisticRegression.this.m_weights[i5] * d2) + Math.log(1.0d + Math.exp((-KernelLogisticRegression.this.m_classValues[i5]) * (d2 + KernelLogisticRegression.this.m_weights[length])));
                    }
                    return Double.valueOf(d);
                }
            }));
            i2++;
        }
        double d = 0.0d;
        try {
            Iterator it = hashSet.iterator();
            while (it.hasNext()) {
                d += ((Double) ((Future) it.next()).get()).doubleValue();
            }
        } catch (Exception e) {
            System.out.println("Loss could not be calculated.");
            e.printStackTrace();
        }
        return d;
    }

    protected double[] calculateGradient() throws Exception {
        final int length = this.m_classValues.length;
        int i = length / this.m_numThreads;
        HashSet hashSet = new HashSet();
        int i2 = 0;
        while (i2 < this.m_numThreads) {
            final int i3 = i2 * i;
            final int i4 = i2 < this.m_numThreads - 1 ? i3 + i : length;
            hashSet.add(this.m_Pool.submit(new Callable<double[]>() { // from class: weka.classifiers.functions.KernelLogisticRegression.2
                /* JADX WARN: Can't rename method to resolve collision */
                @Override // java.util.concurrent.Callable
                public double[] call() throws Exception {
                    double[] dArr = new double[length + 1];
                    for (int i5 = i3; i5 < i4; i5++) {
                        double d = 0.0d;
                        for (int i6 = 0; i6 < length; i6++) {
                            d += KernelLogisticRegression.this.m_weights[i6] * KernelLogisticRegression.this.m_kernelMatrix[i5][i6];
                        }
                        int i7 = i5;
                        dArr[i7] = dArr[i7] + (2.0d * KernelLogisticRegression.this.m_lambda * d);
                        double exp = (-KernelLogisticRegression.this.m_classValues[i5]) * (1.0d / (1.0d + Math.exp(KernelLogisticRegression.this.m_classValues[i5] * (d + KernelLogisticRegression.this.m_weights[length]))));
                        for (int i8 = 0; i8 < length; i8++) {
                            int i9 = i8;
                            dArr[i9] = dArr[i9] + (exp * KernelLogisticRegression.this.m_kernelMatrix[i5][i8]);
                        }
                        int i10 = length;
                        dArr[i10] = dArr[i10] + exp;
                    }
                    return dArr;
                }
            }));
            i2++;
        }
        double[] dArr = new double[length + 1];
        try {
            Iterator it = hashSet.iterator();
            while (it.hasNext()) {
                double[] dArr2 = (double[]) ((Future) it.next()).get();
                for (int i5 = 0; i5 < dArr2.length; i5++) {
                    int i6 = i5;
                    dArr[i6] = dArr[i6] + dArr2[i5];
                }
            }
        } catch (Exception e) {
            System.out.println("Gradient could not be calculated.");
        }
        return dArr;
    }

    public void buildClassifier(Instances instances) throws Exception {
        getCapabilities().testWithFail(instances);
        this.m_data = new Instances(instances);
        this.m_data.deleteWithMissingClass();
        this.m_data.randomize(this.m_data.getRandomNumberGenerator(getSeed()));
        this.m_replaceMissing = new ReplaceMissingValues();
        this.m_replaceMissing.setInputFormat(this.m_data);
        this.m_data = Filter.useFilter(this.m_data, this.m_replaceMissing);
        this.m_removeUseless = new RemoveUseless();
        this.m_removeUseless.setInputFormat(this.m_data);
        this.m_data = Filter.useFilter(this.m_data, this.m_removeUseless);
        this.m_nominalToBinary = new NominalToBinary();
        this.m_nominalToBinary.setInputFormat(this.m_data);
        this.m_data = Filter.useFilter(this.m_data, this.m_nominalToBinary);
        this.m_standardize = new Standardize();
        this.m_standardize.setInputFormat(this.m_data);
        this.m_data = Filter.useFilter(this.m_data, this.m_standardize);
        if (this.m_kernel instanceof CachedKernel) {
            this.m_kernel.setCacheSize(-1);
        }
        this.m_kernel.buildKernel(this.m_data);
        this.m_kernelMatrix = new double[this.m_data.numInstances()][this.m_data.numInstances()];
        this.m_classValues = new double[this.m_data.numInstances()];
        this.m_Pool = Executors.newFixedThreadPool(this.m_poolSize);
        int length = this.m_classValues.length;
        int i = length / this.m_numThreads;
        HashSet hashSet = new HashSet();
        int i2 = 0;
        while (i2 < this.m_numThreads) {
            final int i3 = i2 * i;
            final int i4 = i2 < this.m_numThreads - 1 ? i3 + i : length;
            hashSet.add(this.m_Pool.submit(new Callable<Void>() { // from class: weka.classifiers.functions.KernelLogisticRegression.3
                /* JADX WARN: Can't rename method to resolve collision */
                @Override // java.util.concurrent.Callable
                public Void call() throws Exception {
                    for (int i5 = i3; i5 < i4; i5++) {
                        for (int i6 = 0; i6 < KernelLogisticRegression.this.m_data.numInstances(); i6++) {
                            if (i6 >= i5 || KernelLogisticRegression.this.m_numThreads > 1) {
                                KernelLogisticRegression.this.m_kernelMatrix[i5][i6] = KernelLogisticRegression.this.m_kernel.eval(-1, i5, KernelLogisticRegression.this.m_data.instance(i6));
                            } else {
                                KernelLogisticRegression.this.m_kernelMatrix[i5][i6] = KernelLogisticRegression.this.m_kernelMatrix[i6][i5];
                            }
                        }
                        KernelLogisticRegression.this.m_classValues[i5] = (2.0d * KernelLogisticRegression.this.m_data.instance(i5).classValue()) - 1.0d;
                    }
                    return null;
                }
            }));
            i2++;
        }
        try {
            Iterator it = hashSet.iterator();
            while (it.hasNext()) {
                ((Future) it.next()).get();
            }
        } catch (Exception e) {
            System.out.println("Kernel matrix could not be calculated.");
        }
        this.m_weights = new double[this.m_data.numInstances() + 1];
        double numInstances = 1.0d / this.m_data.numInstances();
        for (int i5 = 0; i5 < this.m_data.numInstances(); i5++) {
            this.m_weights[i5] = this.m_data.instance(i5).classValue() == 0.0d ? -numInstances : numInstances;
        }
        double[] dArr = new double[2];
        for (int i6 = 0; i6 < this.m_data.numInstances(); i6++) {
            int classValue = (int) this.m_data.instance(i6).classValue();
            dArr[classValue] = dArr[classValue] + 1.0d;
        }
        this.m_weights[this.m_data.numInstances()] = Math.log(dArr[1] + 1.0d) - Math.log(dArr[0] + 1.0d);
        double[][] dArr2 = new double[2][this.m_weights.length];
        for (int i7 = 0; i7 < this.m_weights.length; i7++) {
            dArr2[0][i7] = Double.NaN;
            dArr2[1][i7] = Double.NaN;
        }
        ConjugateGradientOptimization optEngCGD = this.m_useCGD ? new OptEngCGD() : new OptEng();
        optEngCGD.setDebug(this.m_Debug);
        this.m_weights = optEngCGD.findArgmin(this.m_weights, dArr2);
        while (this.m_weights == null) {
            this.m_weights = optEngCGD.getVarbValues();
            if (this.m_Debug) {
                System.out.println("First set of iterations finished, not enough!");
            }
            this.m_weights = optEngCGD.findArgmin(this.m_weights, dArr2);
        }
        this.m_kernelMatrix = (double[][]) null;
        this.m_Pool.shutdown();
        if (this.m_kernel instanceof CachedKernel) {
            this.m_kernel = Kernel.makeCopy(this.m_kernel);
            this.m_kernel.setCacheSize(-1);
            this.m_kernel.buildKernel(this.m_data);
        }
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        this.m_replaceMissing.input(instance);
        this.m_removeUseless.input(this.m_replaceMissing.output());
        this.m_nominalToBinary.input(this.m_removeUseless.output());
        this.m_standardize.input(this.m_nominalToBinary.output());
        Instance output = this.m_standardize.output();
        double d = this.m_weights[this.m_data.numInstances()];
        for (int i = 0; i < this.m_data.numInstances(); i++) {
            d += this.m_weights[i] * this.m_kernel.eval(-1, i, output);
        }
        double[] dArr = {1.0d - dArr[1], 1.0d / (1.0d + Math.exp(-d))};
        return dArr;
    }

    public String toString() {
        if (this.m_data == null) {
            return "Classifier not built yet.";
        }
        String str = "\nlog(p / (1 - p))\t=\n";
        int i = 0;
        while (i < this.m_data.numInstances()) {
            str = (i > 0 ? str + "\t+  " : str + "\t   ") + Utils.doubleToString(this.m_weights[i], 4) + "   \t* " + ("(standardized) X" + (i + 1)) + "\n";
            i++;
        }
        return str + "\t+  " + Utils.doubleToString(this.m_weights[this.m_data.numInstances()], 4) + "\n";
    }

    public String getRevision() {
        return RevisionUtils.extract("$Revision: ???? $");
    }

    public static void main(String[] strArr) {
        runClassifier(new KernelLogisticRegression(), strArr);
    }
}
