package edu.umass.cs.mallet.base.topics;

import edu.umass.cs.mallet.base.types.Alphabet;
import edu.umass.cs.mallet.base.types.FeatureSequence;
import edu.umass.cs.mallet.base.types.InstanceList;
import edu.umass.cs.mallet.base.util.Random;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import org.codehaus.groovy.tools.shell.util.ANSI;
import pl.edu.icm.yadda.exports.zentralblatt.YElementToZentralBlattConverter;

/* loaded from: input_file:WEB-INF/lib/mallet-0.1.3.jar:edu/umass/cs/mallet/base/topics/LDA.class */
public class LDA {
    int numTopics;
    double alpha;
    double beta;
    double tAlpha;
    double vBeta;
    InstanceList ilist;
    int[][] topics;
    int numTypes;
    int numTokens;
    int[][] docTopicCounts;
    int[][] typeTopicCounts;
    int[] tokensPerTopic;
    private static final long serialVersionUID = 1;
    private static final int CURRENT_SERIAL_VERSION = 0;
    private static final int NULL_INTEGER = -1;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: edu.umass.cs.mallet.base.topics.LDA$1WordProb, reason: invalid class name */
    /* loaded from: input_file:WEB-INF/lib/mallet-0.1.3.jar:edu/umass/cs/mallet/base/topics/LDA$1WordProb.class */
    public class C1WordProb implements Comparable {
        int wi;
        double p;
        private final LDA this$0;

        public C1WordProb(LDA lda, int i, double d) {
            this.this$0 = lda;
            this.wi = i;
            this.p = d;
        }

        @Override // java.lang.Comparable
        public final int compareTo(Object obj) {
            if (this.p > ((C1WordProb) obj).p) {
                return -1;
            }
            return this.p == ((C1WordProb) obj).p ? 0 : 1;
        }
    }

    public LDA(int i) {
        this(i, 50.0d, 0.01d);
    }

    public LDA(int i, double d, double d2) {
        this.numTopics = i;
        this.alpha = d / this.numTopics;
        this.beta = d2;
    }

    /* JADX WARN: Type inference failed for: r1v6, types: [int[], int[][]] */
    public void estimate(InstanceList instanceList, int i, int i2, int i3, String str, Random random) {
        this.ilist = instanceList;
        this.numTypes = this.ilist.getDataAlphabet().size();
        int size = this.ilist.size();
        this.topics = new int[size];
        this.docTopicCounts = new int[size][this.numTopics];
        this.typeTopicCounts = new int[this.numTypes][this.numTopics];
        this.tokensPerTopic = new int[this.numTopics];
        this.tAlpha = this.alpha * this.numTopics;
        this.vBeta = this.beta * this.numTypes;
        long currentTimeMillis = System.currentTimeMillis();
        for (int i4 = 0; i4 < size; i4++) {
            FeatureSequence featureSequence = (FeatureSequence) this.ilist.getInstance(i4).getData();
            int length = featureSequence.getLength();
            this.numTokens += length;
            this.topics[i4] = new int[length];
            for (int i5 = 0; i5 < length; i5++) {
                int nextInt = random.nextInt(this.numTopics);
                this.topics[i4][i5] = nextInt;
                int[] iArr = this.docTopicCounts[i4];
                iArr[nextInt] = iArr[nextInt] + 1;
                int[] iArr2 = this.typeTopicCounts[featureSequence.getIndexAtPosition(i5)];
                iArr2[nextInt] = iArr2[nextInt] + 1;
                int[] iArr3 = this.tokensPerTopic;
                iArr3[nextInt] = iArr3[nextInt] + 1;
            }
        }
        for (int i6 = 0; i6 < i; i6++) {
            if (i6 % 10 == 0) {
                System.out.print(i6);
            } else {
                System.out.print(".");
            }
            System.out.flush();
            if (i2 != 0 && i6 % i2 == 0 && i6 > 0) {
                System.out.println();
                printTopWords(5, false);
            }
            if (i3 != 0 && i6 % i3 == 0 && i6 > 0) {
                write(new File(new StringBuffer().append(str).append('.').append(i6).toString()));
            }
            sampleTopicsForAllDocs(random);
        }
        long round = Math.round((System.currentTimeMillis() - currentTimeMillis) / 1000.0d);
        long j = round / 60;
        long j2 = round % 60;
        long j3 = j / 60;
        long j4 = j % 60;
        long j5 = j3 / 24;
        long j6 = j3 % 24;
        System.out.print("\nTotal time: ");
        if (j5 != 0) {
            System.out.print(j5);
            System.out.print(" days ");
        }
        if (j6 != 0) {
            System.out.print(j6);
            System.out.print(" hours ");
        }
        if (j4 != 0) {
            System.out.print(j4);
            System.out.print(" minutes ");
        }
        System.out.print(j2);
        System.out.println(" seconds");
    }

