diff --git a/docs/stats.md b/docs/stats.md index d5a76b61..78fe9051 100644 --- a/docs/stats.md +++ b/docs/stats.md @@ -4,6 +4,14 @@ Simplest usage: `stockfish.exe gather_statistics all input_file a.binpack` +Any name that doesn't designate an argument name or is not an argument will be interpreted as a group name. + +## Parameters + +`input_file` - the path to the .bin or .binpack input file to read + +`max_count` - the maximum number of positions to process. Default: no limit. + ## Groups `all` @@ -13,3 +21,25 @@ Simplest usage: `stockfish.exe gather_statistics all input_file a.binpack` `position_count` - `struct PositionCounter` - the total number of positions in the file. + + + reg.add("king", "king_square_count"); + + reg.add("move", "move_from_count"); + reg.add("move", "move_to_count"); + reg.add("move", "move_type"); + reg.add("move", "moved_piece_type"); + + reg.add("piece_count"); + +`king`, `king_square_count` - the number of times a king was on each square. Output is layed out as a chessboard, with the 8th rank being the topmost. Separate values for white and black kings. + +`move`, `move_from_count` - same as `king_square_count` but for from_sq(move) + +`move`, `move_to_count` - same as `king_square_count` but for to_sq(move) + +`move`, `move_type` - the number of moves with each type. Includes normal, captures, castling, promotions, enpassant. The groups are not disjoint. + +`move`, `moved_piece_type` - the number of times a piece of each type was moved + +`piece_count` - the histogram of the number of pieces on the board diff --git a/src/learn/stats.cpp b/src/learn/stats.cpp index 9d9589c4..c0e2c0a1 100644 --- a/src/learn/stats.cpp +++ b/src/learn/stats.cpp @@ -11,12 +11,16 @@ #include "nnue/evaluate_nnue.h" +#include #include #include +#include #include #include #include #include +#include +#include #include #include #include @@ -26,13 +30,190 @@ namespace Learner::Stats struct StatisticGathererBase { virtual void on_position(const Position&) {} - virtual void on_move(const Move&) {} + virtual void on_move(const Position&, const Move&) {} virtual void reset() = 0; - [[nodiscard]] virtual std::map get_formatted_stats() const = 0; + [[nodiscard]] virtual const std::string& get_name() const = 0; + [[nodiscard]] virtual std::vector> get_formatted_stats() const = 0; }; + struct StatisticGathererFactoryBase + { + [[nodiscard]] virtual std::unique_ptr create() const = 0; + [[nodiscard]] virtual const std::string& get_name() const = 0; + }; + + template + struct StatisticGathererFactory : StatisticGathererFactoryBase + { + static inline std::string name = T::name; + + [[nodiscard]] std::unique_ptr create() const override + { + return std::make_unique(); + } + + [[nodiscard]] const std::string& get_name() const override + { + return name; + } + }; + + struct StatisticGathererSet : StatisticGathererBase + { + void add(const StatisticGathererFactoryBase& factory) + { + const std::string name = factory.get_name(); + if (m_gatherers_names.count(name) == 0) + { + m_gatherers_names.insert(name); + m_gatherers.emplace_back(factory.create()); + } + } + + void add(std::unique_ptr&& gatherer) + { + const std::string name = gatherer->get_name(); + if (m_gatherers_names.count(name) == 0) + { + m_gatherers_names.insert(name); + m_gatherers.emplace_back(std::move(gatherer)); + } + } + + void on_position(const Position& position) override + { + for (auto& g : m_gatherers) + { + g->on_position(position); + } + } + + void on_move(const Position& pos, const Move& move) override + { + for (auto& g : m_gatherers) + { + g->on_move(pos, move); + } + } + + void reset() override + { + for (auto& g : m_gatherers) + { + g->reset(); + } + } + + [[nodiscard]] virtual const std::string& get_name() const override + { + static std::string name = "SET"; + return name; + } + + [[nodiscard]] virtual std::vector> get_formatted_stats() const override + { + std::vector> parts; + for (auto&& s : m_gatherers) + { + auto part = s->get_formatted_stats(); + parts.insert(parts.end(), part.begin(), part.end()); + } + return parts; + } + + private: + std::vector> m_gatherers; + std::set m_gatherers_names; + }; + + struct StatisticGathererRegistry + { + void add_statistic_gatherers_by_group( + StatisticGathererSet& gatherers, + const std::string& group) const + { + auto it = m_gatherers_by_group.find(group); + if (it != m_gatherers_by_group.end()) + { + for (auto& factory : it->second) + { + gatherers.add(*factory); + } + } + } + + template + void add(const ArgsTs&... group) + { + auto dummy = {(add_single(group), 0)...}; + (void)dummy; + add_single("all"); + } + + private: + std::map>> m_gatherers_by_group; + std::map> m_gatherers_names_by_group; + + template + void add_single(const ArgT& group) + { + using FactoryT = StatisticGathererFactory; + + if (m_gatherers_names_by_group[group].count(FactoryT::name) == 0) + { + m_gatherers_by_group[group].emplace_back(std::make_unique()); + m_gatherers_names_by_group[group].insert(FactoryT::name); + } + } + }; + + /* + Statistic gatherer helpers + */ + + template + struct StatPerSquare + { + StatPerSquare() + { + for (int i = 0; i < SQUARE_NB; ++i) + m_squares[i] = 0; + } + + [[nodiscard]] T& operator[](Square sq) + { + return m_squares[sq]; + } + + [[nodiscard]] const T& operator[](Square sq) const + { + return m_squares[sq]; + } + + [[nodiscard]] std::string get_formatted_stats() const + { + std::stringstream ss; + for (int i = 0; i < SQUARE_NB; ++i) + { + ss << std::setw(8) << m_squares[i ^ (int)SQ_A8] << ' '; + if ((i + 1) % 8 == 0) + ss << '\n'; + } + return ss.str(); + } + + private: + std::array m_squares; + }; + + /* + Definitions for specific statistic gatherers follow: + */ + struct PositionCounter : StatisticGathererBase { + static inline std::string name = "PositionCounter"; + PositionCounter() : m_num_positions(0) { @@ -48,7 +229,12 @@ namespace Learner::Stats m_num_positions = 0; } - [[nodiscard]] std::map get_formatted_stats() const override + [[nodiscard]] const std::string& get_name() const override + { + return name; + } + + [[nodiscard]] std::vector> get_formatted_stats() const override { return { { "Number of positions", std::to_string(m_num_positions) } @@ -59,49 +245,296 @@ namespace Learner::Stats std::uint64_t m_num_positions; }; - struct StatisticGathererFactoryBase + struct KingSquareCounter : StatisticGathererBase { - [[nodiscard]] virtual std::unique_ptr create() const = 0; - }; + static inline std::string name = "KingSquareCounter"; - template - struct StatisticGathererFactory : StatisticGathererFactoryBase - { - [[nodiscard]] std::unique_ptr create() const override + KingSquareCounter() : + m_white{}, + m_black{} { - return std::make_unique(); - } - }; - struct StatisticGathererRegistry - { - void add_statistic_gatherers_by_group( - std::vector>& gatherers, - const std::string& group) const - { - auto it = m_gatherers_by_group.find(group); - if (it != m_gatherers_by_group.end()) - { - for (auto& factory : it->second) - { - gatherers.emplace_back(factory->create()); - } - } } - template - void add(const std::string& group) + void on_position(const Position& pos) override { - m_gatherers_by_group[group].emplace_back(std::make_unique>()); + m_white[pos.square(WHITE)] += 1; + m_black[pos.square(BLACK)] += 1; + } - // Always add to the special group "all". - m_gatherers_by_group["all"].emplace_back(std::make_unique>()); + void reset() override + { + m_white = StatPerSquare{}; + m_black = StatPerSquare{}; + } + + [[nodiscard]] const std::string& get_name() const override + { + return name; + } + + [[nodiscard]] std::vector> get_formatted_stats() const override + { + return { + { "White king squares", '\n' + m_white.get_formatted_stats() }, + { "Black king squares", '\n' + m_black.get_formatted_stats() } + }; } private: - std::map>> m_gatherers_by_group; + StatPerSquare m_white; + StatPerSquare m_black; }; + struct MoveFromCounter : StatisticGathererBase + { + static inline std::string name = "MoveFromCounter"; + + MoveFromCounter() : + m_white{}, + m_black{} + { + + } + + void on_move(const Position& pos, const Move& move) override + { + if (pos.side_to_move() == WHITE) + m_white[from_sq(move)] += 1; + else + m_black[from_sq(move)] += 1; + } + + void reset() override + { + m_white = StatPerSquare{}; + m_black = StatPerSquare{}; + } + + [[nodiscard]] const std::string& get_name() const override + { + return name; + } + + [[nodiscard]] std::vector> get_formatted_stats() const override + { + return { + { "White move from squares", '\n' + m_white.get_formatted_stats() }, + { "Black move from squares", '\n' + m_black.get_formatted_stats() } + }; + } + + private: + StatPerSquare m_white; + StatPerSquare m_black; + }; + + struct MoveToCounter : StatisticGathererBase + { + static inline std::string name = "MoveToCounter"; + + MoveToCounter() : + m_white{}, + m_black{} + { + + } + + void on_move(const Position& pos, const Move& move) override + { + if (pos.side_to_move() == WHITE) + m_white[to_sq(move)] += 1; + else + m_black[to_sq(move)] += 1; + } + + void reset() override + { + m_white = StatPerSquare{}; + m_black = StatPerSquare{}; + } + + [[nodiscard]] const std::string& get_name() const override + { + return name; + } + + [[nodiscard]] std::vector> get_formatted_stats() const override + { + return { + { "White move to squares", '\n' + m_white.get_formatted_stats() }, + { "Black move to squares", '\n' + m_black.get_formatted_stats() } + }; + } + + private: + StatPerSquare m_white; + StatPerSquare m_black; + }; + + struct MoveTypeCounter : StatisticGathererBase + { + static inline std::string name = "MoveTypeCounter"; + + MoveTypeCounter() : + m_total(0), + m_normal(0), + m_capture(0), + m_promotion(0), + m_castling(0), + m_enpassant(0) + { + + } + + void on_move(const Position& pos, const Move& move) override + { + m_total += 1; + + if (!pos.empty(to_sq(move))) + m_capture += 1; + + if (type_of(move) == CASTLING) + m_castling += 1; + else if (type_of(move) == PROMOTION) + m_promotion += 1; + else if (type_of(move) == ENPASSANT) + m_enpassant += 1; + else if (type_of(move) == NORMAL) + m_normal += 1; + } + + void reset() override + { + m_total = 0; + m_normal = 0; + m_capture = 0; + m_promotion = 0; + m_castling = 0; + m_enpassant = 0; + } + + [[nodiscard]] const std::string& get_name() const override + { + return name; + } + + [[nodiscard]] std::vector> get_formatted_stats() const override + { + return { + { "Total moves", std::to_string(m_total) }, + { "Normal moves", std::to_string(m_normal) }, + { "Capture moves", std::to_string(m_capture) }, + { "Promotion moves", std::to_string(m_promotion) }, + { "Castling moves", std::to_string(m_castling) }, + { "En-passant moves", std::to_string(m_enpassant) } + }; + } + + private: + std::uint64_t m_total; + std::uint64_t m_normal; + std::uint64_t m_capture; + std::uint64_t m_promotion; + std::uint64_t m_castling; + std::uint64_t m_enpassant; + }; + + struct PieceCountCounter : StatisticGathererBase + { + static inline std::string name = "PieceCountCounter"; + + PieceCountCounter() + { + reset(); + } + + void on_position(const Position& pos) override + { + m_piece_count_hist[popcount(pos.pieces())] += 1; + } + + void reset() override + { + for (int i = 0; i < SQUARE_NB; ++i) + m_piece_count_hist[i] = 0; + } + + [[nodiscard]] const std::string& get_name() const override + { + return name; + } + + [[nodiscard]] std::vector> get_formatted_stats() const override + { + std::vector> result; + bool do_write = false; + for (int i = SQUARE_NB - 1; i >= 0; --i) + { + if (m_piece_count_hist[i] != 0) + do_write = true; + + // Start writing when the first non-zero number pops up. + if (do_write) + { + result.emplace_back( + std::string("Number of positions with ") + std::to_string(i) + " pieces", + std::to_string(m_piece_count_hist[i]) + ); + } + } + return result; + } + + private: + std::uint64_t m_piece_count_hist[SQUARE_NB]; + }; + + struct MovedPieceTypeCounter : StatisticGathererBase + { + static inline std::string name = "MovedPieceTypeCounter"; + + MovedPieceTypeCounter() + { + reset(); + } + + void on_move(const Position& pos, const Move& move) override + { + m_moved_piece_type_hist[type_of(pos.piece_on(from_sq(move)))] += 1; + } + + void reset() override + { + for (int i = 0; i < PIECE_TYPE_NB; ++i) + m_moved_piece_type_hist[i] = 0; + } + + [[nodiscard]] const std::string& get_name() const override + { + return name; + } + + [[nodiscard]] std::vector> get_formatted_stats() const override + { + return { + { "Pawn moves", std::to_string(m_moved_piece_type_hist[PAWN]) }, + { "Knight moves", std::to_string(m_moved_piece_type_hist[KNIGHT]) }, + { "Bishop moves", std::to_string(m_moved_piece_type_hist[BISHOP]) }, + { "Rook moves", std::to_string(m_moved_piece_type_hist[ROOK]) }, + { "Queen moves", std::to_string(m_moved_piece_type_hist[QUEEN]) }, + { "King moves", std::to_string(m_moved_piece_type_hist[KING]) } + }; + } + + private: + std::uint64_t m_moved_piece_type_hist[PIECE_TYPE_NB]; + }; + + /* + This function provides factories for all possible statistic gatherers. + Each new statistic gatherer needs to be added there. + */ const auto& get_statistics_gatherers_registry() { static StatisticGathererRegistry s_reg = [](){ @@ -109,6 +542,15 @@ namespace Learner::Stats reg.add("position_count"); + reg.add("king", "king_square_count"); + + reg.add("move", "move_from_count"); + reg.add("move", "move_to_count"); + reg.add("move", "move_type"); + reg.add("move", "moved_piece_type"); + + reg.add("piece_count"); + return reg; }(); @@ -117,7 +559,8 @@ namespace Learner::Stats void do_gather_statistics( const std::string& filename, - std::vector>& statistic_gatherers) + StatisticGathererSet& statistic_gatherers, + std::uint64_t max_count) { Thread* th = Threads.main(); Position& pos = th->rootPos; @@ -125,18 +568,12 @@ namespace Learner::Stats auto in = Learner::open_sfen_input_file(filename); - auto on_move = [&](Move move) { - for (auto&& s : statistic_gatherers) - { - s->on_move(move); - } + auto on_move = [&](const Position& position, const Move& move) { + statistic_gatherers.on_move(position, move); }; auto on_position = [&](const Position& position) { - for (auto&& s : statistic_gatherers) - { - s->on_position(position); - } + statistic_gatherers.on_position(position); }; if (in == nullptr) @@ -146,7 +583,7 @@ namespace Learner::Stats } uint64_t num_processed = 0; - for (;;) + while (num_processed < max_count) { auto v = in->next(); if (!v.has_value()) @@ -157,7 +594,7 @@ namespace Learner::Stats pos.set_from_packed_sfen(ps.sfen, &si, th); on_position(pos); - on_move((Move)ps.move); + on_move(pos, (Move)ps.move); num_processed += 1; if (num_processed % 1'000'000 == 0) @@ -169,13 +606,9 @@ namespace Learner::Stats std::cout << "Finished gathering statistics.\n\n"; std::cout << "Results:\n\n"; - for (auto&& s : statistic_gatherers) + for (auto&& [name, value] : statistic_gatherers.get_formatted_stats()) { - for (auto&& [name, value] : s->get_formatted_stats()) - { - std::cout << name << ": " << value << '\n'; - } - std::cout << '\n'; + std::cout << name << ": " << value << '\n'; } } @@ -185,9 +618,10 @@ namespace Learner::Stats auto& registry = get_statistics_gatherers_registry(); - std::vector> statistic_gatherers; + StatisticGathererSet statistic_gatherers; std::string input_file; + std::uint64_t max_count = std::numeric_limits::max(); while(true) { @@ -199,11 +633,13 @@ namespace Learner::Stats if (token == "input_file") is >> input_file; + else if (token == "max_count") + is >> max_count; else registry.add_statistic_gatherers_by_group(statistic_gatherers, token); } - do_gather_statistics(input_file, statistic_gatherers); + do_gather_statistics(input_file, statistic_gatherers, max_count); } }