Move nnue evaluation stuff from evaluate.h to nnue/evaluate_nnue.h

This commit is contained in:
Tomasz Sobczyk
2020-10-14 19:06:47 +02:00
committed by nodchip
parent 4a340ad3b2
commit 0494adeb2c
13 changed files with 122 additions and 109 deletions

View File

@@ -27,6 +27,8 @@
#include <streambuf>
#include <vector>
#include "nnue/evaluate_nnue.h"
#include "bitboard.h"
#include "evaluate.h"
#include "material.h"
@@ -37,88 +39,6 @@
#include "incbin/incbin.h"
using namespace std;
using namespace Eval::NNUE;
namespace Eval {
UseNNUEMode useNNUE;
string eval_file_loaded = "None";
static UseNNUEMode nnue_mode_from_option(const UCI::Option& mode)
{
if (mode == "false")
return UseNNUEMode::False;
else if (mode == "true")
return UseNNUEMode::True;
else if (mode == "pure")
return UseNNUEMode::Pure;
return UseNNUEMode::False;
}
void NNUE::init() {
useNNUE = nnue_mode_from_option(Options["Use NNUE"]);
if (useNNUE == UseNNUEMode::False)
return;
string eval_file = string(Options["EvalFile"]);
#if defined(DEFAULT_NNUE_DIRECTORY)
#define stringify2(x) #x
#define stringify(x) stringify2(x)
vector<string> dirs = { "" , CommandLine::binaryDirectory , stringify(DEFAULT_NNUE_DIRECTORY) };
#else
vector<string> dirs = { "" , CommandLine::binaryDirectory };
#endif
for (string directory : dirs)
if (eval_file_loaded != eval_file)
{
ifstream stream(directory + eval_file, ios::binary);
if (load_eval(eval_file, stream))
{
sync_cout << "info string Loaded eval file " << directory + eval_file << sync_endl;
eval_file_loaded = eval_file;
}
else
{
sync_cout << "info string ERROR: failed to load eval file " << directory + eval_file << sync_endl;
}
}
}
/// NNUE::verify() verifies that the last net used was loaded successfully
void NNUE::verify() {
string eval_file = string(Options["EvalFile"]);
if (useNNUE != UseNNUEMode::False && eval_file_loaded != eval_file)
{
UCI::OptionsMap defaults;
UCI::init(defaults);
string msg1 = "If the UCI option \"Use NNUE\" is set to true, network evaluation parameters compatible with the engine must be available.";
string msg2 = "The option is set to true, but the network file " + eval_file + " was not loaded successfully.";
string msg3 = "The UCI option EvalFile might need to specify the full path, including the directory name, to the network file.";
string msg4 = "The default net can be downloaded from: https://tests.stockfishchess.org/api/nn/" + string(defaults["EvalFile"]);
string msg5 = "The engine will be terminated now.";
sync_cout << "info string ERROR: " << msg1 << sync_endl;
sync_cout << "info string ERROR: " << msg2 << sync_endl;
sync_cout << "info string ERROR: " << msg3 << sync_endl;
sync_cout << "info string ERROR: " << msg4 << sync_endl;
sync_cout << "info string ERROR: " << msg5 << sync_endl;
exit(EXIT_FAILURE);
}
if (useNNUE != UseNNUEMode::False)
sync_cout << "info string NNUE evaluation using " << eval_file << " enabled" << sync_endl;
else
sync_cout << "info string classical evaluation enabled" << sync_endl;
}
}
namespace Trace {
@@ -994,7 +914,7 @@ Value Eval::evaluate(const Position& pos) {
Value v;
if (Eval::useNNUE == UseNNUEMode::Pure) {
if (NNUE::useNNUE == NNUE::UseNNUEMode::Pure) {
v = NNUE::evaluate(pos);
// Guarantee evaluation does not hit the tablebase range
@@ -1002,7 +922,7 @@ Value Eval::evaluate(const Position& pos) {
return v;
}
else if (Eval::useNNUE == UseNNUEMode::False)
else if (NNUE::useNNUE == NNUE::UseNNUEMode::False)
v = Evaluation<NO_TRACE>(pos).value();
else
{
@@ -1085,7 +1005,7 @@ std::string Eval::trace(const Position& pos) {
ss << "\nClassical evaluation: " << to_cp(v) << " (white side)\n";
if (useNNUE != UseNNUEMode::False)
if (NNUE::useNNUE != NNUE::UseNNUEMode::False)
{
v = NNUE::evaluate(pos);
v = pos.side_to_move() == WHITE ? v : -v;