package edu.umn.biomedicus.acronym;

import edu.umn.biomedicus.common.collect.IndexMap;
import edu.umn.biomedicus.serialization.YamlSerialization;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.stream.Collectors;
import org.yaml.snakeyaml.Yaml;

/* loaded from: input_file:edu/umn/biomedicus/acronym/OrthographicAcronymModelTrainer.class */
public class OrthographicAcronymModelTrainer {
    private static final double discounting = 0.9d;
    private final boolean caseSensitive;
    private final transient IndexMap<Character> symbols;
    private final transient int symbolsCount;
    private final transient Set<Character> chars;
    private final double[][][] longformProbs;
    private final double[][][] abbrevProbs;
    private Set<String> longformsLower;
    private Path abbrevPath;
    private Path longformsPath;

    public OrthographicAcronymModelTrainer(boolean z) {
        this.caseSensitive = z;
        this.symbols = z ? OrthographicAcronymModel.CASE_SENS_SYMBOLS : OrthographicAcronymModel.CASE_INSENS_SYMBOLS;
        this.symbolsCount = this.symbols.size();
        this.chars = z ? OrthographicAcronymModel.CASE_SENS_CHARS : OrthographicAcronymModel.CASE_INSENS_CHARS;
        this.longformProbs = new double[this.symbolsCount][this.symbolsCount][this.symbolsCount];
        this.abbrevProbs = new double[this.symbolsCount][this.symbolsCount][this.symbolsCount];
        this.longformsLower = new HashSet();
    }

