From eac1d430b40015734a7ad92cf8520af4b92db76b Mon Sep 17 00:00:00 2001 From: Tomasz Sobczyk Date: Mon, 24 May 2021 19:43:07 +0200 Subject: [PATCH] Add dedicated command for training data validation. --- docs/validate_training_data.md | 12 +++ src/Makefile | 1 + src/extra/nnue_data_binpack_format.h | 150 +++++++++++++++++++++++++++ src/tools/validate_training_data.cpp | 122 ++++++++++++++++++++++ src/tools/validate_training_data.h | 12 +++ src/uci.cpp | 2 + 6 files changed, 299 insertions(+) create mode 100644 docs/validate_training_data.md create mode 100644 src/tools/validate_training_data.cpp create mode 100644 src/tools/validate_training_data.h diff --git a/docs/validate_training_data.md b/docs/validate_training_data.md new file mode 100644 index 00000000..e2bfc30c --- /dev/null +++ b/docs/validate_training_data.md @@ -0,0 +1,12 @@ +# validate_training_data + +`validate_training_data` allows validation of training data of types `.plain`, `.bin`, and `.binpack`. + +As all commands in stockfish `validate_training_data` can be invoked either from command line (as `stockfish.exe validate_training_data ...`) or in the interactive prompt. + +The syntax of this command is as follows: +``` +validate_training_data in_path +``` + +`in_path` is the path to the file to validate. The type of the data is deduced based on its extension (one of `.plain`, `.bin`, `.binpack`). \ No newline at end of file diff --git a/src/Makefile b/src/Makefile index 4661e494..d3cea8de 100644 --- a/src/Makefile +++ b/src/Makefile @@ -51,6 +51,7 @@ SRCS = benchmark.cpp bitbase.cpp bitboard.cpp endgame.cpp evaluate.cpp main.cpp search.cpp thread.cpp timeman.cpp tt.cpp uci.cpp ucioption.cpp tune.cpp syzygy/tbprobe.cpp \ nnue/evaluate_nnue.cpp \ nnue/features/half_ka_v2.cpp \ + tools/validate_training_data.cpp \ tools/sfen_packer.cpp \ tools/training_data_generator.cpp \ tools/training_data_generator_nonpv.cpp \ diff --git a/src/extra/nnue_data_binpack_format.h b/src/extra/nnue_data_binpack_format.h index dce53b83..a6366d81 100644 --- a/src/extra/nnue_data_binpack_format.h +++ b/src/extra/nnue_data_binpack_format.h @@ -7831,4 +7831,154 @@ namespace binpack std::cout << "Finished. Converted " << numProcessedPositions << " positions.\n"; } + + inline void validatePlain(std::string inputPath) + { + constexpr std::size_t reportSize = 1000000; + + std::cout << "Validating " << inputPath << '\n'; + + TrainingDataEntry e; + + std::string key; + std::string value; + std::string move; + + std::ifstream inputFile(inputPath); + const auto base = inputFile.tellg(); + std::size_t numProcessedPositions = 0; + std::size_t numProcessedPositionsBatch = 0; + + for(;;) + { + inputFile >> key; + if (!inputFile) + { + break; + } + + if (key == "e"sv) + { + e.move = chess::uci::uciToMove(e.pos, move); + if (!e.isValid()) + { + std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " << e.pos.fen() << '\n'; + return; + } + + ++numProcessedPositions; + ++numProcessedPositionsBatch; + + if (numProcessedPositionsBatch >= reportSize) + { + numProcessedPositionsBatch -= reportSize; + const auto cur = inputFile.tellg(); + std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n"; + } + + continue; + } + + inputFile >> std::ws; + std::getline(inputFile, value, '\n'); + + if (key == "fen"sv) e.pos = chess::Position::fromFen(value.c_str()); + if (key == "move"sv) move = value; + if (key == "score"sv) e.score = std::stoi(value); + if (key == "ply"sv) e.ply = std::stoi(value); + if (key == "result"sv) e.result = std::stoi(value); + } + + if (numProcessedPositionsBatch) + { + const auto cur = inputFile.tellg(); + std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n"; + } + + std::cout << "Finished. Validated " << numProcessedPositions << " positions.\n"; + } + + inline void validateBin(std::string inputPath) + { + constexpr std::size_t reportSize = 1000000; + + std::cout << "Validating " << inputPath << '\n'; + + std::ifstream inputFile(inputPath, std::ios_base::binary); + const auto base = inputFile.tellg(); + std::size_t numProcessedPositions = 0; + std::size_t numProcessedPositionsBatch = 0; + + nodchip::PackedSfenValue psv; + for(;;) + { + inputFile.read(reinterpret_cast(&psv), sizeof(psv)); + if (inputFile.gcount() != 40) + { + break; + } + + auto e = packedSfenValueToTrainingDataEntry(psv); + if (!e.isValid()) + { + std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " << e.pos.fen() << '\n'; + return; + } + + ++numProcessedPositions; + ++numProcessedPositionsBatch; + + if (numProcessedPositionsBatch >= reportSize) + { + numProcessedPositionsBatch -= reportSize; + const auto cur = inputFile.tellg(); + std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n"; + } + } + + if (numProcessedPositionsBatch) + { + const auto cur = inputFile.tellg(); + std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n"; + } + + std::cout << "Finished. Validated " << numProcessedPositions << " positions.\n"; + } + + inline void validateBinpack(std::string inputPath) + { + constexpr std::size_t reportSize = 1000000; + + std::cout << "Validating " << inputPath << '\n'; + + CompressedTrainingDataEntryReader reader(inputPath); + std::size_t numProcessedPositions = 0; + std::size_t numProcessedPositionsBatch = 0; + + while(reader.hasNext()) + { + auto e = reader.next(); + if (!e.isValid()) + { + std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " << e.pos.fen() << '\n'; + return; + } + + ++numProcessedPositions; + ++numProcessedPositionsBatch; + + if (numProcessedPositionsBatch >= reportSize) + { + numProcessedPositionsBatch -= reportSize; + std::cout << "Processed " << numProcessedPositions << " positions.\n"; + } + } + + if (numProcessedPositionsBatch) + { + std::cout << "Processed " << numProcessedPositions << " positions.\n"; + } + + std::cout << "Finished. Validated " << numProcessedPositions << " positions.\n"; + } } diff --git a/src/tools/validate_training_data.cpp b/src/tools/validate_training_data.cpp new file mode 100644 index 00000000..18cae456 --- /dev/null +++ b/src/tools/validate_training_data.cpp @@ -0,0 +1,122 @@ +#include "validate_training_data.h" + +#include "uci.h" +#include "misc.h" +#include "thread.h" +#include "position.h" +#include "tt.h" + +#include "extra/nnue_data_binpack_format.h" + +#include "nnue/evaluate_nnue.h" + +#include "syzygy/tbprobe.h" + +#include +#include +#include +#include +#include +#include // std::exp(),std::pow(),std::log() +#include // memcpy() +#include +#include +#include +#include +#include +#include +#include + +using namespace std; +namespace sys = std::filesystem; + +namespace Stockfish::Tools +{ + static inline const std::string plain_extension = ".plain"; + static inline const std::string bin_extension = ".bin"; + static inline const std::string binpack_extension = ".binpack"; + + static bool file_exists(const std::string& name) + { + std::ifstream f(name); + return f.good(); + } + + static bool ends_with(const std::string& lhs, const std::string& end) + { + if (end.size() > lhs.size()) return false; + + return std::equal(end.rbegin(), end.rend(), lhs.rbegin()); + } + + static bool is_validation_of_type( + const std::string& input_path, + const std::string& expected_input_extension) + { + return ends_with(input_path, expected_input_extension); + } + + using ValidateFunctionType = void(std::string inputPath); + + static ValidateFunctionType* get_validate_function(const std::string& input_path) + { + if (is_validation_of_type(input_path, plain_extension)) + return binpack::validatePlain; + + if (is_validation_of_type(input_path, bin_extension)) + return binpack::validateBin; + + if (is_validation_of_type(input_path, binpack_extension)) + return binpack::validateBinpack; + + return nullptr; + } + + static void validate_training_data(const std::string& input_path) + { + if(!file_exists(input_path)) + { + std::cerr << "Input file does not exist.\n"; + return; + } + + auto func = get_validate_function(input_path); + if (func != nullptr) + { + func(input_path); + } + else + { + std::cerr << "Validation of files of this type is not supported.\n"; + } + } + + static void validate_training_data(const std::vector& args) + { + if (args.size() != 1) + { + std::cerr << "Invalid arguments.\n"; + std::cerr << "Usage: validate in_path\n"; + return; + } + + validate_training_data(args[0]); + } + + void validate_training_data(istringstream& is) + { + std::vector args; + + while (true) + { + std::string token = ""; + is >> token; + if (token == "") + break; + + args.push_back(token); + } + + validate_training_data(args); + } +} diff --git a/src/tools/validate_training_data.h b/src/tools/validate_training_data.h new file mode 100644 index 00000000..0c62ab50 --- /dev/null +++ b/src/tools/validate_training_data.h @@ -0,0 +1,12 @@ +#ifndef _VALIDATE_TRAINING_DATA_H_ +#define _VALIDATE_TRAINING_DATA_H_ + +#include +#include +#include + +namespace Stockfish::Tools { + void validate_training_data(std::istringstream& is); +} + +#endif diff --git a/src/uci.cpp b/src/uci.cpp index 2fa7a186..5e0bb11b 100644 --- a/src/uci.cpp +++ b/src/uci.cpp @@ -33,6 +33,7 @@ #include "tt.h" #include "uci.h" +#include "tools/validate_training_data.h" #include "tools/training_data_generator.h" #include "tools/training_data_generator_nonpv.h" #include "tools/convert.h" @@ -330,6 +331,7 @@ void UCI::loop(int argc, char* argv[]) { else if (token == "generate_training_data") Tools::generate_training_data(is); else if (token == "generate_training_data") Tools::generate_training_data_nonpv(is); else if (token == "convert") Tools::convert(is); + else if (token == "validate_training_data") Tools::validate_training_data(is); else if (token == "convert_bin") Tools::convert_bin(is); else if (token == "convert_plain") Tools::convert_plain(is); else if (token == "convert_bin_from_pgn_extract") Tools::convert_bin_from_pgn_extract(is);