diff --git a/src/learn/stats.cpp b/src/learn/stats.cpp index 0a93cc5e..c77b61b9 100644 --- a/src/learn/stats.cpp +++ b/src/learn/stats.cpp @@ -30,7 +30,7 @@ namespace Learner::Stats struct StatisticGathererBase { virtual void on_position(const Position&) {} - virtual void on_move(const Move&) {} + virtual void on_move(const Position&, const Move&) {} virtual void reset() = 0; [[nodiscard]] virtual const std::string& get_name() const = 0; [[nodiscard]] virtual std::map get_formatted_stats() const = 0; @@ -86,11 +86,11 @@ namespace Learner::Stats } } - void on_move(const Move& move) override + void on_move(const Position& pos, const Move& move) override { for (auto& g : m_gatherers) { - g->on_move(move); + g->on_move(pos, move); } } @@ -261,8 +261,94 @@ namespace Learner::Stats [[nodiscard]] std::map get_formatted_stats() const override { return { - { "White king squares", m_white.get_formatted_stats() }, - { "Black king squares", m_black.get_formatted_stats() } + { "White king squares", '\n' + m_white.get_formatted_stats() }, + { "Black king squares", '\n' + m_black.get_formatted_stats() } + }; + } + + private: + StatPerSquare m_white; + StatPerSquare m_black; + }; + + struct MoveFromCounter : StatisticGathererBase + { + static inline std::string name = "MoveFromCounter"; + + MoveFromCounter() : + m_white{}, + m_black{} + { + + } + + void on_move(const Position& pos, const Move& move) override + { + if (pos.side_to_move() == WHITE) + m_white[from_sq(move)] += 1; + else + m_black[from_sq(move)] += 1; + } + + void reset() override + { + m_white = StatPerSquare{}; + m_black = StatPerSquare{}; + } + + [[nodiscard]] const std::string& get_name() const override + { + return name; + } + + [[nodiscard]] std::map get_formatted_stats() const override + { + return { + { "White move from squares", '\n' + m_white.get_formatted_stats() }, + { "Black move from squares", '\n' + m_black.get_formatted_stats() } + }; + } + + private: + StatPerSquare m_white; + StatPerSquare m_black; + }; + + struct MoveToCounter : StatisticGathererBase + { + static inline std::string name = "MoveToCounter"; + + MoveToCounter() : + m_white{}, + m_black{} + { + + } + + void on_move(const Position& pos, const Move& move) override + { + if (pos.side_to_move() == WHITE) + m_white[to_sq(move)] += 1; + else + m_black[to_sq(move)] += 1; + } + + void reset() override + { + m_white = StatPerSquare{}; + m_black = StatPerSquare{}; + } + + [[nodiscard]] const std::string& get_name() const override + { + return name; + } + + [[nodiscard]] std::map get_formatted_stats() const override + { + return { + { "White move to squares", '\n' + m_white.get_formatted_stats() }, + { "Black move to squares", '\n' + m_black.get_formatted_stats() } }; } @@ -285,6 +371,10 @@ namespace Learner::Stats reg.add("king"); reg.add("king_square_count"); + reg.add("move"); + reg.add("move_from_count"); + reg.add("move_to_count"); + return reg; }(); @@ -302,8 +392,8 @@ namespace Learner::Stats auto in = Learner::open_sfen_input_file(filename); - auto on_move = [&](Move move) { - statistic_gatherers.on_move(move); + auto on_move = [&](const Position& position, const Move& move) { + statistic_gatherers.on_move(position, move); }; auto on_position = [&](const Position& position) { @@ -328,7 +418,7 @@ namespace Learner::Stats pos.set_from_packed_sfen(ps.sfen, &si, th); on_position(pos); - on_move((Move)ps.move); + on_move(pos, (Move)ps.move); num_processed += 1; if (num_processed % 1'000'000 == 0)