Transform search output to engine callbacks

Part 2 of the Split UCI into UCIEngine and Engine refactor.
This creates function callbacks for search to use when an update should occur.
The benching in uci.cpp for example does this to extract the total nodes
searched.

No functional change
This commit is contained in:
Disservin
2024-03-23 10:22:20 +01:00
parent 299707d2c2
commit 9032c6cbe7
12 changed files with 372 additions and 104 deletions

View File

@@ -19,15 +19,14 @@
#include "uci.h"
#include <algorithm>
#include <cassert>
#include <cctype>
#include <cmath>
#include <cstdint>
#include <cstdlib>
#include <deque>
#include <memory>
#include <optional>
#include <sstream>
#include <string_view>
#include <utility>
#include <vector>
@@ -35,10 +34,8 @@
#include "engine.h"
#include "evaluate.h"
#include "movegen.h"
#include "nnue/network.h"
#include "nnue/nnue_common.h"
#include "perft.h"
#include "position.h"
#include "score.h"
#include "search.h"
#include "syzygy/tbprobe.h"
#include "types.h"
@@ -49,6 +46,13 @@ namespace Stockfish {
constexpr auto StartFEN = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1";
constexpr int MaxHashMB = Is64Bit ? 33554432 : 2048;
template<typename... Ts>
struct overload: Ts... {
using Ts::operator()...;
};
template<typename... Ts>
overload(Ts...) -> overload<Ts...>;
UCIEngine::UCIEngine(int argc, char** argv) :
engine(argv[0]),
@@ -81,6 +85,12 @@ UCIEngine::UCIEngine(int argc, char** argv) :
options["EvalFileSmall"] << Option(EvalFileDefaultNameSmall,
[this](const Option& o) { engine.load_small_network(o); });
engine.set_on_iter([](const auto& i) { on_iter(i); });
engine.set_on_update_no_moves([](const auto& i) { on_update_no_moves(i); });
engine.set_on_update_full([&](const auto& i) { on_update_full(i, options["UCI_ShowWDL"]); });
engine.set_on_bestmove([](const auto& bm, const auto& p) { on_bestmove(bm, p); });
engine.load_networks();
engine.resize_threads();
engine.search_clear(); // After threads are up
@@ -221,6 +231,13 @@ void UCIEngine::go(Position& pos, std::istringstream& is) {
void UCIEngine::bench(Position& pos, std::istream& args) {
std::string token;
uint64_t num, nodes = 0, cnt = 1;
uint64_t nodesSearched = 0;
const auto& options = engine.get_options();
engine.set_on_update_full([&](const auto& i) {
nodesSearched = i.nodes;
on_update_full(i, options["UCI_ShowWDL"]);
});
std::vector<std::string> list = setup_bench(pos, args);
@@ -242,7 +259,8 @@ void UCIEngine::bench(Position& pos, std::istream& args) {
{
go(pos, is);
engine.wait_for_search_finished();
nodes += engine.nodes_searched();
nodes += nodesSearched;
nodesSearched = 0;
}
else
engine.trace_eval();
@@ -265,6 +283,9 @@ void UCIEngine::bench(Position& pos, std::istream& args) {
std::cerr << "\n==========================="
<< "\nTotal time (ms) : " << elapsed << "\nNodes searched : " << nodes
<< "\nNodes/second : " << 1000 * nodes / elapsed << std::endl;
// reset callback, to not capture a dangling reference to nodesSearched
engine.set_on_update_full([&](const auto& i) { on_update_full(i, options["UCI_ShowWDL"]); });
}
@@ -335,22 +356,22 @@ int win_rate_model(Value v, const Position& pos) {
}
}
std::string UCIEngine::to_score(Value v, const Position& pos) {
assert(-VALUE_INFINITE < v && v < VALUE_INFINITE);
std::string UCIEngine::format_score(const Score& s) {
constexpr int TB_CP = 20000;
const auto format =
overload{[](Score::Mate mate) -> std::string {
auto m = (mate.plies > 0 ? (mate.plies + 1) : -mate.plies) / 2;
return std::string("mate ") + std::to_string(m);
},
[](Score::TBWin tb) -> std::string {
return std::string("cp ")
+ std::to_string((tb.plies > 0 ? TB_CP - tb.plies : -TB_CP + tb.plies));
},
[](Score::InternalUnits units) -> std::string {
return std::string("cp ") + std::to_string(units.value);
}};
std::stringstream ss;
if (std::abs(v) < VALUE_TB_WIN_IN_MAX_PLY)
ss << "cp " << to_cp(v, pos);
else if (std::abs(v) <= VALUE_TB)
{
const int ply = VALUE_TB - std::abs(v); // recompute ss->ply
ss << "cp " << (v > 0 ? 20000 - ply : -20000 + ply);
}
else
ss << "mate " << (v > 0 ? VALUE_MATE - v + 1 : -VALUE_MATE - v) / 2;
return ss.str();
return s.visit(format);
}
// Turns a Value to an integer centipawn number,
@@ -414,4 +435,51 @@ Move UCIEngine::to_move(const Position& pos, std::string str) {
return Move::none();
}
void UCIEngine::on_update_no_moves(const Engine::InfoShort& info) {
sync_cout << "info depth" << info.depth << " score " << format_score(info.score) << sync_endl;
}
void UCIEngine::on_update_full(const Engine::InfoFull& info, bool showWDL) {
std::stringstream ss;
ss << "info";
ss << " depth " << info.depth //
<< " seldepth " << info.selDepth //
<< " multipv " << info.multiPV //
<< " score " << format_score(info.score); //
if (showWDL)
ss << " wdl " << info.wdl;
if (!info.bound.empty())
ss << " " << info.bound;
ss << " nodes " << info.nodes //
<< " nps " << info.nps //
<< " hashfull " << info.hashfull //
<< " tbhits " << info.tbHits //
<< " time " << info.timeMs //
<< " pv " << info.pv; //
sync_cout << ss.str() << sync_endl;
}
void UCIEngine::on_iter(const Engine::InfoIter& info) {
std::stringstream ss;
ss << "info";
ss << " depth " << info.depth //
<< " currmove " << info.currmove //
<< " currmovenumber " << info.currmovenumber; //
sync_cout << ss.str() << sync_endl;
}
void UCIEngine::on_bestmove(std::string_view bestmove, std::string_view ponder) {
sync_cout << "bestmove " << bestmove;
if (!ponder.empty())
std::cout << " ponder " << ponder;
std::cout << sync_endl;
}
} // namespace Stockfish