package ai.djl.training.metrics;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;

/* loaded from: input_file:ai/djl/training/metrics/Accuracy.class */
public class Accuracy extends TrainingMetric {
    private long correctInstances;
    private long totalInstances;
    protected int axis;
    protected int index;

    public Accuracy(String str, int i, int i2) {
        super(str);
        this.axis = i2;
        this.index = i;
    }

    public Accuracy() {
        this("Accuracy", 0, 1);
    }

    public Accuracy(String str, int i) {
        this(str, i, 1);
    }

    @Override // ai.djl.training.metrics.TrainingMetric
    public void reset() {
        this.correctInstances = 0L;
        this.totalInstances = 0L;
    }

    public void update(NDArray nDArray, NDArray nDArray2) {
        checkLabelShapes(nDArray, nDArray2);
        addCorrectInstances(nDArray.asType(DataType.INT64, false).eq((!nDArray.getShape().equals(nDArray2.getShape()) ? nDArray2.argMax(this.axis) : nDArray2).asType(DataType.INT64, false)).countNonzero().getLong(new long[0]));
        addTotalInstances(nDArray.size());
    }

    @Override // ai.djl.training.metrics.TrainingMetric
    public void update(NDList nDList, NDList nDList2) {
        if (nDList.size() != nDList2.size()) {
            throw new IllegalArgumentException("labels and prediction length does not match.");
        }
        update(nDList.get(this.index), nDList2.get(this.index));
    }

    @Override // ai.djl.training.metrics.TrainingMetric
    public float getValue() {
        if (this.totalInstances == 0) {
            return Float.NaN;
        }
        return ((float) this.correctInstances) / ((float) this.totalInstances);
    }

    public void addCorrectInstances(long j) {
        this.correctInstances += j;
    }

    public void addTotalInstances(long j) {
        this.totalInstances += j;
    }
}
