Deduplicate statistic gatherers. Fix King square counter compilation errors.

This commit is contained in:
Tomasz Sobczyk
2021-04-05 16:36:27 +02:00
parent eda51f19a2
commit b2a5bf4171

View File

@@ -14,6 +14,7 @@
#include <array>
#include <string>
#include <map>
#include <set>
#include <iostream>
#include <cmath>
#include <algorithm>
@@ -31,12 +32,14 @@ namespace Learner::Stats
virtual void on_position(const Position&) {}
virtual void on_move(const Move&) {}
virtual void reset() = 0;
[[nodiscard]] virtual const std::string& get_name() const = 0;
[[nodiscard]] virtual std::map<std::string, std::string> get_formatted_stats() const = 0;
};
struct StatisticGathererFactoryBase
{
[[nodiscard]] virtual std::unique_ptr<StatisticGathererBase> create() const = 0;
[[nodiscard]] virtual const std::string& get_name() const = 0;
};
template <typename T>
@@ -46,12 +49,84 @@ namespace Learner::Stats
{
return std::make_unique<T>();
}
[[nodiscard]] const std::string& get_name() const override
{
return T::name;
}
};
struct StatisticGathererSet : StatisticGathererBase
{
void add(const StatisticGathererFactoryBase& factory)
{
const std::string name = factory.get_name();
if (m_gatherers_names.count(name) == 0)
{
m_gatherers_names.insert(name);
m_gatherers.emplace_back(factory.create());
}
}
void add(std::unique_ptr<StatisticGathererBase>&& gatherer)
{
const std::string name = gatherer->get_name();
if (m_gatherers_names.count(name) == 0)
{
m_gatherers_names.insert(name);
m_gatherers.emplace_back(std::move(gatherer));
}
}
void on_position(const Position& position) override
{
for (auto& g : m_gatherers)
{
g->on_position(position);
}
}
void on_move(const Move& move) override
{
for (auto& g : m_gatherers)
{
g->on_move(move);
}
}
void reset() override
{
for (auto& g : m_gatherers)
{
g->reset();
}
}
[[nodiscard]] virtual const std::string& get_name() const override
{
static std::string name = "SET";
return name;
}
[[nodiscard]] virtual std::map<std::string, std::string> get_formatted_stats() const override
{
std::map<std::string, std::string> parts;
for (auto&& s : m_gatherers)
{
parts.merge(s->get_formatted_stats());
}
return parts;
}
private:
std::vector<std::unique_ptr<StatisticGathererBase>> m_gatherers;
std::set<std::string> m_gatherers_names;
};
struct StatisticGathererRegistry
{
void add_statistic_gatherers_by_group(
std::vector<std::unique_ptr<StatisticGathererBase>>& gatherers,
StatisticGathererSet& gatherers,
const std::string& group) const
{
auto it = m_gatherers_by_group.find(group);
@@ -59,7 +134,7 @@ namespace Learner::Stats
{
for (auto& factory : it->second)
{
gatherers.emplace_back(factory->create());
gatherers.add(*factory);
}
}
}
@@ -109,6 +184,7 @@ namespace Learner::Stats
if ((i + 1) % 8 == 0)
ss << '\n';
}
return ss.str();
}
private:
@@ -121,6 +197,8 @@ namespace Learner::Stats
struct PositionCounter : StatisticGathererBase
{
static inline std::string name = "PositionCounter";
PositionCounter() :
m_num_positions(0)
{
@@ -136,6 +214,11 @@ namespace Learner::Stats
m_num_positions = 0;
}
[[nodiscard]] const std::string& get_name() const override
{
return name;
}
[[nodiscard]] std::map<std::string, std::string> get_formatted_stats() const override
{
return {
@@ -149,6 +232,8 @@ namespace Learner::Stats
struct KingSquareCounter : StatisticGathererBase
{
static inline std::string name = "KingSquareCounter";
KingSquareCounter() :
m_white{},
m_black{}
@@ -168,6 +253,19 @@ namespace Learner::Stats
m_black = StatPerSquare<std::uint64_t>{};
}
[[nodiscard]] const std::string& get_name() const override
{
return name;
}
[[nodiscard]] std::map<std::string, std::string> get_formatted_stats() const override
{
return {
{ "White king squares", m_white.get_formatted_stats() },
{ "Black king squares", m_black.get_formatted_stats() }
};
}
private:
StatPerSquare<std::uint64_t> m_white;
StatPerSquare<std::uint64_t> m_black;
@@ -195,7 +293,7 @@ namespace Learner::Stats
void do_gather_statistics(
const std::string& filename,
std::vector<std::unique_ptr<StatisticGathererBase>>& statistic_gatherers,
StatisticGathererSet& statistic_gatherers,
std::uint64_t max_count)
{
Thread* th = Threads.main();
@@ -205,17 +303,11 @@ namespace Learner::Stats
auto in = Learner::open_sfen_input_file(filename);
auto on_move = [&](Move move) {
for (auto&& s : statistic_gatherers)
{
s->on_move(move);
}
statistic_gatherers.on_move(move);
};
auto on_position = [&](const Position& position) {
for (auto&& s : statistic_gatherers)
{
s->on_position(position);
}
statistic_gatherers.on_position(position);
};
if (in == nullptr)
@@ -248,13 +340,9 @@ namespace Learner::Stats
std::cout << "Finished gathering statistics.\n\n";
std::cout << "Results:\n\n";
for (auto&& s : statistic_gatherers)
for (auto&& [name, value] : statistic_gatherers.get_formatted_stats())
{
for (auto&& [name, value] : s->get_formatted_stats())
{
std::cout << name << ": " << value << '\n';
}
std::cout << '\n';
std::cout << name << ": " << value << '\n';
}
}
@@ -264,7 +352,7 @@ namespace Learner::Stats
auto& registry = get_statistics_gatherers_registry();
std::vector<std::unique_ptr<StatisticGathererBase>> statistic_gatherers;
StatisticGathererSet statistic_gatherers;
std::string input_file;
std::uint64_t max_count = std::numeric_limits<std::uint64_t>::max();