package ws.palladian.extraction.text.vector;

import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import ws.palladian.core.FeatureVector;
import ws.palladian.core.InstanceBuilder;
import ws.palladian.core.dataset.AbstractDatasetFeatureVectorTransformer;
import ws.palladian.core.dataset.FeatureInformation;
import ws.palladian.core.dataset.FeatureInformationBuilder;
import ws.palladian.core.value.NominalValue;
import ws.palladian.core.value.TextValue;
import ws.palladian.core.value.Value;
import ws.palladian.core.value.ValueDefinitions;
import ws.palladian.extraction.token.Tokenizer;

/* loaded from: input_file:ws/palladian/extraction/text/vector/EmbeddingTextVectorizer.class */
public class EmbeddingTextVectorizer extends AbstractDatasetFeatureVectorTransformer implements ITextVectorizer {
    private final String inputFeatureName;
    private final WordVectorDictionary dictionary;

    public EmbeddingTextVectorizer(String str, WordVectorDictionary wordVectorDictionary) {
        this.inputFeatureName = (String) Objects.requireNonNull(str, "inputFeatureName must not be null");
        this.dictionary = (WordVectorDictionary) Objects.requireNonNull(wordVectorDictionary, "dictionary must not be null");
    }

    @Override // ws.palladian.core.dataset.AbstractDatasetFeatureVectorTransformer, ws.palladian.core.dataset.DatasetTransformer
    public FeatureInformation getFeatureInformation(FeatureInformation featureInformation) {
        FeatureInformationBuilder featureInformationBuilder = new FeatureInformationBuilder();
        for (int i = 0; i < this.dictionary.vectorSize(); i++) {
            featureInformationBuilder.set("embedding-" + i, ValueDefinitions.floatValue());
        }
        return featureInformationBuilder.m80create();
    }

    @Override // ws.palladian.core.dataset.AbstractDatasetFeatureVectorTransformer
    public FeatureVector compute(FeatureVector featureVector) {
        String textValue = getTextValue(featureVector);
        if (!this.dictionary.isCaseSensitive()) {
            textValue = textValue.toLowerCase();
        }
        List<String> list = Tokenizer.tokenize(textValue.toLowerCase());
        float[] fArr = new float[this.dictionary.vectorSize()];
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            float[] vector = this.dictionary.getVector(it.next());
            if (vector != null) {
                fArr = FloatVectorUtil.add(fArr, vector);
            }
        }
        if (list.size() > 0) {
            fArr = FloatVectorUtil.scalar(fArr, 1.0f / list.size());
        }
        InstanceBuilder instanceBuilder = new InstanceBuilder();
        for (int i = 0; i < fArr.length; i++) {
            instanceBuilder.set("embedding-" + i, fArr[i]);
        }
        return instanceBuilder.create();
    }

    private String getTextValue(FeatureVector featureVector) {
        Value value = (Value) featureVector.get(this.inputFeatureName);
        if (value instanceof NominalValue) {
            return ((NominalValue) value).getString();
        }
        if (value instanceof TextValue) {
            return ((TextValue) value).getText();
        }
        throw new IllegalArgumentException("Invalid type: " + value.getClass().getName());
    }

    public String toString() {
        return String.format("%s [%s]", getClass().getSimpleName(), this.dictionary);
    }
}
