package ai.djl.modality.nlp.preprocess;

import java.util.ArrayList;
import java.util.List;

/* loaded from: input_file:ai/djl/modality/nlp/preprocess/SentenceLengthNormalizer.class */
public class SentenceLengthNormalizer implements TextProcessor {
    private static final int DEFAULT_SENTENCE_LENGTH = 10;
    private static final String DEFAULT_PADDING_TOKEN = "<pad>";
    private static final String DEFAULT_EOS_TOKEN = "<eos>";
    private static final String DEFAULT_BOS_TOKEN = "<bos>";
    private int sentenceLength;
    private boolean addEosBosTokens;
    private String paddingToken;
    private String eosToken;
    private String bosToken;
    private int lastValidLength;

    public SentenceLengthNormalizer() {
        this(DEFAULT_SENTENCE_LENGTH, false);
    }

    public SentenceLengthNormalizer(int i, boolean z) {
        this(i, z, DEFAULT_PADDING_TOKEN, DEFAULT_EOS_TOKEN, DEFAULT_BOS_TOKEN);
    }

    public SentenceLengthNormalizer(int i, boolean z, String str, String str2, String str3) {
        this.lastValidLength = -1;
        this.sentenceLength = i;
        this.addEosBosTokens = z;
        this.paddingToken = str;
        this.eosToken = str2;
        this.bosToken = str3;
    }

    @Override // ai.djl.modality.nlp.preprocess.TextProcessor
    public List<String> preprocess(List<String> list) {
        ArrayList arrayList = new ArrayList(this.sentenceLength);
        if (this.addEosBosTokens) {
            arrayList.add(this.bosToken);
        }
        arrayList.addAll(list);
        if (this.addEosBosTokens) {
            arrayList.add(this.eosToken);
        }
        int size = arrayList.size();
        if (this.sentenceLength < size) {
            this.lastValidLength = this.sentenceLength;
            if (this.addEosBosTokens) {
                arrayList.set(this.sentenceLength - 1, this.eosToken);
            }
            return arrayList.subList(0, this.sentenceLength);
        }
        this.lastValidLength = size;
        for (int i = size; i < this.sentenceLength; i++) {
            arrayList.add(this.paddingToken);
        }
        return arrayList;
    }

    public int getLastValidLength() {
        return this.lastValidLength;
    }
}
