From b2a5bf4171c943a6037eeabb49710cf726861e72 Mon Sep 17 00:00:00 2001 From: Tomasz Sobczyk Date: Mon, 5 Apr 2021 16:36:27 +0200 Subject: [PATCH] Deduplicate statistic gatherers. Fix King square counter compilation errors. --- src/learn/stats.cpp | 124 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 106 insertions(+), 18 deletions(-) diff --git a/src/learn/stats.cpp b/src/learn/stats.cpp index 419108d9..0a93cc5e 100644 --- a/src/learn/stats.cpp +++ b/src/learn/stats.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -31,12 +32,14 @@ namespace Learner::Stats virtual void on_position(const Position&) {} virtual void on_move(const Move&) {} virtual void reset() = 0; + [[nodiscard]] virtual const std::string& get_name() const = 0; [[nodiscard]] virtual std::map get_formatted_stats() const = 0; }; struct StatisticGathererFactoryBase { [[nodiscard]] virtual std::unique_ptr create() const = 0; + [[nodiscard]] virtual const std::string& get_name() const = 0; }; template @@ -46,12 +49,84 @@ namespace Learner::Stats { return std::make_unique(); } + + [[nodiscard]] const std::string& get_name() const override + { + return T::name; + } + }; + + struct StatisticGathererSet : StatisticGathererBase + { + void add(const StatisticGathererFactoryBase& factory) + { + const std::string name = factory.get_name(); + if (m_gatherers_names.count(name) == 0) + { + m_gatherers_names.insert(name); + m_gatherers.emplace_back(factory.create()); + } + } + + void add(std::unique_ptr&& gatherer) + { + const std::string name = gatherer->get_name(); + if (m_gatherers_names.count(name) == 0) + { + m_gatherers_names.insert(name); + m_gatherers.emplace_back(std::move(gatherer)); + } + } + + void on_position(const Position& position) override + { + for (auto& g : m_gatherers) + { + g->on_position(position); + } + } + + void on_move(const Move& move) override + { + for (auto& g : m_gatherers) + { + g->on_move(move); + } + } + + void reset() override + { + for (auto& g : m_gatherers) + { + g->reset(); + } + } + + [[nodiscard]] virtual const std::string& get_name() const override + { + static std::string name = "SET"; + return name; + } + + [[nodiscard]] virtual std::map get_formatted_stats() const override + { + std::map parts; + for (auto&& s : m_gatherers) + { + parts.merge(s->get_formatted_stats()); + } + return parts; + } + + private: + std::vector> m_gatherers; + std::set m_gatherers_names; }; struct StatisticGathererRegistry { void add_statistic_gatherers_by_group( - std::vector>& gatherers, + StatisticGathererSet& gatherers, const std::string& group) const { auto it = m_gatherers_by_group.find(group); @@ -59,7 +134,7 @@ namespace Learner::Stats { for (auto& factory : it->second) { - gatherers.emplace_back(factory->create()); + gatherers.add(*factory); } } } @@ -109,6 +184,7 @@ namespace Learner::Stats if ((i + 1) % 8 == 0) ss << '\n'; } + return ss.str(); } private: @@ -121,6 +197,8 @@ namespace Learner::Stats struct PositionCounter : StatisticGathererBase { + static inline std::string name = "PositionCounter"; + PositionCounter() : m_num_positions(0) { @@ -136,6 +214,11 @@ namespace Learner::Stats m_num_positions = 0; } + [[nodiscard]] const std::string& get_name() const override + { + return name; + } + [[nodiscard]] std::map get_formatted_stats() const override { return { @@ -149,6 +232,8 @@ namespace Learner::Stats struct KingSquareCounter : StatisticGathererBase { + static inline std::string name = "KingSquareCounter"; + KingSquareCounter() : m_white{}, m_black{} @@ -168,6 +253,19 @@ namespace Learner::Stats m_black = StatPerSquare{}; } + [[nodiscard]] const std::string& get_name() const override + { + return name; + } + + [[nodiscard]] std::map get_formatted_stats() const override + { + return { + { "White king squares", m_white.get_formatted_stats() }, + { "Black king squares", m_black.get_formatted_stats() } + }; + } + private: StatPerSquare m_white; StatPerSquare m_black; @@ -195,7 +293,7 @@ namespace Learner::Stats void do_gather_statistics( const std::string& filename, - std::vector>& statistic_gatherers, + StatisticGathererSet& statistic_gatherers, std::uint64_t max_count) { Thread* th = Threads.main(); @@ -205,17 +303,11 @@ namespace Learner::Stats auto in = Learner::open_sfen_input_file(filename); auto on_move = [&](Move move) { - for (auto&& s : statistic_gatherers) - { - s->on_move(move); - } + statistic_gatherers.on_move(move); }; auto on_position = [&](const Position& position) { - for (auto&& s : statistic_gatherers) - { - s->on_position(position); - } + statistic_gatherers.on_position(position); }; if (in == nullptr) @@ -248,13 +340,9 @@ namespace Learner::Stats std::cout << "Finished gathering statistics.\n\n"; std::cout << "Results:\n\n"; - for (auto&& s : statistic_gatherers) + for (auto&& [name, value] : statistic_gatherers.get_formatted_stats()) { - for (auto&& [name, value] : s->get_formatted_stats()) - { - std::cout << name << ": " << value << '\n'; - } - std::cout << '\n'; + std::cout << name << ": " << value << '\n'; } } @@ -264,7 +352,7 @@ namespace Learner::Stats auto& registry = get_statistics_gatherers_registry(); - std::vector> statistic_gatherers; + StatisticGathererSet statistic_gatherers; std::string input_file; std::uint64_t max_count = std::numeric_limits::max();