diff --git a/src/learn/gensfen.cpp b/src/learn/gensfen.cpp index 1a9187ae..22fddafb 100644 --- a/src/learn/gensfen.cpp +++ b/src/learn/gensfen.cpp @@ -1000,7 +1000,7 @@ namespace Learner << " detect_draw_by_insufficient_mating_material = " << detect_draw_by_insufficient_mating_material << endl; // Show if the training data generator uses NNUE. - Eval::NNUE::verify(); + Eval::NNUE::verify_eval_file_loaded(); Threads.main()->ponder = false; diff --git a/src/learn/learn.cpp b/src/learn/learn.cpp index 0fce5d95..a0a8ec07 100644 --- a/src/learn/learn.cpp +++ b/src/learn/learn.cpp @@ -1486,6 +1486,27 @@ namespace Learner std::cout << "..shuffle_on_memory done." << std::endl; } + static void set_learning_search_limits() + { + // About Search::Limits + // Be careful because this member variable is global and affects other threads. + auto& limits = Search::Limits; + + limits.startTime = now(); + + // Make the search equivalent to the "go infinite" command. (Because it is troublesome if time management is done) + limits.infinite = true; + + // Since PV is an obstacle when displayed, erase it. + limits.silent = true; + + // If you use this, it will be compared with the accumulated nodes of each thread. Therefore, do not use it. + limits.nodes = 0; + + // depth is also processed by the one passed as an argument of Learner::search(). + limits.depth = 0; + } + // Learning from the generated game record void learn(Position&, istringstream& is) { @@ -1837,30 +1858,9 @@ namespace Learner cout << "init.." << endl; - // Read evaluation function parameters - Eval::NNUE::init(); - Threads.main()->ponder = false; - // About Search::Limits - // Be careful because this member variable is global and affects other threads. - { - auto& limits = Search::Limits; - - limits.startTime = now(); - - // Make the search equivalent to the "go infinite" command. (Because it is troublesome if time management is done) - limits.infinite = true; - - // Since PV is an obstacle when displayed, erase it. - limits.silent = true; - - // If you use this, it will be compared with the accumulated nodes of each thread. Therefore, do not use it. - limits.nodes = 0; - - // depth is also processed by the one passed as an argument of Learner::search(). - limits.depth = 0; - } + set_learning_search_limits(); cout << "init_training.." << endl; Eval::NNUE::InitializeTraining(seed); @@ -1907,6 +1907,8 @@ namespace Learner sr.read_validation_set(validation_set_file_name, eval_limit); } + Eval::NNUE::verify_any_net_loaded(); + // Calculate rmse once at this point (timing of 0 sfen) // sr.calc_rmse(); diff --git a/src/learn/multi_think.cpp b/src/learn/multi_think.cpp index d2ae65eb..bf1ab29b 100644 --- a/src/learn/multi_think.cpp +++ b/src/learn/multi_think.cpp @@ -11,11 +11,6 @@ void MultiThink::go_think() { - // Read evaluation function, etc. - // In the case of the learn command, the value of the evaluation function may be corrected after reading the evaluation function, so - // Skip memory corruption check. - Eval::NNUE::init(); - // Call the derived class's init(). init(); diff --git a/src/nnue/evaluate_nnue.cpp b/src/nnue/evaluate_nnue.cpp index f7f9adcc..e3a7be63 100644 --- a/src/nnue/evaluate_nnue.cpp +++ b/src/nnue/evaluate_nnue.cpp @@ -235,6 +235,7 @@ namespace Eval::NNUE { else { sync_cout << "info string ERROR: failed to load eval file " << directory + eval_file << sync_endl; + eval_file_loaded.clear(); } } @@ -243,7 +244,7 @@ namespace Eval::NNUE { } /// NNUE::verify() verifies that the last net used was loaded successfully - void verify() { + void verify_eval_file_loaded() { std::string eval_file = std::string(Options["EvalFile"]); @@ -273,4 +274,31 @@ namespace Eval::NNUE { sync_cout << "info string classical evaluation enabled" << sync_endl; } + /// In training we override eval file so this is useful. + void verify_any_net_loaded() { + + if (useNNUE != UseNNUEMode::False && eval_file_loaded.empty()) + { + UCI::OptionsMap defaults; + UCI::init(defaults); + + std::string msg1 = "If the UCI option \"Use NNUE\" is set to true, network evaluation parameters compatible with the engine must be available."; + std::string msg2 = "The option is set to true, but the network file was not loaded successfully."; + std::string msg3 = "The UCI option EvalFile might need to specify the full path, including the directory name, to the network file."; + std::string msg5 = "The engine will be terminated now."; + + sync_cout << "info string ERROR: " << msg1 << sync_endl; + sync_cout << "info string ERROR: " << msg2 << sync_endl; + sync_cout << "info string ERROR: " << msg3 << sync_endl; + sync_cout << "info string ERROR: " << msg5 << sync_endl; + + std::exit(EXIT_FAILURE); + } + + if (useNNUE != UseNNUEMode::False) + sync_cout << "info string NNUE evaluation using " << eval_file_loaded << " enabled" << sync_endl; + else + sync_cout << "info string classical evaluation enabled" << sync_endl; + } + } // namespace Eval::NNUE diff --git a/src/nnue/evaluate_nnue.h b/src/nnue/evaluate_nnue.h index dcfa071d..5335713b 100644 --- a/src/nnue/evaluate_nnue.h +++ b/src/nnue/evaluate_nnue.h @@ -96,7 +96,8 @@ namespace Eval::NNUE { Value evaluate(const Position& pos); bool load_eval(std::string name, std::istream& stream); void init(); - void verify(); + void verify_eval_file_loaded(); + void verify_any_net_loaded(); } // namespace Eval::NNUE diff --git a/src/search.cpp b/src/search.cpp index 79848812..436e11fd 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -219,7 +219,7 @@ void MainThread::search() { Time.init(Limits, us, rootPos.game_ply()); TT.new_search(); - Eval::NNUE::verify(); + Eval::NNUE::verify_eval_file_loaded(); if (rootMoves.empty()) { diff --git a/src/uci.cpp b/src/uci.cpp index ff735b2e..896f6db8 100644 --- a/src/uci.cpp +++ b/src/uci.cpp @@ -101,7 +101,7 @@ namespace { Position p; p.set(pos.fen(), Options["UCI_Chess960"], &states->back(), Threads.main()); - Eval::NNUE::verify(); + Eval::NNUE::verify_eval_file_loaded(); sync_cout << "\n" << Eval::trace(p) << sync_endl; }