    private void sampleTopicsForAllDocs(Random random) {
        double[] dArr = new double[this.numTopics];
        for (int i = 0; i < this.topics.length; i++) {
            sampleTopicsForOneDoc((FeatureSequence) this.ilist.getInstance(i).getData(), this.topics[i], this.docTopicCounts[i], dArr, random);
        }
    }

    private void sampleTopicsForOneDoc(FeatureSequence featureSequence, int[] iArr, int[] iArr2, double[] dArr, Random random) {
        int length = featureSequence.getLength();
        for (int i = 0; i < length; i++) {
            int indexAtPosition = featureSequence.getIndexAtPosition(i);
            int i2 = iArr[i];
            iArr2[i2] = iArr2[i2] - 1;
            int[] iArr3 = this.typeTopicCounts[indexAtPosition];
            iArr3[i2] = iArr3[i2] - 1;
            int[] iArr4 = this.tokensPerTopic;
            iArr4[i2] = iArr4[i2] - 1;
            Arrays.fill(dArr, 0.0d);
            double d = 0.0d;
            int[] iArr5 = this.typeTopicCounts[indexAtPosition];
            for (int i3 = 0; i3 < this.numTopics; i3++) {
                double d2 = ((iArr5[i3] + this.beta) / (this.tokensPerTopic[i3] + this.vBeta)) * (iArr2[i3] + this.alpha);
                d += d2;
                dArr[i3] = d2;
            }
            int nextDiscrete = random.nextDiscrete(dArr, d);
            iArr[i] = nextDiscrete;
            iArr2[nextDiscrete] = iArr2[nextDiscrete] + 1;
            int[] iArr6 = this.typeTopicCounts[indexAtPosition];
            iArr6[nextDiscrete] = iArr6[nextDiscrete] + 1;
            int[] iArr7 = this.tokensPerTopic;
            iArr7[nextDiscrete] = iArr7[nextDiscrete] + 1;
        }
    }

    public void printTopWords(int i, boolean z) {
        C1WordProb[] c1WordProbArr = new C1WordProb[this.numTypes];
        for (int i2 = 0; i2 < this.numTopics; i2++) {
            for (int i3 = 0; i3 < this.numTypes; i3++) {
                c1WordProbArr[i3] = new C1WordProb(this, i3, this.typeTopicCounts[i3][i2] / this.tokensPerTopic[i2]);
            }
            Arrays.sort(c1WordProbArr);
            if (z) {
                System.out.println(new StringBuffer().append("\nTopic ").append(i2).toString());
                for (int i4 = 0; i4 < i; i4++) {
                    System.out.println(new StringBuffer().append(this.ilist.getDataAlphabet().lookupObject(c1WordProbArr[i4].wi).toString()).append(ANSI.Renderer.CODE_TEXT_SEPARATOR).append(c1WordProbArr[i4].p).toString());
                }
            } else {
                System.out.print(new StringBuffer().append("Topic ").append(i2).append(YElementToZentralBlattConverter.SUGGESTED_DICTIONARY_VALUE_SEPARATOR).toString());
                for (int i5 = 0; i5 < i; i5++) {
                    System.out.print(new StringBuffer().append(this.ilist.getDataAlphabet().lookupObject(c1WordProbArr[i5].wi).toString()).append(ANSI.Renderer.CODE_TEXT_SEPARATOR).toString());
                }
                System.out.println();
            }
        }
    }

    public void printDocumentTopics(File file) throws IOException {
        printDocumentTopics(new PrintWriter(new FileWriter(file)));
    }

    public void printDocumentTopics(PrintWriter printWriter) {
        printWriter.println("#doc source topic proportions");
        for (int i = 0; i < this.topics.length; i++) {
            printWriter.print(i);
            printWriter.print(' ');
            int length = this.topics[i].length;
            for (int i2 = 0; i2 < this.numTopics; i2++) {
                printWriter.print(this.docTopicCounts[i][i2] / length);
            }
            printWriter.print(' ');
            printWriter.println(this.ilist.getInstance(i).getSource().toString());
            printWriter.print(' ');
        }
    }

    public void printState(File file) throws IOException {
        printState(new PrintWriter(new FileWriter(file)));
    }

