package ws.palladian.classification.evaluation.roc;

import java.awt.Color;
import java.awt.Dimension;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.ChartUtilities;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.axis.NumberAxis;
import org.jfree.chart.axis.NumberTickUnit;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.chart.plot.XYPlot;
import org.jfree.chart.renderer.xy.XYLineAndShapeRenderer;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;
import org.jfree.ui.ApplicationFrame;
import org.jfree.ui.RefineryUtilities;
import ws.palladian.classification.evaluation.roc.RocCurves;

/* loaded from: input_file:ws/palladian/classification/evaluation/roc/RocCurvesPainter.class */
public class RocCurvesPainter {
    private static final Color[] COLORS = {Color.red, Color.pink, Color.orange, Color.yellow, Color.green, Color.magenta, Color.cyan, Color.blue};
    private final List<NamedRocCurve> curves = new ArrayList();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ws/palladian/classification/evaluation/roc/RocCurvesPainter$NamedRocCurve.class */
    public static final class NamedRocCurve {
        private final RocCurves curve;
        private final String name;

        NamedRocCurve(RocCurves rocCurves, String str) {
            this.curve = rocCurves;
            this.name = str;
        }
    }

    public RocCurvesPainter add(RocCurves rocCurves, String str) {
        Objects.requireNonNull(rocCurves, "curves must not be null");
        Objects.requireNonNull(str, "name must not be null");
        this.curves.add(new NamedRocCurve(rocCurves, str));
        return this;
    }

    private JFreeChart createChart() {
        XYSeriesCollection xYSeriesCollection = new XYSeriesCollection();
        XYSeries xYSeries = new XYSeries("Random");
        xYSeries.add(0.0d, 0.0d);
        xYSeries.add(1.0d, 1.0d);
        xYSeriesCollection.addSeries(xYSeries);
        for (NamedRocCurve namedRocCurve : this.curves) {
            XYSeries xYSeries2 = new XYSeries(namedRocCurve.name + " [AUC = " + RocCurves.format(namedRocCurve.curve.getAreaUnderCurve()) + "]");
            Iterator<RocCurves.EvaluationPoint> it = namedRocCurve.curve.iterator();
            while (it.hasNext()) {
                RocCurves.EvaluationPoint next = it.next();
                xYSeries2.add(1.0d - next.getSpecificity(), next.getSensitivity());
            }
            xYSeriesCollection.addSeries(xYSeries2);
        }
        JFreeChart createXYLineChart = ChartFactory.createXYLineChart("ROC Curves", "False Positive Rate (1 – Specificity)", "True Positive Rate (Sensitivity)", xYSeriesCollection, PlotOrientation.VERTICAL, true, true, false);
        createXYLineChart.setBackgroundPaint(Color.white);
        XYPlot xYPlot = createXYLineChart.getXYPlot();
        xYPlot.setBackgroundPaint(Color.white);
        xYPlot.setDomainGridlinePaint(Color.white);
        xYPlot.setRangeGridlinePaint(Color.white);
        XYLineAndShapeRenderer xYLineAndShapeRenderer = new XYLineAndShapeRenderer();
        xYLineAndShapeRenderer.setSeriesLinesVisible(0, true);
        xYLineAndShapeRenderer.setSeriesShapesVisible(0, false);
        xYLineAndShapeRenderer.setSeriesPaint(0, Color.darkGray);
        for (int i = 0; i < this.curves.size(); i++) {
            xYLineAndShapeRenderer.setSeriesLinesVisible(i + 1, true);
            xYLineAndShapeRenderer.setSeriesShapesVisible(i + 1, false);
            xYLineAndShapeRenderer.setSeriesPaint(i + 1, COLORS[i % COLORS.length]);
        }
        xYPlot.setRenderer(xYLineAndShapeRenderer);
        NumberAxis rangeAxis = xYPlot.getRangeAxis();
        rangeAxis.setRange(0.0d, 1.0d);
        rangeAxis.setTickUnit(new NumberTickUnit(0.1d));
        NumberAxis domainAxis = xYPlot.getDomainAxis();
        domainAxis.setRange(0.0d, 1.0d);
        domainAxis.setTickUnit(new NumberTickUnit(0.1d));
        return createXYLineChart;
    }

    public void showCurves() {
        ChartPanel chartPanel = new ChartPanel(createChart());
        chartPanel.setPreferredSize(new Dimension(800, 600));
        ApplicationFrame applicationFrame = new ApplicationFrame("ROC");
        applicationFrame.setContentPane(chartPanel);
        applicationFrame.pack();
        RefineryUtilities.centerFrameOnScreen(applicationFrame);
        applicationFrame.setVisible(true);
    }

    public void saveCurves(File file) throws IOException {
        ChartUtilities.saveChartAsPNG(file, createChart(), 800, 600);
    }
}
