package edu.umn.biomedicus.acronym;

import com.google.inject.Inject;
import com.google.inject.ProvidedBy;
import com.google.inject.Provider;
import com.google.inject.Singleton;
import edu.umn.biomedicus.acronyms.ScoredSense;
import edu.umn.biomedicus.annotations.Setting;
import edu.umn.biomedicus.common.tuples.Pair;
import edu.umn.biomedicus.exc.BiomedicusException;
import edu.umn.biomedicus.framework.DataLoader;
import edu.umn.biomedicus.serialization.YamlSerialization;
import edu.umn.biomedicus.tokenization.Token;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.yaml.snakeyaml.Yaml;

/* JADX INFO: Access modifiers changed from: package-private */
@ProvidedBy(Loader.class)
/* loaded from: input_file:edu/umn/biomedicus/acronym/AcronymVectorModel.class */
public class AcronymVectorModel implements AcronymModel {
    private static final Logger LOGGER = LoggerFactory.getLogger(AcronymVectorModel.class);
    private final WordVectorSpace wordVectorSpace;
    private final AcronymExpansionsModel acronymExpansionsModel;
    private final SenseVectors senseVectors;

    @Nullable
    private final AlignmentModel alignmentModel;
    private final double cutoffScore;

    @Singleton
    /* loaded from: input_file:edu/umn/biomedicus/acronym/AcronymVectorModel$Loader.class */
    static class Loader extends DataLoader<AcronymVectorModel> {

        @Nullable
        private final Provider<AlignmentModel> alignmentModel;
        private final Path vectorSpacePath;
        private final Path senseMapPath;
        private final boolean useAlignment;
        private final Boolean sensesInMemory;
        private final AcronymExpansionsModel expansionsModel;
        private final Double cutoffScore;

