package com.amazon.randomcutforest;

import com.amazon.randomcutforest.returntypes.DensityOutput;
import com.amazon.randomcutforest.returntypes.DiVector;
import com.amazon.randomcutforest.testutils.NormalMixtureTestData;
import java.util.List;
import java.util.Random;
import org.github.jamm.MemoryMeter;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.OperationsPerInvocation;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;

@Warmup(iterations = 2)
@Measurement(iterations = 5)
@State(Scope.Thread)
@Fork(1)
/* loaded from: input_file:com/amazon/randomcutforest/RandomCutForestBenchmark.class */
public class RandomCutForestBenchmark {
    public static final int DATA_SIZE = 50000;
    public static final int INITIAL_DATA_SIZE = 25000;
    private RandomCutForest forest;

    @State(Scope.Benchmark)
    /* loaded from: input_file:com/amazon/randomcutforest/RandomCutForestBenchmark$BenchmarkState.class */
    public static class BenchmarkState {

        @Param({"40"})
        int baseDimensions;

        @Param({"1"})
        int shingleSize;

        @Param({"30"})
        int numberOfTrees;

        @Param({"1.0", "0.9", "0.8", "0.7", "0.6", "0.5", "0.4", "0.3", "0.2", "0.1", "0.0"})
        double boundingBoxCacheFraction;

        @Param({"false", "true"})
        boolean parallel;
        double[][] data;
        RandomCutForest forest;

        @Setup(Level.Trial)
        public void setUpData() {
            this.data = new NormalMixtureTestData().generateTestData(75000, this.baseDimensions * this.shingleSize);
        }

        @Setup(Level.Invocation)
        public void setUpForest() {
            this.forest = RandomCutForest.builder().numberOfTrees(this.numberOfTrees).dimensions(this.baseDimensions * this.shingleSize).internalShinglingEnabled(true).shingleSize(this.shingleSize).parallelExecutionEnabled(this.parallel).boundingBoxCacheFraction(this.boundingBoxCacheFraction).randomSeed(99L).build();
            for (int i = 0; i < 25000; i++) {
                this.forest.update(this.data[i]);
            }
        }
    }

    @Benchmark
    @OperationsPerInvocation(50000)
    public RandomCutForest updateOnly(BenchmarkState benchmarkState) {
        double[][] dArr = benchmarkState.data;
        this.forest = benchmarkState.forest;
        for (int i = 25000; i < dArr.length; i++) {
            this.forest.update(dArr[i]);
        }
        return this.forest;
    }

    @Benchmark
    @OperationsPerInvocation(50000)
    public RandomCutForest scoreOnly(BenchmarkState benchmarkState, Blackhole blackhole) {
        double[][] dArr = benchmarkState.data;
        this.forest = benchmarkState.forest;
        double d = 0.0d;
        Random random = new Random(0L);
        for (int i = 25000; i < dArr.length; i++) {
            d += this.forest.getAnomalyScore(dArr[i]);
            if (random.nextDouble() < 0.01d) {
                this.forest.update(dArr[i]);
            }
        }
        blackhole.consume(d);
        return this.forest;
    }

    @Benchmark
    @OperationsPerInvocation(50000)
    public RandomCutForest scoreAndUpdate(BenchmarkState benchmarkState, Blackhole blackhole) {
        double[][] dArr = benchmarkState.data;
        this.forest = benchmarkState.forest;
        double d = 0.0d;
        for (int i = 25000; i < dArr.length; i++) {
            d = this.forest.getAnomalyScore(dArr[i]);
            this.forest.update(dArr[i]);
        }
        blackhole.consume(d);
        if (!this.forest.parallelExecutionEnabled) {
            System.out.println(" forest size " + new MemoryMeter().measureDeep(this.forest));
        }
        return this.forest;
    }

    @Benchmark
    @OperationsPerInvocation(50000)
    public RandomCutForest attributionAndUpdate(BenchmarkState benchmarkState, Blackhole blackhole) {
        double[][] dArr = benchmarkState.data;
        this.forest = benchmarkState.forest;
        DiVector diVector = new DiVector(this.forest.getDimensions());
        for (int i = 25000; i < dArr.length; i++) {
            diVector = this.forest.getAnomalyAttribution(dArr[i]);
            this.forest.update(dArr[i]);
        }
        blackhole.consume(diVector);
        return this.forest;
    }

    @Benchmark
    @OperationsPerInvocation(50000)
    public RandomCutForest basicDensityAndUpdate(BenchmarkState benchmarkState, Blackhole blackhole) {
        double[][] dArr = benchmarkState.data;
        this.forest = benchmarkState.forest;
        DensityOutput densityOutput = new DensityOutput(this.forest.getDimensions(), this.forest.getSampleSize());
        for (int i = 25000; i < dArr.length; i++) {
            densityOutput = this.forest.getSimpleDensity(dArr[i]);
            this.forest.update(dArr[i]);
        }
        blackhole.consume(densityOutput);
        return this.forest;
    }

    @Benchmark
    @OperationsPerInvocation(50000)
    public RandomCutForest basicNeighborAndUpdate(BenchmarkState benchmarkState, Blackhole blackhole) {
        double[][] dArr = benchmarkState.data;
        this.forest = benchmarkState.forest;
        List list = null;
        for (int i = 25000; i < dArr.length; i++) {
            list = this.forest.getNearNeighborsInSample(dArr[i]);
            this.forest.update(dArr[i]);
        }
        blackhole.consume(list);
        return this.forest;
    }

    @Benchmark
    @OperationsPerInvocation(50000)
    public RandomCutForest imputeAndUpdate(BenchmarkState benchmarkState, Blackhole blackhole) {
        double[][] dArr = benchmarkState.data;
        this.forest = benchmarkState.forest;
        double[] dArr2 = null;
        for (int i = 25000; i < dArr.length; i++) {
            dArr2 = this.forest.imputeMissingValues(dArr[i], 1, new int[]{this.forest.dimensions - 1});
            this.forest.update(dArr[i]);
        }
        blackhole.consume(dArr2);
        return this.forest;
    }
}
