diff --git a/src/learn/stats.cpp b/src/learn/stats.cpp index c77b61b9..3a8b4454 100644 --- a/src/learn/stats.cpp +++ b/src/learn/stats.cpp @@ -357,6 +357,165 @@ namespace Learner::Stats StatPerSquare m_black; }; + struct MoveTypeCounter : StatisticGathererBase + { + static inline std::string name = "MoveTypeCounter"; + + MoveTypeCounter() : + m_total(0), + m_normal(0), + m_capture(0), + m_promotion(0), + m_castling(0), + m_enpassant(0) + { + + } + + void on_move(const Position& pos, const Move& move) override + { + m_total += 1; + + if (!pos.empty(to_sq(move))) + m_capture += 1; + + if (type_of(move) == CASTLING) + m_castling += 1; + else if (type_of(move) == PROMOTION) + m_promotion += 1; + else if (type_of(move) == ENPASSANT) + m_enpassant += 1; + else if (type_of(move) == NORMAL) + m_normal += 1; + } + + void reset() override + { + m_total = 0; + m_normal = 0; + m_capture = 0; + m_promotion = 0; + m_castling = 0; + m_enpassant = 0; + } + + [[nodiscard]] const std::string& get_name() const override + { + return name; + } + + [[nodiscard]] std::map get_formatted_stats() const override + { + return { + { "Total moves", std::to_string(m_total) }, + { "Normal moves", std::to_string(m_normal) }, + { "Capture moves", std::to_string(m_capture) }, + { "Promotion moves", std::to_string(m_promotion) }, + { "Castling moves", std::to_string(m_castling) }, + { "En-passant moves", std::to_string(m_enpassant) } + }; + } + + private: + std::uint64_t m_total; + std::uint64_t m_normal; + std::uint64_t m_capture; + std::uint64_t m_promotion; + std::uint64_t m_castling; + std::uint64_t m_enpassant; + }; + + struct PieceCountCounter : StatisticGathererBase + { + static inline std::string name = "PieceCountCounter"; + + PieceCountCounter() + { + reset(); + } + + void on_position(const Position& pos) override + { + m_piece_count_hist[popcount(pos.pieces())] += 1; + } + + void reset() override + { + for (int i = 0; i < SQUARE_NB; ++i) + m_num_pieces[i] = 0; + } + + [[nodiscard]] const std::string& get_name() const override + { + return name; + } + + [[nodiscard]] std::map get_formatted_stats() const override + { + std::map result; + bool do_write = false; + for (int i = SQUARE_NB; i >= 0; --i) + { + if (m_piece_count_hist[i] != 0) + do_write = true; + + // Start writing when the first non-zero number pops up. + if (do_write) + { + result.try_emplace( + std::string("Number of positions with ") + std::to_string(i) + " pieces", + std::to_string(m_piece_count_hist[i]) + ); + } + } + return result; + } + + private: + std::uint64_t m_piece_count_hist[SQUARE_NB]; + }; + + struct MovedPieceTypeCounter : StatisticGathererBase + { + static inline std::string name = "MovedPieceTypeCounter"; + + MovedPieceTypeCounter() + { + reset(); + } + + void on_move(const Position& pos, const Move& move) override + { + m_moved_piece_type_hist[type_of(pos.piece_on(from_sq(move)))] += 1; + } + + void reset() override + { + for (int i = 0; i < PIECE_TYPE_NB; ++i) + m_moved_piece_type_hist[i] = 0; + } + + [[nodiscard]] const std::string& get_name() const override + { + return name; + } + + [[nodiscard]] std::map get_formatted_stats() const override + { + return { + { "Pawn moves", std::to_string(m_moved_piece_type_hist[PAWN]) }, + { "Knight moves", std::to_string(m_moved_piece_type_hist[KNIGHT]) }, + { "Bishop moves", std::to_string(m_moved_piece_type_hist[BISHOP]) }, + { "Rook moves", std::to_string(m_moved_piece_type_hist[ROOK]) }, + { "Queen moves", std::to_string(m_moved_piece_type_hist[QUEEN]) }, + { "King moves", std::to_string(m_moved_piece_type_hist[KING]) } + }; + } + + private: + std::uint64_t m_moved_piece_type_hist[PIECE_TYPE_NB]; + }; + /* This function provides factories for all possible statistic gatherers. Each new statistic gatherer needs to be added there. @@ -374,6 +533,10 @@ namespace Learner::Stats reg.add("move"); reg.add("move_from_count"); reg.add("move_to_count"); + reg.add("move_type"); + reg.add("moved_piece_type"); + + reg.add("piece_count") return reg; }();