diff --git a/script/README.md b/script/README.md index feb57ca2..f9c6bc7b 100644 --- a/script/README.md +++ b/script/README.md @@ -12,7 +12,7 @@ You can download data from [here](http://rebel13.nl/index.html) ## Convert pgn files -**Important : convert text will be superheavy (approx 200 byte / position)** +**Important : convert text will be superheavy (approx 200 byte / position)** python pgn_to_plain.py --pgn "pgn/*.pgn" --start_ply 1 --output converted_pgn.txt diff --git a/script/pgn_to_plain.py b/script/pgn_to_plain.py index c551c136..86671261 100644 --- a/script/pgn_to_plain.py +++ b/script/pgn_to_plain.py @@ -61,7 +61,7 @@ def parse_comment_for_score(comment_str: str, board: chess.Board) -> int: score = 0 return score - + def parse_game(game: chess.pgn.Game, writer, start_play: int=1)->None: board: chess.Board = game.board() if not game_sanity_check(game): @@ -105,6 +105,6 @@ def main(): break parse_game(game, f, args.start_ply) f.close() - + if __name__=="__main__": main() diff --git a/src/misc.h b/src/misc.h index 8b332efb..99b8c3bb 100644 --- a/src/misc.h +++ b/src/misc.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -52,6 +53,20 @@ void dbg_hit_on(bool c, bool b); void dbg_mean_of(int v); void dbg_print(); +/// Debug macro to write to std::err if NDEBUG flag is set, and do nothing otherwise +#if defined(NDEBUG) +#define debug 1 && std::cerr +#else +#define debug 0 && std::cerr +#endif + +inline void hit_any_key() { +#ifndef NDEBUG + debug << "Hit any key to continue..." << std::endl << std::flush; + system("read"); // on Windows, should be system("pause"); +#endif +} + typedef std::chrono::milliseconds::rep TimePoint; // A value in milliseconds static_assert(sizeof(TimePoint) == sizeof(int64_t), "TimePoint should be 64 bits"); inline TimePoint now() { diff --git a/src/search.cpp b/src/search.cpp index f39d9d18..bcf85ec7 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -47,6 +47,7 @@ namespace Search { using std::string; using Eval::evaluate; using namespace Search; +using namespace std; bool Search::prune_at_shallow_depth = true; @@ -1965,7 +1966,10 @@ namespace Search // Zero initialization of the number of search nodes th->nodes = 0; - // Clear all history types. This initialization takes a little time, and the accuracy of the search is rather low, so the good and bad are not well understood. + // Clear all history types. This initialization takes a little time, and + // the accuracy of the search is rather low, so the good and bad are + // not well understood. + // th->clear(); // Evaluation score is from the white point of view @@ -2183,6 +2187,944 @@ namespace Search return ValueAndPV(bestValue, pvs); } + + + // This implementation of the MCTS is heavily based on Stephane Nicolet's work here + // https://github.com/snicolet/Stockfish/commit/28501872a1e7ce84dd1f38ab9e59c5adb0d24b41 + // and the adjusted implementation of it in ShashChess https://github.com/amchess/ShashChess + + namespace MCTS + { + static constexpr float sigmoidScale = 600.0f; + + static inline float fast_sigmoid(float x) { + bool negative = x < 0.0f; + if (negative) + x = -x; + const float xx = x*x; + const float v = 1.0f / (1.0f + 1.0f / (1.0f + x + xx*(0.555f + xx*0.143f))); + if (negative) + return 1.0f - v; + else + return v; + } + + static inline Value reward_to_value(float r) { + if (r > 0.99f) return VALUE_KNOWN_WIN; + if (r < 0.01f) return -VALUE_KNOWN_WIN; + + return Value(-sigmoidScale * std::log(1.0f/r - 1.0f)); + } + + static inline float value_to_reward(Value v) { + return fast_sigmoid(static_cast(v) * (1.0f / sigmoidScale)); + } + + // struct MCTSNode : store info at one node of the MCTS algorithm + + struct MCTSNode { + + Key posKey = 0; // for consistency checks + MCTSNode* parent = nullptr; // only nullptr for the root node + unique_ptr children = nullptr; // only nullptr for nodes that have not been expanded + uint64_t numVisits = 0; // the number of playouts for this node and all descendants + Value leafSearchEval = VALUE_NONE; // the evaluation from AB playout + float prior = 0.0f; // the policy, currently a rough estimation based on the playout of the parent + float actionValue = 0.0f; // the accumulated rewards + float actionValueWeight = 0.0f; // the maximum value for the accumulater rewards + Move prevMove = MOVE_NONE; // the move on the edge from the parent + int numChildren = 0; // the number of legal moves, filled on expansion + int childId = 0; // the index of this node in the parent's children array + Depth leafSearchDepth = DEPTH_NONE; // the depth with which the AB playout was done + bool isTerminal = false; // whether the node is terminal. Terminal nodes are always "expanded" immediately. + + + // ucb_value() calculates the upper confidence bound of a child. + // When searching for the node to expand/playout we take one with the highest ucb. + + float ucb_value(MCTSNode& child, float explorationFactor, bool flipPerspective = false) const + { + + assert(explorationFactor >= 0.0f); + assert(child.actionValue >= 0.0f); + assert(child.actionValueWeight >= 0.0f); + assert(child.actionValue <= child.actionValueWeight); + assert(child.prior >= 0.0f); + assert(child.prior <= 1.0f); + + // For the nodes which have not been played-out we use the prior. + // Otherwise we have some averaged score or the eval already. + float reward = child.numVisits == 0 ? child.prior + : child.actionValue / child.actionValueWeight; + + if (flipPerspective) + reward = 1.0f - reward; + + // The exploration factor. + // In theory unplayed nodes should have priority, but we + // add 1 to avoid div by 0 so they might not always be prioritized. + + if (explorationFactor != 0.0f) + reward += + explorationFactor + * std::sqrt(std::log(1.0 + numVisits) / (1.0 + child.numVisits)); + + assert(!std::isnan(reward)); + assert(reward >= 0.0f); + + return reward; + } + + + // get_best_child() returns a const reference to the best child node, + // according to the UCB value. + + const MCTSNode& get_best_child(float explorationFactor) const + { + + assert(!is_leaf()); + assert(numChildren > 0); + + if (numChildren == 1) + { + assert(children[0].childId == 0); + return children[0]; + } + + int bestIdx = -1; + float bestValue = std::numeric_limits::lowest(); + for (int i = 0 ; i < numChildren ; ++i) + { + MCTSNode& child = children[i]; + // The "best" is the one with the best UCB. + // Child values are with opposite signs. + const float r = ucb_value(child, explorationFactor, true); + if (r > bestValue) + { + bestIdx = i; + bestValue = r; + } + } + + assert(bestIdx >= 0); + assert(bestIdx < numChildren); + assert(children[bestIdx].childId == bestIdx); + + return children[bestIdx]; + } + + + // get_best_child() : like the previous one, but does not return a const reference + + MCTSNode& get_best_child(float explorationFactor) { + return const_cast(static_cast(this)->get_best_child(explorationFactor)); + } + + + // get_best_move() returns a pair (move,value) leading to the best child, + // according to the action value heuristic. + + std::pair get_best_move() const { + + assert(!is_leaf()); + assert(numChildren > 0); + + int bestIdx = -1; + float bestValue = std::numeric_limits::lowest(); + for (int i = 0; i < numChildren ; ++i) + { + MCTSNode& child = children[i]; + // The "best" is the one with the best action value. + // Child values are with opposite signs. + const float r = 1.0f - (child.actionValue / child.actionValueWeight); + if (r > bestValue) + { + bestIdx = i; + bestValue = r; + } + } + + assert(bestIdx >= 0); + assert(bestIdx < numChildren); + assert(children[bestIdx].childId == bestIdx); + + return { children[bestIdx].prevMove, reward_to_value(bestValue) }; + } + + + // get_child_by_move() finds a child, given the move that leads to it + + const MCTSNode* get_child_by_move(Move move) const { + for (int i = 0; i < numChildren ; ++i) + { + MCTSNode& child = children[i]; + if (child.prevMove == move) + return &child; + } + + return nullptr; + } + + + // get_child_by_move() : like the previous one, but does not return a const + + MCTSNode* get_child_by_move(Move move) { + return const_cast(static_cast(this)->get_child_by_move(move)); + } + + + // is_root() returns true when node is the root + + bool is_root() const { + return parent == nullptr; + } + + // is_leaf() returns true when node is a leaf + + bool is_leaf() const { + return children == nullptr; + } + }; + + + // struct BackpropValues is a structure to manipulate the kind of stuff + // that needs to be back-propagated down and up the tree by MCTS. + + struct BackpropValues { + + uint64_t numVisits = 0; + float actionValue = 0.0f; + float actionValueWeight = 0.0f; + + // We always keep everything for the side to move perspective. + // When changing the side the score flips. + void flip_side() { + assert(actionValueWeight >= actionValue); + assert(actionValue >= 0.0f); + assert(actionValueWeight >= 0.0f); + + actionValue = actionValueWeight - actionValue; + } + }; + + + // struct MonteCarloTreeSearch implements the methods for the MCTS algorithm + + struct MonteCarloTreeSearch { + + // IMPORTANT: + // The position is stateful so we always have one. + // It has to match certain expectations in different functions. + // For example when looking for the node to expand the pos must correspond + // to the root mcts node. When expanding the node it must correspond to the + // node being expanded, etc. + + static constexpr Depth terminalEvalDepth = Depth(255); + + // We add a lot of stuff to the actionValue, but the weights differ. + // The prior is currently bad so low weight, + static constexpr float priorWeight = 0.01f; + static constexpr float terminalWeight = 1.0f; // could be increased? Different for wins/draws? + static constexpr float normalWeight = 1.0f; + + static_assert(priorWeight > 0.0f); + static_assert(terminalWeight > 0.0f); + static_assert(normalWeight > 0.0f); + + MonteCarloTreeSearch() {} + MonteCarloTreeSearch(const MonteCarloTreeSearch&) = delete; + + + // search_new() : let's start the search ! + + ValueAndPV search_new( + Position& pos, + std::uint64_t maxPlayouts, + Depth leafDepth, + float explorationFactor = 0.25f) { + + init_for_mcts_search(pos); + return search_continue(pos, maxPlayouts, leafDepth, explorationFactor); + } + + + // search_continue_after_move() : continue after a move and reuse the relevant + // part of the tree. The prevMove is the move that lead to position 'pos'. + // + // TODO: make the node limit be the total. + + ValueAndPV search_continue_after_move( Position& pos, + Move prevMove, + std::uint64_t maxPlayouts, + Depth leafDepth, + float explorationFactor = 0.25f) { + do_move_at_root(pos, prevMove); + return search_continue(pos, maxPlayouts, leafDepth, explorationFactor); + } + + + // get_all_continuations() is missing description + + std::vector get_all_continuations() const { + + std::vector continuations; + continuations.resize(rootNode.numChildren); + + for (int i = 0; i < rootNode.numChildren; ++i) + { + MCTSNode& child = rootNode.children[i]; + + auto& cont = continuations[i]; + + cont.numVisits = child.numVisits; + cont.value = reward_to_value(cont.actionValue); + cont.pv = get_pv(child); + cont.actionValue = 1.0f - (child.actionValue / child.actionValueWeight); // child value is with opposite sign + } + + std::stable_sort( continuations.begin(), + continuations.end(), + [](const auto& lhs, const auto& rhs) { return lhs.value > rhs.value; } + ); + + return continuations; + } + + + // search_continue() : continues with the same tree + + ValueAndPV search_continue( Position& pos, + std::uint64_t maxPlayouts, + Depth leafDepth, + float explorationFactor = 0.25f) { + + if (rootNode.leafSearchDepth == DEPTH_NONE) + do_playout(pos, rootNode, leafDepth); + + while (numPlayouts < maxPlayouts) + { + debug << "Starting iteration " << numPlayouts << endl; + do_search_iteration(pos, leafDepth, explorationFactor); + } + + if (rootNode.is_leaf()) + return {}; + else + return { rootNode.get_best_move().second, get_pv() }; + } + + + Stack stackBuffer [MAX_PLY + 10]; + StateInfo statesBuffer[MAX_PLY + 10]; + + Stack* stack = stackBuffer + 7; + StateInfo* states = statesBuffer + 7; + + MCTSNode rootNode; + + int ply = 1; + int maximumPly = ply; // Effectively the selective depth. + std::uint64_t numPlayouts = 0; + + private : + + // reset_stats(), recalculate_stats() and accumulate_stats_recursively() + // are used to recalculate the number of playouts in our MCTS tree. Note + // that at the moment we call recalculate_stats() each time we play a move + // at root, to recalculate the stats in the subtree. + + void reset_stats() { + numPlayouts = 0; + } + + void accumulate_stats_recursively(MCTSNode& node) { + + if (node.leafSearchDepth != DEPTH_NONE) + numPlayouts += 1; + + if (!node.is_leaf()) + for (int i = 0; i < node.numChildren; ++i) + accumulate_stats_recursively(node.children[i]); + } + + void recalculate_stats() { + reset_stats(); + accumulate_stats_recursively(rootNode); + } + + + // do_move_at_root() is missing description + // Tree reuse (?) + // pos is the position after move. + + void do_move_at_root(Position& pos, Move move) { + + MCTSNode* child = rootNode.get_child_by_move(move); + if (child == nullptr) + create_new_root(pos); + else + { + rootNode = std::move(*child); + rootNode.parent = nullptr; + rootNode.childId = 0; + // keep rootNode.prevMove for move ordering heuristics + } + + recalculate_stats(); + + assert(rootNode.posKey == pos.key()); + } + + + // do_search_iteration() does one iteration of the search. + // + // Basically: + // 1. find a node to expand/playout + // 2. if the node is a terminal then we just get the stuff and backprop + // 3. if we only have prior for the node then do a playout + // 4. otherwise we expand the children and do at least one playout from the best child (chosen by prior) + // 4.1. a terminal node counts as a playout. All terminal nodes are played out. + // 5. Backpropagate all changes down the tree. + + void do_search_iteration(Position& pos, Depth leafDepth, float explorationFactor) { + + MCTSNode& node = find_node_to_expand_or_playout(pos, explorationFactor); + BackpropValues backprops{}; + if (node.isTerminal) + { + debug << "Root is terminal" << endl; + backprops.numVisits = 1; + backprops.actionValue += node.actionValue; + backprops.actionValueWeight += node.actionValueWeight; + + numPlayouts += 1; + } + else if (node.leafSearchDepth == DEPTH_NONE) + { + // The node is considered the best but it only has a prior value. + // We don't really want to expand nodes based just on the prior, so + // first do a playout to get a better estimate, and expand only in the + // next iteration. + // Normally we playout immediately the move with the best prior, but that + // playout can put it below another move. + backprops = do_playout(pos, node, leafDepth); + } + else + { + // We have done leaf evaluation with AB search so we know that + // this node is *actually good* and not just *prior good*, so we + // can now expand it and do an immediate playout for the node with the best prior. + backprops = expand_node_and_do_playout(pos, node, leafDepth, explorationFactor); + } + + backpropagate(pos, node, backprops); + } + + + // Backpropagates() is the function we use to back-propagate the changes + // after an expand/playout, all the way to the root. The position 'pos' + // is expected to be at the node from which we start backpropagating. + + void backpropagate(Position& pos, MCTSNode& node, BackpropValues backprops) { + + assert(node.posKey == pos.key()); + assert(ply >= 1); + + debug << "Backpropagating: " << pos.fen() << endl; + + MCTSNode* currentNode = &node; + while (!currentNode->is_root()) + { + // On each descent we switch the side to move + + undo_move(pos); + currentNode = currentNode->parent; + backprops.flip_side(); + + debug << "Backprop step: " << pos.fen() << endl; + assert(currentNode->posKey == pos.key()); + + currentNode->numVisits += backprops.numVisits; + currentNode->actionValue += backprops.actionValue; + currentNode->actionValueWeight += backprops.actionValueWeight; + } + + // At the end we must be at the root + + assert(currentNode == &rootNode); + assert(rootNode.posKey == pos.key()); + } + + + // find_node_to_expand_or_playout() navigates from pos to the node to expand/playout, + // according to the get_best_child() heuristics. + + MCTSNode& find_node_to_expand_or_playout(Position& pos, float explorationFactor) { + assert(rootNode.posKey == pos.key()); + + // Find a node that has not yet been expanded + MCTSNode* currentNode = &rootNode; + while (!currentNode->is_leaf()) + { + MCTSNode& bestChild = currentNode->get_best_child(explorationFactor); + + do_move(pos, *currentNode, bestChild); + + currentNode = &bestChild; + } + + return *currentNode; + } + + + // generate_moves_unordered() generates moves in a random order + + int generate_moves_unordered(Position& pos, Move* out) const { + int moveCount = 0; + for (auto move : MoveList(pos)) + out[moveCount++] = move; + + return moveCount; + } + + + // generate_moves_ordered() generates moves with some reasonable ordering. + // Using this function, we can assume some reasonable priors. + + int generate_moves_ordered(Position& pos, MCTSNode& node, Depth leafDepth, Move* out) const { + assert(ply >= 1); + + debug << "Generating moves: " << pos.fen() << endl; + + Thread* const thread = pos.this_thread(); + const Square prevSq = to_sq(node.prevMove); + const Move countermove = thread->counterMoves[pos.piece_on(prevSq)][prevSq]; + const Move ttMove = MOVE_NONE; // TODO: retrieve tt move + const Move* const killers = stack[ply].killers; + const Depth depth = leafDepth + 1; + + const PieceToHistory* contHist[] = { + stack[ply-1].continuationHistory, stack[ply-2].continuationHistory, + nullptr , stack[ply-4].continuationHistory, + nullptr , stack[ply-6].continuationHistory + }; + + assert(contHist[0] != nullptr); + assert(contHist[1] != nullptr); + assert(contHist[3] != nullptr); + assert(contHist[5] != nullptr); + + MovePicker mp( + pos, + ttMove, + depth, + &(thread->mainHistory), + &(thread->lowPlyHistory), + &(thread->captureHistory), + contHist, + countermove, + killers, + ply + ); + + int moveCount = 0; + while (true) + { + const Move move = mp.next_move(); + debug << "Generated move " << UCI::move(move, false) << ": " << pos.fen() << endl; + + if (move == MOVE_NONE) + break; + + if (pos.legal(move)) + out[moveCount++] = move; + } + + debug << "Generated " << moveCount << " legal moves: " << pos.fen() << endl; + + return moveCount; + } + + + // init_for_leaf_search() prepares some global variables in the thread of the + // given position, for compatibility with the normal AB search of Stockfish. + // This allows us to use that AB search to get an estimated value of the leaf, + // if necessary. + + void init_for_leaf_search(Position& pos) { + + auto th = pos.this_thread(); + + th->completedDepth = 0; + th->selDepth = 0; + th->rootDepth = 0; + th->nmpMinPly = th->bestMoveChanges = th->failedHighCnt = 0; + th->ttHitAverage = TtHitAverageWindow * TtHitAverageResolution / 2; + th->nodes = 0; + } + + + // terminal_value() checks whether the position is terminal. We return + // the right value if position is terminal, otherwise we return VALUE_NONE. + + Value terminal_value(Position& pos) const { + + if (MoveList(pos).size() == 0) + return pos.checkers() ? VALUE_MATE : -VALUE_MATE;; + + if (ply >= MAX_PLY - 2 || pos.is_draw(ply - 1)) + return VALUE_DRAW; + + return VALUE_NONE; + } + + + // evaluate_leaf() does AB search on the position to get its value + + Value evaluate_leaf(Position& pos, MCTSNode& node, Depth leafDepth) { + + assert(node.posKey == pos.key()); + assert(node.leafSearchDepth == DEPTH_NONE); + assert(node.leafSearchEval == VALUE_NONE); + + debug << "Evaluating leaf: " << pos.fen() << endl; + + init_for_leaf_search(pos); + + Move pv[MAX_PLY + 1]; + stack[ply].pv = pv; + stack[ply].currentMove = MOVE_NONE; + stack[ply].excludedMove = MOVE_NONE; + + if (!node.is_root() && node.parent->leafSearchEval != VALUE_NONE) + { + // If we have some parent score then use an aspiration window. + // We know what to expect. + Value delta = Value(18); + Value alpha = std::max(node.parent->leafSearchEval - delta, -VALUE_INFINITE); + Value beta = std::min(node.parent->leafSearchEval + delta, VALUE_INFINITE); + while (true) + { + const Value value = Stockfish::search(pos, stack + ply, alpha, beta, leafDepth, false); + if (value <= alpha) + { + beta = (alpha + beta) / 2; + alpha = std::max(value - delta, -VALUE_INFINITE); + } + else + if (value >= beta) + beta = std::min(value + delta, VALUE_INFINITE); + else + return value; + + delta += delta / 4 + 5; + } + } + + else + // If no parent score then do infinite aspiration window. + return Stockfish::search(pos, stack + ply, -VALUE_INFINITE, VALUE_INFINITE, leafDepth, false); + } + + + // get_pv(node) tries to get a pv, starting from the given node + + std::vector get_pv(const MCTSNode& node) const { + std::vector pv; + + const MCTSNode* currentNode = &node; + if (!currentNode->is_root()) + pv.emplace_back(currentNode->prevMove); + + while (!currentNode->is_leaf()) + { + // No exploration factor for choosing the PV. + const MCTSNode& bestChild = currentNode->get_best_child(0.0f); + pv.emplace_back(bestChild.prevMove); + currentNode = &bestChild; + } + + return pv; + } + + + // get_pv() tries to get the pv, starting from the root + + std::vector get_pv() const { + return get_pv(rootNode); + } + + + // do_playout() does a single playout and returns what is needed to backprop + + BackpropValues do_playout(Position& pos, MCTSNode& node, Depth leafDepth) { + + assert(node.posKey == pos.key()); + assert(node.numVisits == 0); + assert(node.is_leaf()); + assert(node.numChildren == 0); + assert(node.leafSearchDepth == DEPTH_NONE); + assert(!node.isTerminal); + + debug << "Doing playout " << numPlayouts << ": " << pos.fen() << endl; + + numPlayouts += 1; + + const Value v = evaluate_leaf(pos, node, leafDepth); + + BackpropValues backprops{}; + backprops.numVisits = 1; // playout counts as a visit + backprops.actionValue += value_to_reward(v); + backprops.actionValueWeight += normalWeight; + + // Bookkeeping for raw eval + node.leafSearchEval = v; + node.leafSearchDepth = leafDepth; + + // Local backprop because normal backprop handles only the + // nodes starting from the parent of this one. + node.numVisits = backprops.numVisits; + node.actionValue += backprops.actionValue; + node.actionValueWeight += backprops.actionValueWeight; + + return backprops; + } + + + // expand_node_and_do_playout() : expand a node and do at least one playout. + // May do more "playouts" if there are terminals as those are "played out" immediately. + // Returns what needs to be backpropagated. + + BackpropValues expand_node_and_do_playout( Position& pos, + MCTSNode& node, + Depth leafDepth, + float explorationFactor) + { + assert(node.posKey == pos.key()); // node must match the position + assert(node.is_leaf()); // otherwise already expanded + assert(node.numChildren == 0); // leafs have no children + assert(!node.isTerminal); // terminals cannot be expanded + assert(node.numVisits == 1); // we expect it to have the "playout visit". Fake visit for the root. + assert(node.leafSearchDepth != DEPTH_NONE); + assert(node.leafSearchEval != VALUE_NONE); + + debug << "Expanding and playing out: " << pos.fen() << endl; + + Move moves[MAX_MOVES]; + const int moveCount = generate_moves_ordered(pos, node, leafDepth, moves); + + assert(moveCount > 0); + + node.children = std::make_unique(moveCount); + node.numChildren = moveCount; + + int numTerminals = 0; + BackpropValues backprops{}; + + float prior = value_to_reward(node.leafSearchEval); + + // Note that prior is attenuated for later moves - we rely on move ordering. + // Attenuate more at higher plies, where we have better move ordering. + + const float priorAttenuation = 1.0f - std::min((ply - 1) / 100.0f, 0.05f); + for (int i = 0; i < moveCount; ++i) + { + // Setup the child + MCTSNode& child = node.children[i]; + child.prevMove = moves[i]; + child.childId = i; + child.parent = &node; + + debug << "Expanding move " << i+1 << " out of " << moveCount << ": " << pos.fen() << endl; + + // We enter the child's position + do_move(pos, node, child); + child.posKey = pos.key(); + + const Value terminalValue = terminal_value(pos); + if (terminalValue != VALUE_NONE) + { + // if it's a terminal then "play it out" + child.isTerminal = true; + child.prior = value_to_reward(terminalValue); + child.numVisits = 1; + child.actionValue = child.prior * terminalWeight; + child.actionValueWeight = terminalWeight; + child.leafSearchEval = terminalValue; + child.leafSearchDepth = terminalEvalDepth; + + numTerminals += 1; + numPlayouts += 1; + } + else + { + // Otherwise we just note the prior (policy) + child.prior = 1.0f - prior; + child.actionValue = child.prior * priorWeight; + child.actionValueWeight = priorWeight; + } + + undo_move(pos); + + // Accumulate the policies to backprop + backprops.actionValue += child.actionValue; + backprops.actionValueWeight += child.actionValueWeight; + + // Reduce the prior for the next move + prior *= priorAttenuation; + } + + if (numTerminals == 0) + { + // If no terminals then we do one playout on the best child + MCTSNode& bestChild = node.get_best_child(explorationFactor); + do_move(pos, node, bestChild); + + backprops.numVisits += 1; + + auto playoutBackprops = do_playout(pos, bestChild, leafDepth); + backprops.numVisits += playoutBackprops.numVisits; + backprops.actionValue += playoutBackprops.actionValue; + backprops.actionValueWeight += playoutBackprops.actionValueWeight; + + undo_move(pos); + } + else + { + // If there are any terminals we don't do more playouts + backprops.numVisits += numTerminals; + } + + // Local backprop because normal backprop handles only the + // nodes starting from the parent of this one + backprops.flip_side(); + + node.actionValue = backprops.actionValue; + node.actionValueWeight = backprops.actionValueWeight; + node.numVisits = backprops.numVisits; + + return backprops; + } + + + // do_move() does a move and updates the stack + + void do_move(Position& pos, MCTSNode& parentNode, MCTSNode& childNode) { + + assert(ply < MAX_PLY); + assert(!parentNode.is_leaf()); + assert(&parentNode.children[childNode.childId] == &childNode); + assert(parentNode.posKey == pos.key()); + + const Move move = childNode.prevMove; + + stack[ply].currentMove = move; + stack[ply].inCheck = pos.checkers(); + stack[ply].continuationHistory = + &( + pos.this_thread()->continuationHistory + [stack[ply].inCheck] + [pos.capture_or_promotion(move)] + [pos.moved_piece(move)] + [to_sq(move)] + ); + stack[ply].staticEval = parentNode.leafSearchEval; + stack[ply].moveCount = childNode.childId + 1; + + pos.do_move(move, states[ply]); + + // The first time around we don't have posKey set yet, + // because we need to do the move first. + assert(childNode.posKey == 0 || childNode.posKey == pos.key()); + + ply += 1; + + if (ply > maximumPly) + maximumPly = ply; + } + + + // undo_move() undoes a move and pops the stack + + void undo_move(Position& pos) { + assert(ply > 1); + + ply -= 1; + + pos.undo_move(stack[ply].currentMove); + } + + + // create_new_root() inits a root from the given position + + void create_new_root(Position& pos) { + rootNode = MCTSNode{}; + rootNode.posKey = pos.key(); + rootNode.isTerminal = MoveList(pos).size() == 0; + } + + + void init_for_mcts_search(Position& pos) { + std::memset(stack - 7, 0, 10 * sizeof(Stack)); + + auto th = pos.this_thread(); + + // stack + 0 also needs to be initialized because we start from ply = 1 + for (int i = 7; i >= 0; --i) + (stack - i)->continuationHistory = &th->continuationHistory[0][0][NO_PIECE][0]; // Use as a sentinel + + for (int i = 1; i <= MAX_PLY; ++i) + (stack + i)->ply = i; + + int ct = int(Options["Contempt"]) * PawnValueEg / 100; // From centipawns + Color us = pos.side_to_move(); + + // In analysis mode, adjust contempt in accordance with user preference + if (Limits.infinite || Options["UCI_AnalyseMode"]) + ct = Options["Analysis Contempt"] == "Off" ? 0 + : Options["Analysis Contempt"] == "Both" ? ct + : Options["Analysis Contempt"] == "White" && us == BLACK ? -ct + : Options["Analysis Contempt"] == "Black" && us == WHITE ? -ct + : ct; + + // Evaluation score is from the white point of view + th->contempt = (us == WHITE ? make_score(ct, ct / 2) + : -make_score(ct, ct / 2)); + + create_new_root(pos); + + ply = 1; + maximumPly = ply; + numPlayouts = 0; + } + }; + + + // search_mcts() : this is the main function of the MonteCarloTreeSearch class + + ValueAndPV search_mcts( Position& pos, + uint64_t numPlayouts, + Depth leafDepth, + float explorationFactor) + { + MonteCarloTreeSearch mcts{}; + return mcts.search_new(pos, numPlayouts, leafDepth, explorationFactor); + } + + + // search_mcts_multipv() : use this for multiPV + + std::vector search_mcts_multipv( Position& pos, + uint64_t numPlayouts, + Depth leafDepth, + float explorationFactor) + { + MonteCarloTreeSearch mcts{}; + mcts.search_new(pos, numPlayouts, leafDepth, explorationFactor); + + return mcts.get_all_continuations(); + } + } } } // namespace Stockfish diff --git a/src/search.h b/src/search.h index 73a3bd86..36bcb18b 100644 --- a/src/search.h +++ b/src/search.h @@ -119,6 +119,28 @@ using ValueAndPV = std::pair>; ValueAndPV qsearch(Position& pos); ValueAndPV search(Position& pos, int depth_, size_t multiPV = 1, uint64_t nodesLimit = 0); +namespace MCTS { + + struct MctsContinuation { + std::uint64_t numVisits; + Value value; + float actionValue; + std::vector pv; + }; + + ValueAndPV search_mcts( + Position& pos, + std::uint64_t nodes, + Depth leafDepth, + float explorationFactor); + + std::vector search_mcts_multipv( + Position& pos, + std::uint64_t numPlayouts, + Depth leafDepth, + float explorationFactor); +} + } } // namespace Stockfish diff --git a/src/tools/training_data_generator.cpp b/src/tools/training_data_generator.cpp index a69031c4..0c4f8d82 100644 --- a/src/tools/training_data_generator.cpp +++ b/src/tools/training_data_generator.cpp @@ -767,7 +767,7 @@ namespace Stockfish::Tools else if (token == "min_depth") is >> params.search_depth_min; else if (token == "max_depth") - is >> params.search_depth_min; + is >> params.search_depth_max; else if (token == "nodes") is >> params.nodes; else if (token == "count") diff --git a/src/uci.cpp b/src/uci.cpp index 7d471d1f..b1d385d0 100644 --- a/src/uci.cpp +++ b/src/uci.cpp @@ -269,6 +269,34 @@ void search_cmd(Position& pos, istringstream& is) cout << endl; } +void search_mcts_cmd(Position& pos, istringstream& is) +{ + string token; + int nodes = 1000; + int leafDepth = 3; + float explorationFactor = 0.25f; + while (is >> token) + { + if (token == "nodes") + is >> nodes; + if (token == "leaf_depth") + is >> leafDepth; + if (token == "exploration_factor") + is >> explorationFactor; + } + + cout << "search nodes = " << nodes << " , leaf_depth = " << leafDepth << " :\n"; + auto continuations = Search::MCTS::search_mcts_multipv(pos, nodes, leafDepth, explorationFactor); + for (auto&& [numVisits, value, actionValue, pv] : continuations) + { + cout << "NumVisits = " << numVisits << " , Value = " << UCI::value(value) << " , ActionValue = " << actionValue << " , PV = "; + for (auto m : pv) + cout << UCI::move(m, false) << " "; + cout << endl; + } + cout << endl; +} + /// UCI::loop() waits for a command from stdin, parses it and calls the appropriate /// function. Also intercepts EOF from stdin to ensure gracefully exiting if the /// GUI dies unexpectedly. When called with some command line arguments, e.g. to @@ -344,6 +372,7 @@ void UCI::loop(int argc, char* argv[]) { // Command to call qsearch(),search() directly for testing else if (token == "qsearch") qsearch_cmd(pos); + else if (token == "search_mcts") search_mcts_cmd(pos, is); else if (token == "search") search_cmd(pos, is); else if (token == "tasktest") {