package edu.stanford.nlp.ie.crf;

import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.sequences.SeqClassifierFlags;

/* loaded from: input_file:WEB-INF/lib/stanford-corenlp-3.2.0.jar:edu/stanford/nlp/ie/crf/NonLinearSecondOrderCliquePotentialFunction.class */
public class NonLinearSecondOrderCliquePotentialFunction implements CliquePotentialFunction {
    double[][] inputLayerWeights4Edge;
    double[][] outputLayerWeights4Edge;
    double[][] inputLayerWeights;
    double[][] outputLayerWeights;
    double[] layerOneCache;
    double[] hiddenLayerCache;
    double[] layerOneCache4Edge;
    double[] hiddenLayerCache4Edge;
    SeqClassifierFlags flags;

    public NonLinearSecondOrderCliquePotentialFunction(double[][] dArr, double[][] dArr2, double[][] dArr3, double[][] dArr4, SeqClassifierFlags seqClassifierFlags) {
        this.inputLayerWeights4Edge = dArr;
        this.outputLayerWeights4Edge = dArr2;
        this.inputLayerWeights = dArr3;
        this.outputLayerWeights = dArr4;
        this.flags = seqClassifierFlags;
    }

    public double[] hiddenLayerOutput(double[][] dArr, int[] iArr, SeqClassifierFlags seqClassifierFlags, double[] dArr2, int i) {
        double[] dArr3;
        double[] dArr4;
        int length = dArr.length;
        if (i > 1) {
            if (this.layerOneCache4Edge == null || length != this.layerOneCache4Edge.length) {
                this.layerOneCache4Edge = new double[length];
            }
            dArr3 = this.layerOneCache4Edge;
        } else {
            if (this.layerOneCache == null || length != this.layerOneCache.length) {
                this.layerOneCache = new double[length];
            }
            dArr3 = this.layerOneCache;
        }
        for (int i2 = 0; i2 < length; i2++) {
            double[] dArr5 = dArr[i2];
            double d = 0.0d;
            for (int i3 = 0; i3 < iArr.length; i3++) {
                double d2 = dArr5[iArr[i3]];
                if (dArr2 != null) {
                    d2 *= dArr2[i3];
                }
                d += d2;
            }
            dArr3[i2] = d;
        }
        if (!seqClassifierFlags.useHiddenLayer) {
            return dArr3;
        }
        if (i > 1) {
            if (this.hiddenLayerCache4Edge == null || length != this.hiddenLayerCache4Edge.length) {
                this.hiddenLayerCache4Edge = new double[length];
            }
            dArr4 = this.hiddenLayerCache4Edge;
        } else {
            if (this.hiddenLayerCache == null || length != this.hiddenLayerCache.length) {
                this.hiddenLayerCache = new double[length];
            }
            dArr4 = this.hiddenLayerCache;
        }
        for (int i4 = 0; i4 < length; i4++) {
            if (seqClassifierFlags.useSigmoid) {
                dArr4[i4] = sigmoid(dArr3[i4]);
            } else {
                dArr4[i4] = Math.tanh(dArr3[i4]);
            }
        }
        return dArr4;
    }

    private static double sigmoid(double d) {
        return 1.0d / (1.0d + Math.exp(-d));
    }

    @Override // edu.stanford.nlp.ie.crf.CliquePotentialFunction
    public double computeCliquePotential(int i, int i2, int[] iArr, double[] dArr) {
        double[][] dArr2;
        double[][] dArr3;
        double d = 0.0d;
        if (i > 1) {
            dArr2 = this.inputLayerWeights4Edge;
            dArr3 = this.outputLayerWeights4Edge;
        } else {
            dArr2 = this.inputLayerWeights;
            dArr3 = this.outputLayerWeights;
        }
        double[] hiddenLayerOutput = hiddenLayerOutput(dArr2, iArr, this.flags, dArr, i);
        int length = dArr2.length / dArr3[0].length;
        if (this.flags.useOutputLayer) {
            double[] dArr4 = this.flags.tieOutputLayer ? dArr3[0] : dArr3[i2];
            if (this.flags.softmaxOutputLayer) {
                dArr4 = ArrayMath.softmax(dArr4);
            }
            for (int i3 = 0; i3 < dArr2.length; i3++) {
                if (!this.flags.sparseOutputLayer && !this.flags.tieOutputLayer) {
                    d += dArr4[i3] * hiddenLayerOutput[i3];
                } else if (i3 % length == i2) {
                    d += dArr4[i3 / length] * hiddenLayerOutput[i3];
                }
            }
        } else {
            d = hiddenLayerOutput[i2];
        }
        return d;
    }
}
