package com.cloudera.oryx.ml.param;

import com.cloudera.oryx.common.OryxTest;
import com.cloudera.oryx.common.settings.ConfigUtils;
import com.typesafe.config.Config;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import org.junit.Test;

/* loaded from: input_file:com/cloudera/oryx/ml/param/HyperParamsTest.class */
public final class HyperParamsTest extends OryxTest {
    @Test
    public void testFixedContinuous() {
        doTestContinuous(HyperParams.fixed(3.0d), 1, 3.0d);
        doTestContinuous(HyperParams.fixed(3.0d), 3, 3.0d);
    }

    @Test
    public void testContinuousRange() {
        doTestContinuous(HyperParams.range(3.0d, 5.0d), 1, 4.0d);
        doTestContinuous(HyperParams.range(3.0d, 5.0d), 2, 3.0d, 5.0d);
        doTestContinuous(HyperParams.range(3.0d, 5.0d), 4, 3.0d, 3.6666666666666665d, 4.333333333333333d, 5.0d);
        doTestContinuous(HyperParams.range(0.0d, 1.0d), 3, 0.0d, 0.5d, 1.0d);
        doTestContinuous(HyperParams.range(-1.0d, 1.0d), 5, -1.0d, -0.5d, 0.0d, 0.5d, 1.0d);
        doTestContinuous(HyperParams.range(-1.0d, 1.0d), 4, -1.0d, -0.3333333333333333d, 0.3333333333333333d, 1.0d);
    }

    @Test
    public void testAroundContinuous() {
        doTestContinuous(HyperParams.around(-3.0d, 0.1d), 1, -3.0d);
        doTestContinuous(HyperParams.around(-3.0d, 0.1d), 2, -3.05d, -2.95d);
        doTestContinuous(HyperParams.around(-3.0d, 0.1d), 3, -3.1d, -3.0d, -2.9d);
    }

    @Test
    public void testFixedDiscrete() {
        doTest(HyperParams.fixed(3), 1, Collections.singletonList(3));
        doTest(HyperParams.fixed(3), 3, Collections.singletonList(3));
    }

    @Test
    public void testDiscreteRange() {
        doTest(HyperParams.range(3, 4), 1, Collections.singletonList(3));
        doTest(HyperParams.range(3, 5), 1, Collections.singletonList(4));
        doTest(HyperParams.range(3, 5), 2, Arrays.asList(3, 5));
        doTest(HyperParams.range(3, 5), 3, Arrays.asList(3, 4, 5));
        doTest(HyperParams.range(3, 5), 4, Arrays.asList(3, 4, 5));
        doTest(HyperParams.range(0, 1), 3, Arrays.asList(0, 1));
        doTest(HyperParams.range(-1, 1), 5, Arrays.asList(-1, 0, 1));
        doTest(HyperParams.range(0, 10), 3, Arrays.asList(0, 5, 10));
    }

    @Test
    public void testAroundDiscrete() {
        doTest(HyperParams.around(-3, 1), 1, Collections.singletonList(-3));
        doTest(HyperParams.around(-3, 1), 2, Arrays.asList(-3, -2));
        doTest(HyperParams.around(-3, 1), 3, Arrays.asList(-4, -3, -2));
        doTest(HyperParams.around(-3, 10), 2, Arrays.asList(-8, 2));
        doTest(HyperParams.around(-3, 10), 3, Arrays.asList(-13, -3, 7));
    }

    @Test
    public void testUnordered() {
        doTest(HyperParams.unorderedFromValues(Arrays.asList("foo", "bar")), 1, Collections.singletonList("foo"));
        doTest(HyperParams.unorderedFromValues(Arrays.asList("foo", "bar")), 2, Arrays.asList("foo", "bar"));
        doTest(HyperParams.unorderedFromValues(Arrays.asList("foo", "bar")), 3, Arrays.asList("foo", "bar"));
    }

    @Test
    public void testConfig() {
        HashMap hashMap = new HashMap();
        hashMap.put("a", 1);
        hashMap.put("b", Double.valueOf(2.7d));
        hashMap.put("c", "[3,4]");
        hashMap.put("d", "[5.3,6.6]");
        Config overlayOn = ConfigUtils.overlayOn(hashMap, ConfigUtils.getDefault());
        doTest(HyperParams.fromConfig(overlayOn, "a"), 1, Collections.singletonList(1));
        doTest(HyperParams.fromConfig(overlayOn, "b"), 1, Collections.singletonList(Double.valueOf(2.7d)));
        doTest(HyperParams.fromConfig(overlayOn, "c"), 2, Arrays.asList(3, 4));
        doTest(HyperParams.fromConfig(overlayOn, "d"), 2, Arrays.asList(Double.valueOf(5.3d), Double.valueOf(6.6d)));
    }

    private static void doTest(HyperParamValues<?> hyperParamValues, int i, List<?> list) {
        assertEquals(list, hyperParamValues.getTrialValues(i));
        assertNotNull(hyperParamValues.toString());
    }

    private static void doTestContinuous(HyperParamValues<Double> hyperParamValues, int i, double... dArr) {
        List trialValues = hyperParamValues.getTrialValues(i);
        double[] dArr2 = new double[trialValues.size()];
        for (int i2 = 0; i2 < dArr2.length; i2++) {
            dArr2[i2] = ((Double) trialValues.get(i2)).doubleValue();
        }
        assertArrayEquals(dArr, dArr2);
        assertNotNull(hyperParamValues.toString());
    }
}
