diff --git a/src/evaluate.cpp b/src/evaluate.cpp index b3894fe8..0326a2f8 100644 --- a/src/evaluate.cpp +++ b/src/evaluate.cpp @@ -27,6 +27,8 @@ #include #include +#include "nnue/evaluate_nnue.h" + #include "bitboard.h" #include "evaluate.h" #include "material.h" @@ -37,88 +39,6 @@ #include "incbin/incbin.h" using namespace std; -using namespace Eval::NNUE; - -namespace Eval { - - UseNNUEMode useNNUE; - string eval_file_loaded = "None"; - - static UseNNUEMode nnue_mode_from_option(const UCI::Option& mode) - { - if (mode == "false") - return UseNNUEMode::False; - else if (mode == "true") - return UseNNUEMode::True; - else if (mode == "pure") - return UseNNUEMode::Pure; - - return UseNNUEMode::False; - } - - void NNUE::init() { - - useNNUE = nnue_mode_from_option(Options["Use NNUE"]); - if (useNNUE == UseNNUEMode::False) - return; - - string eval_file = string(Options["EvalFile"]); - - #if defined(DEFAULT_NNUE_DIRECTORY) - #define stringify2(x) #x - #define stringify(x) stringify2(x) - vector dirs = { "" , CommandLine::binaryDirectory , stringify(DEFAULT_NNUE_DIRECTORY) }; - #else - vector dirs = { "" , CommandLine::binaryDirectory }; - #endif - - for (string directory : dirs) - if (eval_file_loaded != eval_file) - { - ifstream stream(directory + eval_file, ios::binary); - if (load_eval(eval_file, stream)) - { - sync_cout << "info string Loaded eval file " << directory + eval_file << sync_endl; - eval_file_loaded = eval_file; - } - else - { - sync_cout << "info string ERROR: failed to load eval file " << directory + eval_file << sync_endl; - } - } - } - - /// NNUE::verify() verifies that the last net used was loaded successfully - void NNUE::verify() { - - string eval_file = string(Options["EvalFile"]); - - if (useNNUE != UseNNUEMode::False && eval_file_loaded != eval_file) - { - UCI::OptionsMap defaults; - UCI::init(defaults); - - string msg1 = "If the UCI option \"Use NNUE\" is set to true, network evaluation parameters compatible with the engine must be available."; - string msg2 = "The option is set to true, but the network file " + eval_file + " was not loaded successfully."; - string msg3 = "The UCI option EvalFile might need to specify the full path, including the directory name, to the network file."; - string msg4 = "The default net can be downloaded from: https://tests.stockfishchess.org/api/nn/" + string(defaults["EvalFile"]); - string msg5 = "The engine will be terminated now."; - - sync_cout << "info string ERROR: " << msg1 << sync_endl; - sync_cout << "info string ERROR: " << msg2 << sync_endl; - sync_cout << "info string ERROR: " << msg3 << sync_endl; - sync_cout << "info string ERROR: " << msg4 << sync_endl; - sync_cout << "info string ERROR: " << msg5 << sync_endl; - - exit(EXIT_FAILURE); - } - - if (useNNUE != UseNNUEMode::False) - sync_cout << "info string NNUE evaluation using " << eval_file << " enabled" << sync_endl; - else - sync_cout << "info string classical evaluation enabled" << sync_endl; - } -} namespace Trace { @@ -994,7 +914,7 @@ Value Eval::evaluate(const Position& pos) { Value v; - if (Eval::useNNUE == UseNNUEMode::Pure) { + if (NNUE::useNNUE == NNUE::UseNNUEMode::Pure) { v = NNUE::evaluate(pos); // Guarantee evaluation does not hit the tablebase range @@ -1002,7 +922,7 @@ Value Eval::evaluate(const Position& pos) { return v; } - else if (Eval::useNNUE == UseNNUEMode::False) + else if (NNUE::useNNUE == NNUE::UseNNUEMode::False) v = Evaluation(pos).value(); else { @@ -1085,7 +1005,7 @@ std::string Eval::trace(const Position& pos) { ss << "\nClassical evaluation: " << to_cp(v) << " (white side)\n"; - if (useNNUE != UseNNUEMode::False) + if (NNUE::useNNUE != NNUE::UseNNUEMode::False) { v = NNUE::evaluate(pos); v = pos.side_to_move() == WHITE ? v : -v; diff --git a/src/evaluate.h b/src/evaluate.h index bce5488d..fc626698 100644 --- a/src/evaluate.h +++ b/src/evaluate.h @@ -26,33 +26,14 @@ class Position; namespace Eval { - enum struct UseNNUEMode - { - False, - True, - Pure - }; - std::string trace(const Position& pos); Value evaluate(const Position& pos); - extern UseNNUEMode useNNUE; - extern std::string eval_file_loaded; - // The default net name MUST follow the format nn-[SHA256 first 12 digits].nnue // for the build process (profile-build and fishtest) to work. Do not change the // name of the macro, as it is used in the Makefile. #define EvalFileDefaultName "nn-98a7585c85e9.nnue" - namespace NNUE { - - Value evaluate(const Position& pos); - bool load_eval(std::string name, std::istream& stream); - void init(); - void verify(); - - } // namespace NNUE - } // namespace Eval #endif // #ifndef EVALUATE_H_INCLUDED diff --git a/src/learn/gensfen.cpp b/src/learn/gensfen.cpp index 5f7541f5..7c5b20be 100644 --- a/src/learn/gensfen.cpp +++ b/src/learn/gensfen.cpp @@ -12,6 +12,7 @@ #include "extra/nnue_data_binpack_format.h" +#include "nnue/evaluate_nnue.h" #include "nnue/evaluate_nnue_learner.h" #include "syzygy/tbprobe.h" diff --git a/src/learn/learn.cpp b/src/learn/learn.cpp index 3648a40f..b2ee5aa1 100644 --- a/src/learn/learn.cpp +++ b/src/learn/learn.cpp @@ -32,6 +32,7 @@ #include "extra/nnue_data_binpack_format.h" +#include "nnue/evaluate_nnue.h" #include "nnue/evaluate_nnue_learner.h" #include "syzygy/tbprobe.h" diff --git a/src/learn/multi_think.cpp b/src/learn/multi_think.cpp index 80bc72b5..daed3e96 100644 --- a/src/learn/multi_think.cpp +++ b/src/learn/multi_think.cpp @@ -1,5 +1,7 @@ #include "multi_think.h" +#include "nnue/evaluate_nnue.h" + #include "tt.h" #include "uci.h" #include "types.h" diff --git a/src/main.cpp b/src/main.cpp index e6dff918..1a13dc62 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -18,6 +18,8 @@ #include +#include "nnue/evaluate_nnue.h" + #include "bitboard.h" #include "endgame.h" #include "position.h" diff --git a/src/nnue/evaluate_nnue.cpp b/src/nnue/evaluate_nnue.cpp index 28c86feb..f7f9adcc 100644 --- a/src/nnue/evaluate_nnue.cpp +++ b/src/nnue/evaluate_nnue.cpp @@ -19,12 +19,14 @@ // Code for calculating NNUE evaluation function #include +#include +#include #include -#include "../evaluate.h" #include "../position.h" #include "../misc.h" #include "../uci.h" +#include "../types.h" #include "evaluate_nnue.h" @@ -69,6 +71,9 @@ namespace Eval::NNUE { ",Network=" + Network::GetStructureString(); } + UseNNUEMode useNNUE; + std::string eval_file_loaded = "None"; + namespace Detail { // Initialize the evaluation function parameters @@ -190,4 +195,82 @@ namespace Eval::NNUE { return ReadParameters(stream); } + static UseNNUEMode nnue_mode_from_option(const UCI::Option& mode) + { + if (mode == "false") + return UseNNUEMode::False; + else if (mode == "true") + return UseNNUEMode::True; + else if (mode == "pure") + return UseNNUEMode::Pure; + + return UseNNUEMode::False; + } + + void init() { + + useNNUE = nnue_mode_from_option(Options["Use NNUE"]); + if (useNNUE == UseNNUEMode::False) + return; + + std::string eval_file = std::string(Options["EvalFile"]); + + #if defined(DEFAULT_NNUE_DIRECTORY) + #define stringify2(x) #x + #define stringify(x) stringify2(x) + std::vector dirs = { "" , CommandLine::binaryDirectory , stringify(DEFAULT_NNUE_DIRECTORY) }; + #else + std::vector dirs = { "" , CommandLine::binaryDirectory }; + #endif + + for (std::string directory : dirs) + if (eval_file_loaded != eval_file) + { + std::ifstream stream(directory + eval_file, std::ios::binary); + if (load_eval(eval_file, stream)) + { + sync_cout << "info string Loaded eval file " << directory + eval_file << sync_endl; + eval_file_loaded = eval_file; + } + else + { + sync_cout << "info string ERROR: failed to load eval file " << directory + eval_file << sync_endl; + } + } + + #undef stringify2 + #undef stringify + } + + /// NNUE::verify() verifies that the last net used was loaded successfully + void verify() { + + std::string eval_file = std::string(Options["EvalFile"]); + + if (useNNUE != UseNNUEMode::False && eval_file_loaded != eval_file) + { + UCI::OptionsMap defaults; + UCI::init(defaults); + + std::string msg1 = "If the UCI option \"Use NNUE\" is set to true, network evaluation parameters compatible with the engine must be available."; + std::string msg2 = "The option is set to true, but the network file " + eval_file + " was not loaded successfully."; + std::string msg3 = "The UCI option EvalFile might need to specify the full path, including the directory name, to the network file."; + std::string msg4 = "The default net can be downloaded from: https://tests.stockfishchess.org/api/nn/" + std::string(defaults["EvalFile"]); + std::string msg5 = "The engine will be terminated now."; + + sync_cout << "info string ERROR: " << msg1 << sync_endl; + sync_cout << "info string ERROR: " << msg2 << sync_endl; + sync_cout << "info string ERROR: " << msg3 << sync_endl; + sync_cout << "info string ERROR: " << msg4 << sync_endl; + sync_cout << "info string ERROR: " << msg5 << sync_endl; + + std::exit(EXIT_FAILURE); + } + + if (useNNUE != UseNNUEMode::False) + sync_cout << "info string NNUE evaluation using " << eval_file << " enabled" << sync_endl; + else + sync_cout << "info string classical evaluation enabled" << sync_endl; + } + } // namespace Eval::NNUE diff --git a/src/nnue/evaluate_nnue.h b/src/nnue/evaluate_nnue.h index 68153cac..dcfa071d 100644 --- a/src/nnue/evaluate_nnue.h +++ b/src/nnue/evaluate_nnue.h @@ -27,6 +27,13 @@ namespace Eval::NNUE { + enum struct UseNNUEMode + { + False, + True, + Pure + }; + // Hash value of evaluation function structure constexpr std::uint32_t kHashValue = FeatureTransformer::GetHashValue() ^ Network::GetHashValue(); @@ -66,6 +73,9 @@ namespace Eval::NNUE { // Saved evaluation function file name extern std::string savedfileName; + extern UseNNUEMode useNNUE; + extern std::string eval_file_loaded; + // Get a string that represents the structure of the evaluation function std::string GetArchitectureString(); @@ -83,6 +93,11 @@ namespace Eval::NNUE { // write evaluation function parameters bool WriteParameters(std::ostream& stream); + Value evaluate(const Position& pos); + bool load_eval(std::string name, std::istream& stream); + void init(); + void verify(); + } // namespace Eval::NNUE #endif // #ifndef NNUE_EVALUATE_NNUE_H_INCLUDED diff --git a/src/nnue/nnue_common.h b/src/nnue/nnue_common.h index 319f005b..9975134c 100644 --- a/src/nnue/nnue_common.h +++ b/src/nnue/nnue_common.h @@ -24,6 +24,8 @@ #include #include +#include "../types.h" + #if defined(USE_AVX2) #include diff --git a/src/position.cpp b/src/position.cpp index 4e47f772..06a4e0b7 100644 --- a/src/position.cpp +++ b/src/position.cpp @@ -23,6 +23,8 @@ #include #include +#include "nnue/evaluate_nnue.h" + #include "bitboard.h" #include "misc.h" #include "movegen.h" @@ -757,7 +759,7 @@ void Position::do_move(Move m, StateInfo& newSt, bool givesCheck) { else st->nonPawnMaterial[them] -= PieceValue[MG][captured]; - if (Eval::useNNUE != Eval::UseNNUEMode::False) + if (Eval::NNUE::useNNUE != Eval::NNUE::UseNNUEMode::False) { dp.dirty_num = 2; // 1 piece moved, 1 piece captured dp.piece[1] = captured; @@ -801,7 +803,7 @@ void Position::do_move(Move m, StateInfo& newSt, bool givesCheck) { // Move the piece. The tricky Chess960 castling is handled earlier if (type_of(m) != CASTLING) { - if (Eval::useNNUE != Eval::UseNNUEMode::False) + if (Eval::NNUE::useNNUE != Eval::NNUE::UseNNUEMode::False) { dp.piece[0] = pc; dp.from[0] = from; @@ -832,7 +834,7 @@ void Position::do_move(Move m, StateInfo& newSt, bool givesCheck) { remove_piece(to); put_piece(promotion, to); - if (Eval::useNNUE != Eval::UseNNUEMode::False) + if (Eval::NNUE::useNNUE != Eval::NNUE::UseNNUEMode::False) { // Promoting pawn to SQ_NONE, promoted piece from SQ_NONE dp.to[0] = SQ_NONE; @@ -970,7 +972,7 @@ void Position::do_castling(Color us, Square from, Square& to, Square& rfrom, Squ rto = relative_square(us, kingSide ? SQ_F1 : SQ_D1); to = relative_square(us, kingSide ? SQ_G1 : SQ_C1); - if (Do && Eval::useNNUE != Eval::UseNNUEMode::False) + if (Do && Eval::NNUE::useNNUE != Eval::NNUE::UseNNUEMode::False) { auto& dp = st->dirtyPiece; dp.piece[0] = make_piece(us, KING); diff --git a/src/search.cpp b/src/search.cpp index 1623ff06..26a675d7 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -23,6 +23,8 @@ #include #include +#include "nnue/evaluate_nnue.h" + #include "evaluate.h" #include "misc.h" #include "movegen.h" diff --git a/src/uci.cpp b/src/uci.cpp index 166e437c..73ff0256 100644 --- a/src/uci.cpp +++ b/src/uci.cpp @@ -22,6 +22,7 @@ #include #include +#include "nnue/evaluate_nnue.h" #include "evaluate.h" #include "movegen.h" #include "nnue/nnue_test_command.h" diff --git a/src/ucioption.cpp b/src/ucioption.cpp index 099ca2ae..bdb1c6b1 100644 --- a/src/ucioption.cpp +++ b/src/ucioption.cpp @@ -21,6 +21,7 @@ #include #include +#include "nnue/evaluate_nnue.h" #include "evaluate.h" #include "misc.h" #include "search.h"