/* Stockfish, a UCI chess playing engine derived from Glaurung 2.1 Copyright (C) 2004-2020 The Stockfish developers (see AUTHORS file) Stockfish is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. Stockfish is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . */ // Code for calculating NNUE evaluation function #include #include #include #include #include "../position.h" #include "../misc.h" #include "../uci.h" #include "../types.h" #include "evaluate_nnue.h" namespace Eval::NNUE { const uint32_t kpp_board_index[PIECE_NB][COLOR_NB] = { // convention: W - us, B - them // viewed from other side, W and B are reversed { PS_NONE, PS_NONE }, { PS_W_PAWN, PS_B_PAWN }, { PS_W_KNIGHT, PS_B_KNIGHT }, { PS_W_BISHOP, PS_B_BISHOP }, { PS_W_ROOK, PS_B_ROOK }, { PS_W_QUEEN, PS_B_QUEEN }, { PS_W_KING, PS_B_KING }, { PS_NONE, PS_NONE }, { PS_NONE, PS_NONE }, { PS_B_PAWN, PS_W_PAWN }, { PS_B_KNIGHT, PS_W_KNIGHT }, { PS_B_BISHOP, PS_W_BISHOP }, { PS_B_ROOK, PS_W_ROOK }, { PS_B_QUEEN, PS_W_QUEEN }, { PS_B_KING, PS_W_KING }, { PS_NONE, PS_NONE } }; // Input feature converter LargePagePtr feature_transformer; // Evaluation function AlignedPtr network; // Evaluation function file name std::string fileName; // Saved evaluation function file name std::string savedfileName = "nn.bin"; // Get a string that represents the structure of the evaluation function std::string GetArchitectureString() { return "Features=" + FeatureTransformer::GetStructureString() + ",Network=" + Network::GetStructureString(); } UseNNUEMode useNNUE; std::string eval_file_loaded = "None"; namespace Detail { // Initialize the evaluation function parameters template void Initialize(AlignedPtr& pointer) { pointer.reset(reinterpret_cast(std_aligned_alloc(alignof(T), sizeof(T)))); std::memset(pointer.get(), 0, sizeof(T)); } template void Initialize(LargePagePtr& pointer) { static_assert(alignof(T) <= 4096, "aligned_large_pages_alloc() may fail for such a big alignment requirement of T"); pointer.reset(reinterpret_cast(aligned_large_pages_alloc(sizeof(T)))); std::memset(pointer.get(), 0, sizeof(T)); } // Read evaluation function parameters template bool ReadParameters(std::istream& stream, T& reference) { std::uint32_t header; header = read_little_endian(stream); if (!stream || header != T::GetHashValue()) return false; return reference.ReadParameters(stream); } // write evaluation function parameters template bool WriteParameters(std::ostream& stream, const AlignedPtr& pointer) { constexpr std::uint32_t header = T::GetHashValue(); stream.write(reinterpret_cast(&header), sizeof(header)); return pointer->WriteParameters(stream); } template bool WriteParameters(std::ostream& stream, const LargePagePtr& pointer) { constexpr std::uint32_t header = T::GetHashValue(); stream.write(reinterpret_cast(&header), sizeof(header)); return pointer->WriteParameters(stream); } } // namespace Detail // Initialize the evaluation function parameters void Initialize() { Detail::Initialize(feature_transformer); Detail::Initialize(network); } // Read network header bool ReadHeader(std::istream& stream, std::uint32_t* hash_value, std::string* architecture) { std::uint32_t version, size; version = read_little_endian(stream); *hash_value = read_little_endian(stream); size = read_little_endian(stream); if (!stream || version != kVersion) return false; architecture->resize(size); stream.read(&(*architecture)[0], size); return !stream.fail(); } // write the header bool WriteHeader(std::ostream& stream, std::uint32_t hash_value, const std::string& architecture) { stream.write(reinterpret_cast(&kVersion), sizeof(kVersion)); stream.write(reinterpret_cast(&hash_value), sizeof(hash_value)); const std::uint32_t size = static_cast(architecture.size()); stream.write(reinterpret_cast(&size), sizeof(size)); stream.write(architecture.data(), size); return !stream.fail(); } // Read network parameters bool ReadParameters(std::istream& stream) { std::uint32_t hash_value; std::string architecture; if (!ReadHeader(stream, &hash_value, &architecture)) return false; if (hash_value != kHashValue) return false; if (!Detail::ReadParameters(stream, *feature_transformer)) return false; if (!Detail::ReadParameters(stream, *network)) return false; return stream && stream.peek() == std::ios::traits_type::eof(); } // write evaluation function parameters bool WriteParameters(std::ostream& stream) { if (!WriteHeader(stream, kHashValue, GetArchitectureString())) return false; if (!Detail::WriteParameters(stream, feature_transformer)) return false; if (!Detail::WriteParameters(stream, network)) return false; return !stream.fail(); } // Evaluation function. Perform differential calculation. Value evaluate(const Position& pos) { alignas(kCacheLineSize) TransformedFeatureType transformed_features[FeatureTransformer::kBufferSize]; feature_transformer->Transform(pos, transformed_features); alignas(kCacheLineSize) char buffer[Network::kBufferSize]; const auto output = network->Propagate(transformed_features, buffer); return static_cast(output[0] / FV_SCALE); } // Load eval, from a file stream or a memory stream bool load_eval(std::string name, std::istream& stream) { Initialize(); if (Options["SkipLoadingEval"]) { std::cout << "info string SkipLoadingEval set to true, Net not loaded!" << std::endl; return true; } fileName = name; return ReadParameters(stream); } static UseNNUEMode nnue_mode_from_option(const UCI::Option& mode) { if (mode == "false") return UseNNUEMode::False; else if (mode == "true") return UseNNUEMode::True; else if (mode == "pure") return UseNNUEMode::Pure; return UseNNUEMode::False; } void init() { useNNUE = nnue_mode_from_option(Options["Use NNUE"]); if (useNNUE == UseNNUEMode::False) return; std::string eval_file = std::string(Options["EvalFile"]); #if defined(DEFAULT_NNUE_DIRECTORY) #define stringify2(x) #x #define stringify(x) stringify2(x) std::vector dirs = { "" , CommandLine::binaryDirectory , stringify(DEFAULT_NNUE_DIRECTORY) }; #else std::vector dirs = { "" , CommandLine::binaryDirectory }; #endif for (std::string directory : dirs) if (eval_file_loaded != eval_file) { std::ifstream stream(directory + eval_file, std::ios::binary); if (load_eval(eval_file, stream)) { sync_cout << "info string Loaded eval file " << directory + eval_file << sync_endl; eval_file_loaded = eval_file; } else { sync_cout << "info string ERROR: failed to load eval file " << directory + eval_file << sync_endl; } } #undef stringify2 #undef stringify } /// NNUE::verify() verifies that the last net used was loaded successfully void verify() { std::string eval_file = std::string(Options["EvalFile"]); if (useNNUE != UseNNUEMode::False && eval_file_loaded != eval_file) { UCI::OptionsMap defaults; UCI::init(defaults); std::string msg1 = "If the UCI option \"Use NNUE\" is set to true, network evaluation parameters compatible with the engine must be available."; std::string msg2 = "The option is set to true, but the network file " + eval_file + " was not loaded successfully."; std::string msg3 = "The UCI option EvalFile might need to specify the full path, including the directory name, to the network file."; std::string msg4 = "The default net can be downloaded from: https://tests.stockfishchess.org/api/nn/" + std::string(defaults["EvalFile"]); std::string msg5 = "The engine will be terminated now."; sync_cout << "info string ERROR: " << msg1 << sync_endl; sync_cout << "info string ERROR: " << msg2 << sync_endl; sync_cout << "info string ERROR: " << msg3 << sync_endl; sync_cout << "info string ERROR: " << msg4 << sync_endl; sync_cout << "info string ERROR: " << msg5 << sync_endl; std::exit(EXIT_FAILURE); } if (useNNUE != UseNNUEMode::False) sync_cout << "info string NNUE evaluation using " << eval_file << " enabled" << sync_endl; else sync_cout << "info string classical evaluation enabled" << sync_endl; } } // namespace Eval::NNUE