package org.xxdc.oss.example.bot;

import java.lang.System;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.xxdc.oss.example.GameState;

/* loaded from: input_file:org/xxdc/oss/example/bot/MonteCarloTreeSearch.class */
public final class MonteCarloTreeSearch implements BotStrategy {
    private static final System.Logger log = System.getLogger(MonteCarloTreeSearch.class.getName());
    private final GameState initialState;
    private final BotStrategyConfig config;
    private static final double MIN_SCORE = -0.5d;
    private static final double MAX_SCORE = 1.0d;
    private static final double DRAW_SCORE = 0.0d;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/xxdc/oss/example/bot/MonteCarloTreeSearch$MCTSNode.class */
    public static class MCTSNode {
        GameState state;
        MCTSNode parent;
        List<MCTSNode> children;
        int visits;
        double[] scores;

        public MCTSNode(GameState gameState) {
            this(gameState, null);
        }

        public MCTSNode(GameState gameState, MCTSNode mCTSNode) {
            this.state = gameState;
            this.parent = mCTSNode;
            this.children = new ArrayList();
            this.visits = 0;
            this.scores = new double[gameState.playerMarkers().size()];
        }

        public MCTSNode select() {
            MCTSNode mCTSNode = null;
            double d = Double.NEGATIVE_INFINITY;
            for (MCTSNode mCTSNode2 : this.children) {
                double sqrt = (mCTSNode2.scores[this.state.currentPlayerIndex()] / mCTSNode2.visits) + Math.sqrt((2.0d * Math.log(this.visits)) / mCTSNode2.visits);
                if (sqrt > d) {
                    mCTSNode = mCTSNode2;
                    d = sqrt;
                }
            }
            return mCTSNode;
        }

        public boolean isFullyExpanded() {
            return this.children.size() == this.state.board().availableMoves().size();
        }

        public String toString() {
            return toString(0);
        }

        String toString(int i) {
            StringBuilder sb = new StringBuilder();
            sb.append(" ".repeat(i * 2));
            sb.append(this.parent == null ? "Root" : this.state.playerMarkers().get(this.state.lastPlayerIndex()) + " -> " + this.state.lastMove());
            sb.append(" (");
            sb.append(this.visits);
            sb.append(") => ");
            sb.append((this.parent == null || !this.state.lastPlayerHasChain()) ? this.state.availableMoves() : "WINNER");
            sb.append("\n");
            sb.append(" ".repeat(i * 2));
            sb.append(" (");
            int i2 = 0;
            while (i2 < this.scores.length) {
                sb.append(this.state.playerMarkers().get(i2));
                sb.append(": ");
                sb.append(this.scores[i2]);
                sb.append(i2 < this.scores.length - 1 ? ", " : "");
                i2++;
            }
            sb.append(")");
            for (MCTSNode mCTSNode : this.children) {
                sb.append("\n");
                sb.append(mCTSNode.toString(i + 1));
            }
            return sb.toString();
        }
    }

    public MonteCarloTreeSearch(GameState gameState) {
        this(gameState, BotStrategyConfig.newBuilder().maxTimeMillis(TimeUnit.SECONDS, 1L).build());
    }

    public MonteCarloTreeSearch(GameState gameState, BotStrategyConfig botStrategyConfig) {
        this.initialState = gameState;
        this.config = botStrategyConfig;
    }

    @Override // org.xxdc.oss.example.bot.BotStrategy
    public int bestMove() {
        return monteCarloTreeSearch(this.initialState);
    }

    private int monteCarloTreeSearch(GameState gameState) {
        MCTSNode mCTSNode = new MCTSNode(gameState);
        long currentTimeMillis = System.currentTimeMillis();
        int i = 0;
        while (!this.config.exceedsMaxTimeMillis(System.currentTimeMillis() - currentTimeMillis)) {
            int i2 = i;
            i++;
            if (this.config.exceedsMaxIterations(i2)) {
                break;
            }
            MCTSNode treePolicy = treePolicy(mCTSNode);
            backpropagate(treePolicy, defaultPolicy(treePolicy.state));
        }
        if (log.isLoggable(System.Logger.Level.DEBUG)) {
            log.log(System.Logger.Level.DEBUG, "MCTS: \n" + String.valueOf(mCTSNode));
            log.log(System.Logger.Level.DEBUG, "MCTS (Selected): \n" + bestChild(mCTSNode).state.lastMove());
        }
        return bestChild(mCTSNode).state.lastMove();
    }

    private MCTSNode treePolicy(MCTSNode mCTSNode) {
        while (!mCTSNode.state.isTerminal()) {
            if (!mCTSNode.isFullyExpanded()) {
                return expand(mCTSNode);
            }
            mCTSNode = mCTSNode.select();
        }
        return mCTSNode;
    }

    private MCTSNode expand(MCTSNode mCTSNode) {
        ArrayList arrayList = new ArrayList(mCTSNode.state.board().availableMoves());
        arrayList.removeAll((Collection) mCTSNode.children.stream().map(mCTSNode2 -> {
            return Integer.valueOf(mCTSNode2.state.lastMove());
        }).collect(Collectors.toList()));
        MCTSNode mCTSNode3 = new MCTSNode(mCTSNode.state.afterPlayerMoves(((Integer) arrayList.get(new java.util.Random().nextInt(arrayList.size()))).intValue()), mCTSNode);
        mCTSNode.children.add(mCTSNode3);
        return mCTSNode3;
    }

    private double[] defaultPolicy(GameState gameState) {
        GameState gameState2 = new GameState(gameState);
        while (true) {
            GameState gameState3 = gameState2;
            if (gameState3.isTerminal()) {
                return defaultReward(gameState3);
            }
            List<Integer> availableMoves = gameState3.board().availableMoves();
            gameState2 = gameState3.afterPlayerMoves(availableMoves.get(new java.util.Random().nextInt(availableMoves.size())).intValue());
        }
    }

    private double[] defaultReward(GameState gameState) {
        double[] dArr = new double[gameState.playerMarkers().size()];
        int i = -1;
        int i2 = 0;
        while (true) {
            if (i2 >= gameState.playerMarkers().size()) {
                break;
            }
            if (gameState.board().hasChain(gameState.playerMarkers().get(i2))) {
                i = i2;
                break;
            }
            i2++;
        }
        for (int i3 = 0; i3 < gameState.playerMarkers().size(); i3++) {
            if (i3 == i) {
                dArr[i3] = 1.0d;
            } else if (i != -1) {
                dArr[i3] = -0.5d;
            } else {
                dArr[i3] = 0.0d;
            }
        }
        return dArr;
    }

    private void backpropagate(MCTSNode mCTSNode, double[] dArr) {
        while (mCTSNode != null) {
            mCTSNode.visits++;
            for (int i = 0; i < this.initialState.playerMarkers().size(); i++) {
                double[] dArr2 = mCTSNode.scores;
                int i2 = i;
                dArr2[i2] = dArr2[i2] + dArr[i];
            }
            mCTSNode = mCTSNode.parent;
        }
    }

    private MCTSNode bestChild(MCTSNode mCTSNode) {
        return mCTSNode.children.stream().max(Comparator.comparingDouble(mCTSNode2 -> {
            return mCTSNode2.visits;
        })).orElseThrow();
    }
}
