package ai.djl.basicdataset.nlp;

import ai.djl.basicdataset.BasicDatasets;
import ai.djl.basicdataset.utils.TextData;
import ai.djl.engine.Engine;
import ai.djl.modality.nlp.Vocabulary;
import ai.djl.modality.nlp.embedding.EmbeddingException;
import ai.djl.modality.nlp.embedding.TextEmbedding;
import ai.djl.ndarray.NDManager;
import ai.djl.repository.MRL;
import ai.djl.repository.Repository;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;

/* loaded from: input_file:ai/djl/basicdataset/nlp/TextDataset.class */
public abstract class TextDataset extends RandomAccessDataset {
    protected TextData sourceTextData;
    protected TextData targetTextData;
    protected NDManager manager;
    protected Dataset.Usage usage;
    protected MRL mrl;
    protected boolean prepared;
    protected List<Sample> samples;

    /* loaded from: input_file:ai/djl/basicdataset/nlp/TextDataset$Builder.class */
    public static abstract class Builder<T extends Builder<T>> extends RandomAccessDataset.BaseBuilder<T> {
        protected String artifactId;
        TextData.Configuration sourceConfiguration = new TextData.Configuration();
        TextData.Configuration targetConfiguration = new TextData.Configuration();
        NDManager manager = Engine.getInstance().newBaseManager();
        protected Repository repository = BasicDatasets.REPOSITORY;
        protected String groupId = BasicDatasets.GROUP_ID;
        protected Dataset.Usage usage = Dataset.Usage.TRAIN;

        public T setSourceConfiguration(TextData.Configuration configuration) {
            this.sourceConfiguration = configuration;
            return (T) self();
        }

        public T setTargetConfiguration(TextData.Configuration configuration) {
            this.targetConfiguration = configuration;
            return (T) self();
        }

        public T optManager(NDManager nDManager) {
            this.manager = nDManager.newSubManager();
            return (T) self();
        }

        public T optUsage(Dataset.Usage usage) {
            this.usage = usage;
            return (T) self();
        }

        public T optRepository(Repository repository) {
            this.repository = repository;
            return (T) self();
        }

        public T optGroupId(String str) {
            this.groupId = str;
            return (T) self();
        }

        public T optArtifactId(String str) {
            if (str.contains(":")) {
                String[] split = str.split(":");
                this.groupId = split[0];
                this.artifactId = split[1];
            } else {
                this.artifactId = str;
            }
            return (T) self();
        }
    }

    /* loaded from: input_file:ai/djl/basicdataset/nlp/TextDataset$Sample.class */
    public static final class Sample {
        private int sentenceLength;
        private long index;

        public Sample(int i, int i2) {
            this.index = i;
            this.sentenceLength = i2;
        }

        public int getSentenceLength() {
            return this.sentenceLength;
        }

        public long getIndex() {
            return this.index;
        }
    }

    public TextDataset(Builder<?> builder) {
        super(builder);
        this.sourceTextData = new TextData(TextData.getDefaultConfiguration().update(builder.sourceConfiguration));
        this.targetTextData = new TextData(TextData.getDefaultConfiguration().update(builder.targetConfiguration));
        this.manager = builder.manager;
        this.manager.setName("textDataset");
        this.usage = builder.usage;
    }

    public TextEmbedding getTextEmbedding(boolean z) {
        return (z ? this.sourceTextData : this.targetTextData).getTextEmbedding();
    }

    public Vocabulary getVocabulary(boolean z) {
        return (z ? this.sourceTextData : this.targetTextData).getVocabulary();
    }

    public String getRawText(long j, boolean z) {
        return (z ? this.sourceTextData : this.targetTextData).getRawText(j);
    }

    public List<String> getProcessedText(long j, boolean z) {
        return (z ? this.sourceTextData : this.targetTextData).getProcessedText(j);
    }

    public List<Sample> getSamples() {
        if (this.samples == null) {
            this.samples = new ArrayList();
            for (int i = 0; i < size(); i++) {
                this.samples.add(new Sample(i, getProcessedText(i, true).size()));
            }
            this.samples.sort(Comparator.comparingInt(sample -> {
                return sample.sentenceLength;
            }));
        }
        return this.samples;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void preprocess(List<String> list, boolean z) throws EmbeddingException {
        (z ? this.sourceTextData : this.targetTextData).preprocess(this.manager, list.subList(0, (int) Math.min(this.limit, list.size())));
    }
}
