package ai.djl.tensorflow.zoo.cv.objectdetction;

import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.translator.SingleShotDetectionTranslator;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.translate.Batchifier;
import ai.djl.translate.TranslatorContext;
import java.util.ArrayList;
import java.util.Iterator;

/* loaded from: input_file:ai/djl/tensorflow/zoo/cv/objectdetction/TfSsdTranslator.class */
public class TfSsdTranslator extends SingleShotDetectionTranslator {
    private int maxBoxes;
    private float threshHold;

    /* loaded from: input_file:ai/djl/tensorflow/zoo/cv/objectdetction/TfSsdTranslator$Builder.class */
    public static class Builder extends SingleShotDetectionTranslator.Builder {
        private int maxBoxes = 10;

        public Builder optMaxBoxes(int i) {
            this.maxBoxes = i;
            return this;
        }

        /* renamed from: build, reason: merged with bridge method [inline-methods] */
        public TfSsdTranslator m4build() {
            validate();
            return new TfSsdTranslator(this);
        }
    }

    protected TfSsdTranslator(Builder builder) {
        super(builder);
        this.maxBoxes = builder.maxBoxes;
        this.threshHold = builder.getThreshold();
    }

    public NDList processInput(TranslatorContext translatorContext, Image image) {
        return new NDList(new NDArray[]{((NDArray) super.processInput(translatorContext, image).get(0)).expandDims(0)});
    }

    public Batchifier getBatchifier() {
        return null;
    }

    /* renamed from: processOutput, reason: merged with bridge method [inline-methods] */
    public DetectedObjects m3processOutput(TranslatorContext translatorContext, NDList nDList) {
        int i = (int) ((NDArray) nDList.get(0)).getShape().get(0);
        float[] fArr = new float[i];
        long[] jArr = new long[i];
        NDArray nDArray = (NDArray) nDList.get(0);
        Iterator it = nDList.iterator();
        while (it.hasNext()) {
            NDArray nDArray2 = (NDArray) it.next();
            DataType dataType = nDArray2.getDataType();
            int dimension = nDArray2.getShape().dimension();
            if (dataType == DataType.FLOAT32 && dimension == 1) {
                fArr = nDArray2.toFloatArray();
            } else if (dataType == DataType.FLOAT32 && dimension == 2) {
                nDArray = nDArray2;
            } else {
                if (dataType != DataType.INT64 || dimension != 1) {
                    throw new IllegalStateException("Unexpected result NDArray type:" + dataType + ", and dim: " + dimension);
                }
                jArr = nDArray2.toLongArray();
            }
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        for (int i2 = 0; i2 < Math.min(jArr.length, this.maxBoxes); i2++) {
            long j = jArr[i2];
            double d = fArr[i2];
            if (j >= 0 && d > this.threshHold) {
                if (j >= this.classes.size()) {
                    throw new AssertionError("Unexpected index: " + j);
                }
                String str = (String) this.classes.get(((int) j) - 1);
                float[] floatArray = nDArray.get(new long[]{i2}).toFloatArray();
                float f = floatArray[0];
                Rectangle rectangle = new Rectangle(floatArray[1], f, floatArray[3] - r0, floatArray[2] - f);
                arrayList.add(str);
                arrayList2.add(Double.valueOf(d));
                arrayList3.add(rectangle);
            }
        }
        return new DetectedObjects(arrayList, arrayList2, arrayList3);
    }

    public static Builder builder() {
        return new Builder();
    }
}
