package ai.djl.ndarray.index;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.types.Shape;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Stream;

/* loaded from: input_file:ai/djl/ndarray/index/NDIndex.class */
public class NDIndex {
    private static final Pattern ITEM_PATTERN = Pattern.compile("(\\*)|((-?\\d+)?:(-?\\d+)?(:(-?\\d+))?)|(-?\\d+)");
    private int rank;
    private List<NDIndexElement> indices;

    public NDIndex() {
        this.rank = 0;
        this.indices = new ArrayList();
    }

    public NDIndex(String str) {
        this();
        addIndices(str);
    }

    public NDIndex(long... jArr) {
        this();
        addIndices(jArr);
    }

    public int getRank() {
        return this.rank;
    }

    public NDIndexElement get(int i) {
        return this.indices.get(i);
    }

    public List<NDIndexElement> getIndices() {
        return this.indices;
    }

    public final NDIndex addIndices(String str) {
        String[] split = str.split(",");
        this.rank += split.length;
        for (String str2 : split) {
            addIndexItem(str2);
        }
        return this;
    }

    public final NDIndex addIndices(long... jArr) {
        this.rank += jArr.length;
        for (long j : jArr) {
            this.indices.add(new NDIndexFixed(j));
        }
        return this;
    }

    public NDIndex addBooleanIndex(NDArray nDArray) {
        this.rank += nDArray.getShape().dimension();
        this.indices.add(new NDIndexBooleans(nDArray));
        return this;
    }

    public NDIndex addSliceDim(long j, long j2) {
        this.rank++;
        this.indices.add(new NDIndexSlice(Long.valueOf(j), Long.valueOf(j2), null));
        return this;
    }

    public NDIndex addSliceDim(long j, long j2, long j3) {
        this.rank++;
        this.indices.add(new NDIndexSlice(Long.valueOf(j), Long.valueOf(j2), Long.valueOf(j3)));
        return this;
    }

    public Stream<NDIndexElement> stream() {
        return this.indices.stream();
    }

    private void addIndexItem(String str) {
        String trim = str.trim();
        Matcher matcher = ITEM_PATTERN.matcher(trim);
        if (!matcher.matches()) {
            throw new IllegalArgumentException("Invalid argument index: " + trim);
        }
        if (matcher.group(1) != null) {
            this.indices.add(new NDIndexAll());
            return;
        }
        String group = matcher.group(7);
        if (group != null) {
            this.indices.add(new NDIndexFixed(Long.parseLong(group)));
            return;
        }
        Long valueOf = matcher.group(3) != null ? Long.valueOf(Long.parseLong(matcher.group(3))) : null;
        Long valueOf2 = matcher.group(4) != null ? Long.valueOf(Long.parseLong(matcher.group(4))) : null;
        Long valueOf3 = matcher.group(6) != null ? Long.valueOf(Long.parseLong(matcher.group(6))) : null;
        if (valueOf == null && valueOf2 == null && valueOf3 == null) {
            this.indices.add(new NDIndexAll());
        } else {
            this.indices.add(new NDIndexSlice(valueOf, valueOf2, valueOf3));
        }
    }

    public Optional<NDIndexFullSlice> getAsFullSlice(Shape shape) {
        if (!stream().allMatch(nDIndexElement -> {
            return (nDIndexElement instanceof NDIndexAll) || (nDIndexElement instanceof NDIndexFixed) || (nDIndexElement instanceof NDIndexSlice);
        })) {
            return Optional.empty();
        }
        int rank = getRank();
        int dimension = shape.dimension();
        if (rank > shape.dimension()) {
            throw new IllegalArgumentException("The index has too many dimensions - " + rank + " dimensions for array with " + dimension + " dimensions");
        }
        long[] jArr = new long[dimension];
        long[] jArr2 = new long[dimension];
        long[] jArr3 = new long[dimension];
        ArrayList arrayList = new ArrayList(dimension);
        long[] jArr4 = new long[dimension];
        ArrayList arrayList2 = new ArrayList(dimension);
        for (int i = 0; i < rank; i++) {
            NDIndexElement nDIndexElement2 = get(i);
            if (nDIndexElement2 instanceof NDIndexFixed) {
                jArr[i] = ((NDIndexFixed) nDIndexElement2).getIndex();
                jArr2[i] = ((NDIndexFixed) nDIndexElement2).getIndex() + 1;
                jArr3[i] = 1;
                arrayList.add(Integer.valueOf(i));
                jArr4[i] = 1;
            } else if (nDIndexElement2 instanceof NDIndexSlice) {
                NDIndexSlice nDIndexSlice = (NDIndexSlice) nDIndexElement2;
                long longValue = ((Long) Optional.ofNullable(nDIndexSlice.getMin()).orElse(0L)).longValue();
                jArr[i] = longValue < 0 ? Math.floorMod(longValue, shape.get(i)) : longValue;
                long longValue2 = ((Long) Optional.ofNullable(nDIndexSlice.getMax()).orElse(Long.valueOf(shape.size(i)))).longValue();
                jArr2[i] = longValue2 < 0 ? Math.floorMod(longValue2, shape.get(i)) : longValue2;
                jArr3[i] = ((Long) Optional.ofNullable(nDIndexSlice.getStep()).orElse(1L)).longValue();
                if (jArr3[i] > 0) {
                    jArr4[i] = ((jArr2[i] - jArr[i]) - 1) / (jArr3[i] + 1);
                } else {
                    jArr4[i] = (jArr[i] - jArr2[i]) / ((-jArr3[i]) + 1);
                }
                arrayList2.add(Long.valueOf(jArr4[i]));
            } else if (nDIndexElement2 instanceof NDIndexAll) {
                jArr[i] = 0;
                jArr2[i] = shape.size(i);
                jArr3[i] = 1;
                jArr4[i] = shape.size(i);
                arrayList2.add(Long.valueOf(shape.size(i)));
            }
        }
        for (int i2 = rank; i2 < shape.dimension(); i2++) {
            jArr[i2] = 0;
            jArr2[i2] = shape.size(i2);
            jArr3[i2] = 1;
            jArr4[i2] = shape.size(i2);
            arrayList2.add(Long.valueOf(shape.size(i2)));
        }
        return Optional.of(new NDIndexFullSlice(jArr, jArr2, jArr3, arrayList, new Shape(jArr4), new Shape(arrayList2)));
    }
}
