diff --git a/docs/stats.md b/docs/stats.md new file mode 100644 index 00000000..d5a76b61 --- /dev/null +++ b/docs/stats.md @@ -0,0 +1,15 @@ +# Stats + +`gather_statistics` command allows gathering various statistics from a .bin or a .binpack file. The syntax is `gather_statistics (GROUP)* input_file FILENAME`. There can be many groups specified. Any statistic gatherer that belongs to at least one of the specified groups will be used. + +Simplest usage: `stockfish.exe gather_statistics all input_file a.binpack` + +## Groups + +`all` + + - A special group designating all statistics gatherers available. + +`position_count` + + - `struct PositionCounter` - the total number of positions in the file. diff --git a/src/Makefile b/src/Makefile index 586656d3..a4ced5f0 100644 --- a/src/Makefile +++ b/src/Makefile @@ -66,7 +66,8 @@ SRCS = benchmark.cpp bitbase.cpp bitboard.cpp endgame.cpp evaluate.cpp main.cpp learn/gensfen_nonpv.cpp \ learn/opening_book.cpp \ learn/convert.cpp \ - learn/transform.cpp + learn/transform.cpp \ + learn/stats.cpp OBJS = $(notdir $(SRCS:.cpp=.o)) diff --git a/src/learn/stats.cpp b/src/learn/stats.cpp new file mode 100644 index 00000000..9d9589c4 --- /dev/null +++ b/src/learn/stats.cpp @@ -0,0 +1,209 @@ +#include "stats.h" + +#include "sfen_stream.h" +#include "packed_sfen.h" +#include "sfen_writer.h" + +#include "thread.h" +#include "position.h" +#include "evaluate.h" +#include "search.h" + +#include "nnue/evaluate_nnue.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace Learner::Stats +{ + struct StatisticGathererBase + { + virtual void on_position(const Position&) {} + virtual void on_move(const Move&) {} + virtual void reset() = 0; + [[nodiscard]] virtual std::map get_formatted_stats() const = 0; + }; + + struct PositionCounter : StatisticGathererBase + { + PositionCounter() : + m_num_positions(0) + { + } + + void on_position(const Position&) override + { + m_num_positions += 1; + } + + void reset() override + { + m_num_positions = 0; + } + + [[nodiscard]] std::map get_formatted_stats() const override + { + return { + { "Number of positions", std::to_string(m_num_positions) } + }; + } + + private: + std::uint64_t m_num_positions; + }; + + struct StatisticGathererFactoryBase + { + [[nodiscard]] virtual std::unique_ptr create() const = 0; + }; + + template + struct StatisticGathererFactory : StatisticGathererFactoryBase + { + [[nodiscard]] std::unique_ptr create() const override + { + return std::make_unique(); + } + }; + + struct StatisticGathererRegistry + { + void add_statistic_gatherers_by_group( + std::vector>& gatherers, + const std::string& group) const + { + auto it = m_gatherers_by_group.find(group); + if (it != m_gatherers_by_group.end()) + { + for (auto& factory : it->second) + { + gatherers.emplace_back(factory->create()); + } + } + } + + template + void add(const std::string& group) + { + m_gatherers_by_group[group].emplace_back(std::make_unique>()); + + // Always add to the special group "all". + m_gatherers_by_group["all"].emplace_back(std::make_unique>()); + } + + private: + std::map>> m_gatherers_by_group; + }; + + const auto& get_statistics_gatherers_registry() + { + static StatisticGathererRegistry s_reg = [](){ + StatisticGathererRegistry reg; + + reg.add("position_count"); + + return reg; + }(); + + return s_reg; + } + + void do_gather_statistics( + const std::string& filename, + std::vector>& statistic_gatherers) + { + Thread* th = Threads.main(); + Position& pos = th->rootPos; + StateInfo si; + + auto in = Learner::open_sfen_input_file(filename); + + auto on_move = [&](Move move) { + for (auto&& s : statistic_gatherers) + { + s->on_move(move); + } + }; + + auto on_position = [&](const Position& position) { + for (auto&& s : statistic_gatherers) + { + s->on_position(position); + } + }; + + if (in == nullptr) + { + std::cerr << "Invalid input file type.\n"; + return; + } + + uint64_t num_processed = 0; + for (;;) + { + auto v = in->next(); + if (!v.has_value()) + break; + + auto& ps = v.value(); + + pos.set_from_packed_sfen(ps.sfen, &si, th); + + on_position(pos); + on_move((Move)ps.move); + + num_processed += 1; + if (num_processed % 1'000'000 == 0) + { + std::cout << "Processed " << num_processed << " positions.\n"; + } + } + + std::cout << "Finished gathering statistics.\n\n"; + std::cout << "Results:\n\n"; + + for (auto&& s : statistic_gatherers) + { + for (auto&& [name, value] : s->get_formatted_stats()) + { + std::cout << name << ": " << value << '\n'; + } + std::cout << '\n'; + } + } + + void gather_statistics(std::istringstream& is) + { + Eval::NNUE::init(); + + auto& registry = get_statistics_gatherers_registry(); + + std::vector> statistic_gatherers; + + std::string input_file; + + while(true) + { + std::string token; + is >> token; + + if (token == "") + break; + + if (token == "input_file") + is >> input_file; + else + registry.add_statistic_gatherers_by_group(statistic_gatherers, token); + } + + do_gather_statistics(input_file, statistic_gatherers); + } + +} diff --git a/src/learn/stats.h b/src/learn/stats.h new file mode 100644 index 00000000..c9a71e5a --- /dev/null +++ b/src/learn/stats.h @@ -0,0 +1,12 @@ +#ifndef _STATS_H_ +#define _STATS_H_ + +#include + +namespace Learner::Stats { + + void gather_statistics(std::istringstream& is); + +} + +#endif diff --git a/src/uci.cpp b/src/uci.cpp index 55fccea7..7da2881f 100644 --- a/src/uci.cpp +++ b/src/uci.cpp @@ -40,6 +40,7 @@ #include "learn/learn.h" #include "learn/convert.h" #include "learn/transform.h" +#include "learn/stats.h" using namespace std; @@ -349,6 +350,7 @@ void UCI::loop(int argc, char* argv[]) { else if (token == "convert_plain") Learner::convert_plain(is); else if (token == "convert_bin_from_pgn_extract") Learner::convert_bin_from_pgn_extract(is); else if (token == "transform") Learner::transform(is); + else if (token == "gather_statistics") Learner::Stats::gather_statistics(is); // Command to call qsearch(),search() directly for testing else if (token == "qsearch") qsearch_cmd(pos);