From a4605860c69513b7f0d84dfbda7eda8ecc8b121b Mon Sep 17 00:00:00 2001 From: Tomasz Sobczyk Date: Mon, 24 May 2021 11:45:21 +0200 Subject: [PATCH] Post-merge fixes. --- src/evaluate.cpp | 20 ++++++++++---------- src/search.cpp | 2 +- src/thread.h | 16 +++++++++++++++- src/tools/training_data_generator.cpp | 4 ++-- src/tools/training_data_generator_nonpv.cpp | 2 +- src/tools/transform.cpp | 2 +- src/uci.cpp | 2 +- 7 files changed, 31 insertions(+), 17 deletions(-) diff --git a/src/evaluate.cpp b/src/evaluate.cpp index b58ff624..ccb7436b 100644 --- a/src/evaluate.cpp +++ b/src/evaluate.cpp @@ -63,18 +63,18 @@ namespace Eval { namespace NNUE { string eval_file_loaded = "None"; UseNNUEMode useNNUE; - } - static UseNNUEMode NNUE::nnue_mode_from_option(const UCI::Option& mode) - { - if (mode == "false") + static UseNNUEMode nnue_mode_from_option(const UCI::Option& mode) + { + if (mode == "false") + return UseNNUEMode::False; + else if (mode == "true") + return UseNNUEMode::True; + else if (mode == "pure") + return UseNNUEMode::Pure; + return UseNNUEMode::False; - else if (mode == "true") - return UseNNUEMode::True; - else if (mode == "pure") - return UseNNUEMode::Pure; - - return UseNNUEMode::False; + } } /// NNUE::init() tries to load a NNUE network at startup time, or when the engine diff --git a/src/search.cpp b/src/search.cpp index 73c7f856..f0289b45 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -374,7 +374,7 @@ void Thread::search() { // Start with a small aspiration window and, in the case of a fail // high/low, re-search with a bigger window until we don't fail // high/low anymore. - int failedHighCnt = 0; + failedHighCnt = 0; while (true) { Depth adjustedDepth = std::max(1, rootDepth - failedHighCnt - searchAgainCounter); diff --git a/src/thread.h b/src/thread.h index fec68e05..0989f4ba 100644 --- a/src/thread.h +++ b/src/thread.h @@ -76,6 +76,15 @@ public: void wait_for_search_finished(); size_t id() const { return idx; } + void wait_for_worker_finished(); + + template + void set_eval_callback(FuncT&& f) { on_eval_callback = std::forward(f); } + + void clear_eval_callback() { on_eval_callback = nullptr; } + + void on_eval() { if (on_eval_callback) on_eval_callback(rootPos); } + Pawns::Table pawnsTable; Material::Table materialTable; size_t pvIdx, pvLast; @@ -94,6 +103,11 @@ public: CapturePieceToHistory captureHistory; ContinuationHistory continuationHistory[2][2]; Score contempt; + int failedHighCnt; + bool rootInTB; + int Cardinality; + bool UseRule50; + Depth ProbeDepth; }; @@ -166,7 +180,7 @@ struct ThreadPool : public std::vector { execute_with_workers( [chunk_size, end, func](Thread& th) mutable { - const IndexT thread_id = th.thread_idx(); + const IndexT thread_id = th.id(); const IndexT offset = chunk_size * thread_id; if (offset >= end) return; diff --git a/src/tools/training_data_generator.cpp b/src/tools/training_data_generator.cpp index 24498917..6495d566 100644 --- a/src/tools/training_data_generator.cpp +++ b/src/tools/training_data_generator.cpp @@ -257,7 +257,7 @@ namespace Stockfish::Tools StateInfo si; - auto& prng = prngs[th.thread_idx()]; + auto& prng = prngs[th.id()]; // end flag bool quit = false; @@ -693,7 +693,7 @@ namespace Stockfish::Tools maybe_report(iter + 1); // Write out one sfen. - sfen_writer.write(th.thread_idx(), sfen); + sfen_writer.write(th.id(), sfen); } return false; diff --git a/src/tools/training_data_generator_nonpv.cpp b/src/tools/training_data_generator_nonpv.cpp index e8df9c50..278259c6 100644 --- a/src/tools/training_data_generator_nonpv.cpp +++ b/src/tools/training_data_generator_nonpv.cpp @@ -341,7 +341,7 @@ namespace Stockfish::Tools maybe_report(iter + 1); // Write out one sfen. - sfen_writer.write(th.thread_idx(), sfen); + sfen_writer.write(th.id(), sfen); } return false; diff --git a/src/tools/transform.cpp b/src/tools/transform.cpp index ab7a3db8..f657b410 100644 --- a/src/tools/transform.cpp +++ b/src/tools/transform.cpp @@ -426,7 +426,7 @@ namespace Stockfish::Tools ps.move = search_pv[0]; ps.padding = 0; - out.write(th.thread_idx(), ps); + out.write(th.id(), ps); auto p = num_processed.fetch_add(1) + 1; if (p % 10000 == 0) diff --git a/src/uci.cpp b/src/uci.cpp index b1b39bc4..2fa7a186 100644 --- a/src/uci.cpp +++ b/src/uci.cpp @@ -342,7 +342,7 @@ void UCI::loop(int argc, char* argv[]) { else if (token == "tasktest") { Threads.execute_with_workers([](auto& th) { - std::cout << th.thread_idx() << '\n'; + std::cout << th.id() << '\n'; }); } else if (!token.empty() && token[0] != '#')