package com.amazon.randomcutforest;

import com.amazon.randomcutforest.config.Precision;
import com.amazon.randomcutforest.profilers.ObjectGraphSizeProfiler;
import com.amazon.randomcutforest.profilers.OutputSizeProfiler;
import com.amazon.randomcutforest.state.RandomCutForestMapper;
import com.amazon.randomcutforest.state.RandomCutForestState;
import com.amazon.randomcutforest.testutils.NormalMixtureTestData;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.protostuff.LinkedBuffer;
import io.protostuff.ProtostuffIOUtil;
import io.protostuff.Schema;
import io.protostuff.runtime.RuntimeSchema;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.OperationsPerInvocation;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.TearDown;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;

@Warmup(iterations = 2)
@Measurement(iterations = 5)
@State(Scope.Benchmark)
@Fork(1)
/* loaded from: input_file:com/amazon/randomcutforest/StateMapperBenchmark.class */
public class StateMapperBenchmark {
    public static final int NUM_TRAIN_SAMPLES = 2048;
    public static final int NUM_TEST_SAMPLES = 50;
    private RandomCutForest forest;
    private byte[] bytes;

    @State(Scope.Thread)
    /* loaded from: input_file:com/amazon/randomcutforest/StateMapperBenchmark$BenchmarkState.class */
    public static class BenchmarkState {

        @Param({"10"})
        int dimensions;

        @Param({"50"})
        int numberOfTrees;

        @Param({"256"})
        int sampleSize;

        @Param({"false", "true"})
        boolean saveTreeState;

        @Param({"FLOAT_32", "FLOAT_64"})
        Precision precision;
        double[][] trainingData;
        double[][] testData;
        RandomCutForestState forestState;
        String json;
        byte[] protostuff;

        @Setup(Level.Trial)
        public void setUpData() {
            NormalMixtureTestData normalMixtureTestData = new NormalMixtureTestData();
            this.trainingData = normalMixtureTestData.generateTestData(2048, this.dimensions);
            this.testData = normalMixtureTestData.generateTestData(50, this.dimensions);
        }

        @Setup(Level.Invocation)
        public void setUpForest() throws JsonProcessingException {
            RandomCutForest build = RandomCutForest.builder().compact(true).dimensions(this.dimensions).numberOfTrees(this.numberOfTrees).sampleSize(this.sampleSize).precision(this.precision).boundingBoxCacheFraction(0.0d).build();
            for (int i = 0; i < 2048; i++) {
                build.update(this.trainingData[i]);
            }
            RandomCutForestMapper randomCutForestMapper = new RandomCutForestMapper();
            randomCutForestMapper.setSaveExecutorContextEnabled(true);
            randomCutForestMapper.setSaveTreeStateEnabled(this.saveTreeState);
            this.forestState = randomCutForestMapper.toState(build);
            this.json = new ObjectMapper().writeValueAsString(this.forestState);
            Schema schema = RuntimeSchema.getSchema(RandomCutForestState.class);
            LinkedBuffer allocate = LinkedBuffer.allocate(512);
            try {
                this.protostuff = ProtostuffIOUtil.toByteArray(this.forestState, schema, allocate);
                allocate.clear();
            } catch (Throwable th) {
                allocate.clear();
                throw th;
            }
        }
    }

    @TearDown(Level.Iteration)
    public void tearDown() {
        OutputSizeProfiler.setTestArray(this.bytes);
        ObjectGraphSizeProfiler.setObject(this.forest);
    }

    @Benchmark
    @OperationsPerInvocation(50)
    public RandomCutForestState roundTripFromState(BenchmarkState benchmarkState, Blackhole blackhole) {
        RandomCutForestState randomCutForestState = benchmarkState.forestState;
        double[][] dArr = benchmarkState.testData;
        for (int i = 0; i < 50; i++) {
            RandomCutForestMapper randomCutForestMapper = new RandomCutForestMapper();
            randomCutForestMapper.setSaveExecutorContextEnabled(true);
            randomCutForestMapper.setSaveTreeStateEnabled(benchmarkState.saveTreeState);
            this.forest = randomCutForestMapper.toModel(randomCutForestState);
            blackhole.consume(this.forest.getAnomalyScore(dArr[i]));
            this.forest.update(dArr[i]);
            randomCutForestState = randomCutForestMapper.toState(this.forest);
        }
        return randomCutForestState;
    }

    @Benchmark
    @OperationsPerInvocation(50)
    public String roundTripFromJson(BenchmarkState benchmarkState, Blackhole blackhole) throws JsonProcessingException {
        String str = benchmarkState.json;
        double[][] dArr = benchmarkState.testData;
        for (int i = 0; i < 50; i++) {
            ObjectMapper objectMapper = new ObjectMapper();
            RandomCutForestState randomCutForestState = (RandomCutForestState) objectMapper.readValue(str, RandomCutForestState.class);
            RandomCutForestMapper randomCutForestMapper = new RandomCutForestMapper();
            randomCutForestMapper.setSaveExecutorContextEnabled(true);
            randomCutForestMapper.setSaveTreeStateEnabled(benchmarkState.saveTreeState);
            this.forest = randomCutForestMapper.toModel(randomCutForestState);
            blackhole.consume(this.forest.getAnomalyScore(dArr[i]));
            this.forest.update(dArr[i]);
            str = objectMapper.writeValueAsString(randomCutForestMapper.toState(this.forest));
        }
        this.bytes = str.getBytes();
        return str;
    }

    @Benchmark
    @OperationsPerInvocation(50)
    public byte[] roundTripFromProtostuff(BenchmarkState benchmarkState, Blackhole blackhole) {
        this.bytes = benchmarkState.protostuff;
        double[][] dArr = benchmarkState.testData;
        for (int i = 0; i < 50; i++) {
            Schema schema = RuntimeSchema.getSchema(RandomCutForestState.class);
            RandomCutForestState randomCutForestState = (RandomCutForestState) schema.newMessage();
            ProtostuffIOUtil.mergeFrom(this.bytes, randomCutForestState, schema);
            RandomCutForestMapper randomCutForestMapper = new RandomCutForestMapper();
            randomCutForestMapper.setSaveExecutorContextEnabled(true);
            randomCutForestMapper.setSaveTreeStateEnabled(benchmarkState.saveTreeState);
            this.forest = randomCutForestMapper.toModel(randomCutForestState);
            blackhole.consume(this.forest.getAnomalyScore(dArr[i]));
            this.forest.update(dArr[i]);
            RandomCutForestState state = randomCutForestMapper.toState(this.forest);
            LinkedBuffer allocate = LinkedBuffer.allocate(512);
            try {
                this.bytes = ProtostuffIOUtil.toByteArray(state, schema, allocate);
                allocate.clear();
            } catch (Throwable th) {
                allocate.clear();
                throw th;
            }
        }
        return this.bytes;
    }
}
