package ai.djl.modality.cv;

import ai.djl.modality.cv.output.BoundingBox;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Joints;
import ai.djl.modality.cv.output.Mask;
import ai.djl.modality.cv.output.Point;
import ai.djl.modality.cv.util.BufferedImageUtils;
import ai.djl.util.RandomUtils;
import java.awt.BasicStroke;
import java.awt.Color;
import java.awt.FontMetrics;
import java.awt.Graphics2D;
import java.awt.Rectangle;
import java.awt.RenderingHints;
import java.awt.image.BufferedImage;
import java.awt.image.ImageObserver;

/* loaded from: input_file:ai/djl/modality/cv/ImageVisualization.class */
public final class ImageVisualization {
    private ImageVisualization() {
    }

    public static void drawBoundingBoxes(BufferedImage bufferedImage, DetectedObjects detectedObjects) {
        Graphics2D graphics = bufferedImage.getGraphics();
        graphics.setStroke(new BasicStroke(2));
        graphics.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON);
        int width = bufferedImage.getWidth();
        int height = bufferedImage.getHeight();
        for (DetectedObjects.DetectedObject detectedObject : detectedObjects.items()) {
            String className = detectedObject.getClassName();
            BoundingBox boundingBox = detectedObject.getBoundingBox();
            graphics.setPaint(BufferedImageUtils.randomColor().darker());
            boundingBox.draw(graphics, width, height);
            Point point = boundingBox.getPoint();
            drawText(graphics, className, (int) (point.getX() * width), (int) (point.getY() * height), 2, 4);
            if (boundingBox instanceof Mask) {
                drawMask(bufferedImage, (Mask) boundingBox);
            }
        }
        graphics.dispose();
    }

    private static void drawMask(BufferedImage bufferedImage, Mask mask) {
        float nextFloat = RandomUtils.nextFloat();
        float nextFloat2 = RandomUtils.nextFloat();
        float nextFloat3 = RandomUtils.nextFloat();
        int width = bufferedImage.getWidth();
        int height = bufferedImage.getHeight();
        int x = (int) (mask.getX() * width);
        int y = (int) (mask.getY() * height);
        float[][] probDist = mask.getProbDist();
        if (x < 0) {
            x = 0;
        }
        if (y < 0) {
            y = 0;
        }
        BufferedImage bufferedImage2 = new BufferedImage(probDist.length, probDist[0].length, 2);
        for (int i = 0; i < probDist.length; i++) {
            for (int i2 = 0; i2 < probDist[i].length; i2++) {
                float f = probDist[i][i2];
                if (f < 0.1d) {
                    f = 0.0f;
                }
                if (f > 0.8d) {
                    f = 0.8f;
                }
                bufferedImage2.setRGB(i, i2, new Color(nextFloat, nextFloat2, nextFloat3, f).darker().getRGB());
            }
        }
        Graphics2D graphics = bufferedImage.getGraphics();
        graphics.drawImage(bufferedImage2, x, y, (ImageObserver) null);
        graphics.dispose();
    }

    private static void drawText(Graphics2D graphics2D, String str, int i, int i2, int i3, int i4) {
        FontMetrics fontMetrics = graphics2D.getFontMetrics();
        int i5 = i + (i3 / 2);
        int i6 = i2 + (i3 / 2);
        int stringWidth = (fontMetrics.stringWidth(str) + (i4 * 2)) - (i3 / 2);
        int height = fontMetrics.getHeight() + fontMetrics.getDescent();
        int ascent = fontMetrics.getAscent();
        graphics2D.fill(new Rectangle(i5, i6, stringWidth, height));
        graphics2D.setPaint(Color.WHITE);
        graphics2D.drawString(str, i5 + i4, i6 + ascent);
    }

    public static void drawJoints(BufferedImage bufferedImage, Joints joints) {
        Graphics2D graphics = bufferedImage.getGraphics();
        graphics.setStroke(new BasicStroke(2));
        int width = bufferedImage.getWidth();
        int height = bufferedImage.getHeight();
        for (Joints.Joint joint : joints.getJoints()) {
            graphics.setPaint(BufferedImageUtils.randomColor().darker());
            graphics.fillOval((int) (joint.getX() * width), (int) (joint.getY() * height), 10, 10);
        }
        graphics.dispose();
    }
}