        @Inject
        public Loader(@Nullable Provider<AlignmentModel> provider, @Setting("acronym.useAlignment") Boolean bool, @Setting("acronym.vector.model.asDataPath") Path path, @Setting("acronym.senseMap.senseVectors.asDataPath") Path path2, @Setting("acronym.senseMap.inMemory") Boolean bool2, @Setting("acronym.cutoffScore") Double d, AcronymExpansionsModel acronymExpansionsModel) {
            this.alignmentModel = provider;
            this.useAlignment = bool.booleanValue();
            this.vectorSpacePath = path;
            this.senseMapPath = path2;
            this.sensesInMemory = bool2;
            this.expansionsModel = acronymExpansionsModel;
            this.cutoffScore = d;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // edu.umn.biomedicus.framework.DataLoader
        public AcronymVectorModel loadModel() throws BiomedicusException {
            Yaml createYaml = YamlSerialization.createYaml();
            try {
                AcronymVectorModel.LOGGER.info("Loading acronym vector space: {}", this.vectorSpacePath);
                WordVectorSpace wordVectorSpace = (WordVectorSpace) createYaml.load(Files.newBufferedReader(this.vectorSpacePath));
                AcronymVectorModel.LOGGER.info("Loading acronym sense map: {}. inMemory = {}", this.senseMapPath, this.sensesInMemory);
                return new AcronymVectorModel(wordVectorSpace, new RocksDBSenseVectors(this.senseMapPath, false).inMemory(this.sensesInMemory), this.expansionsModel, this.useAlignment ? (AlignmentModel) this.alignmentModel.get() : null, this.cutoffScore.doubleValue());
            } catch (IOException e) {
                throw new BiomedicusException(e);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public AcronymVectorModel(WordVectorSpace wordVectorSpace, SenseVectors senseVectors, AcronymExpansionsModel acronymExpansionsModel, @Nullable AlignmentModel alignmentModel, double d) {
        this.acronymExpansionsModel = acronymExpansionsModel;
        this.senseVectors = senseVectors;
        this.wordVectorSpace = wordVectorSpace;
        this.alignmentModel = alignmentModel;
        this.cutoffScore = d;
    }

    public Collection<String> getExpansions(Token token) {
        Collection<String> expansions = this.acronymExpansionsModel.getExpansions(Acronyms.standardAcronymForm(token));
        return expansions != null ? expansions : Collections.emptyList();
    }

    @Override // edu.umn.biomedicus.acronym.AcronymModel
    public boolean hasAcronym(Token token) {
        return this.acronymExpansionsModel.hasExpansions(Acronyms.standardAcronymForm(token));
    }

    @Override // edu.umn.biomedicus.acronym.AcronymModel
    public List<ScoredSense> findBestSense(List<? extends Token> list, int i) {
        String standardAcronymForm = Acronyms.standardAcronymForm(list.get(i));
        Collection<String> expansions = this.acronymExpansionsModel.getExpansions(standardAcronymForm);
        if (expansions == null) {
            expansions = this.acronymExpansionsModel.getExpansions(standardAcronymForm.toUpperCase());
        }
        if (expansions == null) {
            expansions = this.acronymExpansionsModel.getExpansions(standardAcronymForm.replace(".", ""));
        }
        if (expansions == null) {
            expansions = this.acronymExpansionsModel.getExpansions(standardAcronymForm.toLowerCase());
        }
        if (expansions == null && this.alignmentModel != null) {
            expansions = this.alignmentModel.findBestLongforms(standardAcronymForm);
        }
        if (expansions == null || expansions.size() == 0) {
            return Collections.emptyList();
        }
        if (expansions.size() == 1) {
            return Collections.singletonList(new ScoredSense(expansions.iterator().next(), 1.0d));
        }
        ArrayList<Pair> arrayList = new ArrayList();
        for (String str : expansions) {
            SparseVector sparseVector = this.senseVectors.get(str);
            if (sparseVector != null) {
                arrayList.add(Pair.of(str, sparseVector));
            }
        }
        if (arrayList.size() == 0 && this.acronymExpansionsModel.hasExpansions(standardAcronymForm.toUpperCase())) {
            for (String str2 : expansions) {
                SparseVector sparseVector2 = this.senseVectors.get(str2);
                if (sparseVector2 != null) {
                    arrayList.add(Pair.of(str2, sparseVector2));
                }
            }
        }
        if (arrayList.size() == 0) {
            return Collections.emptyList();
        }
        double d = -1.7976931348623157E308d;
        SparseVector vectorize = this.wordVectorSpace.vectorize(list, i);
        for (Pair pair : arrayList) {
            double dot = vectorize.dot((SparseVector) pair.getSecond());
            if (dot > d) {
                d = dot;
            }
        }
        return (List) arrayList.stream().map(pair2 -> {
            return new ScoredSense((String) pair2.first(), vectorize.dot((SparseVector) pair2.getSecond()));
        }).filter(scoredSense -> {
            return scoredSense.getScore() >= this.cutoffScore;
        }).sorted(Comparator.comparing((v0) -> {
            return v0.getScore();
        }).reversed()).collect(Collectors.toList());
    }

    @Override // edu.umn.biomedicus.acronym.AcronymModel
    public void removeWord(String str) {
        Integer removeWord = this.wordVectorSpace.removeWord(str);
        if (removeWord != null) {
            this.senseVectors.removeWord(removeWord.intValue());
        }
    }

    @Override // edu.umn.biomedicus.acronym.AcronymModel
    public void removeWordsExcept(Set<String> set) {
        Set<Integer> removeWordsExcept = this.wordVectorSpace.removeWordsExcept(set);
        removeWordsExcept.remove(null);
        this.senseVectors.removeWords(removeWordsExcept);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void writeToDirectory(Path path, @Nullable Map<String, SparseVector> map) throws IOException {
        Yaml createYaml = YamlSerialization.createYaml();
        if (this.alignmentModel != null) {
            createYaml.dump(this.alignmentModel, Files.newBufferedWriter(path.resolve("alignment.yml"), new OpenOption[0]));
        }
        createYaml.dump(this.wordVectorSpace, Files.newBufferedWriter(path.resolve("vectorSpace.yml"), new OpenOption[0]));
        if (map != null) {
            RocksDBSenseVectors rocksDBSenseVectors = new RocksDBSenseVectors(path.resolve("senseVectors"), true);
            rocksDBSenseVectors.putAll(map);
            rocksDBSenseVectors.close();
        }
    }
}
