package ca.aqtech.mctreesearch4j;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import kotlin.Metadata;
import kotlin._Assertions;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.Intrinsics;
import kotlin.random.Random;
import org.jetbrains.annotations.NotNull;

/* compiled from: StatefulSolver.kt */
@Metadata(mv = {1, 4, 0}, bv = {1, 0, 3}, k = 1, d1 = {"��6\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\b\n��\n\u0002\u0010\u0006\n\u0002\b\u0002\n\u0002\u0010\u000b\n\u0002\b\r\n\u0002\u0010\u0002\n\u0002\b\u000b\b\u0016\u0018��*\u0004\b��\u0010\u0001*\u0004\b\u0001\u0010\u00022\u001a\u0012\u0004\u0012\u0002H\u0002\u0012\u0010\u0012\u000e\u0012\u0004\u0012\u0002H\u0001\u0012\u0004\u0012\u0002H\u00020\u00040\u0003B9\u0012\u0012\u0010\u0005\u001a\u000e\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u00010\u0006\u0012\u0006\u0010\u0007\u001a\u00020\b\u0012\u0006\u0010\t\u001a\u00020\n\u0012\u0006\u0010\u000b\u001a\u00020\n\u0012\u0006\u0010\f\u001a\u00020\r¢\u0006\u0002\u0010\u000eJ$\u0010\u001a\u001a\u00020\u001b2\u0012\u0010\u001c\u001a\u000e\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u00010\u00042\u0006\u0010\u001d\u001a\u00020\nH\u0016JA\u0010\u001e\u001a\u000e\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u00010\u00042\u0014\u0010\u001f\u001a\u0010\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u0001\u0018\u00010\u00042\b\u0010 \u001a\u0004\u0018\u00018\u00012\u0006\u0010!\u001a\u00028��H\u0002¢\u0006\u0002\u0010\"J(\u0010#\u001a\u000e\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u00010\u00042\u0012\u0010\u001c\u001a\u000e\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u00010\u0004H\u0016J(\u0010$\u001a\u000e\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u00010\u00042\u0012\u0010\u001c\u001a\u000e\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u00010\u0004H\u0016J\u001c\u0010%\u001a\u00020\n2\u0012\u0010\u001c\u001a\u000e\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u00010\u0004H\u0016R \u0010\u0005\u001a\u000e\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u00010\u0006X\u0084\u0004¢\u0006\b\n��\u001a\u0004\b\u000f\u0010\u0010R\u0014\u0010\u000b\u001a\u00020\nX\u0084\u0004¢\u0006\b\n��\u001a\u0004\b\u0011\u0010\u0012R&\u0010\u0013\u001a\u000e\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u00010\u0004X\u0096\u000e¢\u0006\u000e\n��\u001a\u0004\b\u0014\u0010\u0015\"\u0004\b\u0016\u0010\u0017R\u0014\u0010\u0007\u001a\u00020\bX\u0084\u0004¢\u0006\b\n��\u001a\u0004\b\u0018\u0010\u0019¨\u0006&"}, d2 = {"Lca/aqtech/mctreesearch4j/StatefulSolver;", "StateType", "ActionType", "Lca/aqtech/mctreesearch4j/Solver;", "Lca/aqtech/mctreesearch4j/StateNode;", "mdp", "Lca/aqtech/mctreesearch4j/MDP;", "simulationDepthLimit", "", "explorationConstant", "", "rewardDiscountFactor", "verbose", "", "(Lca/aqtech/mctreesearch4j/MDP;IDDZ)V", "getMdp", "()Lca/aqtech/mctreesearch4j/MDP;", "getRewardDiscountFactor", "()D", "root", "getRoot", "()Lca/aqtech/mctreesearch4j/StateNode;", "setRoot", "(Lca/aqtech/mctreesearch4j/StateNode;)V", "getSimulationDepthLimit", "()I", "backpropagate", "", "node", "reward", "createNode", "parent", "inducingAction", "state", "(Lca/aqtech/mctreesearch4j/StateNode;Ljava/lang/Object;Ljava/lang/Object;)Lca/aqtech/mctreesearch4j/StateNode;", "expand", "select", "simulate", "mctreesearch4j"})
/* loaded from: input_file:ca/aqtech/mctreesearch4j/StatefulSolver.class */
public class StatefulSolver<StateType, ActionType> extends Solver<ActionType, StateNode<StateType, ActionType>> {

    @NotNull
    private StateNode<StateType, ActionType> root;

    @NotNull
    private final MDP<StateType, ActionType> mdp;
    private final int simulationDepthLimit;
    private final double rewardDiscountFactor;

    @Override // ca.aqtech.mctreesearch4j.Solver
    @NotNull
    public StateNode<StateType, ActionType> getRoot() {
        return this.root;
    }

    @Override // ca.aqtech.mctreesearch4j.Solver
    public void setRoot(@NotNull StateNode<StateType, ActionType> stateNode) {
        Intrinsics.checkNotNullParameter(stateNode, "<set-?>");
        this.root = stateNode;
    }

