From a4b598060c0618e8093e71a9ff37f2ed86ca3521 Mon Sep 17 00:00:00 2001 From: Tomasz Sobczyk Date: Wed, 19 May 2021 12:55:14 +0200 Subject: [PATCH] Add stats: ply_discontinuities, material_imbalance, results --- src/tools/stats.cpp | 226 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 197 insertions(+), 29 deletions(-) diff --git a/src/tools/stats.cpp b/src/tools/stats.cpp index 1172f18b..d8274f65 100644 --- a/src/tools/stats.cpp +++ b/src/tools/stats.cpp @@ -57,6 +57,18 @@ namespace Stockfish::Tools::Stats return digits; } + [[nodiscard]] std::string left_pad_to_length(const std::string& str, char ch, int length) + { + if (str.size() < length) + { + return std::string(length - static_cast(str.size()), ch) + str; + } + else + { + return str; + } + } + [[nodiscard]] std::string indent_text(const std::string& text, Indentation indent) { std::string delimiter = "\n"; @@ -258,8 +270,7 @@ namespace Stockfish::Tools::Stats struct StatisticGathererBase { - virtual void on_position(const Position&) {} - virtual void on_move(const Position&, const Move&) {} + virtual void on_entry(const Position&, const Move&, const PackedSfenValue&) {} virtual void reset() = 0; [[nodiscard]] virtual const std::string& get_name() const = 0; [[nodiscard]] virtual StatisticOutput get_output() const = 0; @@ -309,19 +320,11 @@ namespace Stockfish::Tools::Stats } } - void on_position(const Position& position) override + void on_entry(const Position& pos, const Move& move, const PackedSfenValue& psv) override { for (auto& g : m_gatherers) { - g->on_position(position); - } - } - - void on_move(const Position& pos, const Move& move) override - { - for (auto& g : m_gatherers) - { - g->on_move(pos, move); + g->on_entry(pos, move, psv); } } @@ -458,7 +461,7 @@ namespace Stockfish::Tools::Stats { } - void on_position(const Position&) override + void on_entry(const Position&, const Move&, const PackedSfenValue&) override { m_num_positions += 1; } @@ -495,7 +498,7 @@ namespace Stockfish::Tools::Stats } - void on_position(const Position& pos) override + void on_entry(const Position& pos, const Move&, const PackedSfenValue&) override { m_white[pos.square(WHITE)] += 1; m_black[pos.square(BLACK)] += 1; @@ -537,7 +540,7 @@ namespace Stockfish::Tools::Stats } - void on_move(const Position& pos, const Move& move) override + void on_entry(const Position& pos, const Move& move, const PackedSfenValue&) override { if (pos.side_to_move() == WHITE) m_white[from_sq(move)] += 1; @@ -581,7 +584,7 @@ namespace Stockfish::Tools::Stats } - void on_move(const Position& pos, const Move& move) override + void on_entry(const Position& pos, const Move& move, const PackedSfenValue&) override { if (pos.side_to_move() == WHITE) m_white[to_sq(move)] += 1; @@ -629,7 +632,7 @@ namespace Stockfish::Tools::Stats } - void on_move(const Position& pos, const Move& move) override + void on_entry(const Position& pos, const Move& move, const PackedSfenValue&) override { m_total += 1; @@ -692,7 +695,7 @@ namespace Stockfish::Tools::Stats reset(); } - void on_position(const Position& pos) override + void on_entry(const Position& pos, const Move&, const PackedSfenValue&) override { m_piece_count_hist[popcount(pos.pieces())] += 1; } @@ -740,7 +743,7 @@ namespace Stockfish::Tools::Stats reset(); } - void on_move(const Position& pos, const Move& move) override + void on_entry(const Position& pos, const Move& move, const PackedSfenValue&) override { m_moved_piece_type_hist[type_of(pos.piece_on(from_sq(move)))] += 1; } @@ -773,6 +776,170 @@ namespace Stockfish::Tools::Stats std::uint64_t m_moved_piece_type_hist[PIECE_TYPE_NB]; }; + struct PlyDiscontinuitiesCounter : StatisticGathererBase + { + static inline std::string name = "PlyDiscontinuitiesCounter"; + + PlyDiscontinuitiesCounter() + { + reset(); + } + + void on_entry(const Position& pos, const Move&, const PackedSfenValue&) override + { + const int current_ply = pos.game_ply(); + if (m_prev_ply != -1) + { + const bool is_discontinuity = (current_ply != (m_prev_ply + 1)); + if (is_discontinuity) + { + m_num_discontinuities += 1; + } + } + m_prev_ply = current_ply; + } + + void reset() override + { + m_num_discontinuities = 0; + m_prev_ply = -1; + } + + [[nodiscard]] const std::string& get_name() const override + { + return name; + } + + [[nodiscard]] StatisticOutput get_output() const override + { + StatisticOutput out; + out.emplace_node>("Number of ply discontinuities (usually games)", m_num_discontinuities); + return out; + } + + private: + std::uint64_t m_num_discontinuities; + int m_prev_ply; + }; + + struct MaterialImbalanceDistribution : StatisticGathererBase + { + static inline std::string name = "MaterialImbalanceDistribution"; + static constexpr int max_imbalance = 64; + + MaterialImbalanceDistribution() + { + reset(); + } + + void on_entry(const Position& pos, const Move&, const PackedSfenValue&) override + { + const int imbalance = get_simple_material(pos, WHITE) - get_simple_material(pos, BLACK); + const int imbalance_idx = std::clamp(imbalance, -max_imbalance, max_imbalance) + max_imbalance; + m_num_imbalances[imbalance_idx] += 1; + } + + void reset() override + { + for (auto& imb : m_num_imbalances) + imb = 0; + } + + [[nodiscard]] const std::string& get_name() const override + { + return name; + } + + [[nodiscard]] StatisticOutput get_output() const override + { + StatisticOutput out; + auto& header = out.emplace_node("Number of \"simple eval\" imbalances for white's perspective:"); + const int key_length = get_num_base_10_digits(max_imbalance) + 1; + for (int i = -max_imbalance; i <= max_imbalance; ++i) + { + header.emplace_child>( + left_pad_to_length(std::to_string(i), ' ', key_length), + m_num_imbalances[i + max_imbalance] + ); + } + return out; + } + + private: + std::uint64_t m_num_imbalances[max_imbalance + 1 + max_imbalance]; + + [[nodiscard]] int get_simple_material(const Position& pos, Color c) + { + return + 9 * pos.count(c) + + 5 * pos.count(c) + + 3 * pos.count(c) + + 3 * pos.count(c) + + pos.count(c); + } + }; + + struct ResultDistribution : StatisticGathererBase + { + static inline std::string name = "ResultDistribution"; + + ResultDistribution() + { + reset(); + } + + void on_entry(const Position& pos, const Move&, const PackedSfenValue& psv) override + { + const Color stm = pos.side_to_move(); + if (psv.game_result == 0) + { + m_draws += 1; + } + else if (psv.game_result == 1) + { + m_stm_wins += 1; + m_wins[stm] += 1; + } + else + { + m_stm_loses += 1; + m_wins[~stm] += 1; + } + } + + void reset() override + { + m_wins[WHITE] = 0; + m_wins[BLACK] = 0; + m_draws = 0; + m_stm_wins = 0; + m_stm_loses = 0; + } + + [[nodiscard]] const std::string& get_name() const override + { + return name; + } + + [[nodiscard]] StatisticOutput get_output() const override + { + StatisticOutput out; + auto& header = out.emplace_node("Distribution of results:"); + header.emplace_child>("White wins", m_wins[WHITE]); + header.emplace_child>("Black wins", m_wins[BLACK]); + header.emplace_child>("Draws", m_draws); + header.emplace_child>("Side to move wins", m_stm_wins); + header.emplace_child>("Side to move loses", m_stm_loses); + return out; + } + + private: + std::uint64_t m_wins[COLOR_NB]; + std::uint64_t m_draws; + std::uint64_t m_stm_wins; + std::uint64_t m_stm_loses; + }; + /* This function provides factories for all possible statistic gatherers. Each new statistic gatherer needs to be added there. @@ -791,6 +958,12 @@ namespace Stockfish::Tools::Stats reg.add("move", "move_type"); reg.add("move", "moved_piece_type"); + reg.add("ply_discontinuities"); + + reg.add("material_imbalance"); + + reg.add("results"); + reg.add("piece_count"); return reg; @@ -810,12 +983,8 @@ namespace Stockfish::Tools::Stats auto in = Tools::open_sfen_input_file(filename); - auto on_move = [&](const Position& position, const Move& move) { - statistic_gatherers.on_move(position, move); - }; - - auto on_position = [&](const Position& position) { - statistic_gatherers.on_position(position); + auto on_entry = [&](const Position& position, const Move& move, const PackedSfenValue& psv) { + statistic_gatherers.on_entry(position, move, psv); }; if (in == nullptr) @@ -831,12 +1000,11 @@ namespace Stockfish::Tools::Stats if (!v.has_value()) break; - auto& ps = v.value(); + auto& psv = v.value(); - pos.set_from_packed_sfen(ps.sfen, &si, th); + pos.set_from_packed_sfen(psv.sfen, &si, th); - on_position(pos); - on_move(pos, (Move)ps.move); + on_entry(pos, (Move)psv.move, psv); num_processed += 1; if (num_processed % 1'000'000 == 0)