    public static void main(String[] strArr) {
        Path path = Paths.get(strArr[0], new String[0]);
        Path path2 = Paths.get(strArr[1], new String[0]);
        OrthographicAcronymModelTrainer orthographicAcronymModelTrainer = new OrthographicAcronymModelTrainer(true);
        orthographicAcronymModelTrainer.setAbbrevPath(path);
        orthographicAcronymModelTrainer.setLongformsPath(path2);
        try {
            orthographicAcronymModelTrainer.trainTrigramModel();
            orthographicAcronymModelTrainer.write(Paths.get(strArr[2], new String[0]));
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public void setAbbrevPath(Path path) {
        this.abbrevPath = path;
    }

    public void setLongformsPath(Path path) {
        this.longformsPath = path;
    }

    private void trainTrigramModel() throws IOException {
        Set<String> set = (Set) Files.lines(this.abbrevPath).collect(Collectors.toSet());
        Set<String> set2 = (Set) Files.lines(this.longformsPath).collect(Collectors.toSet());
        this.longformsLower = (Set) set2.stream().map((v0) -> {
            return v0.toLowerCase();
        }).collect(Collectors.toSet());
        wordsToLogProbs(set2, this.longformProbs);
        wordsToLogProbs(set, this.abbrevProbs);
    }

    private void wordsToLogProbs(Set<String> set, double[][][] dArr) {
        int[][][] iArr = new int[this.symbolsCount][this.symbolsCount][this.symbolsCount];
        Iterator<String> it = set.iterator();
        while (it.hasNext()) {
            addTrigramsFromWord(it.next(), iArr);
        }
        for (int i = 0; i < this.symbolsCount; i++) {
            for (int i2 = 0; i2 < this.symbolsCount; i2++) {
                for (int i3 = 0; i3 < this.symbolsCount; i3++) {
                    dArr[i][i2][i3] = (float) getTrigramLogProbability(i, i2, i3, iArr);
                }
            }
        }
    }

    private double getTrigramLogProbability(int i, int i2, int i3, int[][][] iArr) {
        double tensorSum;
        double d = 0.0d;
        int tensorSum2 = tensorSum(iArr[i][i2]);
        if (tensorSum2 == 0) {
            tensorSum = getBigramProbability(i2, i3, iArr);
        } else {
            int i4 = iArr[i][i2][i3];
            if (i4 > 0) {
                d = 0.0d + ((i4 - discounting) / tensorSum2);
            }
            tensorSum = d + (((discounting * tensorSum(iArr[i][i2], true)) / tensorSum2) * getBigramProbability(i2, i3, iArr));
        }
        if (tensorSum <= 0.0d) {
            tensorSum = 1.0d / tensorSum(iArr);
        }
        return Math.log(tensorSum);
    }

    private double getBigramProbability(int i, int i2, int[][][] iArr) {
        if (tensorSum(iArr[i]) == 0) {
            return getUnigramProbability(i2, iArr);
        }
        double d = 0.0d;
        int tensorSum = tensorSum(iArr[i][i2]);
        if (tensorSum > 0) {
            d = 0.0d + ((tensorSum - discounting) / tensorSum(iArr[i]));
        }
        return d + (((discounting * tensorSum(iArr[i], true)) / tensorSum(iArr[i])) * getUnigramProbability(i2, iArr));
    }

    private double getUnigramProbability(int i, int[][][] iArr) {
        return tensorSum(iArr[i]) / tensorSum(iArr);
    }

    private int tensorSum(int[] iArr, boolean z) {
        int i = 0;
        for (int i2 : iArr) {
            if (!z) {
                i += i2;
            } else if (i2 > 0) {
                i++;
            }
        }
        return i;
    }

    private int tensorSum(int[][] iArr, boolean z) {
        int i = 0;
        for (int[] iArr2 : iArr) {
            i += tensorSum(iArr2, z);
        }
        return i;
    }

    private int tensorSum(int[][][] iArr, boolean z) {
        int i = 0;
        for (int[][] iArr2 : iArr) {
            i += tensorSum(iArr2, z);
        }
        return i;
    }

    private int tensorSum(int[] iArr) {
        return tensorSum(iArr, false);
    }

    private int tensorSum(int[][] iArr) {
        return tensorSum(iArr, false);
    }

    private int tensorSum(int[][][] iArr) {
        return tensorSum(iArr, false);
    }

    private void addTrigramsFromWord(String str, int[][][] iArr) {
        char c = '^';
        char c2 = '^';
        char c3 = '^';
        for (int i = 0; i < str.length(); i++) {
            c3 = fixChar(str.charAt(i));
            int[] iArr2 = iArr[this.symbols.indexOf(Character.valueOf(c)).intValue()][this.symbols.indexOf(Character.valueOf(c2)).intValue()];
            int intValue = this.symbols.indexOf(Character.valueOf(c3)).intValue();
            iArr2[intValue] = iArr2[intValue] + 1;
            c = c2;
            c2 = c3;
        }
        int[] iArr3 = iArr[this.symbols.indexOf(Character.valueOf(c2)).intValue()][this.symbols.indexOf(Character.valueOf(c3)).intValue()];
        int intValue2 = this.symbols.indexOf('$').intValue();
        iArr3[intValue2] = iArr3[intValue2] + 1;
        int[] iArr4 = iArr[this.symbols.indexOf(Character.valueOf(c3)).intValue()][this.symbols.indexOf('$').intValue()];
        int intValue3 = this.symbols.indexOf('$').intValue();
        iArr4[intValue3] = iArr4[intValue3] + 1;
    }

    private char fixChar(char c) {
        if (!this.caseSensitive) {
            c = Character.toLowerCase(c);
        }
        if (Character.isDigit(c)) {
            c = '0';
        } else if (!this.chars.contains(Character.valueOf(c))) {
            c = '?';
        }
        return c;
    }

    private void write(Path path) throws IOException {
        Yaml createYaml = YamlSerialization.createYaml();
        TreeMap treeMap = new TreeMap();
        treeMap.put("abbrevProbs", collapseProbs(this.abbrevProbs));
        treeMap.put("longformProbs", collapseProbs(this.longformProbs));
        treeMap.put("longformsLower", this.longformsLower.stream().collect(Collectors.toList()));
        treeMap.put("caseSensitive", Boolean.valueOf(this.caseSensitive));
        createYaml.dump(treeMap, Files.newBufferedWriter(path, new OpenOption[0]));
    }

    private Map<String, Double> collapseProbs(double[][][] dArr) {
        TreeMap treeMap = new TreeMap();
        for (int i = 0; i < dArr.length; i++) {
            for (int i2 = 0; i2 < dArr[i].length; i2++) {
                for (int i3 = 0; i3 < dArr[i][i2].length; i3++) {
                    double d = dArr[i][i2][i3];
                    if (d != 0.0d) {
                        treeMap.put("" + this.symbols.forIndex(i) + this.symbols.forIndex(i2) + this.symbols.forIndex(i3), Double.valueOf(d));
                    }
                }
            }
        }
        return treeMap;
    }
}
