diff --git a/script/README.md b/script/README.md index feb57ca2..f9c6bc7b 100644 --- a/script/README.md +++ b/script/README.md @@ -12,7 +12,7 @@ You can download data from [here](http://rebel13.nl/index.html) ## Convert pgn files -**Important : convert text will be superheavy (approx 200 byte / position)** +**Important : convert text will be superheavy (approx 200 byte / position)** python pgn_to_plain.py --pgn "pgn/*.pgn" --start_ply 1 --output converted_pgn.txt diff --git a/script/pgn_to_plain.py b/script/pgn_to_plain.py index c551c136..86671261 100644 --- a/script/pgn_to_plain.py +++ b/script/pgn_to_plain.py @@ -61,7 +61,7 @@ def parse_comment_for_score(comment_str: str, board: chess.Board) -> int: score = 0 return score - + def parse_game(game: chess.pgn.Game, writer, start_play: int=1)->None: board: chess.Board = game.board() if not game_sanity_check(game): @@ -105,6 +105,6 @@ def main(): break parse_game(game, f, args.start_ply) f.close() - + if __name__=="__main__": main() diff --git a/src/misc.h b/src/misc.h index de9006a5..1a574c58 100644 --- a/src/misc.h +++ b/src/misc.h @@ -53,16 +53,20 @@ void dbg_hit_on(bool c, bool b); void dbg_mean_of(int v); void dbg_print(); +/// Debug macro to write to std::err if NDEBUG flag is set, and do nothing otherwise #if defined(NDEBUG) -template -void debug_print(const Ts&...) {} +#define debug 1 && std::cerr #else -template -void debug_print(const Ts&... v) { - ((std::cerr << v), ...); -} +#define debug 0 && std::cerr #endif +inline void hit_any_key() { +#ifndef NDEBUG + debug << "Hit any key to continue..." << std::endl << std::flush; + system("read"); // on Windows, should be system("pause"); +#endif +} + typedef std::chrono::milliseconds::rep TimePoint; // A value in milliseconds static_assert(sizeof(TimePoint) == sizeof(int64_t), "TimePoint should be 64 bits"); inline TimePoint now() { diff --git a/src/search.cpp b/src/search.cpp index 91d3ee35..be137f33 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -47,6 +47,7 @@ namespace Search { using std::string; using Eval::evaluate; using namespace Search; +using namespace std; bool Search::prune_at_shallow_depth = true; @@ -1970,7 +1971,10 @@ namespace Search // Zero initialization of the number of search nodes th->nodes = 0; - // Clear all history types. This initialization takes a little time, and the accuracy of the search is rather low, so the good and bad are not well understood. + // Clear all history types. This initialization takes a little time, and + // the accuracy of the search is rather low, so the good and bad are + // not well understood. + // th->clear(); int ct = int(Options["Contempt"]) * PawnValueEg / 100; // From centipawns @@ -2200,12 +2204,14 @@ namespace Search return ValueAndPV(bestValue, pvs); } - namespace MCTS { - /* - The implementation of the MCTS is heavily based on Stephane Nicolet's work here - https://github.com/snicolet/Stockfish/commit/28501872a1e7ce84dd1f38ab9e59c5adb0d24b41 - and the adjusted implementation of it in ShashChess https://github.com/amchess/ShashChess - */ + + + // This implementation of the MCTS is heavily based on Stephane Nicolet's work here + // https://github.com/snicolet/Stockfish/commit/28501872a1e7ce84dd1f38ab9e59c5adb0d24b41 + // and the adjusted implementation of it in ShashChess https://github.com/amchess/ShashChess + + namespace MCTS + { static constexpr float sigmoidScale = 600.0f; static inline float fast_sigmoid(float x) { @@ -2231,24 +2237,31 @@ namespace Search return fast_sigmoid(static_cast(v) * (1.0f / sigmoidScale)); } - struct MctsNode { - Key posKey = 0; // for consistency checks. - MctsNode* parent = nullptr; // only nullptr for the root node - std::unique_ptr children = nullptr; // only nullptr for nodes that have not been expanded - std::uint64_t numVisits = 0; // the number of playouts for this node and all descendants - Value leafSearchEval = VALUE_NONE; // the evaluation from AB playout - float prior = 0.0f; // the policy, currently a rough estimation based on the playout of the parent - float actionValue = 0.0f; // the accumulated rewards - float actionValueWeight = 0.0f; // the maximum value for the accumulater rewards - Move prevMove = MOVE_NONE; // the move on the edge from the parent - std::uint16_t numChildren = 0; // the number of legal moves, filled on expansion - std::uint16_t childId = 0; // the index of this node in the parent's children array - Depth leafSearchDepth = DEPTH_NONE; // the depth with which the AB playout was done - bool isTerminal = false; // whether the node is terminal. Terminal nodes are always "expanded" immediately. + // struct MCTSNode : store info at one node of the MCTS algorithm + + struct MCTSNode { + + Key posKey = 0; // for consistency checks + MCTSNode* parent = nullptr; // only nullptr for the root node + unique_ptr children = nullptr; // only nullptr for nodes that have not been expanded + uint64_t numVisits = 0; // the number of playouts for this node and all descendants + Value leafSearchEval = VALUE_NONE; // the evaluation from AB playout + float prior = 0.0f; // the policy, currently a rough estimation based on the playout of the parent + float actionValue = 0.0f; // the accumulated rewards + float actionValueWeight = 0.0f; // the maximum value for the accumulater rewards + Move prevMove = MOVE_NONE; // the move on the edge from the parent + int numChildren = 0; // the number of legal moves, filled on expansion + int childId = 0; // the index of this node in the parent's children array + Depth leafSearchDepth = DEPTH_NONE; // the depth with which the AB playout was done + bool isTerminal = false; // whether the node is terminal. Terminal nodes are always "expanded" immediately. + + + // ucb_value() calculates the upper confidence bound of a child. + // When searching for the node to expand/playout we take one with the highest ucb. + + float ucb_value(MCTSNode& child, float explorationFactor, bool flipPerspective = false) const + { - // The upper confidence bound. When searching for the node to expand/playout we take - // one with the highest ucb. - float ucb_value(MctsNode& child, float explorationFactor, bool flipPerspective = false) const { assert(explorationFactor >= 0.0f); assert(child.actionValue >= 0.0f); assert(child.actionValueWeight >= 0.0f); @@ -2258,10 +2271,8 @@ namespace Search // For the nodes which have not been played-out we use the prior. // Otherwise we have some averaged score or the eval already. - float reward = - child.numVisits == 0 - ? child.prior - : child.actionValue / child.actionValueWeight; + float reward = child.numVisits == 0 ? child.prior + : child.actionValue / child.actionValueWeight; if (flipPerspective) reward = 1.0f - reward; @@ -2269,11 +2280,11 @@ namespace Search // The exploration factor. // In theory unplayed nodes should have priority, but we // add 1 to avoid div by 0 so they might not always be prioritized. + if (explorationFactor != 0.0f) reward += explorationFactor - * std::sqrt(std::log(static_cast(1 + numVisits)) - / static_cast(1 + child.numVisits)); + * std::sqrt(std::log(1.0 + numVisits) / (1.0 + child.numVisits)); assert(!std::isnan(reward)); assert(reward >= 0.0f); @@ -2281,24 +2292,32 @@ namespace Search return reward; } - // Returns the reference to the best child node. - const MctsNode& get_best_child(float explorationFactor) const { + + // get_best_child() returns a const reference to the best child node, + // according to the UCB value. + + const MCTSNode& get_best_child(float explorationFactor) const + { + assert(!is_leaf()); assert(numChildren > 0); - if (numChildren == 1) { + if (numChildren == 1) + { assert(children[0].childId == 0); return children[0]; } int bestIdx = -1; float bestValue = std::numeric_limits::lowest(); - for (int i = 0; i < static_cast(numChildren); ++i) { - MctsNode& child = children[i]; + for (int i = 0 ; i < numChildren ; ++i) + { + MCTSNode& child = children[i]; // The "best" is the one with the best UCB. // Child values are with opposite signs. const float r = ucb_value(child, explorationFactor, true); - if (r > bestValue) { + if (r > bestValue) + { bestIdx = i; bestValue = r; } @@ -2311,22 +2330,32 @@ namespace Search return children[bestIdx]; } - MctsNode& get_best_child(float explorationFactor) { - return const_cast(static_cast(this)->get_best_child(explorationFactor)); + + // get_best_child() : like the previous one, but does not return a const reference + + MCTSNode& get_best_child(float explorationFactor) { + return const_cast(static_cast(this)->get_best_child(explorationFactor)); } + + // get_best_move() returns a pair (move,value) leading to the best child, + // according to the action value heuristic. + std::pair get_best_move() const { + assert(!is_leaf()); assert(numChildren > 0); int bestIdx = -1; float bestValue = std::numeric_limits::lowest(); - for (int i = 0; i < static_cast(numChildren); ++i) { - MctsNode& child = children[i]; - // The "best" is the one with the best UCB. + for (int i = 0; i < numChildren ; ++i) + { + MCTSNode& child = children[i]; + // The "best" is the one with the best action value. // Child values are with opposite signs. const float r = 1.0f - (child.actionValue / child.actionValueWeight); - if (r > bestValue) { + if (r > bestValue) + { bestIdx = i; bestValue = r; } @@ -2339,9 +2368,13 @@ namespace Search return { children[bestIdx].prevMove, reward_to_value(bestValue) }; } - const MctsNode* get_child_by_move(Move move) const { - for (int i = 0; i < static_cast(numChildren); ++i) { - MctsNode& child = children[i]; + + // get_child_by_move() finds a child, given the move that leads to it + + const MCTSNode* get_child_by_move(Move move) const { + for (int i = 0; i < numChildren ; ++i) + { + MCTSNode& child = children[i]; if (child.prevMove == move) return &child; } @@ -2349,22 +2382,34 @@ namespace Search return nullptr; } - MctsNode* get_child_by_move(Move move) { - return const_cast(static_cast(this)->get_child_by_move(move)); + + // get_child_by_move() : like the previous one, but does not return a const + + MCTSNode* get_child_by_move(Move move) { + return const_cast(static_cast(this)->get_child_by_move(move)); } + + // is_root() returns true when node is the root + bool is_root() const { return parent == nullptr; } + // is_leaf() returns true when node is a leaf + bool is_leaf() const { return children == nullptr; } }; - // The stuff that needs to be backpropagated down the tree. + + // struct BackpropValues is a structure to manipulate the kind of stuff + // that needs to be back-propagated down and up the tree by MCTS. + struct BackpropValues { - std::uint64_t numVisits = 0; + + uint64_t numVisits = 0; float actionValue = 0.0f; float actionValueWeight = 0.0f; @@ -2379,23 +2424,36 @@ namespace Search } }; + + // struct MonteCarloTreeSearch implements the methods for the MCTS algorithm + struct MonteCarloTreeSearch { + + // IMPORTANT: + // The position is stateful so we always have one. + // It has to match certain expectations in different functions. + // For example when looking for the node to expand the pos must correspond + // to the root mcts node. When expanding the node it must correspond to the + // node being expanded, etc. + static constexpr Depth terminalEvalDepth = Depth(255); // We add a lot of stuff to the actionValue, but the weights differ. // The prior is currently bad so low weight, - static constexpr float priorWeight = 0.01f; - static constexpr float terminalWeight = 1.0f; // could be increased? Different for wins/draws? - static constexpr float normalWeight = 1.0f; + static constexpr float priorWeight = 0.01f; + static constexpr float terminalWeight = 1.0f; // could be increased? Different for wins/draws? + static constexpr float normalWeight = 1.0f; - static_assert(priorWeight > 0.0f); + static_assert(priorWeight > 0.0f); static_assert(terminalWeight > 0.0f); - static_assert(normalWeight > 0.0f); + static_assert(normalWeight > 0.0f); MonteCarloTreeSearch() {} - MonteCarloTreeSearch(const MonteCarloTreeSearch&) = delete; + + // search_new() : let's start the search ! + ValueAndPV search_new( Position& pos, std::uint64_t maxPlayouts, @@ -2406,91 +2464,98 @@ namespace Search return search_continue(pos, maxPlayouts, leafDepth, explorationFactor); } - // Continue after a move and reuse the relevant part of the tree. - // The prevMove is the move that lead to pos - // TODO: Make the node limit be the total. - ValueAndPV search_continue_after_move( - Position& pos, - Move prevMove, - std::uint64_t maxPlayouts, - Depth leafDepth, - float explorationFactor = 0.25f) { + // search_continue_after_move() : continue after a move and reuse the relevant + // part of the tree. The prevMove is the move that lead to position 'pos'. + // + // TODO: make the node limit be the total. + + ValueAndPV search_continue_after_move( Position& pos, + Move prevMove, + std::uint64_t maxPlayouts, + Depth leafDepth, + float explorationFactor = 0.25f) { do_move_at_root(pos, prevMove); return search_continue(pos, maxPlayouts, leafDepth, explorationFactor); } + + // get_all_continuations() is missing description + std::vector get_all_continuations() const { + std::vector continuations; continuations.resize(rootNode.numChildren); - for (int i = 0; i < rootNode.numChildren; ++i) { - MctsNode& child = rootNode.children[i]; + for (int i = 0; i < rootNode.numChildren; ++i) + { + MCTSNode& child = rootNode.children[i]; - auto& cont = continuations[i]; - cont.numVisits = child.numVisits; - // child value is with opposite sign - cont.actionValue = 1.0f - (child.actionValue / child.actionValueWeight); - cont.value = reward_to_value(cont.actionValue); - cont.pv = get_pv(child); + auto& cont = continuations[i]; + + cont.numVisits = child.numVisits; + cont.value = reward_to_value(cont.actionValue); + cont.pv = get_pv(child); + cont.actionValue = 1.0f - (child.actionValue / child.actionValueWeight); // child value is with opposite sign } - std::stable_sort( - continuations.begin(), - continuations.end(), - [](const auto& lhs, const auto& rhs) { return lhs.value > rhs.value; } + std::stable_sort( continuations.begin(), + continuations.end(), + [](const auto& lhs, const auto& rhs) { return lhs.value > rhs.value; } ); return continuations; } - // Continues with the same tree. - ValueAndPV search_continue( - Position& pos, - std::uint64_t maxPlayouts, - Depth leafDepth, - float explorationFactor = 0.25f) { + + // search_continue() : continues with the same tree + + ValueAndPV search_continue( Position& pos, + std::uint64_t maxPlayouts, + Depth leafDepth, + float explorationFactor = 0.25f) { if (rootNode.leafSearchDepth == DEPTH_NONE) do_playout(pos, rootNode, leafDepth); - while (numPlayouts < maxPlayouts) { - debug_print("Starting iteration ", numPlayouts, '\n'); + while (numPlayouts < maxPlayouts) + { + debug << "Starting iteration " << numPlayouts << endl; do_search_iteration(pos, leafDepth, explorationFactor); } if (rootNode.is_leaf()) return {}; - else { + else return { rootNode.get_best_move().second, get_pv() }; - } } - Stack stackBuffer[MAX_PLY + 10]; + + Stack stackBuffer [MAX_PLY + 10]; StateInfo statesBuffer[MAX_PLY + 10]; - Stack* stack = stackBuffer + 7; + Stack* stack = stackBuffer + 7; StateInfo* states = statesBuffer + 7; - MctsNode rootNode; + MCTSNode rootNode; int ply = 1; int maximumPly = ply; // Effectively the selective depth. std::uint64_t numPlayouts = 0; - private: - // IMPORTANT: - // The position is stateful so we always have one. - // It has to match certain expectations in different functions. - // For example when looking for the node to expand the pos must correspond - // to the root mcts node. When expanding the node it must correspond to the - // node being expanded, etc. + private : + + // reset_stats(), recalculate_stats() and accumulate_stats_recursively() + // are used to recalculate the number of playouts in our MCTS tree. Note + // that at the moment we call recalculate_stats() each time we play a move + // at root, to recalculate the stats in the subtree. void reset_stats() { numPlayouts = 0; } - void accumulate_stats_recursively(MctsNode& node) { + void accumulate_stats_recursively(MCTSNode& node) { + if (node.leafSearchDepth != DEPTH_NONE) numPlayouts += 1; @@ -2504,13 +2569,18 @@ namespace Search accumulate_stats_recursively(rootNode); } - // Tree reuse. + + // do_move_at_root() is missing description + // Tree reuse (?) // pos is the position after move. + void do_move_at_root(Position& pos, Move move) { - MctsNode* child = rootNode.get_child_by_move(move); + + MCTSNode* child = rootNode.get_child_by_move(move); if (child == nullptr) create_new_root(pos); - else { + else + { rootNode = std::move(*child); rootNode.parent = nullptr; rootNode.childId = 0; @@ -2522,25 +2592,32 @@ namespace Search assert(rootNode.posKey == pos.key()); } - // One iteration of the search. Basically: + + // do_search_iteration() does one iteration of the search. + // + // Basically: // 1. find a node to expand/playout // 2. if the node is a terminal then we just get the stuff and backprop // 3. if we only have prior for the node then do a playout // 4. otherwise we expand the children and do at least one playout from the best child (chosen by prior) // 4.1. a terminal node counts as a playout. All terminal nodes are played out. // 5. Backpropagate all changes down the tree. + void do_search_iteration(Position& pos, Depth leafDepth, float explorationFactor) { - MctsNode& node = find_node_to_expand_or_playout(pos, explorationFactor); + + MCTSNode& node = find_node_to_expand_or_playout(pos, explorationFactor); BackpropValues backprops{}; - if (node.isTerminal) { - debug_print("Root is terminal\n"); + if (node.isTerminal) + { + debug << "Root is terminal" << endl; backprops.numVisits = 1; backprops.actionValue += node.actionValue; backprops.actionValueWeight += node.actionValueWeight; numPlayouts += 1; } - else if (node.leafSearchDepth == DEPTH_NONE) { + else if (node.leafSearchDepth == DEPTH_NONE) + { // The node is considered the best but it only has a prior value. // We don't really want to expand nodes based just on the prior, so // first do a playout to get a better estimate, and expand only in the @@ -2549,7 +2626,8 @@ namespace Search // playout can put it below another move. backprops = do_playout(pos, node, leafDepth); } - else { + else + { // We have done leaf evaluation with AB search so we know that // this node is *actually good* and not just *prior good*, so we // can now expand it and do an immediate playout for the node with the best prior. @@ -2559,22 +2637,28 @@ namespace Search backpropagate(pos, node, backprops); } - // Backpropagates the changes after an expand/playout all the way to the root. - // pos is expected to be at the node from which we start backpropagating. - void backpropagate(Position& pos, MctsNode& node, BackpropValues backprops) { + + // Backpropagates() is the function we use to back-propagate the changes + // after an expand/playout, all the way to the root. The position 'pos' + // is expected to be at the node from which we start backpropagating. + + void backpropagate(Position& pos, MCTSNode& node, BackpropValues backprops) { + assert(node.posKey == pos.key()); assert(ply >= 1); - debug_print("Backpropagating: ", pos.fen(), '\n'); + debug << "Backpropagating: " << pos.fen() << endl; + + MCTSNode* currentNode = &node; + while (!currentNode->is_root()) + { + // On each descent we switch the side to move - MctsNode* currentNode = &node; - while (!currentNode->is_root()) { - // On each descent we switch the side to move. undo_move(pos); currentNode = currentNode->parent; backprops.flip_side(); - debug_print("Backprop step: ", pos.fen(), '\n'); + debug << "Backprop step: " << pos.fen() << endl; assert(currentNode->posKey == pos.key()); currentNode->numVisits += backprops.numVisits; @@ -2582,19 +2666,24 @@ namespace Search currentNode->actionValueWeight += backprops.actionValueWeight; } - // At the end we must be at the root. + // At the end we must be at the root + assert(currentNode == &rootNode); assert(rootNode.posKey == pos.key()); } - // Navigate with pos to the node expand/playout - MctsNode& find_node_to_expand_or_playout(Position& pos, float explorationFactor) { + + // find_node_to_expand_or_playout() navigates from pos to the node to expand/playout, + // according to the get_best_child() heuristics. + + MCTSNode& find_node_to_expand_or_playout(Position& pos, float explorationFactor) { assert(rootNode.posKey == pos.key()); // Find a node that has not yet been expanded - MctsNode* currentNode = &rootNode; - while (!currentNode->is_leaf()) { - MctsNode& bestChild = currentNode->get_best_child(explorationFactor); + MCTSNode* currentNode = &rootNode; + while (!currentNode->is_leaf()) + { + MCTSNode& bestChild = currentNode->get_best_child(explorationFactor); do_move(pos, *currentNode, bestChild); @@ -2604,6 +2693,9 @@ namespace Search return *currentNode; } + + // generate_moves_unordered() generates moves in a random order + int generate_moves_unordered(Position& pos, Move* out) const { int moveCount = 0; for (auto move : MoveList(pos)) @@ -2612,11 +2704,14 @@ namespace Search return moveCount; } - // Generate moves with some reasonable ordering. Using this we can assume some reasonable priors. - int generate_moves_ordered(Position& pos, MctsNode& node, Depth leafDepth, Move* out) const { + + // generate_moves_ordered() generates moves with some reasonable ordering. + // Using this function, we can assume some reasonable priors. + + int generate_moves_ordered(Position& pos, MCTSNode& node, Depth leafDepth, Move* out) const { assert(ply >= 1); - debug_print("Generating moves: ", pos.fen(), '\n'); + debug << "Generating moves: " << pos.fen() << endl; Thread* const thread = pos.this_thread(); const Square prevSq = to_sq(node.prevMove); @@ -2650,9 +2745,10 @@ namespace Search ); int moveCount = 0; - for (;;) { + while (true) + { const Move move = mp.next_move(); - debug_print("Generated move ", UCI::move(move, false), ": ", pos.fen(), '\n'); + debug << "Generated move " << UCI::move(move, false) << ": " << pos.fen() << endl; if (move == MOVE_NONE) break; @@ -2661,12 +2757,19 @@ namespace Search out[moveCount++] = move; } - debug_print("Generated ", moveCount, " legal moves: ", pos.fen(), '\n'); + debug << "Generated " << moveCount << " legal moves: " << pos.fen() << endl; return moveCount; } + + // init_for_leaf_search() prepares some global variables in the thread of the + // given position, for compatibility with the normal AB search of Stockfish. + // This allows us to use that AB search to get an estimated value of the leaf, + // if necessary. + void init_for_leaf_search(Position& pos) { + auto th = pos.this_thread(); th->completedDepth = 0; @@ -2677,9 +2780,12 @@ namespace Search th->nodes = 0; } - // Checks whether the position is terminal and return the right value if it is. - // Otherwise returns VALUE_NONE + + // terminal_value() checks whether the position is terminal. We return + // the right value if position is terminal, otherwise we return VALUE_NONE. + Value terminal_value(Position& pos) const { + if (MoveList(pos).size() == 0) return pos.checkers() ? VALUE_MATE : -VALUE_MATE;; @@ -2689,13 +2795,16 @@ namespace Search return VALUE_NONE; } - // Does AB search on the position to get the value. - Value evaluate_leaf(Position& pos, MctsNode& node, Depth leafDepth) { + + // evaluate_leaf() does AB search on the position to get its value + + Value evaluate_leaf(Position& pos, MCTSNode& node, Depth leafDepth) { + assert(node.posKey == pos.key()); assert(node.leafSearchDepth == DEPTH_NONE); assert(node.leafSearchEval == VALUE_NONE); - debug_print("Evaluating leaf: ", pos.fen(), '\n'); + debug << "Evaluating leaf: " << pos.fen() << endl; init_for_leaf_search(pos); @@ -2704,19 +2813,23 @@ namespace Search stack[ply].currentMove = MOVE_NONE; stack[ply].excludedMove = MOVE_NONE; - if (!node.is_root() && node.parent->leafSearchEval != VALUE_NONE) { + if (!node.is_root() && node.parent->leafSearchEval != VALUE_NONE) + { // If we have some parent score then use an aspiration window. // We know what to expect. Value delta = Value(18); Value alpha = std::max(node.parent->leafSearchEval - delta, -VALUE_INFINITE); Value beta = std::min(node.parent->leafSearchEval + delta, VALUE_INFINITE); - for (;;) { + while (true) + { const Value value = Stockfish::search(pos, stack + ply, alpha, beta, leafDepth, false); - if (value <= alpha) { + if (value <= alpha) + { beta = (alpha + beta) / 2; alpha = std::max(value - delta, -VALUE_INFINITE); } - else if (value >= beta) + else + if (value >= beta) beta = std::min(value + delta, VALUE_INFINITE); else return value; @@ -2724,22 +2837,26 @@ namespace Search delta += delta / 4 + 5; } } - else { + + else // If no parent score then do infinite aspiration window. return Stockfish::search(pos, stack + ply, -VALUE_INFINITE, VALUE_INFINITE, leafDepth, false); - } } - std::vector get_pv(const MctsNode& node) const { + + // get_pv(node) tries to get a pv, starting from the given node + + std::vector get_pv(const MCTSNode& node) const { std::vector pv; - const MctsNode* currentNode = &node; + const MCTSNode* currentNode = &node; if (!currentNode->is_root()) pv.emplace_back(currentNode->prevMove); - while (!currentNode->is_leaf()) { + while (!currentNode->is_leaf()) + { // No exploration factor for choosing the PV. - const MctsNode& bestChild = currentNode->get_best_child(0.0f); + const MCTSNode& bestChild = currentNode->get_best_child(0.0f); pv.emplace_back(bestChild.prevMove); currentNode = &bestChild; } @@ -2747,12 +2864,18 @@ namespace Search return pv; } + + // get_pv() tries to get the pv, starting from the root + std::vector get_pv() const { return get_pv(rootNode); } - // Does a single playout and returns what is needed to backprop. - BackpropValues do_playout(Position& pos, MctsNode& node, Depth leafDepth) { + + // do_playout() does a single playout and returns what is needed to backprop + + BackpropValues do_playout(Position& pos, MCTSNode& node, Depth leafDepth) { + assert(node.posKey == pos.key()); assert(node.numVisits == 0); assert(node.is_leaf()); @@ -2760,80 +2883,84 @@ namespace Search assert(node.leafSearchDepth == DEPTH_NONE); assert(!node.isTerminal); - debug_print("Doing playout ", numPlayouts, ": ", pos.fen(), '\n'); + debug << "Doing playout " << numPlayouts << ": " << pos.fen() << endl; numPlayouts += 1; const Value v = evaluate_leaf(pos, node, leafDepth); BackpropValues backprops{}; - backprops.numVisits = 1; // playout counts as a visit + backprops.numVisits = 1; // playout counts as a visit backprops.actionValue += value_to_reward(v); backprops.actionValueWeight += normalWeight; - // Bookkeep raw eval. + // Bookkeeping for raw eval node.leafSearchEval = v; node.leafSearchDepth = leafDepth; // Local backprop because normal backprop handles only the - // nodes starting from the parent of this one - node.numVisits = backprops.numVisits; - node.actionValue += backprops.actionValue; + // nodes starting from the parent of this one. + node.numVisits = backprops.numVisits; + node.actionValue += backprops.actionValue; node.actionValueWeight += backprops.actionValueWeight; return backprops; } - // Expand a node and do at least one playout. + + // expand_node_and_do_playout() : expand a node and do at least one playout. // May do more "playouts" if there are terminals as those are "played out" immediately. // Returns what needs to be backpropagated. - BackpropValues expand_node_and_do_playout( - Position& pos, - MctsNode& node, - Depth leafDepth, - float explorationFactor) { - assert(node.posKey == pos.key()); // node must match the position - assert(node.is_leaf()); // otherwise already expanded - assert(node.numChildren == 0); - assert(!node.isTerminal); // terminals cannot be expanded - assert(node.numVisits == 1); // we expect it to have the "playout visit". Fake visit for the root. + BackpropValues expand_node_and_do_playout( Position& pos, + MCTSNode& node, + Depth leafDepth, + float explorationFactor) + { + assert(node.posKey == pos.key()); // node must match the position + assert(node.is_leaf()); // otherwise already expanded + assert(node.numChildren == 0); // leafs have no children + assert(!node.isTerminal); // terminals cannot be expanded + assert(node.numVisits == 1); // we expect it to have the "playout visit". Fake visit for the root. assert(node.leafSearchDepth != DEPTH_NONE); assert(node.leafSearchEval != VALUE_NONE); - debug_print("Expanding and playing out: ", pos.fen(), '\n'); + debug << "Expanding and playing out: " << pos.fen() << endl; Move moves[MAX_MOVES]; const int moveCount = generate_moves_ordered(pos, node, leafDepth, moves); assert(moveCount > 0); - node.children = std::make_unique(moveCount); + node.children = std::make_unique(moveCount); node.numChildren = moveCount; int numTerminals = 0; BackpropValues backprops{}; float prior = value_to_reward(node.leafSearchEval); - // Prior is attenuated for later moves - we rely on move ordering. - // Attenuate more at higher plies where we have more move ordering. + + // Note that prior is attenuated for later moves - we rely on move ordering. + // Attenuate more at higher plies, where we have better move ordering. + const float priorAttenuation = 1.0f - std::min((ply - 1) / 100.0f, 0.05f); - for (int i = 0; i < moveCount; ++i) { - // setup the child - MctsNode& child = node.children[i]; + for (int i = 0; i < moveCount; ++i) + { + // Setup the child + MCTSNode& child = node.children[i]; child.prevMove = moves[i]; child.childId = i; child.parent = &node; + debug << "Expanding move " << i+1 << " out of " << moveCount << ": " << pos.fen() << endl; + + // We enter the child's position + do_move(pos, node, child); + child.posKey = pos.key(); + + const Value terminalValue = terminal_value(pos); + if (terminalValue != VALUE_NONE) { - debug_print("Expanding move ", i+1, " out of ", moveCount, ": ", pos.fen(), '\n'); - - // we enter the child's position - do_move(pos, node, child); - child.posKey = pos.key(); - - const Value terminalValue = terminal_value(pos); - if (terminalValue != VALUE_NONE) { // if it's a terminal then "play it out" child.isTerminal = true; child.prior = value_to_reward(terminalValue); @@ -2845,28 +2972,29 @@ namespace Search numTerminals += 1; numPlayouts += 1; - } - else { - // otherwise we just note the prior (policy) + } + else + { + // Otherwise we just note the prior (policy) child.prior = 1.0f - prior; child.actionValue = child.prior * priorWeight; child.actionValueWeight = priorWeight; - } - - undo_move(pos); } - // accumulate the policies to backprop. + undo_move(pos); + + // Accumulate the policies to backprop backprops.actionValue += child.actionValue; backprops.actionValueWeight += child.actionValueWeight; - // Reduce the prior for the next move. + // Reduce the prior for the next move prior *= priorAttenuation; } - if (numTerminals == 0) { - // If no terminals then we do one playout on the best child. - MctsNode& bestChild = node.get_best_child(explorationFactor); + if (numTerminals == 0) + { + // If no terminals then we do one playout on the best child + MCTSNode& bestChild = node.get_best_child(explorationFactor); do_move(pos, node, bestChild); backprops.numVisits += 1; @@ -2878,7 +3006,8 @@ namespace Search undo_move(pos); } - else { + else + { // If there are any terminals we don't do more playouts backprops.numVisits += numTerminals; } @@ -2894,8 +3023,11 @@ namespace Search return backprops; } - // Does a move and updates the stack. - void do_move(Position& pos, MctsNode& parentNode, MctsNode& childNode) { + + // do_move() does a move and updates the stack + + void do_move(Position& pos, MCTSNode& parentNode, MCTSNode& childNode) { + assert(ply < MAX_PLY); assert(!parentNode.is_leaf()); assert(&parentNode.children[childNode.childId] == &childNode); @@ -2923,11 +3055,14 @@ namespace Search assert(childNode.posKey == 0 || childNode.posKey == pos.key()); ply += 1; + if (ply > maximumPly) maximumPly = ply; } - // Undoes a move, pops the stack. + + // undo_move() undoes a move and pops the stack + void undo_move(Position& pos) { assert(ply > 1); @@ -2936,18 +3071,22 @@ namespace Search pos.undo_move(stack[ply].currentMove); } + + // create_new_root() inits a root from the given position + void create_new_root(Position& pos) { - rootNode = MctsNode{}; + rootNode = MCTSNode{}; rootNode.posKey = pos.key(); rootNode.isTerminal = MoveList(pos).size() == 0; } + void init_for_mcts_search(Position& pos) { std::memset(stack - 7, 0, 10 * sizeof(Stack)); auto th = pos.this_thread(); - // stack + 0 also needs to be initialized because we start from ply=1 + // stack + 0 also needs to be initialized because we start from ply = 1 for (int i = 7; i >= 0; --i) (stack - i)->continuationHistory = &th->continuationHistory[0][0][NO_PIECE][0]; // Use as a sentinel @@ -2959,17 +3098,15 @@ namespace Search // In analysis mode, adjust contempt in accordance with user preference if (Limits.infinite || Options["UCI_AnalyseMode"]) - ct = Options["Analysis Contempt"] == "Off" ? 0 - : Options["Analysis Contempt"] == "Both" ? ct - : Options["Analysis Contempt"] == "White" && us == BLACK ? -ct - : Options["Analysis Contempt"] == "Black" && us == WHITE ? -ct - : ct; + ct = Options["Analysis Contempt"] == "Off" ? 0 + : Options["Analysis Contempt"] == "Both" ? ct + : Options["Analysis Contempt"] == "White" && us == BLACK ? -ct + : Options["Analysis Contempt"] == "Black" && us == WHITE ? -ct + : ct; // Evaluation score is from the white point of view - th->contempt = - us == WHITE - ? make_score(ct, ct / 2) - : -make_score(ct, ct / 2); + th->contempt = (us == WHITE ? make_score(ct, ct / 2) + : -make_score(ct, ct / 2)); create_new_root(pos); @@ -2979,24 +3116,29 @@ namespace Search } }; - ValueAndPV search_mcts( - Position& pos, - std::uint64_t numPlayouts, - Depth leafDepth, - float explorationFactor) { + // search_mcts() : this is the main function of the MonteCarloTreeSearch class + + ValueAndPV search_mcts( Position& pos, + uint64_t numPlayouts, + Depth leafDepth, + float explorationFactor) + { MonteCarloTreeSearch mcts{}; return mcts.search_new(pos, numPlayouts, leafDepth, explorationFactor); } - std::vector search_mcts_multipv( - Position& pos, - std::uint64_t numPlayouts, - Depth leafDepth, - float explorationFactor) { + // search_mcts_multipv() : use this for multiPV + + std::vector search_mcts_multipv( Position& pos, + uint64_t numPlayouts, + Depth leafDepth, + float explorationFactor) + { MonteCarloTreeSearch mcts{}; mcts.search_new(pos, numPlayouts, leafDepth, explorationFactor); + return mcts.get_all_continuations(); } }