mirror of
https://github.com/HChaZZY/Stockfish.git
synced 2025-12-06 10:53:50 +08:00
Add transform minimize_binpack for minimizing existing .binpack datasets. (#4447)
Takes advantage of the sample skipping rules that are used during training (capture, check, or VALUE_NONE). Adds positions to keep continuity, which improves compression.
This commit is contained in:
@@ -5048,11 +5048,21 @@ namespace chess
|
||||
|
||||
// Generates all pseudo legal moves for the position.
|
||||
// `pos` must be a legal chess position
|
||||
[[nodiscard]] std::vector<Move> generatePseudoLegalMoves(const Position& pos);
|
||||
[[nodiscard]] inline std::vector<Move> generatePseudoLegalMoves(const Position& pos)
|
||||
{
|
||||
std::vector<Move> moves;
|
||||
forEachPseudoLegalMove(pos, [&moves](Move move) { moves.emplace_back(move); });
|
||||
return moves;
|
||||
}
|
||||
|
||||
// Generates all legal moves for the position.
|
||||
// `pos` must be a legal chess position
|
||||
[[nodiscard]] std::vector<Move> generateLegalMoves(const Position& pos);
|
||||
[[nodiscard]] inline std::vector<Move> generateLegalMoves(const Position& pos)
|
||||
{
|
||||
std::vector<Move> moves;
|
||||
forEachLegalMove(pos, [&moves](Move move) { moves.emplace_back(move); });
|
||||
return moves;
|
||||
}
|
||||
}
|
||||
|
||||
[[nodiscard]] inline bool Position::isCheck() const
|
||||
@@ -6835,6 +6845,17 @@ namespace binpack
|
||||
{
|
||||
return pos.isMoveLegal(move);
|
||||
}
|
||||
|
||||
[[nodiscard]] bool isCapturingMove() const
|
||||
{
|
||||
return pos.pieceAt(move.to) != chess::Piece::none() &&
|
||||
pos.pieceAt(move.to).color() != pos.pieceAt(move.from).color(); // Exclude castling
|
||||
}
|
||||
|
||||
[[nodiscard]] bool isInCheck() const
|
||||
{
|
||||
return pos.isCheck();
|
||||
}
|
||||
};
|
||||
|
||||
[[nodiscard]] inline TrainingDataEntry packedSfenValueToTrainingDataEntry(const nodchip::PackedSfenValue& psv)
|
||||
@@ -7110,6 +7131,11 @@ namespace binpack
|
||||
std::uint16_t numPlies = 0;
|
||||
std::vector<unsigned char> movetext;
|
||||
|
||||
[[nodiscard]] std::size_t numBytes() const
|
||||
{
|
||||
return movetext.size();
|
||||
}
|
||||
|
||||
void clear(const TrainingDataEntry& e)
|
||||
{
|
||||
numPlies = 0;
|
||||
|
||||
@@ -11,6 +11,8 @@
|
||||
|
||||
#include "nnue/evaluate_nnue.h"
|
||||
|
||||
#include "extra/nnue_data_binpack_format.h"
|
||||
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <iostream>
|
||||
@@ -725,12 +727,406 @@ namespace Stockfish::Tools
|
||||
do_filter_335a9b2d8a80(params);
|
||||
}
|
||||
|
||||
struct MinimizeBinpackParams
|
||||
{
|
||||
std::string input_filename = "in.binpack";
|
||||
std::string output_filename = "out.binpack";
|
||||
bool debug_print = false;
|
||||
uint64_t chain_search_nodes = 1024 * 64;
|
||||
|
||||
void enforce_constraints()
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
[[nodiscard]] binpack::nodchip::PackedSfenValue packed_sfen_tools_to_lib(Stockfish::Tools::PackedSfenValue ps)
|
||||
{
|
||||
binpack::nodchip::PackedSfenValue ret;
|
||||
static_assert(sizeof(ret) == sizeof(ps));
|
||||
std::memcpy(&ret, &ps, sizeof(ret));
|
||||
return ret;
|
||||
}
|
||||
|
||||
[[nodiscard]] Stockfish::Tools::PackedSfenValue packed_sfen_lib_to_tools(binpack::nodchip::PackedSfenValue ps)
|
||||
{
|
||||
Stockfish::Tools::PackedSfenValue ret;
|
||||
static_assert(sizeof(ret) == sizeof(ps));
|
||||
std::memcpy(&ret, &ps, sizeof(ret));
|
||||
return ret;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool find_move_chain_between_positions_impl(
|
||||
const binpack::TrainingDataEntry& curr_entry,
|
||||
const binpack::TrainingDataEntry& last_entry,
|
||||
uint64_t max_nodes,
|
||||
uint64_t& curr_nodes,
|
||||
std::vector<chess::Move>& reverse_chain_moves
|
||||
)
|
||||
{
|
||||
const chess::EnumArray<chess::Color, int> piece_count_diff = {
|
||||
last_entry.pos.piecesBB(chess::Color::White).count() - curr_entry.pos.piecesBB(chess::Color::White).count(),
|
||||
last_entry.pos.piecesBB(chess::Color::Black).count() - curr_entry.pos.piecesBB(chess::Color::Black).count()
|
||||
};
|
||||
|
||||
const int ply_diff = last_entry.ply - curr_entry.ply;
|
||||
|
||||
// Last position is older than current.
|
||||
if (ply_diff <= 0)
|
||||
return false;
|
||||
|
||||
// Not enough plies for that many captures.
|
||||
if (piece_count_diff[chess::Color::White] + piece_count_diff[chess::Color::Black] > ply_diff)
|
||||
return false;
|
||||
|
||||
// Not enough plies for that many captures. For each side separately.
|
||||
if ( piece_count_diff[chess::Color::White] > (ply_diff + 1) / 2
|
||||
|| piece_count_diff[chess::Color::Black] > (ply_diff + 1) / 2)
|
||||
return false;
|
||||
|
||||
std::vector<std::pair<chess::Move, int>> legal_moves;
|
||||
chess::movegen::forEachLegalMove(curr_entry.pos, [&](const chess::Move move) {
|
||||
int score = 0;
|
||||
|
||||
// Moving a piece that's already on a correct square.
|
||||
if (curr_entry.pos.pieceAt(move.from) == last_entry.pos.pieceAt(move.from))
|
||||
score -= 10'000;
|
||||
|
||||
// Moving a piece to a correct square.
|
||||
if (curr_entry.pos.pieceAt(move.from) == last_entry.pos.pieceAt(move.to))
|
||||
score += 10'000;
|
||||
|
||||
// Not a capture move but needs to be a capture to fullfill piece difference.
|
||||
if ( ( piece_count_diff[chess::Color::White] + piece_count_diff[chess::Color::Black] == ply_diff
|
||||
|| (piece_count_diff[curr_entry.pos.sideToMove()] == (ply_diff + 1) / 2))
|
||||
&& curr_entry.pos.pieceAt(move.to) == chess::Piece::none())
|
||||
score -= 10'000'000;
|
||||
|
||||
legal_moves.emplace_back(move, score);
|
||||
});
|
||||
|
||||
// A heuristic for searching the legal moves such that we hope to find the solution earlier.
|
||||
std::sort(legal_moves.begin(), legal_moves.end(), [](const auto& lhs, const auto& rhs) {
|
||||
return lhs.second > rhs.second;
|
||||
});
|
||||
|
||||
for (const auto [move, score] : legal_moves)
|
||||
{
|
||||
auto next_entry = curr_entry;
|
||||
next_entry.result = -next_entry.result;
|
||||
next_entry.ply += 1;
|
||||
next_entry.pos.doMove(move);
|
||||
|
||||
// We reached the destination position.
|
||||
if ( next_entry.ply == last_entry.ply
|
||||
&& next_entry.result == last_entry.result
|
||||
&& next_entry.pos == last_entry.pos)
|
||||
{
|
||||
reverse_chain_moves.emplace_back(move);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Reached the search limit, aborting.
|
||||
if (++curr_nodes > max_nodes)
|
||||
return false;
|
||||
|
||||
// We reached the destination position somewhere later in the search.
|
||||
if ( next_entry.ply < last_entry.ply
|
||||
&& find_move_chain_between_positions_impl(next_entry, last_entry, max_nodes, curr_nodes, reverse_chain_moves))
|
||||
{
|
||||
reverse_chain_moves.emplace_back(move);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::vector<chess::Move> find_move_chain_between_positions(
|
||||
const binpack::TrainingDataEntry& curr_entry,
|
||||
const binpack::TrainingDataEntry& next_entry,
|
||||
uint64_t max_nodes
|
||||
)
|
||||
{
|
||||
constexpr int MAX_PLY_DISTANCE = 6;
|
||||
if ( binpack::isContinuation(curr_entry, next_entry)
|
||||
|| curr_entry.ply >= next_entry.ply
|
||||
|| curr_entry.ply + MAX_PLY_DISTANCE < next_entry.ply)
|
||||
return {};
|
||||
|
||||
std::vector<chess::Move> reverse_chain_moves;
|
||||
uint64_t curr_nodes = 0;
|
||||
if (find_move_chain_between_positions_impl(curr_entry, next_entry, max_nodes, curr_nodes, reverse_chain_moves))
|
||||
{
|
||||
std::reverse(reverse_chain_moves.begin(), reverse_chain_moves.end());
|
||||
return reverse_chain_moves;
|
||||
}
|
||||
else
|
||||
{
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
[[nodiscard]] bool discarded_during_training_based_on_move(const binpack::TrainingDataEntry& e)
|
||||
{
|
||||
return e.isCapturingMove() || e.isInCheck();
|
||||
}
|
||||
|
||||
void do_minimize_binpack(MinimizeBinpackParams& params)
|
||||
{
|
||||
static constexpr int VALUE_NONE = 32002;
|
||||
|
||||
if (!ends_with(params.input_filename, ".binpack"))
|
||||
{
|
||||
std::cerr << "Invalid input file type. Must be .binpack.\n";
|
||||
return;
|
||||
}
|
||||
|
||||
std::atomic<std::uint64_t> num_positions_read = 0;
|
||||
std::atomic<std::uint64_t> num_positions_intermediate = 0;
|
||||
std::atomic<std::uint64_t> num_positions_filtered = 0;
|
||||
|
||||
auto in = Tools::open_sfen_input_file(params.input_filename);
|
||||
auto readsome = [&in, mutex = std::mutex{}](int n) mutable -> std::vector<binpack::TrainingDataEntry> {
|
||||
|
||||
std::vector<binpack::TrainingDataEntry> psv;
|
||||
psv.reserve(n);
|
||||
|
||||
std::unique_lock lock(mutex);
|
||||
|
||||
for (int i = 0; i < n; ++i)
|
||||
{
|
||||
auto ps_opt = in->next();
|
||||
if (ps_opt.has_value())
|
||||
{
|
||||
psv.emplace_back(binpack::packedSfenValueToTrainingDataEntry(packed_sfen_tools_to_lib(*ps_opt)));
|
||||
}
|
||||
else
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return psv;
|
||||
};
|
||||
|
||||
auto out = SfenWriter(
|
||||
params.output_filename,
|
||||
Threads.size(),
|
||||
std::numeric_limits<std::uint64_t>::max(),
|
||||
SfenOutputType::Binpack);
|
||||
|
||||
Threads.execute_with_workers([&](auto& th){
|
||||
std::vector<binpack::TrainingDataEntry> intermediate_entries;
|
||||
|
||||
auto write_one_intermediate = [&](const binpack::TrainingDataEntry& e) {
|
||||
intermediate_entries.emplace_back(e);
|
||||
|
||||
// If a position already present in the data would have been discarded anyway
|
||||
// due to the move then we can set the score to something that takes less space.
|
||||
if (discarded_during_training_based_on_move(e))
|
||||
intermediate_entries.back().score = 0;
|
||||
|
||||
const auto pi = num_positions_intermediate.fetch_add(1) + 1;
|
||||
if (pi % 10000 == 0)
|
||||
{
|
||||
const auto pr = num_positions_read.load();
|
||||
const auto pf = num_positions_filtered.load();
|
||||
std::cout << "Read: ~" << pr << ". Intermediate: " << pi << ". Write: ~" << pf << "\n";
|
||||
}
|
||||
};
|
||||
|
||||
auto flush_intermediate_entries = [&]() {
|
||||
|
||||
std::vector<binpack::TrainingDataEntry> filtered_entries;
|
||||
filtered_entries.reserve(intermediate_entries.size());
|
||||
|
||||
// Remove positions that are always skipped and at the beginning of the chain.
|
||||
// Remove positions that are always skipped and at the end of the chain.
|
||||
// Remove chains of always skipped positions that are larger than starting a chain again.
|
||||
|
||||
auto is_skipped = [&](const binpack::TrainingDataEntry& e) {
|
||||
return e.score == VALUE_NONE || discarded_during_training_based_on_move(e);
|
||||
};
|
||||
|
||||
for (size_t i = 0; i < intermediate_entries.size();)
|
||||
{
|
||||
const auto& curr_entry = intermediate_entries[i];
|
||||
|
||||
if (is_skipped(curr_entry))
|
||||
{
|
||||
bool is_continuation = false;
|
||||
if (!filtered_entries.empty())
|
||||
{
|
||||
const auto& prev_entry = filtered_entries.back();
|
||||
is_continuation = binpack::isContinuation(prev_entry, curr_entry);
|
||||
}
|
||||
|
||||
if (!is_continuation)
|
||||
{
|
||||
++i;
|
||||
continue;
|
||||
}
|
||||
|
||||
bool is_tail = true;
|
||||
size_t skip_run_end;
|
||||
// Check if it's start of a tail.
|
||||
for (skip_run_end = i + 1; skip_run_end < intermediate_entries.size(); ++skip_run_end)
|
||||
{
|
||||
// Go until we end the chain.
|
||||
if (!binpack::isContinuation(intermediate_entries[skip_run_end - 1], intermediate_entries[skip_run_end]))
|
||||
break;
|
||||
|
||||
// If we found a non-skippable position then this is not the tail and we cannot skip it entirely.
|
||||
if (!is_skipped(intermediate_entries[skip_run_end]))
|
||||
{
|
||||
is_tail = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// If tail then we don't save it at all.
|
||||
if (is_tail)
|
||||
{
|
||||
// Remove (don't save) tail. Move to the next position to consider.
|
||||
i = skip_run_end;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Otherwise check if the skip run can be removed with a space saving.
|
||||
if (skip_run_end - i < 6) // don't unnecessarily check total cost if upper bound is ~below 32 bytes
|
||||
{
|
||||
binpack::PackedMoveScoreList encoding;
|
||||
for (size_t j = i; j < skip_run_end; ++j)
|
||||
{
|
||||
const auto& e = intermediate_entries[j];
|
||||
encoding.addMoveScore(e.pos, e.move, e.score);
|
||||
}
|
||||
|
||||
// Full new entry would have 32 bytes
|
||||
if (encoding.numBytes() >= 32)
|
||||
{
|
||||
i = skip_run_end;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
filtered_entries.emplace_back(curr_entry);
|
||||
++i;
|
||||
}
|
||||
|
||||
num_positions_filtered.fetch_add(filtered_entries.size());
|
||||
for (auto& e : filtered_entries)
|
||||
{
|
||||
const auto ps = packed_sfen_lib_to_tools(binpack::trainingDataEntryToPackedSfenValue(e));
|
||||
out.write(th.id(), ps);
|
||||
}
|
||||
|
||||
intermediate_entries.clear();
|
||||
};
|
||||
|
||||
for (;;)
|
||||
{
|
||||
auto psv = readsome(10000);
|
||||
|
||||
if (psv.empty())
|
||||
break;
|
||||
|
||||
for(size_t i = 0; i + 1 < psv.size(); ++i)
|
||||
{
|
||||
num_positions_read.fetch_add(1);
|
||||
|
||||
binpack::TrainingDataEntry curr_entry = psv[i];
|
||||
const binpack::TrainingDataEntry& next_entry = psv[i + 1];
|
||||
|
||||
auto move_chain = find_move_chain_between_positions(curr_entry, next_entry, params.chain_search_nodes);
|
||||
if (move_chain.empty())
|
||||
{
|
||||
write_one_intermediate(curr_entry);
|
||||
}
|
||||
else
|
||||
{
|
||||
binpack::TrainingDataEntry e = curr_entry;
|
||||
|
||||
int j = 0;
|
||||
for (const auto& move : move_chain)
|
||||
{
|
||||
e.move = move;
|
||||
// If the positions would have been discarded with the old move but we changed the move
|
||||
// then we need to mark the positions with VALUE_NONE score so that it's skipped.
|
||||
if (j == 0 && discarded_during_training_based_on_move(curr_entry))
|
||||
e.score = VALUE_NONE;
|
||||
write_one_intermediate(e);
|
||||
|
||||
e.ply += 1;
|
||||
e.result = -e.result;
|
||||
e.score = VALUE_NONE; // change subsequent scores to VALUE_NONE so these new positions are ignored
|
||||
e.pos.doMove(move);
|
||||
|
||||
j += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
num_positions_read.fetch_add(1);
|
||||
write_one_intermediate(psv.back());
|
||||
|
||||
flush_intermediate_entries();
|
||||
}
|
||||
});
|
||||
Threads.wait_for_workers_finished();
|
||||
|
||||
const auto pi = num_positions_intermediate.load();
|
||||
const auto pr = num_positions_read.load();
|
||||
const auto pf = num_positions_filtered.load();
|
||||
std::cout << "Read: " << pr << ". Intermediate: " << pi << ". Write: " << pf << "\n";
|
||||
std::cout << "Finished.\n";
|
||||
}
|
||||
|
||||
void minimize_binpack(std::istringstream& is)
|
||||
{
|
||||
MinimizeBinpackParams params{};
|
||||
|
||||
while(true)
|
||||
{
|
||||
std::string token;
|
||||
is >> token;
|
||||
|
||||
if (token == "")
|
||||
break;
|
||||
|
||||
else if (token == "input_file")
|
||||
is >> params.input_filename;
|
||||
else if (token == "output_file")
|
||||
is >> params.output_filename;
|
||||
else if (token == "debug_print")
|
||||
is >> params.debug_print;
|
||||
else if (token == "chain_search_nodes")
|
||||
is >> params.chain_search_nodes;
|
||||
else
|
||||
{
|
||||
std::cout << "ERROR: Unknown option " << token << ". Exiting...\n";
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Performing transform minimize_binpack with parameters:\n";
|
||||
std::cout << "input_file : " << params.input_filename << '\n';
|
||||
std::cout << "output_file : " << params.output_filename << '\n';
|
||||
std::cout << "debug_print : " << params.debug_print << '\n';
|
||||
std::cout << "chain_search_nodes : " << params.chain_search_nodes << '\n';
|
||||
std::cout << '\n';
|
||||
|
||||
do_minimize_binpack(params);
|
||||
}
|
||||
|
||||
void transform(std::istringstream& is)
|
||||
{
|
||||
const std::map<std::string, CommandFunc> subcommands = {
|
||||
{ "nudged_static", &nudged_static },
|
||||
{ "rescore", &rescore },
|
||||
{ "filter_335a9b2d8a80", &filter_335a9b2d8a80 }
|
||||
{ "filter_335a9b2d8a80", &filter_335a9b2d8a80 },
|
||||
{ "minimize_binpack", &minimize_binpack }
|
||||
};
|
||||
|
||||
Eval::NNUE::init();
|
||||
|
||||
Reference in New Issue
Block a user