diff --git a/src/eval/nnue/evaluate_nnue.cpp b/src/eval/nnue/evaluate_nnue.cpp index 6b3f0b2f..a573c9c7 100644 --- a/src/eval/nnue/evaluate_nnue.cpp +++ b/src/eval/nnue/evaluate_nnue.cpp @@ -233,31 +233,27 @@ void prefetch_evalhash(const Key key) { // Save and restore Options with bench command etc., so EvalDir is changed at this time, // This function may be called twice to flag that the evaluation function needs to be reloaded. void load_eval() { + + if (Options["SkipLoadingEval"]) + { + std::cout << "info string SkipLoadingEval set to true, Net not loaded!" << std::endl; + return; + } + NNUE::Initialize(); - if (!Options["SkipLoadingEval"]) - { - const std::string dir_name = Options["EvalDir"]; - const std::string file_name = Path::Combine(dir_name, NNUE::kFileName); - //{ - // std::ofstream stream(file_name, std::ios::binary); - // NNUE::WriteParameters(stream); - //} - std::ifstream stream(file_name, std::ios::binary); - const bool result = NNUE::ReadParameters(stream); + const std::string dir_name = Options["EvalDir"]; + const std::string file_name = Path::Combine(dir_name, NNUE::kFileName); + + std::ifstream stream(file_name, std::ios::binary); + const bool result = NNUE::ReadParameters(stream); + + if (!result) + // It's a problem if it doesn't finish when there is a read error. + std::cout << "Error! " << NNUE::kFileName << " not found or wrong format" << std::endl; -// ASSERT(result); - if (!result) - { - // It's a problem if it doesn't finish when there is a read error. - std::cout << "Error! " << NNUE::kFileName << " not found or wrong format" << std::endl; - //my_exit(); - } - else - std::cout << "info string NNUE " << NNUE::kFileName << " found & loaded" << std::endl; - } else - std::cout << "info string NNUE " << NNUE::kFileName << " not loaded" << std::endl; + std::cout << "info string NNUE " << NNUE::kFileName << " found & loaded" << std::endl; } // Initialization diff --git a/src/learn/learner.cpp b/src/learn/learner.cpp index 94991948..58719821 100644 --- a/src/learn/learner.cpp +++ b/src/learn/learner.cpp @@ -3092,7 +3092,7 @@ void learn(Position&, istringstream& is) //} if (use_convert_bin) { - is_ready(true); + init_nnue(true); cout << "convert_bin.." << endl; convert_bin(filenames,output_file_name, ply_minimum, ply_maximum, interpolate_eval); return; @@ -3100,7 +3100,7 @@ void learn(Position&, istringstream& is) } if (use_convert_bin_from_pgn_extract) { - is_ready(true); + init_nnue(true); cout << "convert_bin_from_pgn-extract.." << endl; convert_bin_from_pgn_extract(filenames, output_file_name, pgn_eval_side_to_move); return; @@ -3166,7 +3166,7 @@ void learn(Position&, istringstream& is) cout << "init.." << endl; // Read evaluation function parameters - is_ready(true); + init_nnue(true); #if !defined(EVAL_NNUE) cout << "init_grad.." << endl; diff --git a/src/learn/multi_think.cpp b/src/learn/multi_think.cpp index d511c277..ba2c47d4 100644 --- a/src/learn/multi_think.cpp +++ b/src/learn/multi_think.cpp @@ -20,7 +20,7 @@ void MultiThink::go_think() // Read evaluation function, etc. // In the case of the learn command, the value of the evaluation function may be corrected after reading the evaluation function, so // Skip memory corruption check. - is_ready(true); + init_nnue(true); // Call the derived class's init(). init(); diff --git a/src/uci.cpp b/src/uci.cpp index a95a629d..6d86ebca 100644 --- a/src/uci.cpp +++ b/src/uci.cpp @@ -73,7 +73,7 @@ namespace Learner void test_cmd(Position& pos, istringstream& is) { // Initialize as it may be searched. - is_ready(); + init_nnue(); std::string param; is >> param; @@ -209,7 +209,14 @@ namespace { } else if (token == "setoption") setoption(is); else if (token == "position") position(pos, is, states); - else if (token == "ucinewgame") { Search::clear(); elapsed = now(); } // Search::clear() may take some while + else if (token == "ucinewgame") + { +#if defined(EVAL_NNUE) + init_nnue(); +#endif + Search::clear(); + elapsed = now(); // Search::clear() may take some while + } } elapsed = now() - elapsed + 1; // Ensure positivity to avoid a 'divide by zero' @@ -250,7 +257,7 @@ namespace { // Make is_ready_cmd() callable from outside. (Because I want to call it from the bench command etc.) // Note that the phase is not initialized. -void is_ready(bool skipCorruptCheck) +void init_nnue(bool skipCorruptCheck) { #if defined(EVAL_NNUE) // After receiving "isready", modify so that a line feed is sent every 5 seconds until "readyok" is returned. (keep alive processing) @@ -260,59 +267,29 @@ void is_ready(bool skipCorruptCheck) // -Shogi GUI already does so, so MyShogi will follow along. //-Also, the engine side of Yaneura King modifies it so that after "isready" is received, a line feed is sent every 5 seconds until "readyok" is returned. - auto ended = false; - auto th = std::thread([&ended] { - int count = 0; - while (!ended) - { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - if (++count >= 50 /* 5 seconds */) - { - count = 0; - sync_cout << sync_endl; // Send a line break. - } - } - }); - // Perform processing that may take time, such as reading the evaluation function, at this timing. // If you do a time-consuming process at startup, Shogi place will make a timeout judgment and retire the recognition as a thinking engine. if (!UCI::load_eval_finished) { - // Read evaluation function - Eval::load_eval(); + // Read evaluation function + Eval::load_eval(); - // Calculate and save checksum (to check for subsequent memory corruption) - eval_sum = Eval::calc_check_sum(); + // Calculate and save checksum (to check for subsequent memory corruption) + eval_sum = Eval::calc_check_sum(); - // display soft name - Eval::print_softname(eval_sum); - - UCI::load_eval_finished = true; + // display soft name + Eval::print_softname(eval_sum); + UCI::load_eval_finished = true; } else { - // Check the checksum every time to see if the memory has been corrupted. - // It seems that the time is a little wasteful, but it is good because it is about 0.1 seconds. - if (!skipCorruptCheck && eval_sum != Eval::calc_check_sum()) - sync_cout << "Error! : EVAL memory is corrupted" << sync_endl; + // Check the checksum every time to see if the memory has been corrupted. + // It seems that the time is a little wasteful, but it is good because it is about 0.1 seconds. + if (!skipCorruptCheck && eval_sum != Eval::calc_check_sum()) + sync_cout << "Error! : EVAL memory is corrupted" << sync_endl; } - - // For isready, it is promised that the next command will not come until it returns readyok. - // Initialize various variables at this timing. - - TT.resize(Options["Hash"]); - Search::clear(); - Time.availableNodes = 0; - - Threads.stop = false; - - // Terminate the thread created to send keep alive and wait. - ended = true; - th.join(); #endif // defined(EVAL_NNUE) - - sync_cout << "readyok" << sync_endl; } @@ -399,8 +376,14 @@ void UCI::loop(int argc, char* argv[]) { else if (token == "setoption") setoption(is); else if (token == "go") go(pos, is, states); else if (token == "position") position(pos, is, states); - else if (token == "ucinewgame") Search::clear(); - else if (token == "isready") is_ready(); + else if (token == "ucinewgame") + { +#if defined(EVAL_NNUE) + init_nnue(); +#endif + Search::clear(); + } + else if (token == "isready") sync_cout << "readyok" << sync_endl; // Additional custom non-UCI commands, mainly for debugging. // Do not use these commands during a search! diff --git a/src/uci.h b/src/uci.h index 5073262e..6529f90c 100644 --- a/src/uci.h +++ b/src/uci.h @@ -87,7 +87,7 @@ extern UCI::OptionsMap Options; // If skipCorruptCheck == true, skip memory corruption check by check sum when reading the evaluation function a second time. // * This function is inconvenient if it is not available in Stockfish, so add it. -void is_ready(bool skipCorruptCheck = false); +void init_nnue(bool skipCorruptCheck = false); extern const char* StartFEN; diff --git a/src/ucioption.cpp b/src/ucioption.cpp index d63caa9f..f067a875 100644 --- a/src/ucioption.cpp +++ b/src/ucioption.cpp @@ -42,7 +42,7 @@ void on_hash_size(const Option& o) { TT.resize(size_t(o)); } void on_logger(const Option& o) { start_logger(o); } void on_threads(const Option& o) { Threads.set(size_t(o)); } void on_tb_path(const Option& o) { Tablebases::init(o); } -void on_eval_dir(const Option& o) { load_eval_finished = false; } +void on_eval_dir(const Option& o) { load_eval_finished = false; init_nnue(); } /// Our case insensitive less() function as required by UCI protocol