    public void printState(PrintWriter printWriter) {
        Alphabet dataAlphabet = this.ilist.getDataAlphabet();
        printWriter.println("#doc pos typeindex type topic");
        for (int i = 0; i < this.topics.length; i++) {
            FeatureSequence featureSequence = (FeatureSequence) this.ilist.getInstance(i).getData();
            for (int i2 = 0; i2 < this.topics[i].length; i2++) {
                int indexAtPosition = featureSequence.getIndexAtPosition(i2);
                printWriter.print(i);
                printWriter.print(' ');
                printWriter.print(i2);
                printWriter.print(' ');
                printWriter.print(indexAtPosition);
                printWriter.print(' ');
                printWriter.print(dataAlphabet.lookupObject(indexAtPosition));
                printWriter.print(' ');
                printWriter.print(this.topics[i][i2]);
                printWriter.println();
            }
        }
        printWriter.close();
    }

    public void write(File file) {
        try {
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(file));
            objectOutputStream.writeObject(this);
            objectOutputStream.close();
        } catch (IOException e) {
            System.err.println(new StringBuffer().append("Exception writing file ").append(file).append(YElementToZentralBlattConverter.SUGGESTED_DICTIONARY_VALUE_SEPARATOR).append(e).toString());
        }
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.writeInt(0);
        objectOutputStream.writeObject(this.ilist);
        objectOutputStream.writeInt(this.numTopics);
        objectOutputStream.writeDouble(this.alpha);
        objectOutputStream.writeDouble(this.beta);
        objectOutputStream.writeDouble(this.tAlpha);
        objectOutputStream.writeDouble(this.vBeta);
        for (int i = 0; i < this.topics.length; i++) {
            for (int i2 = 0; i2 < this.topics[i].length; i2++) {
                objectOutputStream.writeInt(this.topics[i][i2]);
            }
        }
        for (int i3 = 0; i3 < this.topics.length; i3++) {
            for (int i4 = 0; i4 < this.numTopics; i4++) {
                objectOutputStream.writeInt(this.docTopicCounts[i3][i4]);
            }
        }
        for (int i5 = 0; i5 < this.numTypes; i5++) {
            for (int i6 = 0; i6 < this.numTopics; i6++) {
                objectOutputStream.writeInt(this.typeTopicCounts[i5][i6]);
            }
        }
        for (int i7 = 0; i7 < this.numTopics; i7++) {
            objectOutputStream.writeInt(this.tokensPerTopic[i7]);
        }
    }

    /* JADX WARN: Type inference failed for: r1v14, types: [int[], int[][]] */
    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.readInt();
        this.ilist = (InstanceList) objectInputStream.readObject();
        this.numTopics = objectInputStream.readInt();
        this.alpha = objectInputStream.readDouble();
        this.beta = objectInputStream.readDouble();
        this.tAlpha = objectInputStream.readDouble();
        this.vBeta = objectInputStream.readDouble();
        int size = this.ilist.size();
        this.topics = new int[size];
        for (int i = 0; i < this.ilist.size(); i++) {
            int length = ((FeatureSequence) this.ilist.getInstance(i).getData()).getLength();
            this.topics[i] = new int[length];
            for (int i2 = 0; i2 < length; i2++) {
                this.topics[i][i2] = objectInputStream.readInt();
            }
        }
        this.docTopicCounts = new int[size][this.numTopics];
        for (int i3 = 0; i3 < this.ilist.size(); i3++) {
            for (int i4 = 0; i4 < this.numTopics; i4++) {
                this.docTopicCounts[i3][i4] = objectInputStream.readInt();
            }
        }
        int size2 = this.ilist.getDataAlphabet().size();
        this.typeTopicCounts = new int[size2][this.numTopics];
        for (int i5 = 0; i5 < size2; i5++) {
            for (int i6 = 0; i6 < this.numTopics; i6++) {
                this.typeTopicCounts[i5][i6] = objectInputStream.readInt();
            }
        }
        this.tokensPerTopic = new int[this.numTopics];
        for (int i7 = 0; i7 < this.numTopics; i7++) {
            this.tokensPerTopic[i7] = objectInputStream.readInt();
        }
    }

    public static void main(String[] strArr) throws IOException {
        InstanceList load = InstanceList.load(new File(strArr[0]));
        int parseInt = strArr.length > 1 ? Integer.parseInt(strArr[1]) : 1000;
        int parseInt2 = strArr.length > 2 ? Integer.parseInt(strArr[2]) : 20;
        System.out.println("Data loaded.");
        LDA lda = new LDA(10);
        lda.estimate(load, parseInt, 50, 0, null, new Random());
        lda.printTopWords(parseInt2, true);
        lda.printDocumentTopics(new File(new StringBuffer().append(strArr[0]).append(".lda").toString()));
    }
}
