Add transform minimize_binpack for minimizing existing .binpack datasets. (#4447)

Takes advantage of the sample skipping rules that are used during training (capture, check, or VALUE_NONE).
Adds positions to keep continuity, which improves compression.
This commit is contained in:
Tomasz Sobczyk
2023-04-25 19:21:29 +02:00
committed by GitHub
parent 8e16592430
commit 9a4c7cf4e3
2 changed files with 425 additions and 3 deletions

View File

@@ -5048,11 +5048,21 @@ namespace chess
// Generates all pseudo legal moves for the position.
// `pos` must be a legal chess position
[[nodiscard]] std::vector<Move> generatePseudoLegalMoves(const Position& pos);
[[nodiscard]] inline std::vector<Move> generatePseudoLegalMoves(const Position& pos)
{
std::vector<Move> moves;
forEachPseudoLegalMove(pos, [&moves](Move move) { moves.emplace_back(move); });
return moves;
}
// Generates all legal moves for the position.
// `pos` must be a legal chess position
[[nodiscard]] std::vector<Move> generateLegalMoves(const Position& pos);
[[nodiscard]] inline std::vector<Move> generateLegalMoves(const Position& pos)
{
std::vector<Move> moves;
forEachLegalMove(pos, [&moves](Move move) { moves.emplace_back(move); });
return moves;
}
}
[[nodiscard]] inline bool Position::isCheck() const
@@ -6835,6 +6845,17 @@ namespace binpack
{
return pos.isMoveLegal(move);
}
[[nodiscard]] bool isCapturingMove() const
{
return pos.pieceAt(move.to) != chess::Piece::none() &&
pos.pieceAt(move.to).color() != pos.pieceAt(move.from).color(); // Exclude castling
}
[[nodiscard]] bool isInCheck() const
{
return pos.isCheck();
}
};
[[nodiscard]] inline TrainingDataEntry packedSfenValueToTrainingDataEntry(const nodchip::PackedSfenValue& psv)
@@ -7110,6 +7131,11 @@ namespace binpack
std::uint16_t numPlies = 0;
std::vector<unsigned char> movetext;
[[nodiscard]] std::size_t numBytes() const
{
return movetext.size();
}
void clear(const TrainingDataEntry& e)
{
numPlies = 0;

View File

@@ -11,6 +11,8 @@
#include "nnue/evaluate_nnue.h"
#include "extra/nnue_data_binpack_format.h"
#include <string>
#include <map>
#include <iostream>
@@ -725,12 +727,406 @@ namespace Stockfish::Tools
do_filter_335a9b2d8a80(params);
}
struct MinimizeBinpackParams
{
std::string input_filename = "in.binpack";
std::string output_filename = "out.binpack";
bool debug_print = false;
uint64_t chain_search_nodes = 1024 * 64;
void enforce_constraints()
{
}
};
[[nodiscard]] binpack::nodchip::PackedSfenValue packed_sfen_tools_to_lib(Stockfish::Tools::PackedSfenValue ps)
{
binpack::nodchip::PackedSfenValue ret;
static_assert(sizeof(ret) == sizeof(ps));
std::memcpy(&ret, &ps, sizeof(ret));
return ret;
}
[[nodiscard]] Stockfish::Tools::PackedSfenValue packed_sfen_lib_to_tools(binpack::nodchip::PackedSfenValue ps)
{
Stockfish::Tools::PackedSfenValue ret;
static_assert(sizeof(ret) == sizeof(ps));
std::memcpy(&ret, &ps, sizeof(ret));
return ret;
}
[[nodiscard]] bool find_move_chain_between_positions_impl(
const binpack::TrainingDataEntry& curr_entry,
const binpack::TrainingDataEntry& last_entry,
uint64_t max_nodes,
uint64_t& curr_nodes,
std::vector<chess::Move>& reverse_chain_moves
)
{
const chess::EnumArray<chess::Color, int> piece_count_diff = {
last_entry.pos.piecesBB(chess::Color::White).count() - curr_entry.pos.piecesBB(chess::Color::White).count(),
last_entry.pos.piecesBB(chess::Color::Black).count() - curr_entry.pos.piecesBB(chess::Color::Black).count()
};
const int ply_diff = last_entry.ply - curr_entry.ply;
// Last position is older than current.
if (ply_diff <= 0)
return false;
// Not enough plies for that many captures.
if (piece_count_diff[chess::Color::White] + piece_count_diff[chess::Color::Black] > ply_diff)
return false;
// Not enough plies for that many captures. For each side separately.
if ( piece_count_diff[chess::Color::White] > (ply_diff + 1) / 2
|| piece_count_diff[chess::Color::Black] > (ply_diff + 1) / 2)
return false;
std::vector<std::pair<chess::Move, int>> legal_moves;
chess::movegen::forEachLegalMove(curr_entry.pos, [&](const chess::Move move) {
int score = 0;
// Moving a piece that's already on a correct square.
if (curr_entry.pos.pieceAt(move.from) == last_entry.pos.pieceAt(move.from))
score -= 10'000;
// Moving a piece to a correct square.
if (curr_entry.pos.pieceAt(move.from) == last_entry.pos.pieceAt(move.to))
score += 10'000;
// Not a capture move but needs to be a capture to fullfill piece difference.
if ( ( piece_count_diff[chess::Color::White] + piece_count_diff[chess::Color::Black] == ply_diff
|| (piece_count_diff[curr_entry.pos.sideToMove()] == (ply_diff + 1) / 2))
&& curr_entry.pos.pieceAt(move.to) == chess::Piece::none())
score -= 10'000'000;
legal_moves.emplace_back(move, score);
});
// A heuristic for searching the legal moves such that we hope to find the solution earlier.
std::sort(legal_moves.begin(), legal_moves.end(), [](const auto& lhs, const auto& rhs) {
return lhs.second > rhs.second;
});
for (const auto [move, score] : legal_moves)
{
auto next_entry = curr_entry;
next_entry.result = -next_entry.result;
next_entry.ply += 1;
next_entry.pos.doMove(move);
// We reached the destination position.
if ( next_entry.ply == last_entry.ply
&& next_entry.result == last_entry.result
&& next_entry.pos == last_entry.pos)
{
reverse_chain_moves.emplace_back(move);
return true;
}
// Reached the search limit, aborting.
if (++curr_nodes > max_nodes)
return false;
// We reached the destination position somewhere later in the search.
if ( next_entry.ply < last_entry.ply
&& find_move_chain_between_positions_impl(next_entry, last_entry, max_nodes, curr_nodes, reverse_chain_moves))
{
reverse_chain_moves.emplace_back(move);
return true;
}
}
return false;
}
[[nodiscard]] std::vector<chess::Move> find_move_chain_between_positions(
const binpack::TrainingDataEntry& curr_entry,
const binpack::TrainingDataEntry& next_entry,
uint64_t max_nodes
)
{
constexpr int MAX_PLY_DISTANCE = 6;
if ( binpack::isContinuation(curr_entry, next_entry)
|| curr_entry.ply >= next_entry.ply
|| curr_entry.ply + MAX_PLY_DISTANCE < next_entry.ply)
return {};
std::vector<chess::Move> reverse_chain_moves;
uint64_t curr_nodes = 0;
if (find_move_chain_between_positions_impl(curr_entry, next_entry, max_nodes, curr_nodes, reverse_chain_moves))
{
std::reverse(reverse_chain_moves.begin(), reverse_chain_moves.end());
return reverse_chain_moves;
}
else
{
return {};
}
}
[[nodiscard]] bool discarded_during_training_based_on_move(const binpack::TrainingDataEntry& e)
{
return e.isCapturingMove() || e.isInCheck();
}
void do_minimize_binpack(MinimizeBinpackParams& params)
{
static constexpr int VALUE_NONE = 32002;
if (!ends_with(params.input_filename, ".binpack"))
{
std::cerr << "Invalid input file type. Must be .binpack.\n";
return;
}
std::atomic<std::uint64_t> num_positions_read = 0;
std::atomic<std::uint64_t> num_positions_intermediate = 0;
std::atomic<std::uint64_t> num_positions_filtered = 0;
auto in = Tools::open_sfen_input_file(params.input_filename);
auto readsome = [&in, mutex = std::mutex{}](int n) mutable -> std::vector<binpack::TrainingDataEntry> {
std::vector<binpack::TrainingDataEntry> psv;
psv.reserve(n);
std::unique_lock lock(mutex);
for (int i = 0; i < n; ++i)
{
auto ps_opt = in->next();
if (ps_opt.has_value())
{
psv.emplace_back(binpack::packedSfenValueToTrainingDataEntry(packed_sfen_tools_to_lib(*ps_opt)));
}
else
{
break;
}
}
return psv;
};
auto out = SfenWriter(
params.output_filename,
Threads.size(),
std::numeric_limits<std::uint64_t>::max(),
SfenOutputType::Binpack);
Threads.execute_with_workers([&](auto& th){
std::vector<binpack::TrainingDataEntry> intermediate_entries;
auto write_one_intermediate = [&](const binpack::TrainingDataEntry& e) {
intermediate_entries.emplace_back(e);
// If a position already present in the data would have been discarded anyway
// due to the move then we can set the score to something that takes less space.
if (discarded_during_training_based_on_move(e))
intermediate_entries.back().score = 0;
const auto pi = num_positions_intermediate.fetch_add(1) + 1;
if (pi % 10000 == 0)
{
const auto pr = num_positions_read.load();
const auto pf = num_positions_filtered.load();
std::cout << "Read: ~" << pr << ". Intermediate: " << pi << ". Write: ~" << pf << "\n";
}
};
auto flush_intermediate_entries = [&]() {
std::vector<binpack::TrainingDataEntry> filtered_entries;
filtered_entries.reserve(intermediate_entries.size());
// Remove positions that are always skipped and at the beginning of the chain.
// Remove positions that are always skipped and at the end of the chain.
// Remove chains of always skipped positions that are larger than starting a chain again.
auto is_skipped = [&](const binpack::TrainingDataEntry& e) {
return e.score == VALUE_NONE || discarded_during_training_based_on_move(e);
};
for (size_t i = 0; i < intermediate_entries.size();)
{
const auto& curr_entry = intermediate_entries[i];
if (is_skipped(curr_entry))
{
bool is_continuation = false;
if (!filtered_entries.empty())
{
const auto& prev_entry = filtered_entries.back();
is_continuation = binpack::isContinuation(prev_entry, curr_entry);
}
if (!is_continuation)
{
++i;
continue;
}
bool is_tail = true;
size_t skip_run_end;
// Check if it's start of a tail.
for (skip_run_end = i + 1; skip_run_end < intermediate_entries.size(); ++skip_run_end)
{
// Go until we end the chain.
if (!binpack::isContinuation(intermediate_entries[skip_run_end - 1], intermediate_entries[skip_run_end]))
break;
// If we found a non-skippable position then this is not the tail and we cannot skip it entirely.
if (!is_skipped(intermediate_entries[skip_run_end]))
{
is_tail = false;
break;
}
}
// If tail then we don't save it at all.
if (is_tail)
{
// Remove (don't save) tail. Move to the next position to consider.
i = skip_run_end;
continue;
}
// Otherwise check if the skip run can be removed with a space saving.
if (skip_run_end - i < 6) // don't unnecessarily check total cost if upper bound is ~below 32 bytes
{
binpack::PackedMoveScoreList encoding;
for (size_t j = i; j < skip_run_end; ++j)
{
const auto& e = intermediate_entries[j];
encoding.addMoveScore(e.pos, e.move, e.score);
}
// Full new entry would have 32 bytes
if (encoding.numBytes() >= 32)
{
i = skip_run_end;
continue;
}
}
}
filtered_entries.emplace_back(curr_entry);
++i;
}
num_positions_filtered.fetch_add(filtered_entries.size());
for (auto& e : filtered_entries)
{
const auto ps = packed_sfen_lib_to_tools(binpack::trainingDataEntryToPackedSfenValue(e));
out.write(th.id(), ps);
}
intermediate_entries.clear();
};
for (;;)
{
auto psv = readsome(10000);
if (psv.empty())
break;
for(size_t i = 0; i + 1 < psv.size(); ++i)
{
num_positions_read.fetch_add(1);
binpack::TrainingDataEntry curr_entry = psv[i];
const binpack::TrainingDataEntry& next_entry = psv[i + 1];
auto move_chain = find_move_chain_between_positions(curr_entry, next_entry, params.chain_search_nodes);
if (move_chain.empty())
{
write_one_intermediate(curr_entry);
}
else
{
binpack::TrainingDataEntry e = curr_entry;
int j = 0;
for (const auto& move : move_chain)
{
e.move = move;
// If the positions would have been discarded with the old move but we changed the move
// then we need to mark the positions with VALUE_NONE score so that it's skipped.
if (j == 0 && discarded_during_training_based_on_move(curr_entry))
e.score = VALUE_NONE;
write_one_intermediate(e);
e.ply += 1;
e.result = -e.result;
e.score = VALUE_NONE; // change subsequent scores to VALUE_NONE so these new positions are ignored
e.pos.doMove(move);
j += 1;
}
}
}
num_positions_read.fetch_add(1);
write_one_intermediate(psv.back());
flush_intermediate_entries();
}
});
Threads.wait_for_workers_finished();
const auto pi = num_positions_intermediate.load();
const auto pr = num_positions_read.load();
const auto pf = num_positions_filtered.load();
std::cout << "Read: " << pr << ". Intermediate: " << pi << ". Write: " << pf << "\n";
std::cout << "Finished.\n";
}
void minimize_binpack(std::istringstream& is)
{
MinimizeBinpackParams params{};
while(true)
{
std::string token;
is >> token;
if (token == "")
break;
else if (token == "input_file")
is >> params.input_filename;
else if (token == "output_file")
is >> params.output_filename;
else if (token == "debug_print")
is >> params.debug_print;
else if (token == "chain_search_nodes")
is >> params.chain_search_nodes;
else
{
std::cout << "ERROR: Unknown option " << token << ". Exiting...\n";
return;
}
}
std::cout << "Performing transform minimize_binpack with parameters:\n";
std::cout << "input_file : " << params.input_filename << '\n';
std::cout << "output_file : " << params.output_filename << '\n';
std::cout << "debug_print : " << params.debug_print << '\n';
std::cout << "chain_search_nodes : " << params.chain_search_nodes << '\n';
std::cout << '\n';
do_minimize_binpack(params);
}
void transform(std::istringstream& is)
{
const std::map<std::string, CommandFunc> subcommands = {
{ "nudged_static", &nudged_static },
{ "rescore", &rescore },
{ "filter_335a9b2d8a80", &filter_335a9b2d8a80 }
{ "filter_335a9b2d8a80", &filter_335a9b2d8a80 },
{ "minimize_binpack", &minimize_binpack }
};
Eval::NNUE::init();