Add stats: ply_discontinuities, material_imbalance, results

This commit is contained in:
Tomasz Sobczyk
2021-05-19 12:55:14 +02:00
parent 640ec5706e
commit a4b598060c

View File

@@ -57,6 +57,18 @@ namespace Stockfish::Tools::Stats
return digits;
}
[[nodiscard]] std::string left_pad_to_length(const std::string& str, char ch, int length)
{
if (str.size() < length)
{
return std::string(length - static_cast<int>(str.size()), ch) + str;
}
else
{
return str;
}
}
[[nodiscard]] std::string indent_text(const std::string& text, Indentation indent)
{
std::string delimiter = "\n";
@@ -258,8 +270,7 @@ namespace Stockfish::Tools::Stats
struct StatisticGathererBase
{
virtual void on_position(const Position&) {}
virtual void on_move(const Position&, const Move&) {}
virtual void on_entry(const Position&, const Move&, const PackedSfenValue&) {}
virtual void reset() = 0;
[[nodiscard]] virtual const std::string& get_name() const = 0;
[[nodiscard]] virtual StatisticOutput get_output() const = 0;
@@ -309,19 +320,11 @@ namespace Stockfish::Tools::Stats
}
}
void on_position(const Position& position) override
void on_entry(const Position& pos, const Move& move, const PackedSfenValue& psv) override
{
for (auto& g : m_gatherers)
{
g->on_position(position);
}
}
void on_move(const Position& pos, const Move& move) override
{
for (auto& g : m_gatherers)
{
g->on_move(pos, move);
g->on_entry(pos, move, psv);
}
}
@@ -458,7 +461,7 @@ namespace Stockfish::Tools::Stats
{
}
void on_position(const Position&) override
void on_entry(const Position&, const Move&, const PackedSfenValue&) override
{
m_num_positions += 1;
}
@@ -495,7 +498,7 @@ namespace Stockfish::Tools::Stats
}
void on_position(const Position& pos) override
void on_entry(const Position& pos, const Move&, const PackedSfenValue&) override
{
m_white[pos.square<KING>(WHITE)] += 1;
m_black[pos.square<KING>(BLACK)] += 1;
@@ -537,7 +540,7 @@ namespace Stockfish::Tools::Stats
}
void on_move(const Position& pos, const Move& move) override
void on_entry(const Position& pos, const Move& move, const PackedSfenValue&) override
{
if (pos.side_to_move() == WHITE)
m_white[from_sq(move)] += 1;
@@ -581,7 +584,7 @@ namespace Stockfish::Tools::Stats
}
void on_move(const Position& pos, const Move& move) override
void on_entry(const Position& pos, const Move& move, const PackedSfenValue&) override
{
if (pos.side_to_move() == WHITE)
m_white[to_sq(move)] += 1;
@@ -629,7 +632,7 @@ namespace Stockfish::Tools::Stats
}
void on_move(const Position& pos, const Move& move) override
void on_entry(const Position& pos, const Move& move, const PackedSfenValue&) override
{
m_total += 1;
@@ -692,7 +695,7 @@ namespace Stockfish::Tools::Stats
reset();
}
void on_position(const Position& pos) override
void on_entry(const Position& pos, const Move&, const PackedSfenValue&) override
{
m_piece_count_hist[popcount(pos.pieces())] += 1;
}
@@ -740,7 +743,7 @@ namespace Stockfish::Tools::Stats
reset();
}
void on_move(const Position& pos, const Move& move) override
void on_entry(const Position& pos, const Move& move, const PackedSfenValue&) override
{
m_moved_piece_type_hist[type_of(pos.piece_on(from_sq(move)))] += 1;
}
@@ -773,6 +776,170 @@ namespace Stockfish::Tools::Stats
std::uint64_t m_moved_piece_type_hist[PIECE_TYPE_NB];
};
struct PlyDiscontinuitiesCounter : StatisticGathererBase
{
static inline std::string name = "PlyDiscontinuitiesCounter";
PlyDiscontinuitiesCounter()
{
reset();
}
void on_entry(const Position& pos, const Move&, const PackedSfenValue&) override
{
const int current_ply = pos.game_ply();
if (m_prev_ply != -1)
{
const bool is_discontinuity = (current_ply != (m_prev_ply + 1));
if (is_discontinuity)
{
m_num_discontinuities += 1;
}
}
m_prev_ply = current_ply;
}
void reset() override
{
m_num_discontinuities = 0;
m_prev_ply = -1;
}
[[nodiscard]] const std::string& get_name() const override
{
return name;
}
[[nodiscard]] StatisticOutput get_output() const override
{
StatisticOutput out;
out.emplace_node<StatisticOutputEntryValue<std::uint64_t>>("Number of ply discontinuities (usually games)", m_num_discontinuities);
return out;
}
private:
std::uint64_t m_num_discontinuities;
int m_prev_ply;
};
struct MaterialImbalanceDistribution : StatisticGathererBase
{
static inline std::string name = "MaterialImbalanceDistribution";
static constexpr int max_imbalance = 64;
MaterialImbalanceDistribution()
{
reset();
}
void on_entry(const Position& pos, const Move&, const PackedSfenValue&) override
{
const int imbalance = get_simple_material(pos, WHITE) - get_simple_material(pos, BLACK);
const int imbalance_idx = std::clamp(imbalance, -max_imbalance, max_imbalance) + max_imbalance;
m_num_imbalances[imbalance_idx] += 1;
}
void reset() override
{
for (auto& imb : m_num_imbalances)
imb = 0;
}
[[nodiscard]] const std::string& get_name() const override
{
return name;
}
[[nodiscard]] StatisticOutput get_output() const override
{
StatisticOutput out;
auto& header = out.emplace_node<StatisticOutputEntryHeader>("Number of \"simple eval\" imbalances for white's perspective:");
const int key_length = get_num_base_10_digits(max_imbalance) + 1;
for (int i = -max_imbalance; i <= max_imbalance; ++i)
{
header.emplace_child<StatisticOutputEntryValue<std::uint64_t>>(
left_pad_to_length(std::to_string(i), ' ', key_length),
m_num_imbalances[i + max_imbalance]
);
}
return out;
}
private:
std::uint64_t m_num_imbalances[max_imbalance + 1 + max_imbalance];
[[nodiscard]] int get_simple_material(const Position& pos, Color c)
{
return
9 * pos.count<QUEEN>(c)
+ 5 * pos.count<ROOK>(c)
+ 3 * pos.count<BISHOP>(c)
+ 3 * pos.count<KNIGHT>(c)
+ pos.count<PAWN>(c);
}
};
struct ResultDistribution : StatisticGathererBase
{
static inline std::string name = "ResultDistribution";
ResultDistribution()
{
reset();
}
void on_entry(const Position& pos, const Move&, const PackedSfenValue& psv) override
{
const Color stm = pos.side_to_move();
if (psv.game_result == 0)
{
m_draws += 1;
}
else if (psv.game_result == 1)
{
m_stm_wins += 1;
m_wins[stm] += 1;
}
else
{
m_stm_loses += 1;
m_wins[~stm] += 1;
}
}
void reset() override
{
m_wins[WHITE] = 0;
m_wins[BLACK] = 0;
m_draws = 0;
m_stm_wins = 0;
m_stm_loses = 0;
}
[[nodiscard]] const std::string& get_name() const override
{
return name;
}
[[nodiscard]] StatisticOutput get_output() const override
{
StatisticOutput out;
auto& header = out.emplace_node<StatisticOutputEntryHeader>("Distribution of results:");
header.emplace_child<StatisticOutputEntryValue<std::uint64_t>>("White wins", m_wins[WHITE]);
header.emplace_child<StatisticOutputEntryValue<std::uint64_t>>("Black wins", m_wins[BLACK]);
header.emplace_child<StatisticOutputEntryValue<std::uint64_t>>("Draws", m_draws);
header.emplace_child<StatisticOutputEntryValue<std::uint64_t>>("Side to move wins", m_stm_wins);
header.emplace_child<StatisticOutputEntryValue<std::uint64_t>>("Side to move loses", m_stm_loses);
return out;
}
private:
std::uint64_t m_wins[COLOR_NB];
std::uint64_t m_draws;
std::uint64_t m_stm_wins;
std::uint64_t m_stm_loses;
};
/*
This function provides factories for all possible statistic gatherers.
Each new statistic gatherer needs to be added there.
@@ -791,6 +958,12 @@ namespace Stockfish::Tools::Stats
reg.add<MoveTypeCounter>("move", "move_type");
reg.add<MovedPieceTypeCounter>("move", "moved_piece_type");
reg.add<PlyDiscontinuitiesCounter>("ply_discontinuities");
reg.add<MaterialImbalanceDistribution>("material_imbalance");
reg.add<ResultDistribution>("results");
reg.add<PieceCountCounter>("piece_count");
return reg;
@@ -810,12 +983,8 @@ namespace Stockfish::Tools::Stats
auto in = Tools::open_sfen_input_file(filename);
auto on_move = [&](const Position& position, const Move& move) {
statistic_gatherers.on_move(position, move);
};
auto on_position = [&](const Position& position) {
statistic_gatherers.on_position(position);
auto on_entry = [&](const Position& position, const Move& move, const PackedSfenValue& psv) {
statistic_gatherers.on_entry(position, move, psv);
};
if (in == nullptr)
@@ -831,12 +1000,11 @@ namespace Stockfish::Tools::Stats
if (!v.has_value())
break;
auto& ps = v.value();
auto& psv = v.value();
pos.set_from_packed_sfen(ps.sfen, &si, th);
pos.set_from_packed_sfen(psv.sfen, &si, th);
on_position(pos);
on_move(pos, (Move)ps.move);
on_entry(pos, (Move)psv.move, psv);
num_processed += 1;
if (num_processed % 1'000'000 == 0)