package ai.djl.training.loss;

import ai.djl.modality.cv.MultiBoxTarget;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;

/* loaded from: input_file:ai/djl/training/loss/SingleShotDetectionLoss.class */
public class SingleShotDetectionLoss extends Loss {
    private Loss softmaxLoss;
    private Loss l1Loss;
    private MultiBoxTarget multiBoxTarget;

    public SingleShotDetectionLoss(String str) {
        super(str);
        this.softmaxLoss = Loss.softmaxCrossEntropyLoss();
        this.l1Loss = Loss.l1Loss();
        this.multiBoxTarget = new MultiBoxTarget.Builder().build();
    }

    @Override // ai.djl.training.loss.Loss
    public NDArray getLoss(NDList nDList, NDList nDList2) {
        NDArray nDArray = nDList2.get(0);
        NDArray nDArray2 = nDList2.get(1);
        NDArray nDArray3 = nDList2.get(2);
        NDList target = this.multiBoxTarget.target(new NDList(nDArray, nDList.head(), nDArray2.transpose(0, 2, 1)));
        NDArray nDArray4 = target.get(0);
        NDArray nDArray5 = target.get(1);
        return this.softmaxLoss.getLoss(new NDList(target.get(2)), new NDList(nDArray2)).add(this.l1Loss.getLoss(new NDList(nDArray4.mul(nDArray5)), new NDList(nDArray3.mul(nDArray5))));
    }
}
