From 9a4c7cf4e311f8d9526b79295b80c4d0464c07cf Mon Sep 17 00:00:00 2001 From: Tomasz Sobczyk Date: Tue, 25 Apr 2023 19:21:29 +0200 Subject: [PATCH] Add transform `minimize_binpack` for minimizing existing .binpack datasets. (#4447) Takes advantage of the sample skipping rules that are used during training (capture, check, or VALUE_NONE). Adds positions to keep continuity, which improves compression. --- src/extra/nnue_data_binpack_format.h | 30 +- src/tools/transform.cpp | 398 ++++++++++++++++++++++++++- 2 files changed, 425 insertions(+), 3 deletions(-) diff --git a/src/extra/nnue_data_binpack_format.h b/src/extra/nnue_data_binpack_format.h index 0b1ac0aa..18602a9b 100644 --- a/src/extra/nnue_data_binpack_format.h +++ b/src/extra/nnue_data_binpack_format.h @@ -5048,11 +5048,21 @@ namespace chess // Generates all pseudo legal moves for the position. // `pos` must be a legal chess position - [[nodiscard]] std::vector generatePseudoLegalMoves(const Position& pos); + [[nodiscard]] inline std::vector generatePseudoLegalMoves(const Position& pos) + { + std::vector moves; + forEachPseudoLegalMove(pos, [&moves](Move move) { moves.emplace_back(move); }); + return moves; + } // Generates all legal moves for the position. // `pos` must be a legal chess position - [[nodiscard]] std::vector generateLegalMoves(const Position& pos); + [[nodiscard]] inline std::vector generateLegalMoves(const Position& pos) + { + std::vector moves; + forEachLegalMove(pos, [&moves](Move move) { moves.emplace_back(move); }); + return moves; + } } [[nodiscard]] inline bool Position::isCheck() const @@ -6835,6 +6845,17 @@ namespace binpack { return pos.isMoveLegal(move); } + + [[nodiscard]] bool isCapturingMove() const + { + return pos.pieceAt(move.to) != chess::Piece::none() && + pos.pieceAt(move.to).color() != pos.pieceAt(move.from).color(); // Exclude castling + } + + [[nodiscard]] bool isInCheck() const + { + return pos.isCheck(); + } }; [[nodiscard]] inline TrainingDataEntry packedSfenValueToTrainingDataEntry(const nodchip::PackedSfenValue& psv) @@ -7110,6 +7131,11 @@ namespace binpack std::uint16_t numPlies = 0; std::vector movetext; + [[nodiscard]] std::size_t numBytes() const + { + return movetext.size(); + } + void clear(const TrainingDataEntry& e) { numPlies = 0; diff --git a/src/tools/transform.cpp b/src/tools/transform.cpp index 5abf24b8..c67fd701 100644 --- a/src/tools/transform.cpp +++ b/src/tools/transform.cpp @@ -11,6 +11,8 @@ #include "nnue/evaluate_nnue.h" +#include "extra/nnue_data_binpack_format.h" + #include #include #include @@ -725,12 +727,406 @@ namespace Stockfish::Tools do_filter_335a9b2d8a80(params); } + struct MinimizeBinpackParams + { + std::string input_filename = "in.binpack"; + std::string output_filename = "out.binpack"; + bool debug_print = false; + uint64_t chain_search_nodes = 1024 * 64; + + void enforce_constraints() + { + } + }; + + [[nodiscard]] binpack::nodchip::PackedSfenValue packed_sfen_tools_to_lib(Stockfish::Tools::PackedSfenValue ps) + { + binpack::nodchip::PackedSfenValue ret; + static_assert(sizeof(ret) == sizeof(ps)); + std::memcpy(&ret, &ps, sizeof(ret)); + return ret; + } + + [[nodiscard]] Stockfish::Tools::PackedSfenValue packed_sfen_lib_to_tools(binpack::nodchip::PackedSfenValue ps) + { + Stockfish::Tools::PackedSfenValue ret; + static_assert(sizeof(ret) == sizeof(ps)); + std::memcpy(&ret, &ps, sizeof(ret)); + return ret; + } + + [[nodiscard]] bool find_move_chain_between_positions_impl( + const binpack::TrainingDataEntry& curr_entry, + const binpack::TrainingDataEntry& last_entry, + uint64_t max_nodes, + uint64_t& curr_nodes, + std::vector& reverse_chain_moves + ) + { + const chess::EnumArray piece_count_diff = { + last_entry.pos.piecesBB(chess::Color::White).count() - curr_entry.pos.piecesBB(chess::Color::White).count(), + last_entry.pos.piecesBB(chess::Color::Black).count() - curr_entry.pos.piecesBB(chess::Color::Black).count() + }; + + const int ply_diff = last_entry.ply - curr_entry.ply; + + // Last position is older than current. + if (ply_diff <= 0) + return false; + + // Not enough plies for that many captures. + if (piece_count_diff[chess::Color::White] + piece_count_diff[chess::Color::Black] > ply_diff) + return false; + + // Not enough plies for that many captures. For each side separately. + if ( piece_count_diff[chess::Color::White] > (ply_diff + 1) / 2 + || piece_count_diff[chess::Color::Black] > (ply_diff + 1) / 2) + return false; + + std::vector> legal_moves; + chess::movegen::forEachLegalMove(curr_entry.pos, [&](const chess::Move move) { + int score = 0; + + // Moving a piece that's already on a correct square. + if (curr_entry.pos.pieceAt(move.from) == last_entry.pos.pieceAt(move.from)) + score -= 10'000; + + // Moving a piece to a correct square. + if (curr_entry.pos.pieceAt(move.from) == last_entry.pos.pieceAt(move.to)) + score += 10'000; + + // Not a capture move but needs to be a capture to fullfill piece difference. + if ( ( piece_count_diff[chess::Color::White] + piece_count_diff[chess::Color::Black] == ply_diff + || (piece_count_diff[curr_entry.pos.sideToMove()] == (ply_diff + 1) / 2)) + && curr_entry.pos.pieceAt(move.to) == chess::Piece::none()) + score -= 10'000'000; + + legal_moves.emplace_back(move, score); + }); + + // A heuristic for searching the legal moves such that we hope to find the solution earlier. + std::sort(legal_moves.begin(), legal_moves.end(), [](const auto& lhs, const auto& rhs) { + return lhs.second > rhs.second; + }); + + for (const auto [move, score] : legal_moves) + { + auto next_entry = curr_entry; + next_entry.result = -next_entry.result; + next_entry.ply += 1; + next_entry.pos.doMove(move); + + // We reached the destination position. + if ( next_entry.ply == last_entry.ply + && next_entry.result == last_entry.result + && next_entry.pos == last_entry.pos) + { + reverse_chain_moves.emplace_back(move); + return true; + } + + // Reached the search limit, aborting. + if (++curr_nodes > max_nodes) + return false; + + // We reached the destination position somewhere later in the search. + if ( next_entry.ply < last_entry.ply + && find_move_chain_between_positions_impl(next_entry, last_entry, max_nodes, curr_nodes, reverse_chain_moves)) + { + reverse_chain_moves.emplace_back(move); + return true; + } + } + + return false; + } + + [[nodiscard]] std::vector find_move_chain_between_positions( + const binpack::TrainingDataEntry& curr_entry, + const binpack::TrainingDataEntry& next_entry, + uint64_t max_nodes + ) + { + constexpr int MAX_PLY_DISTANCE = 6; + if ( binpack::isContinuation(curr_entry, next_entry) + || curr_entry.ply >= next_entry.ply + || curr_entry.ply + MAX_PLY_DISTANCE < next_entry.ply) + return {}; + + std::vector reverse_chain_moves; + uint64_t curr_nodes = 0; + if (find_move_chain_between_positions_impl(curr_entry, next_entry, max_nodes, curr_nodes, reverse_chain_moves)) + { + std::reverse(reverse_chain_moves.begin(), reverse_chain_moves.end()); + return reverse_chain_moves; + } + else + { + return {}; + } + } + + [[nodiscard]] bool discarded_during_training_based_on_move(const binpack::TrainingDataEntry& e) + { + return e.isCapturingMove() || e.isInCheck(); + } + + void do_minimize_binpack(MinimizeBinpackParams& params) + { + static constexpr int VALUE_NONE = 32002; + + if (!ends_with(params.input_filename, ".binpack")) + { + std::cerr << "Invalid input file type. Must be .binpack.\n"; + return; + } + + std::atomic num_positions_read = 0; + std::atomic num_positions_intermediate = 0; + std::atomic num_positions_filtered = 0; + + auto in = Tools::open_sfen_input_file(params.input_filename); + auto readsome = [&in, mutex = std::mutex{}](int n) mutable -> std::vector { + + std::vector psv; + psv.reserve(n); + + std::unique_lock lock(mutex); + + for (int i = 0; i < n; ++i) + { + auto ps_opt = in->next(); + if (ps_opt.has_value()) + { + psv.emplace_back(binpack::packedSfenValueToTrainingDataEntry(packed_sfen_tools_to_lib(*ps_opt))); + } + else + { + break; + } + } + + return psv; + }; + + auto out = SfenWriter( + params.output_filename, + Threads.size(), + std::numeric_limits::max(), + SfenOutputType::Binpack); + + Threads.execute_with_workers([&](auto& th){ + std::vector intermediate_entries; + + auto write_one_intermediate = [&](const binpack::TrainingDataEntry& e) { + intermediate_entries.emplace_back(e); + + // If a position already present in the data would have been discarded anyway + // due to the move then we can set the score to something that takes less space. + if (discarded_during_training_based_on_move(e)) + intermediate_entries.back().score = 0; + + const auto pi = num_positions_intermediate.fetch_add(1) + 1; + if (pi % 10000 == 0) + { + const auto pr = num_positions_read.load(); + const auto pf = num_positions_filtered.load(); + std::cout << "Read: ~" << pr << ". Intermediate: " << pi << ". Write: ~" << pf << "\n"; + } + }; + + auto flush_intermediate_entries = [&]() { + + std::vector filtered_entries; + filtered_entries.reserve(intermediate_entries.size()); + + // Remove positions that are always skipped and at the beginning of the chain. + // Remove positions that are always skipped and at the end of the chain. + // Remove chains of always skipped positions that are larger than starting a chain again. + + auto is_skipped = [&](const binpack::TrainingDataEntry& e) { + return e.score == VALUE_NONE || discarded_during_training_based_on_move(e); + }; + + for (size_t i = 0; i < intermediate_entries.size();) + { + const auto& curr_entry = intermediate_entries[i]; + + if (is_skipped(curr_entry)) + { + bool is_continuation = false; + if (!filtered_entries.empty()) + { + const auto& prev_entry = filtered_entries.back(); + is_continuation = binpack::isContinuation(prev_entry, curr_entry); + } + + if (!is_continuation) + { + ++i; + continue; + } + + bool is_tail = true; + size_t skip_run_end; + // Check if it's start of a tail. + for (skip_run_end = i + 1; skip_run_end < intermediate_entries.size(); ++skip_run_end) + { + // Go until we end the chain. + if (!binpack::isContinuation(intermediate_entries[skip_run_end - 1], intermediate_entries[skip_run_end])) + break; + + // If we found a non-skippable position then this is not the tail and we cannot skip it entirely. + if (!is_skipped(intermediate_entries[skip_run_end])) + { + is_tail = false; + break; + } + } + + // If tail then we don't save it at all. + if (is_tail) + { + // Remove (don't save) tail. Move to the next position to consider. + i = skip_run_end; + continue; + } + + // Otherwise check if the skip run can be removed with a space saving. + if (skip_run_end - i < 6) // don't unnecessarily check total cost if upper bound is ~below 32 bytes + { + binpack::PackedMoveScoreList encoding; + for (size_t j = i; j < skip_run_end; ++j) + { + const auto& e = intermediate_entries[j]; + encoding.addMoveScore(e.pos, e.move, e.score); + } + + // Full new entry would have 32 bytes + if (encoding.numBytes() >= 32) + { + i = skip_run_end; + continue; + } + } + } + + filtered_entries.emplace_back(curr_entry); + ++i; + } + + num_positions_filtered.fetch_add(filtered_entries.size()); + for (auto& e : filtered_entries) + { + const auto ps = packed_sfen_lib_to_tools(binpack::trainingDataEntryToPackedSfenValue(e)); + out.write(th.id(), ps); + } + + intermediate_entries.clear(); + }; + + for (;;) + { + auto psv = readsome(10000); + + if (psv.empty()) + break; + + for(size_t i = 0; i + 1 < psv.size(); ++i) + { + num_positions_read.fetch_add(1); + + binpack::TrainingDataEntry curr_entry = psv[i]; + const binpack::TrainingDataEntry& next_entry = psv[i + 1]; + + auto move_chain = find_move_chain_between_positions(curr_entry, next_entry, params.chain_search_nodes); + if (move_chain.empty()) + { + write_one_intermediate(curr_entry); + } + else + { + binpack::TrainingDataEntry e = curr_entry; + + int j = 0; + for (const auto& move : move_chain) + { + e.move = move; + // If the positions would have been discarded with the old move but we changed the move + // then we need to mark the positions with VALUE_NONE score so that it's skipped. + if (j == 0 && discarded_during_training_based_on_move(curr_entry)) + e.score = VALUE_NONE; + write_one_intermediate(e); + + e.ply += 1; + e.result = -e.result; + e.score = VALUE_NONE; // change subsequent scores to VALUE_NONE so these new positions are ignored + e.pos.doMove(move); + + j += 1; + } + } + } + + num_positions_read.fetch_add(1); + write_one_intermediate(psv.back()); + + flush_intermediate_entries(); + } + }); + Threads.wait_for_workers_finished(); + + const auto pi = num_positions_intermediate.load(); + const auto pr = num_positions_read.load(); + const auto pf = num_positions_filtered.load(); + std::cout << "Read: " << pr << ". Intermediate: " << pi << ". Write: " << pf << "\n"; + std::cout << "Finished.\n"; + } + + void minimize_binpack(std::istringstream& is) + { + MinimizeBinpackParams params{}; + + while(true) + { + std::string token; + is >> token; + + if (token == "") + break; + + else if (token == "input_file") + is >> params.input_filename; + else if (token == "output_file") + is >> params.output_filename; + else if (token == "debug_print") + is >> params.debug_print; + else if (token == "chain_search_nodes") + is >> params.chain_search_nodes; + else + { + std::cout << "ERROR: Unknown option " << token << ". Exiting...\n"; + return; + } + } + + std::cout << "Performing transform minimize_binpack with parameters:\n"; + std::cout << "input_file : " << params.input_filename << '\n'; + std::cout << "output_file : " << params.output_filename << '\n'; + std::cout << "debug_print : " << params.debug_print << '\n'; + std::cout << "chain_search_nodes : " << params.chain_search_nodes << '\n'; + std::cout << '\n'; + + do_minimize_binpack(params); + } + void transform(std::istringstream& is) { const std::map subcommands = { { "nudged_static", &nudged_static }, { "rescore", &rescore }, - { "filter_335a9b2d8a80", &filter_335a9b2d8a80 } + { "filter_335a9b2d8a80", &filter_335a9b2d8a80 }, + { "minimize_binpack", &minimize_binpack } }; Eval::NNUE::init();