    @Override // ca.aqtech.mctreesearch4j.Solver
    @NotNull
    public StateNode<StateType, ActionType> select(@NotNull StateNode<StateType, ActionType> node) {
        Object obj;
        Intrinsics.checkNotNullParameter(node, "node");
        StateNode<StateType, ActionType> stateNode = node;
        do {
            StateNode<StateType, ActionType> stateNode2 = stateNode;
            if (this.mdp.isTerminal(stateNode2.getState())) {
                return stateNode2;
            }
            Collection<ActionType> exploredActions = stateNode2.exploredActions();
            boolean z = stateNode2.getValidActions().size() >= exploredActions.size();
            if (_Assertions.ENABLED && !z) {
                throw new AssertionError("Assertion failed");
            }
            if (stateNode2.getValidActions().size() > exploredActions.size()) {
                return stateNode2;
            }
            Iterator it = Node.getChildren$default(stateNode2, null, 1, null).iterator();
            if (it.hasNext()) {
                Object next = it.next();
                if (it.hasNext()) {
                    double calculateUCT = calculateUCT((StateNode) next);
                    do {
                        Object next2 = it.next();
                        double calculateUCT2 = calculateUCT((StateNode) next2);
                        if (Double.compare(calculateUCT, calculateUCT2) < 0) {
                            next = next2;
                            calculateUCT = calculateUCT2;
                        }
                    } while (it.hasNext());
                    obj = next;
                } else {
                    obj = next;
                }
            } else {
                obj = null;
            }
            stateNode = (StateNode) obj;
        } while (stateNode != null);
        throw new Exception("There were no children for explored node");
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // ca.aqtech.mctreesearch4j.Solver
    @NotNull
    public StateNode<StateType, ActionType> expand(@NotNull StateNode<StateType, ActionType> node) {
        Intrinsics.checkNotNullParameter(node, "node");
        if (node.isTerminal()) {
            return node;
        }
        Collection<ActionType> validActions = node.getValidActions();
        Collection children$default = Node.getChildren$default(node, null, 1, null);
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(children$default, 10));
        Iterator it = children$default.iterator();
        while (it.hasNext()) {
            arrayList.add(((StateNode) it.next()).getInducingAction());
        }
        Object random = CollectionsKt.random(CollectionsKt.distinct(CollectionsKt.minus((Iterable) validActions, (Iterable) arrayList)), Random.Default);
        if (random != null) {
            return createNode(node, random, this.mdp.transition(node.getState(), random));
        }
        throw new Exception("No unexplored actions available");
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // ca.aqtech.mctreesearch4j.Solver
    public double simulate(@NotNull StateNode<StateType, ActionType> node) {
        Object random;
        Object transition;
        Intrinsics.checkNotNullParameter(node, "node");
        traceln("Simulation:");
        if (node.isTerminal()) {
            traceln("Terminal state reached");
            MDP<StateType, ActionType> mdp = this.mdp;
            StateNode<StateType, ActionType> parent = node.getParent();
            return mdp.reward(parent != null ? parent.getState() : null, node.getInducingAction(), node.getState());
        }
        int i = 0;
        StateType state = node.getState();
        double d = this.rewardDiscountFactor;
        do {
            random = CollectionsKt.random(this.mdp.actions(state), Random.Default);
            transition = this.mdp.transition(state, random);
            if (getVerbose()) {
                trace("-> " + random + ' ');
                trace("-> " + transition + ' ');
            }
            if (this.mdp.isTerminal(transition)) {
                double reward = this.mdp.reward(state, random, transition) * d;
                if (getVerbose()) {
                    traceln("-> Terminal state reached : " + reward);
                }
                return reward;
            }
            state = transition;
            i++;
            d *= this.rewardDiscountFactor;
        } while (i <= this.simulationDepthLimit);
        double reward2 = this.mdp.reward(state, random, transition) * d;
        if (getVerbose()) {
            traceln("-> Depth limit reached: " + reward2);
        }
        return reward2;
    }

    @Override // ca.aqtech.mctreesearch4j.Solver
    public void backpropagate(@NotNull StateNode<StateType, ActionType> node, double d) {
        Intrinsics.checkNotNullParameter(node, "node");
        StateNode<StateType, ActionType> stateNode = node;
        double d2 = d;
        while (true) {
            double d3 = d2;
            stateNode.setMaxReward(Math.max(d3, stateNode.getMaxReward()));
            StateNode<StateType, ActionType> stateNode2 = stateNode;
            stateNode2.setReward(stateNode2.getReward() + d3);
            StateNode<StateType, ActionType> stateNode3 = stateNode;
            stateNode3.setN(stateNode3.getN() + 1);
            StateNode<StateType, ActionType> parent = stateNode.getParent();
            if (parent == null) {
                return;
            }
            stateNode = parent;
            d2 = d3 * this.rewardDiscountFactor;
        }
    }

    private final StateNode<StateType, ActionType> createNode(StateNode<StateType, ActionType> stateNode, ActionType actiontype, StateType statetype) {
        StateNode<StateType, ActionType> stateNode2 = new StateNode<>(stateNode, actiontype, statetype, CollectionsKt.toList(this.mdp.actions(statetype)), this.mdp.isTerminal(statetype));
        if (stateNode != null) {
            stateNode.addChild((StateNode) stateNode2);
        }
        return stateNode2;
    }

    @NotNull
    protected final MDP<StateType, ActionType> getMdp() {
        return this.mdp;
    }

    protected final int getSimulationDepthLimit() {
        return this.simulationDepthLimit;
    }

    protected final double getRewardDiscountFactor() {
        return this.rewardDiscountFactor;
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public StatefulSolver(@NotNull MDP<StateType, ActionType> mdp, int i, double d, double d2, boolean z) {
        super(z, d);
        Intrinsics.checkNotNullParameter(mdp, "mdp");
        this.mdp = mdp;
        this.simulationDepthLimit = i;
        this.rewardDiscountFactor = d2;
        this.root = createNode(null, null, this.mdp.initialState());
    }
}
