diff --git a/src/evaluate.cpp b/src/evaluate.cpp index 8edc9bb8..94581998 100644 --- a/src/evaluate.cpp +++ b/src/evaluate.cpp @@ -32,23 +32,32 @@ #include "thread.h" #include "uci.h" -#ifdef EVAL_LEARN -namespace Learner -{ - extern bool use_raw_nnue_eval; -} -#endif - namespace Eval { - bool useNNUE; + UseNNUEMode useNNUE; std::string eval_file_loaded="None"; + static UseNNUEMode nnue_mode_from_option(const std::string& mode) + { + if (mode == "false") + return UseNNUEMode::False; + else if (mode == "true") + return UseNNUEMode::True; + +#ifdef EVAL_LEARN + else if (mode == "pure") + return UseNNUEMode::Pure; +#endif + + return UseNNUEMode::False; + } + void init_NNUE() { - useNNUE = Options["Use NNUE"]; + useNNUE = nnue_mode_from_option(Options["Use NNUE"]); + std::string eval_file = std::string(Options["EvalFile"]); - if (useNNUE && eval_file_loaded != eval_file) + if (useNNUE != UseNNUEMode::False && eval_file_loaded != eval_file) if (Eval::NNUE::load_eval_file(eval_file)) eval_file_loaded = eval_file; } @@ -56,8 +65,7 @@ namespace Eval { void verify_NNUE() { std::string eval_file = std::string(Options["EvalFile"]); - if (useNNUE && eval_file_loaded != eval_file) - { + if (useNNUE != UseNNUEMode::False && eval_file_loaded != eval_file) { UCI::OptionsMap defaults; UCI::init(defaults); @@ -69,7 +77,7 @@ namespace Eval { std::exit(EXIT_FAILURE); } - if (useNNUE) + if (useNNUE != UseNNUEMode::False) sync_cout << "info string NNUE evaluation using " << eval_file << " enabled." << sync_endl; else sync_cout << "info string classical evaluation enabled." << sync_endl; @@ -948,17 +956,17 @@ make_v: Value Eval::evaluate(const Position& pos) { #ifdef EVAL_LEARN - if (Learner::use_raw_nnue_eval) { + if (useNNUE == UseNNUEMode::Pure) { return NNUE::evaluate(pos); } #endif - bool classical = !Eval::useNNUE - || abs(eg_value(pos.psq_score())) * 16 > NNUEThreshold1 * (16 + pos.rule50_count()); + bool classical = useNNUE == UseNNUEMode::False + || abs(eg_value(pos.psq_score())) * 16 > NNUEThreshold1 * (16 + pos.rule50_count()); Value v = classical ? Evaluation(pos).value() : NNUE::evaluate(pos) * 5 / 4 + Tempo; - if (classical && Eval::useNNUE && abs(v) * 16 < NNUEThreshold2 * (16 + pos.rule50_count())) + if (classical && useNNUE != UseNNUEMode::False && abs(v) * 16 < NNUEThreshold2 * (16 + pos.rule50_count())) v = NNUE::evaluate(pos) * 5 / 4 + Tempo; // Damp down the evaluation linearly when shuffling @@ -1015,7 +1023,7 @@ std::string Eval::trace(const Position& pos) { ss << "\nClassical evaluation: " << to_cp(v) << " (white side)\n"; - if (Eval::useNNUE) + if (useNNUE != UseNNUEMode::False) { v = NNUE::evaluate(pos); v = pos.side_to_move() == WHITE ? v : -v; diff --git a/src/evaluate.h b/src/evaluate.h index e808068d..61052e90 100644 --- a/src/evaluate.h +++ b/src/evaluate.h @@ -26,11 +26,20 @@ class Position; namespace Eval { + enum struct UseNNUEMode + { + False, + True + +#ifdef EVAL_LEARN + ,Pure +#endif + }; std::string trace(const Position& pos); Value evaluate(const Position& pos); - extern bool useNNUE; + extern UseNNUEMode useNNUE; extern std::string eval_file_loaded; void init_NNUE(); void verify_NNUE(); diff --git a/src/learn/gensfen.cpp b/src/learn/gensfen.cpp index 99a783bb..9088fd81 100644 --- a/src/learn/gensfen.cpp +++ b/src/learn/gensfen.cpp @@ -44,12 +44,6 @@ namespace Learner static bool detect_draw_by_consecutive_low_score = false; static bool detect_draw_by_insufficient_mating_material = false; - // Use raw NNUE eval value in the Eval::evaluate(). - // If hybrid eval is enabled, training data - // generation and training don't work well. - // https://discordapp.com/channels/435943710472011776/733545871911813221/748524079761326192 - extern bool use_raw_nnue_eval; - static SfenOutputType sfen_output_type = SfenOutputType::Bin; static bool ends_with(const std::string& lhs, const std::string& end) @@ -1111,8 +1105,6 @@ namespace Learner is >> detect_draw_by_consecutive_low_score; else if (token == "detect_draw_by_insufficient_mating_material") is >> detect_draw_by_insufficient_mating_material; - else if (token == "use_raw_nnue_eval") - is >> use_raw_nnue_eval; else if (token == "sfen_format") is >> sfen_format; else diff --git a/src/learn/learner.cpp b/src/learn/learner.cpp index 7cc04406..da093192 100644 --- a/src/learn/learner.cpp +++ b/src/learn/learner.cpp @@ -93,12 +93,6 @@ namespace Learner // data directly. In those cases, we set false to this variable. static bool convert_teacher_signal_to_winning_probability = true; - // Use raw NNUE eval value in the Eval::evaluate(). If hybrid eval is enabled, training data - // generation and training don't work well. - // https://discordapp.com/channels/435943710472011776/733545871911813221/748524079761326192 - // This CANNOT be static since it's used elsewhere. - bool use_raw_nnue_eval = false; - // Using stockfish's WDL with win rate model instead of sigmoid static bool use_wdl = false; @@ -1811,7 +1805,6 @@ namespace Learner else if (option == "dest_score_min_value") is >> dest_score_min_value; else if (option == "dest_score_max_value") is >> dest_score_max_value; else if (option == "convert_teacher_signal_to_winning_probability") is >> convert_teacher_signal_to_winning_probability; - else if (option == "use_raw_nnue_eval") is >> use_raw_nnue_eval; // Otherwise, it's a filename. else diff --git a/src/position.cpp b/src/position.cpp index fe89b753..5ac461bc 100644 --- a/src/position.cpp +++ b/src/position.cpp @@ -755,7 +755,7 @@ void Position::do_move(Move m, StateInfo& newSt, bool givesCheck) { else st->nonPawnMaterial[them] -= PieceValue[MG][captured]; - if (Eval::useNNUE) + if (Eval::useNNUE != Eval::UseNNUEMode::False) { dp.dirty_num = 2; // 1 piece moved, 1 piece captured dp.piece[1] = captured; @@ -799,7 +799,7 @@ void Position::do_move(Move m, StateInfo& newSt, bool givesCheck) { // Move the piece. The tricky Chess960 castling is handled earlier if (type_of(m) != CASTLING) { - if (Eval::useNNUE) + if (Eval::useNNUE != Eval::UseNNUEMode::False) { dp.piece[0] = pc; dp.from[0] = from; @@ -830,7 +830,7 @@ void Position::do_move(Move m, StateInfo& newSt, bool givesCheck) { remove_piece(to); put_piece(promotion, to); - if (Eval::useNNUE) + if (Eval::useNNUE != Eval::UseNNUEMode::False) { // Promoting pawn to SQ_NONE, promoted piece from SQ_NONE dp.to[0] = SQ_NONE; @@ -968,7 +968,7 @@ void Position::do_castling(Color us, Square from, Square& to, Square& rfrom, Squ rto = relative_square(us, kingSide ? SQ_F1 : SQ_D1); to = relative_square(us, kingSide ? SQ_G1 : SQ_C1); - if (Do && Eval::useNNUE) + if (Do && Eval::useNNUE != Eval::UseNNUEMode::False) { auto& dp = st->dirtyPiece; dp.piece[0] = make_piece(us, KING); @@ -997,7 +997,7 @@ void Position::do_null_move(StateInfo& newSt) { assert(!checkers()); assert(&newSt != st); - if (Eval::useNNUE) + if (Eval::useNNUE != Eval::UseNNUEMode::False) { std::memcpy(&newSt, st, sizeof(StateInfo)); st->accumulator.computed_score = false; diff --git a/src/ucioption.cpp b/src/ucioption.cpp index b24d8d78..61e47539 100644 --- a/src/ucioption.cpp +++ b/src/ucioption.cpp @@ -86,7 +86,11 @@ void init(OptionsMap& o) { o["SyzygyProbeDepth"] << Option(1, 1, 100); o["Syzygy50MoveRule"] << Option(true); o["SyzygyProbeLimit"] << Option(7, 0, 7); - o["Use NNUE"] << Option(true, on_use_NNUE); +#ifdef EVAL_LEARN + o["Use NNUE"] << Option("true var true var false var pure", "true", on_use_NNUE); +#else + o["Use NNUE"] << Option("true var true var false", "true", on_use_NNUE); +#endif // The default must follow the format nn-[SHA256 first 12 digits].nnue // for the build process (profile-build and fishtest) to work. o["EvalFile"] << Option("nn-82215d0fd0df.nnue", on_eval_file);