mirror of
https://github.com/HChaZZY/Stockfish.git
synced 2025-12-06 10:53:50 +08:00
Add transform minimize_binpack for minimizing existing .binpack datasets. (#4447)
Takes advantage of the sample skipping rules that are used during training (capture, check, or VALUE_NONE). Adds positions to keep continuity, which improves compression.
This commit is contained in:
@@ -5048,11 +5048,21 @@ namespace chess
|
|||||||
|
|
||||||
// Generates all pseudo legal moves for the position.
|
// Generates all pseudo legal moves for the position.
|
||||||
// `pos` must be a legal chess position
|
// `pos` must be a legal chess position
|
||||||
[[nodiscard]] std::vector<Move> generatePseudoLegalMoves(const Position& pos);
|
[[nodiscard]] inline std::vector<Move> generatePseudoLegalMoves(const Position& pos)
|
||||||
|
{
|
||||||
|
std::vector<Move> moves;
|
||||||
|
forEachPseudoLegalMove(pos, [&moves](Move move) { moves.emplace_back(move); });
|
||||||
|
return moves;
|
||||||
|
}
|
||||||
|
|
||||||
// Generates all legal moves for the position.
|
// Generates all legal moves for the position.
|
||||||
// `pos` must be a legal chess position
|
// `pos` must be a legal chess position
|
||||||
[[nodiscard]] std::vector<Move> generateLegalMoves(const Position& pos);
|
[[nodiscard]] inline std::vector<Move> generateLegalMoves(const Position& pos)
|
||||||
|
{
|
||||||
|
std::vector<Move> moves;
|
||||||
|
forEachLegalMove(pos, [&moves](Move move) { moves.emplace_back(move); });
|
||||||
|
return moves;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] inline bool Position::isCheck() const
|
[[nodiscard]] inline bool Position::isCheck() const
|
||||||
@@ -6835,6 +6845,17 @@ namespace binpack
|
|||||||
{
|
{
|
||||||
return pos.isMoveLegal(move);
|
return pos.isMoveLegal(move);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] bool isCapturingMove() const
|
||||||
|
{
|
||||||
|
return pos.pieceAt(move.to) != chess::Piece::none() &&
|
||||||
|
pos.pieceAt(move.to).color() != pos.pieceAt(move.from).color(); // Exclude castling
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] bool isInCheck() const
|
||||||
|
{
|
||||||
|
return pos.isCheck();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
[[nodiscard]] inline TrainingDataEntry packedSfenValueToTrainingDataEntry(const nodchip::PackedSfenValue& psv)
|
[[nodiscard]] inline TrainingDataEntry packedSfenValueToTrainingDataEntry(const nodchip::PackedSfenValue& psv)
|
||||||
@@ -7110,6 +7131,11 @@ namespace binpack
|
|||||||
std::uint16_t numPlies = 0;
|
std::uint16_t numPlies = 0;
|
||||||
std::vector<unsigned char> movetext;
|
std::vector<unsigned char> movetext;
|
||||||
|
|
||||||
|
[[nodiscard]] std::size_t numBytes() const
|
||||||
|
{
|
||||||
|
return movetext.size();
|
||||||
|
}
|
||||||
|
|
||||||
void clear(const TrainingDataEntry& e)
|
void clear(const TrainingDataEntry& e)
|
||||||
{
|
{
|
||||||
numPlies = 0;
|
numPlies = 0;
|
||||||
|
|||||||
@@ -11,6 +11,8 @@
|
|||||||
|
|
||||||
#include "nnue/evaluate_nnue.h"
|
#include "nnue/evaluate_nnue.h"
|
||||||
|
|
||||||
|
#include "extra/nnue_data_binpack_format.h"
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
@@ -725,12 +727,406 @@ namespace Stockfish::Tools
|
|||||||
do_filter_335a9b2d8a80(params);
|
do_filter_335a9b2d8a80(params);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct MinimizeBinpackParams
|
||||||
|
{
|
||||||
|
std::string input_filename = "in.binpack";
|
||||||
|
std::string output_filename = "out.binpack";
|
||||||
|
bool debug_print = false;
|
||||||
|
uint64_t chain_search_nodes = 1024 * 64;
|
||||||
|
|
||||||
|
void enforce_constraints()
|
||||||
|
{
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
[[nodiscard]] binpack::nodchip::PackedSfenValue packed_sfen_tools_to_lib(Stockfish::Tools::PackedSfenValue ps)
|
||||||
|
{
|
||||||
|
binpack::nodchip::PackedSfenValue ret;
|
||||||
|
static_assert(sizeof(ret) == sizeof(ps));
|
||||||
|
std::memcpy(&ret, &ps, sizeof(ret));
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] Stockfish::Tools::PackedSfenValue packed_sfen_lib_to_tools(binpack::nodchip::PackedSfenValue ps)
|
||||||
|
{
|
||||||
|
Stockfish::Tools::PackedSfenValue ret;
|
||||||
|
static_assert(sizeof(ret) == sizeof(ps));
|
||||||
|
std::memcpy(&ret, &ps, sizeof(ret));
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] bool find_move_chain_between_positions_impl(
|
||||||
|
const binpack::TrainingDataEntry& curr_entry,
|
||||||
|
const binpack::TrainingDataEntry& last_entry,
|
||||||
|
uint64_t max_nodes,
|
||||||
|
uint64_t& curr_nodes,
|
||||||
|
std::vector<chess::Move>& reverse_chain_moves
|
||||||
|
)
|
||||||
|
{
|
||||||
|
const chess::EnumArray<chess::Color, int> piece_count_diff = {
|
||||||
|
last_entry.pos.piecesBB(chess::Color::White).count() - curr_entry.pos.piecesBB(chess::Color::White).count(),
|
||||||
|
last_entry.pos.piecesBB(chess::Color::Black).count() - curr_entry.pos.piecesBB(chess::Color::Black).count()
|
||||||
|
};
|
||||||
|
|
||||||
|
const int ply_diff = last_entry.ply - curr_entry.ply;
|
||||||
|
|
||||||
|
// Last position is older than current.
|
||||||
|
if (ply_diff <= 0)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
// Not enough plies for that many captures.
|
||||||
|
if (piece_count_diff[chess::Color::White] + piece_count_diff[chess::Color::Black] > ply_diff)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
// Not enough plies for that many captures. For each side separately.
|
||||||
|
if ( piece_count_diff[chess::Color::White] > (ply_diff + 1) / 2
|
||||||
|
|| piece_count_diff[chess::Color::Black] > (ply_diff + 1) / 2)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
std::vector<std::pair<chess::Move, int>> legal_moves;
|
||||||
|
chess::movegen::forEachLegalMove(curr_entry.pos, [&](const chess::Move move) {
|
||||||
|
int score = 0;
|
||||||
|
|
||||||
|
// Moving a piece that's already on a correct square.
|
||||||
|
if (curr_entry.pos.pieceAt(move.from) == last_entry.pos.pieceAt(move.from))
|
||||||
|
score -= 10'000;
|
||||||
|
|
||||||
|
// Moving a piece to a correct square.
|
||||||
|
if (curr_entry.pos.pieceAt(move.from) == last_entry.pos.pieceAt(move.to))
|
||||||
|
score += 10'000;
|
||||||
|
|
||||||
|
// Not a capture move but needs to be a capture to fullfill piece difference.
|
||||||
|
if ( ( piece_count_diff[chess::Color::White] + piece_count_diff[chess::Color::Black] == ply_diff
|
||||||
|
|| (piece_count_diff[curr_entry.pos.sideToMove()] == (ply_diff + 1) / 2))
|
||||||
|
&& curr_entry.pos.pieceAt(move.to) == chess::Piece::none())
|
||||||
|
score -= 10'000'000;
|
||||||
|
|
||||||
|
legal_moves.emplace_back(move, score);
|
||||||
|
});
|
||||||
|
|
||||||
|
// A heuristic for searching the legal moves such that we hope to find the solution earlier.
|
||||||
|
std::sort(legal_moves.begin(), legal_moves.end(), [](const auto& lhs, const auto& rhs) {
|
||||||
|
return lhs.second > rhs.second;
|
||||||
|
});
|
||||||
|
|
||||||
|
for (const auto [move, score] : legal_moves)
|
||||||
|
{
|
||||||
|
auto next_entry = curr_entry;
|
||||||
|
next_entry.result = -next_entry.result;
|
||||||
|
next_entry.ply += 1;
|
||||||
|
next_entry.pos.doMove(move);
|
||||||
|
|
||||||
|
// We reached the destination position.
|
||||||
|
if ( next_entry.ply == last_entry.ply
|
||||||
|
&& next_entry.result == last_entry.result
|
||||||
|
&& next_entry.pos == last_entry.pos)
|
||||||
|
{
|
||||||
|
reverse_chain_moves.emplace_back(move);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reached the search limit, aborting.
|
||||||
|
if (++curr_nodes > max_nodes)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
// We reached the destination position somewhere later in the search.
|
||||||
|
if ( next_entry.ply < last_entry.ply
|
||||||
|
&& find_move_chain_between_positions_impl(next_entry, last_entry, max_nodes, curr_nodes, reverse_chain_moves))
|
||||||
|
{
|
||||||
|
reverse_chain_moves.emplace_back(move);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] std::vector<chess::Move> find_move_chain_between_positions(
|
||||||
|
const binpack::TrainingDataEntry& curr_entry,
|
||||||
|
const binpack::TrainingDataEntry& next_entry,
|
||||||
|
uint64_t max_nodes
|
||||||
|
)
|
||||||
|
{
|
||||||
|
constexpr int MAX_PLY_DISTANCE = 6;
|
||||||
|
if ( binpack::isContinuation(curr_entry, next_entry)
|
||||||
|
|| curr_entry.ply >= next_entry.ply
|
||||||
|
|| curr_entry.ply + MAX_PLY_DISTANCE < next_entry.ply)
|
||||||
|
return {};
|
||||||
|
|
||||||
|
std::vector<chess::Move> reverse_chain_moves;
|
||||||
|
uint64_t curr_nodes = 0;
|
||||||
|
if (find_move_chain_between_positions_impl(curr_entry, next_entry, max_nodes, curr_nodes, reverse_chain_moves))
|
||||||
|
{
|
||||||
|
std::reverse(reverse_chain_moves.begin(), reverse_chain_moves.end());
|
||||||
|
return reverse_chain_moves;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] bool discarded_during_training_based_on_move(const binpack::TrainingDataEntry& e)
|
||||||
|
{
|
||||||
|
return e.isCapturingMove() || e.isInCheck();
|
||||||
|
}
|
||||||
|
|
||||||
|
void do_minimize_binpack(MinimizeBinpackParams& params)
|
||||||
|
{
|
||||||
|
static constexpr int VALUE_NONE = 32002;
|
||||||
|
|
||||||
|
if (!ends_with(params.input_filename, ".binpack"))
|
||||||
|
{
|
||||||
|
std::cerr << "Invalid input file type. Must be .binpack.\n";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::atomic<std::uint64_t> num_positions_read = 0;
|
||||||
|
std::atomic<std::uint64_t> num_positions_intermediate = 0;
|
||||||
|
std::atomic<std::uint64_t> num_positions_filtered = 0;
|
||||||
|
|
||||||
|
auto in = Tools::open_sfen_input_file(params.input_filename);
|
||||||
|
auto readsome = [&in, mutex = std::mutex{}](int n) mutable -> std::vector<binpack::TrainingDataEntry> {
|
||||||
|
|
||||||
|
std::vector<binpack::TrainingDataEntry> psv;
|
||||||
|
psv.reserve(n);
|
||||||
|
|
||||||
|
std::unique_lock lock(mutex);
|
||||||
|
|
||||||
|
for (int i = 0; i < n; ++i)
|
||||||
|
{
|
||||||
|
auto ps_opt = in->next();
|
||||||
|
if (ps_opt.has_value())
|
||||||
|
{
|
||||||
|
psv.emplace_back(binpack::packedSfenValueToTrainingDataEntry(packed_sfen_tools_to_lib(*ps_opt)));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return psv;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto out = SfenWriter(
|
||||||
|
params.output_filename,
|
||||||
|
Threads.size(),
|
||||||
|
std::numeric_limits<std::uint64_t>::max(),
|
||||||
|
SfenOutputType::Binpack);
|
||||||
|
|
||||||
|
Threads.execute_with_workers([&](auto& th){
|
||||||
|
std::vector<binpack::TrainingDataEntry> intermediate_entries;
|
||||||
|
|
||||||
|
auto write_one_intermediate = [&](const binpack::TrainingDataEntry& e) {
|
||||||
|
intermediate_entries.emplace_back(e);
|
||||||
|
|
||||||
|
// If a position already present in the data would have been discarded anyway
|
||||||
|
// due to the move then we can set the score to something that takes less space.
|
||||||
|
if (discarded_during_training_based_on_move(e))
|
||||||
|
intermediate_entries.back().score = 0;
|
||||||
|
|
||||||
|
const auto pi = num_positions_intermediate.fetch_add(1) + 1;
|
||||||
|
if (pi % 10000 == 0)
|
||||||
|
{
|
||||||
|
const auto pr = num_positions_read.load();
|
||||||
|
const auto pf = num_positions_filtered.load();
|
||||||
|
std::cout << "Read: ~" << pr << ". Intermediate: " << pi << ". Write: ~" << pf << "\n";
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto flush_intermediate_entries = [&]() {
|
||||||
|
|
||||||
|
std::vector<binpack::TrainingDataEntry> filtered_entries;
|
||||||
|
filtered_entries.reserve(intermediate_entries.size());
|
||||||
|
|
||||||
|
// Remove positions that are always skipped and at the beginning of the chain.
|
||||||
|
// Remove positions that are always skipped and at the end of the chain.
|
||||||
|
// Remove chains of always skipped positions that are larger than starting a chain again.
|
||||||
|
|
||||||
|
auto is_skipped = [&](const binpack::TrainingDataEntry& e) {
|
||||||
|
return e.score == VALUE_NONE || discarded_during_training_based_on_move(e);
|
||||||
|
};
|
||||||
|
|
||||||
|
for (size_t i = 0; i < intermediate_entries.size();)
|
||||||
|
{
|
||||||
|
const auto& curr_entry = intermediate_entries[i];
|
||||||
|
|
||||||
|
if (is_skipped(curr_entry))
|
||||||
|
{
|
||||||
|
bool is_continuation = false;
|
||||||
|
if (!filtered_entries.empty())
|
||||||
|
{
|
||||||
|
const auto& prev_entry = filtered_entries.back();
|
||||||
|
is_continuation = binpack::isContinuation(prev_entry, curr_entry);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!is_continuation)
|
||||||
|
{
|
||||||
|
++i;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_tail = true;
|
||||||
|
size_t skip_run_end;
|
||||||
|
// Check if it's start of a tail.
|
||||||
|
for (skip_run_end = i + 1; skip_run_end < intermediate_entries.size(); ++skip_run_end)
|
||||||
|
{
|
||||||
|
// Go until we end the chain.
|
||||||
|
if (!binpack::isContinuation(intermediate_entries[skip_run_end - 1], intermediate_entries[skip_run_end]))
|
||||||
|
break;
|
||||||
|
|
||||||
|
// If we found a non-skippable position then this is not the tail and we cannot skip it entirely.
|
||||||
|
if (!is_skipped(intermediate_entries[skip_run_end]))
|
||||||
|
{
|
||||||
|
is_tail = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If tail then we don't save it at all.
|
||||||
|
if (is_tail)
|
||||||
|
{
|
||||||
|
// Remove (don't save) tail. Move to the next position to consider.
|
||||||
|
i = skip_run_end;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise check if the skip run can be removed with a space saving.
|
||||||
|
if (skip_run_end - i < 6) // don't unnecessarily check total cost if upper bound is ~below 32 bytes
|
||||||
|
{
|
||||||
|
binpack::PackedMoveScoreList encoding;
|
||||||
|
for (size_t j = i; j < skip_run_end; ++j)
|
||||||
|
{
|
||||||
|
const auto& e = intermediate_entries[j];
|
||||||
|
encoding.addMoveScore(e.pos, e.move, e.score);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Full new entry would have 32 bytes
|
||||||
|
if (encoding.numBytes() >= 32)
|
||||||
|
{
|
||||||
|
i = skip_run_end;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered_entries.emplace_back(curr_entry);
|
||||||
|
++i;
|
||||||
|
}
|
||||||
|
|
||||||
|
num_positions_filtered.fetch_add(filtered_entries.size());
|
||||||
|
for (auto& e : filtered_entries)
|
||||||
|
{
|
||||||
|
const auto ps = packed_sfen_lib_to_tools(binpack::trainingDataEntryToPackedSfenValue(e));
|
||||||
|
out.write(th.id(), ps);
|
||||||
|
}
|
||||||
|
|
||||||
|
intermediate_entries.clear();
|
||||||
|
};
|
||||||
|
|
||||||
|
for (;;)
|
||||||
|
{
|
||||||
|
auto psv = readsome(10000);
|
||||||
|
|
||||||
|
if (psv.empty())
|
||||||
|
break;
|
||||||
|
|
||||||
|
for(size_t i = 0; i + 1 < psv.size(); ++i)
|
||||||
|
{
|
||||||
|
num_positions_read.fetch_add(1);
|
||||||
|
|
||||||
|
binpack::TrainingDataEntry curr_entry = psv[i];
|
||||||
|
const binpack::TrainingDataEntry& next_entry = psv[i + 1];
|
||||||
|
|
||||||
|
auto move_chain = find_move_chain_between_positions(curr_entry, next_entry, params.chain_search_nodes);
|
||||||
|
if (move_chain.empty())
|
||||||
|
{
|
||||||
|
write_one_intermediate(curr_entry);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
binpack::TrainingDataEntry e = curr_entry;
|
||||||
|
|
||||||
|
int j = 0;
|
||||||
|
for (const auto& move : move_chain)
|
||||||
|
{
|
||||||
|
e.move = move;
|
||||||
|
// If the positions would have been discarded with the old move but we changed the move
|
||||||
|
// then we need to mark the positions with VALUE_NONE score so that it's skipped.
|
||||||
|
if (j == 0 && discarded_during_training_based_on_move(curr_entry))
|
||||||
|
e.score = VALUE_NONE;
|
||||||
|
write_one_intermediate(e);
|
||||||
|
|
||||||
|
e.ply += 1;
|
||||||
|
e.result = -e.result;
|
||||||
|
e.score = VALUE_NONE; // change subsequent scores to VALUE_NONE so these new positions are ignored
|
||||||
|
e.pos.doMove(move);
|
||||||
|
|
||||||
|
j += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
num_positions_read.fetch_add(1);
|
||||||
|
write_one_intermediate(psv.back());
|
||||||
|
|
||||||
|
flush_intermediate_entries();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
Threads.wait_for_workers_finished();
|
||||||
|
|
||||||
|
const auto pi = num_positions_intermediate.load();
|
||||||
|
const auto pr = num_positions_read.load();
|
||||||
|
const auto pf = num_positions_filtered.load();
|
||||||
|
std::cout << "Read: " << pr << ". Intermediate: " << pi << ". Write: " << pf << "\n";
|
||||||
|
std::cout << "Finished.\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
void minimize_binpack(std::istringstream& is)
|
||||||
|
{
|
||||||
|
MinimizeBinpackParams params{};
|
||||||
|
|
||||||
|
while(true)
|
||||||
|
{
|
||||||
|
std::string token;
|
||||||
|
is >> token;
|
||||||
|
|
||||||
|
if (token == "")
|
||||||
|
break;
|
||||||
|
|
||||||
|
else if (token == "input_file")
|
||||||
|
is >> params.input_filename;
|
||||||
|
else if (token == "output_file")
|
||||||
|
is >> params.output_filename;
|
||||||
|
else if (token == "debug_print")
|
||||||
|
is >> params.debug_print;
|
||||||
|
else if (token == "chain_search_nodes")
|
||||||
|
is >> params.chain_search_nodes;
|
||||||
|
else
|
||||||
|
{
|
||||||
|
std::cout << "ERROR: Unknown option " << token << ". Exiting...\n";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::cout << "Performing transform minimize_binpack with parameters:\n";
|
||||||
|
std::cout << "input_file : " << params.input_filename << '\n';
|
||||||
|
std::cout << "output_file : " << params.output_filename << '\n';
|
||||||
|
std::cout << "debug_print : " << params.debug_print << '\n';
|
||||||
|
std::cout << "chain_search_nodes : " << params.chain_search_nodes << '\n';
|
||||||
|
std::cout << '\n';
|
||||||
|
|
||||||
|
do_minimize_binpack(params);
|
||||||
|
}
|
||||||
|
|
||||||
void transform(std::istringstream& is)
|
void transform(std::istringstream& is)
|
||||||
{
|
{
|
||||||
const std::map<std::string, CommandFunc> subcommands = {
|
const std::map<std::string, CommandFunc> subcommands = {
|
||||||
{ "nudged_static", &nudged_static },
|
{ "nudged_static", &nudged_static },
|
||||||
{ "rescore", &rescore },
|
{ "rescore", &rescore },
|
||||||
{ "filter_335a9b2d8a80", &filter_335a9b2d8a80 }
|
{ "filter_335a9b2d8a80", &filter_335a9b2d8a80 },
|
||||||
|
{ "minimize_binpack", &minimize_binpack }
|
||||||
};
|
};
|
||||||
|
|
||||||
Eval::NNUE::init();
|
Eval::NNUE::init();
|
||||||
|
|||||||
Reference in New Issue
Block a user