package edu.umn.biomedicus.acronym;

import edu.umn.biomedicus.sentences.Sentence;
import edu.umn.biomedicus.tokenization.Token;
import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.regex.Pattern;
import javax.annotation.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:edu/umn/biomedicus/acronym/WordVectorSpace.class */
public class WordVectorSpace {
    private static final double SLOPE = 0.3d;
    private static final double IDF_POWER = 1.0d;
    private static final double THRESH_WEIGHT = 0.25d;
    private transient double maxDist;
    private transient double windowSize;
    private SparseVector idf;
    private static final Logger LOGGER = LoggerFactory.getLogger(WordVectorSpace.class);
    private static final BiFunction<Integer, Double, Double> DIST_WEIGHT = (BiFunction) ((Serializable) (num, d) -> {
        return Double.valueOf(IDF_POWER / (IDF_POWER + Math.exp(SLOPE * (Math.abs(num.intValue()) - d.doubleValue()))));
    });
    private static final Pattern ALPHANUMERIC = Pattern.compile("[a-zA-Z0-9.&_]*");
    private Map<String, Integer> dictionary = new HashMap();
    private Map<Integer, Integer> documentsPerTerm = new HashMap();
    private long totalDocs = 0;
    private boolean countingDocuments = true;
    private boolean buildingDictionary = true;

    public WordVectorSpace() {
        setMaxDist(9.0d);
    }

    public double getMaxDist() {
        return this.maxDist;
    }

    public void setMaxDist(double d) {
        this.maxDist = d;
        this.windowSize = (Math.log(3.0d) / SLOPE) + d;
    }

    public SparseVector getIdf() {
        return this.idf;
    }

    public void setIdf(SparseVector sparseVector) {
        this.idf = sparseVector;
    }

    public Map<String, Integer> getDictionary() {
        return this.dictionary;
    }

    public void setDictionary(Map<String, Integer> map) {
        this.dictionary = map;
        this.buildingDictionary = false;
    }

    public Map<Integer, Integer> getDocumentsPerTerm() {
        return this.documentsPerTerm;
    }

    public void setDocumentsPerTerm(Map<Integer, Integer> map) {
        this.documentsPerTerm = map;
    }

    public long getTotalDocs() {
        return this.totalDocs;
    }

    public void setTotalDocs(long j) {
        this.totalDocs = j;
    }

    public boolean getBuildingDictionary() {
        return this.buildingDictionary;
    }

    public void setBuildingDictionary(boolean z) {
        this.buildingDictionary = z;
    }

    public boolean getCountingDocuments() {
        return this.countingDocuments;
    }

    public void setCountingDocuments(boolean z) {
        this.countingDocuments = z;
    }

    public void buildIdf() {
        HashMap hashMap = new HashMap();
        Iterator<Map.Entry<Integer, Integer>> it = this.documentsPerTerm.entrySet().iterator();
        while (it.hasNext()) {
            hashMap.put(it.next().getKey(), Double.valueOf(Math.pow(Math.log((IDF_POWER + this.totalDocs) / r0.getValue().intValue()), IDF_POWER)));
        }
        this.idf = new SparseVector(hashMap);
        this.countingDocuments = false;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public SparseVector vectorize(List<? extends Token> list, int i, int i2) {
        HashMap hashMap = new HashMap();
        int max = Math.max(i - ((int) this.windowSize), 0);
        int min = Math.min(i2 + ((int) this.windowSize), list.size());
        int i3 = max;
        while (i3 < min) {
            if (i3 == i) {
                if (i2 >= list.size()) {
                    break;
                }
                i3 = i2;
            }
            String standardContextForm = Acronyms.standardContextForm(list.get(i3));
            if (ALPHANUMERIC.matcher(standardContextForm).matches()) {
                int intValue = this.dictionary.getOrDefault(standardContextForm, -1).intValue();
                if (this.buildingDictionary && intValue == -1) {
                    intValue = this.dictionary.size();
                    this.dictionary.put(standardContextForm, Integer.valueOf(intValue));
                }
                if (this.countingDocuments) {
                    this.documentsPerTerm.put(Integer.valueOf(intValue), Integer.valueOf(this.documentsPerTerm.getOrDefault(Integer.valueOf(intValue), 0).intValue() + 1));
                }
                if (intValue != -1) {
                    hashMap.put(Integer.valueOf(intValue), Double.valueOf(((Double) hashMap.getOrDefault(Integer.valueOf(intValue), Double.valueOf(0.0d))).doubleValue() + DIST_WEIGHT.apply(Integer.valueOf(i3 < i ? i - i3 : i3 - i2), Double.valueOf(this.maxDist)).doubleValue()));
                }
            }
            i3++;
        }
        if (this.countingDocuments) {
            this.totalDocs++;
        }
        return new SparseVector(hashMap);
    }

    public SparseVector vectorize(List<? extends Token> list, int i) {
        return vectorize(list, i, i + 1);
    }

    @Nullable
    public Integer removeWord(String str) {
        LOGGER.info("removing word {}", str);
        Integer remove = this.dictionary.remove(str);
        if (remove != null) {
            this.idf.set(remove.intValue(), 0.0d);
            this.documentsPerTerm.remove(remove);
        }
        return remove;
    }

    public Set<Integer> removeWordsExcept(Set<String> set) {
        LOGGER.info("dictionary size before de-ID: {}", Integer.valueOf(this.dictionary.size()));
        HashSet hashSet = new HashSet();
        for (String str : new HashSet(this.dictionary.keySet())) {
            if (!set.contains(Acronyms.standardContextForm(str))) {
                hashSet.add(removeWord(str));
            }
        }
        LOGGER.info("{} indices removed", Integer.valueOf(hashSet.size()));
        LOGGER.info("dictionary size after de-ID: {}", Integer.valueOf(this.dictionary.size()));
        return hashSet;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 1845277352:
                if (implMethodName.equals("lambda$static$68a85afc$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case Sentence.unknown /* 0 */:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("java/util/function/BiFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("edu/umn/biomedicus/acronym/WordVectorSpace") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Integer;Ljava/lang/Double;)Ljava/lang/Double;")) {
                    return (num, d) -> {
                        return Double.valueOf(IDF_POWER / (IDF_POWER + Math.exp(SLOPE * (Math.abs(num.intValue()) - d.doubleValue()))));
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
