mirror of
https://github.com/HChaZZY/Stockfish.git
synced 2025-12-25 19:46:55 +08:00
learn -> tools
This commit is contained in:
813
src/tools/convert.cpp
Normal file
813
src/tools/convert.cpp
Normal file
@@ -0,0 +1,813 @@
|
||||
#include "convert.h"
|
||||
|
||||
#include "uci.h"
|
||||
#include "misc.h"
|
||||
#include "thread.h"
|
||||
#include "position.h"
|
||||
#include "tt.h"
|
||||
|
||||
#include "extra/nnue_data_binpack_format.h"
|
||||
|
||||
#include "nnue/evaluate_nnue.h"
|
||||
|
||||
#include "syzygy/tbprobe.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
#include <unordered_set>
|
||||
#include <iomanip>
|
||||
#include <list>
|
||||
#include <cmath> // std::exp(),std::pow(),std::log()
|
||||
#include <cstring> // memcpy()
|
||||
#include <memory>
|
||||
#include <limits>
|
||||
#include <optional>
|
||||
#include <chrono>
|
||||
#include <random>
|
||||
#include <regex>
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace Tools
|
||||
{
|
||||
bool fen_is_ok(Position& pos, std::string input_fen) {
|
||||
std::string pos_fen = pos.fen();
|
||||
std::istringstream ss_input(input_fen);
|
||||
std::istringstream ss_pos(pos_fen);
|
||||
|
||||
// example : "2r4r/4kpp1/nb1np3/p2p3p/B2P1BP1/PP6/4NPKP/2R1R3 w - h6 0 24"
|
||||
// --> "2r4r/4kpp1/nb1np3/p2p3p/B2P1BP1/PP6/4NPKP/2R1R3"
|
||||
std::string str_input, str_pos;
|
||||
ss_input >> str_input;
|
||||
ss_pos >> str_pos;
|
||||
|
||||
// Only compare "Piece placement field" between input_fen and pos.fen().
|
||||
return str_input == str_pos;
|
||||
}
|
||||
|
||||
void convert_bin(
|
||||
const vector<string>& filenames,
|
||||
const string& output_file_name,
|
||||
const int ply_minimum,
|
||||
const int ply_maximum,
|
||||
const int interpolate_eval,
|
||||
const int src_score_min_value,
|
||||
const int src_score_max_value,
|
||||
const int dest_score_min_value,
|
||||
const int dest_score_max_value,
|
||||
const bool check_invalid_fen,
|
||||
const bool check_illegal_move)
|
||||
{
|
||||
std::cout << "check_invalid_fen=" << check_invalid_fen << std::endl;
|
||||
std::cout << "check_illegal_move=" << check_illegal_move << std::endl;
|
||||
|
||||
std::fstream fs;
|
||||
uint64_t data_size = 0;
|
||||
uint64_t filtered_size = 0;
|
||||
uint64_t filtered_size_fen = 0;
|
||||
uint64_t filtered_size_move = 0;
|
||||
uint64_t filtered_size_ply = 0;
|
||||
auto th = Threads.main();
|
||||
auto& tpos = th->rootPos;
|
||||
// convert plain rag to packed sfenvalue for Yaneura king
|
||||
fs.open(output_file_name, ios::app | ios::binary);
|
||||
StateListPtr states;
|
||||
for (auto filename : filenames) {
|
||||
std::cout << "convert " << filename << " ... ";
|
||||
std::string line;
|
||||
ifstream ifs;
|
||||
ifs.open(filename);
|
||||
PackedSfenValue p;
|
||||
data_size = 0;
|
||||
filtered_size = 0;
|
||||
filtered_size_fen = 0;
|
||||
filtered_size_move = 0;
|
||||
filtered_size_ply = 0;
|
||||
p.gamePly = 1; // Not included in apery format. Should be initialized
|
||||
bool ignore_flag_fen = false;
|
||||
bool ignore_flag_move = false;
|
||||
bool ignore_flag_ply = false;
|
||||
while (std::getline(ifs, line)) {
|
||||
std::stringstream ss(line);
|
||||
std::string token;
|
||||
std::string value;
|
||||
ss >> token;
|
||||
if (token == "fen") {
|
||||
states = StateListPtr(new std::deque<StateInfo>(1)); // Drop old and create a new one
|
||||
std::string input_fen = line.substr(4);
|
||||
tpos.set(input_fen, false, &states->back(), Threads.main());
|
||||
if (check_invalid_fen && !fen_is_ok(tpos, input_fen)) {
|
||||
ignore_flag_fen = true;
|
||||
filtered_size_fen++;
|
||||
}
|
||||
else {
|
||||
tpos.sfen_pack(p.sfen);
|
||||
}
|
||||
}
|
||||
else if (token == "move") {
|
||||
ss >> value;
|
||||
Move move = UCI::to_move(tpos, value);
|
||||
if (check_illegal_move && move == MOVE_NONE) {
|
||||
ignore_flag_move = true;
|
||||
filtered_size_move++;
|
||||
}
|
||||
else {
|
||||
p.move = move;
|
||||
}
|
||||
}
|
||||
else if (token == "score") {
|
||||
double score;
|
||||
ss >> score;
|
||||
// Training Formula ?Issue #71 ?nodchip/Stockfish https://github.com/nodchip/Stockfish/issues/71
|
||||
// Normalize to [0.0, 1.0].
|
||||
score = (score - src_score_min_value) / (src_score_max_value - src_score_min_value);
|
||||
// Scale to [dest_score_min_value, dest_score_max_value].
|
||||
score = score * (dest_score_max_value - dest_score_min_value) + dest_score_min_value;
|
||||
p.score = Math::clamp((int32_t)std::round(score), -(int32_t)VALUE_MATE, (int32_t)VALUE_MATE);
|
||||
}
|
||||
else if (token == "ply") {
|
||||
int temp;
|
||||
ss >> temp;
|
||||
if (temp < ply_minimum || temp > ply_maximum) {
|
||||
ignore_flag_ply = true;
|
||||
filtered_size_ply++;
|
||||
}
|
||||
p.gamePly = uint16_t(temp); // No cast here?
|
||||
if (interpolate_eval != 0) {
|
||||
p.score = min(3000, interpolate_eval * temp);
|
||||
}
|
||||
}
|
||||
else if (token == "result") {
|
||||
int temp;
|
||||
ss >> temp;
|
||||
p.game_result = int8_t(temp); // Do you need a cast here?
|
||||
if (interpolate_eval) {
|
||||
p.score = p.score * p.game_result;
|
||||
}
|
||||
}
|
||||
else if (token == "e") {
|
||||
if (!(ignore_flag_fen || ignore_flag_move || ignore_flag_ply)) {
|
||||
fs.write((char*)&p, sizeof(PackedSfenValue));
|
||||
data_size += 1;
|
||||
// debug
|
||||
// std::cout<<tpos<<std::endl;
|
||||
// std::cout<<p.score<<","<<int(p.gamePly)<<","<<int(p.game_result)<<std::endl;
|
||||
}
|
||||
else {
|
||||
filtered_size++;
|
||||
}
|
||||
ignore_flag_fen = false;
|
||||
ignore_flag_move = false;
|
||||
ignore_flag_ply = false;
|
||||
}
|
||||
}
|
||||
std::cout << "done " << data_size << " parsed " << filtered_size << " is filtered"
|
||||
<< " (invalid fen:" << filtered_size_fen << ", illegal move:" << filtered_size_move << ", invalid ply:" << filtered_size_ply << ")" << std::endl;
|
||||
ifs.close();
|
||||
}
|
||||
std::cout << "all done" << std::endl;
|
||||
fs.close();
|
||||
}
|
||||
|
||||
static inline void ltrim(std::string& s) {
|
||||
s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int ch) {
|
||||
return !std::isspace(ch);
|
||||
}));
|
||||
}
|
||||
|
||||
static inline void rtrim(std::string& s) {
|
||||
s.erase(std::find_if(s.rbegin(), s.rend(), [](int ch) {
|
||||
return !std::isspace(ch);
|
||||
}).base(), s.end());
|
||||
}
|
||||
|
||||
static inline void trim(std::string& s) {
|
||||
ltrim(s);
|
||||
rtrim(s);
|
||||
}
|
||||
|
||||
int parse_game_result_from_pgn_extract(std::string result) {
|
||||
// White Win
|
||||
if (result == "\"1-0\"") {
|
||||
return 1;
|
||||
}
|
||||
// Black Win
|
||||
else if (result == "\"0-1\"") {
|
||||
return -1;
|
||||
}
|
||||
// Draw
|
||||
else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
// 0.25 --> 0.25 * PawnValueEg
|
||||
// #-4 --> -mate_in(4)
|
||||
// #3 --> mate_in(3)
|
||||
// -M4 --> -mate_in(4)
|
||||
// +M3 --> mate_in(3)
|
||||
Value parse_score_from_pgn_extract(std::string eval, bool& success) {
|
||||
success = true;
|
||||
|
||||
if (eval.substr(0, 1) == "#") {
|
||||
if (eval.substr(1, 1) == "-") {
|
||||
return -mate_in(stoi(eval.substr(2, eval.length() - 2)));
|
||||
}
|
||||
else {
|
||||
return mate_in(stoi(eval.substr(1, eval.length() - 1)));
|
||||
}
|
||||
}
|
||||
else if (eval.substr(0, 2) == "-M") {
|
||||
//std::cout << "eval=" << eval << std::endl;
|
||||
return -mate_in(stoi(eval.substr(2, eval.length() - 2)));
|
||||
}
|
||||
else if (eval.substr(0, 2) == "+M") {
|
||||
//std::cout << "eval=" << eval << std::endl;
|
||||
return mate_in(stoi(eval.substr(2, eval.length() - 2)));
|
||||
}
|
||||
else {
|
||||
char* endptr;
|
||||
double value = strtod(eval.c_str(), &endptr);
|
||||
|
||||
if (*endptr != '\0') {
|
||||
success = false;
|
||||
return VALUE_ZERO;
|
||||
}
|
||||
else {
|
||||
return Value(value * static_cast<double>(PawnValueEg));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// for Debug
|
||||
//#define DEBUG_CONVERT_BIN_FROM_PGN_EXTRACT
|
||||
|
||||
bool is_like_fen(std::string fen) {
|
||||
int count_space = std::count(fen.cbegin(), fen.cend(), ' ');
|
||||
int count_slash = std::count(fen.cbegin(), fen.cend(), '/');
|
||||
|
||||
#if defined(DEBUG_CONVERT_BIN_FROM_PGN_EXTRACT)
|
||||
//std::cout << "count_space=" << count_space << std::endl;
|
||||
//std::cout << "count_slash=" << count_slash << std::endl;
|
||||
#endif
|
||||
|
||||
return count_space == 5 && count_slash == 7;
|
||||
}
|
||||
|
||||
void convert_bin_from_pgn_extract(
|
||||
const vector<string>& filenames,
|
||||
const string& output_file_name,
|
||||
const bool pgn_eval_side_to_move,
|
||||
const bool convert_no_eval_fens_as_score_zero)
|
||||
{
|
||||
std::cout << "pgn_eval_side_to_move=" << pgn_eval_side_to_move << std::endl;
|
||||
std::cout << "convert_no_eval_fens_as_score_zero=" << convert_no_eval_fens_as_score_zero << std::endl;
|
||||
|
||||
auto th = Threads.main();
|
||||
auto& pos = th->rootPos;
|
||||
|
||||
std::fstream ofs;
|
||||
ofs.open(output_file_name, ios::out | ios::binary);
|
||||
|
||||
int game_count = 0;
|
||||
int fen_count = 0;
|
||||
|
||||
for (auto filename : filenames) {
|
||||
std::cout << now_string() << " convert " << filename << std::endl;
|
||||
ifstream ifs;
|
||||
ifs.open(filename);
|
||||
|
||||
int game_result = 0;
|
||||
|
||||
std::string line;
|
||||
while (std::getline(ifs, line)) {
|
||||
|
||||
if (line.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
else if (line.substr(0, 1) == "[") {
|
||||
std::regex pattern_result(R"(\[Result (.+?)\])");
|
||||
std::smatch match;
|
||||
|
||||
// example: [Result "1-0"]
|
||||
if (std::regex_search(line, match, pattern_result)) {
|
||||
game_result = parse_game_result_from_pgn_extract(match.str(1));
|
||||
#if defined(DEBUG_CONVERT_BIN_FROM_PGN_EXTRACT)
|
||||
std::cout << "game_result=" << game_result << std::endl;
|
||||
#endif
|
||||
game_count++;
|
||||
if (game_count % 10000 == 0) {
|
||||
std::cout << now_string() << " game_count=" << game_count << ", fen_count=" << fen_count << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
else {
|
||||
int gamePly = 1;
|
||||
auto itr = line.cbegin();
|
||||
|
||||
while (true) {
|
||||
gamePly++;
|
||||
|
||||
PackedSfenValue psv;
|
||||
memset((char*)&psv, 0, sizeof(PackedSfenValue));
|
||||
|
||||
// fen
|
||||
{
|
||||
bool fen_found = false;
|
||||
|
||||
while (!fen_found) {
|
||||
std::regex pattern_bracket(R"(\{(.+?)\})");
|
||||
std::smatch match;
|
||||
if (!std::regex_search(itr, line.cend(), match, pattern_bracket)) {
|
||||
break;
|
||||
}
|
||||
|
||||
itr += match.position(0) + match.length(0) - 1;
|
||||
std::string str_fen = match.str(1);
|
||||
trim(str_fen);
|
||||
|
||||
if (is_like_fen(str_fen)) {
|
||||
fen_found = true;
|
||||
|
||||
StateInfo si;
|
||||
pos.set(str_fen, false, &si, th);
|
||||
pos.sfen_pack(psv.sfen);
|
||||
}
|
||||
|
||||
#if defined(DEBUG_CONVERT_BIN_FROM_PGN_EXTRACT)
|
||||
std::cout << "str_fen=" << str_fen << std::endl;
|
||||
std::cout << "fen_found=" << fen_found << std::endl;
|
||||
#endif
|
||||
}
|
||||
|
||||
if (!fen_found) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// move
|
||||
{
|
||||
std::regex pattern_move(R"(\}(.+?)\{)");
|
||||
std::smatch match;
|
||||
if (!std::regex_search(itr, line.cend(), match, pattern_move)) {
|
||||
break;
|
||||
}
|
||||
|
||||
itr += match.position(0) + match.length(0) - 1;
|
||||
std::string str_move = match.str(1);
|
||||
trim(str_move);
|
||||
#if defined(DEBUG_CONVERT_BIN_FROM_PGN_EXTRACT)
|
||||
std::cout << "str_move=" << str_move << std::endl;
|
||||
#endif
|
||||
psv.move = UCI::to_move(pos, str_move);
|
||||
}
|
||||
|
||||
// eval
|
||||
bool eval_found = false;
|
||||
{
|
||||
std::regex pattern_bracket(R"(\{(.+?)\})");
|
||||
std::smatch match;
|
||||
if (!std::regex_search(itr, line.cend(), match, pattern_bracket)) {
|
||||
break;
|
||||
}
|
||||
|
||||
std::string str_eval_clk = match.str(1);
|
||||
trim(str_eval_clk);
|
||||
#if defined(DEBUG_CONVERT_BIN_FROM_PGN_EXTRACT)
|
||||
std::cout << "str_eval_clk=" << str_eval_clk << std::endl;
|
||||
#endif
|
||||
|
||||
// example: { [%eval 0.25] [%clk 0:10:00] }
|
||||
// example: { [%eval #-4] [%clk 0:10:00] }
|
||||
// example: { [%eval #3] [%clk 0:10:00] }
|
||||
// example: { +0.71/22 1.2s }
|
||||
// example: { -M4/7 0.003s }
|
||||
// example: { M3/245 0.017s }
|
||||
// example: { +M1/245 0.010s, White mates }
|
||||
// example: { 0.60 }
|
||||
// example: { book }
|
||||
// example: { rnbqkb1r/pp3ppp/2p1pn2/3p4/2PP4/2N2N2/PP2PPPP/R1BQKB1R w KQkq - 0 5 }
|
||||
|
||||
// Considering the absence of eval
|
||||
if (!is_like_fen(str_eval_clk)) {
|
||||
itr += match.position(0) + match.length(0) - 1;
|
||||
|
||||
if (str_eval_clk != "book") {
|
||||
std::regex pattern_eval1(R"(\[\%eval (.+?)\])");
|
||||
std::regex pattern_eval2(R"((.+?)\/)");
|
||||
|
||||
std::string str_eval;
|
||||
if (std::regex_search(str_eval_clk, match, pattern_eval1) ||
|
||||
std::regex_search(str_eval_clk, match, pattern_eval2)) {
|
||||
str_eval = match.str(1);
|
||||
trim(str_eval);
|
||||
}
|
||||
else {
|
||||
str_eval = str_eval_clk;
|
||||
}
|
||||
|
||||
bool success = false;
|
||||
Value value = parse_score_from_pgn_extract(str_eval, success);
|
||||
if (success) {
|
||||
eval_found = true;
|
||||
psv.score = Math::clamp(value, -VALUE_MATE, VALUE_MATE);
|
||||
}
|
||||
|
||||
#if defined(DEBUG_CONVERT_BIN_FROM_PGN_EXTRACT)
|
||||
std::cout << "str_eval=" << str_eval << std::endl;
|
||||
std::cout << "success=" << success << ", psv.score=" << psv.score << std::endl;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// write
|
||||
if (eval_found || convert_no_eval_fens_as_score_zero) {
|
||||
if (!eval_found && convert_no_eval_fens_as_score_zero) {
|
||||
psv.score = 0;
|
||||
}
|
||||
|
||||
psv.gamePly = gamePly;
|
||||
psv.game_result = game_result;
|
||||
|
||||
if (pos.side_to_move() == BLACK) {
|
||||
if (!pgn_eval_side_to_move) {
|
||||
psv.score *= -1;
|
||||
}
|
||||
psv.game_result *= -1;
|
||||
}
|
||||
|
||||
ofs.write((char*)&psv, sizeof(PackedSfenValue));
|
||||
|
||||
fen_count++;
|
||||
}
|
||||
}
|
||||
|
||||
game_result = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << now_string() << " game_count=" << game_count << ", fen_count=" << fen_count << std::endl;
|
||||
std::cout << now_string() << " all done" << std::endl;
|
||||
ofs.close();
|
||||
}
|
||||
|
||||
void convert_plain(
|
||||
const vector<string>& filenames,
|
||||
const string& output_file_name)
|
||||
{
|
||||
Position tpos;
|
||||
std::ofstream ofs;
|
||||
ofs.open(output_file_name, ios::app);
|
||||
auto th = Threads.main();
|
||||
for (auto filename : filenames) {
|
||||
std::cout << "convert " << filename << " ... ";
|
||||
|
||||
// Just convert packedsfenvalue to text
|
||||
std::fstream fs;
|
||||
fs.open(filename, ios::in | ios::binary);
|
||||
PackedSfenValue p;
|
||||
while (true)
|
||||
{
|
||||
if (fs.read((char*)&p, sizeof(PackedSfenValue))) {
|
||||
StateInfo si;
|
||||
tpos.set_from_packed_sfen(p.sfen, &si, th);
|
||||
|
||||
// write as plain text
|
||||
ofs << "fen " << tpos.fen() << std::endl;
|
||||
ofs << "move " << UCI::move(Move(p.move), false) << std::endl;
|
||||
ofs << "score " << p.score << std::endl;
|
||||
ofs << "ply " << int(p.gamePly) << std::endl;
|
||||
ofs << "result " << int(p.game_result) << std::endl;
|
||||
ofs << "e" << std::endl;
|
||||
}
|
||||
else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
fs.close();
|
||||
std::cout << "done" << std::endl;
|
||||
}
|
||||
ofs.close();
|
||||
std::cout << "all done" << std::endl;
|
||||
}
|
||||
|
||||
static inline const std::string plain_extension = ".plain";
|
||||
static inline const std::string bin_extension = ".bin";
|
||||
static inline const std::string binpack_extension = ".binpack";
|
||||
|
||||
static bool file_exists(const std::string& name)
|
||||
{
|
||||
std::ifstream f(name);
|
||||
return f.good();
|
||||
}
|
||||
|
||||
static bool ends_with(const std::string& lhs, const std::string& end)
|
||||
{
|
||||
if (end.size() > lhs.size()) return false;
|
||||
|
||||
return std::equal(end.rbegin(), end.rend(), lhs.rbegin());
|
||||
}
|
||||
|
||||
static bool is_convert_of_type(
|
||||
const std::string& input_path,
|
||||
const std::string& output_path,
|
||||
const std::string& expected_input_extension,
|
||||
const std::string& expected_output_extension)
|
||||
{
|
||||
return ends_with(input_path, expected_input_extension)
|
||||
&& ends_with(output_path, expected_output_extension);
|
||||
}
|
||||
|
||||
using ConvertFunctionType = void(std::string inputPath, std::string outputPath, std::ios_base::openmode om, bool validate);
|
||||
|
||||
static ConvertFunctionType* get_convert_function(const std::string& input_path, const std::string& output_path)
|
||||
{
|
||||
if (is_convert_of_type(input_path, output_path, plain_extension, bin_extension))
|
||||
return binpack::convertPlainToBin;
|
||||
if (is_convert_of_type(input_path, output_path, plain_extension, binpack_extension))
|
||||
return binpack::convertPlainToBinpack;
|
||||
|
||||
if (is_convert_of_type(input_path, output_path, bin_extension, plain_extension))
|
||||
return binpack::convertBinToPlain;
|
||||
if (is_convert_of_type(input_path, output_path, bin_extension, binpack_extension))
|
||||
return binpack::convertBinToBinpack;
|
||||
|
||||
if (is_convert_of_type(input_path, output_path, binpack_extension, plain_extension))
|
||||
return binpack::convertBinpackToPlain;
|
||||
if (is_convert_of_type(input_path, output_path, binpack_extension, bin_extension))
|
||||
return binpack::convertBinpackToBin;
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static void convert(const std::string& input_path, const std::string& output_path, std::ios_base::openmode om, bool validate)
|
||||
{
|
||||
if(!file_exists(input_path))
|
||||
{
|
||||
std::cerr << "Input file does not exist.\n";
|
||||
return;
|
||||
}
|
||||
|
||||
auto func = get_convert_function(input_path, output_path);
|
||||
if (func != nullptr)
|
||||
{
|
||||
func(input_path, output_path, om, validate);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Conversion between files of these types is not supported.\n";
|
||||
}
|
||||
}
|
||||
|
||||
static void convert(const std::vector<std::string>& args)
|
||||
{
|
||||
if (args.size() < 2 || args.size() > 4)
|
||||
{
|
||||
std::cerr << "Invalid arguments.\n";
|
||||
std::cerr << "Usage: convert from_path to_path [append] [validate]\n";
|
||||
return;
|
||||
}
|
||||
|
||||
const bool append = std::find(args.begin() + 2, args.end(), "append") != args.end();
|
||||
const bool validate = std::find(args.begin() + 2, args.end(), "validate") != args.end();
|
||||
|
||||
const std::ios_base::openmode openmode =
|
||||
append
|
||||
? std::ios_base::app
|
||||
: std::ios_base::trunc;
|
||||
|
||||
convert(args[0], args[1], openmode, validate);
|
||||
}
|
||||
|
||||
void convert(istringstream& is)
|
||||
{
|
||||
std::vector<std::string> args;
|
||||
|
||||
while (true)
|
||||
{
|
||||
std::string token = "";
|
||||
is >> token;
|
||||
if (token == "")
|
||||
break;
|
||||
|
||||
args.push_back(token);
|
||||
}
|
||||
|
||||
convert(args);
|
||||
}
|
||||
|
||||
static void append_files_from_dir(
|
||||
std::vector<std::string>& filenames,
|
||||
const std::string& base_dir,
|
||||
const std::string& target_dir)
|
||||
{
|
||||
string kif_base_dir = Path::combine(base_dir, target_dir);
|
||||
|
||||
sys::path p(kif_base_dir); // Origin of enumeration
|
||||
std::for_each(sys::directory_iterator(p), sys::directory_iterator(),
|
||||
[&](const sys::path& path) {
|
||||
if (sys::is_regular_file(path))
|
||||
filenames.push_back(Path::combine(target_dir, path.filename().generic_string()));
|
||||
});
|
||||
}
|
||||
|
||||
static void rebase_files(
|
||||
std::vector<std::string>& filenames,
|
||||
const std::string& base_dir)
|
||||
{
|
||||
for (auto& file : filenames)
|
||||
{
|
||||
file = Path::combine(base_dir, file);
|
||||
}
|
||||
}
|
||||
|
||||
void convert_bin_from_pgn_extract(std::istringstream& is)
|
||||
{
|
||||
std::vector<std::string> filenames;
|
||||
|
||||
string base_dir;
|
||||
string target_dir;
|
||||
|
||||
bool pgn_eval_side_to_move = false;
|
||||
bool convert_no_eval_fens_as_score_zero = false;
|
||||
|
||||
string output_file_name = "shuffled_sfen.bin";
|
||||
|
||||
while (true)
|
||||
{
|
||||
string option;
|
||||
is >> option;
|
||||
|
||||
if (option == "")
|
||||
break;
|
||||
|
||||
if (option == "targetdir") is >> target_dir;
|
||||
else if (option == "targetfile")
|
||||
{
|
||||
std::string filename;
|
||||
is >> filename;
|
||||
filenames.push_back(filename);
|
||||
}
|
||||
|
||||
else if (option == "basedir") is >> base_dir;
|
||||
|
||||
else if (option == "pgn_eval_side_to_move") is >> pgn_eval_side_to_move;
|
||||
else if (option == "convert_no_eval_fens_as_score_zero") is >> convert_no_eval_fens_as_score_zero;
|
||||
else if (option == "output_file_name") is >> output_file_name;
|
||||
else
|
||||
{
|
||||
cout << "Unknown option: " << option << ". Ignoring.\n";
|
||||
}
|
||||
}
|
||||
|
||||
if (!target_dir.empty())
|
||||
{
|
||||
append_files_from_dir(filenames, base_dir, target_dir);
|
||||
}
|
||||
rebase_files(filenames, base_dir);
|
||||
|
||||
Eval::NNUE::init();
|
||||
|
||||
cout << "convert_bin_from_pgn-extract.." << endl;
|
||||
convert_bin_from_pgn_extract(
|
||||
filenames,
|
||||
output_file_name,
|
||||
pgn_eval_side_to_move,
|
||||
convert_no_eval_fens_as_score_zero);
|
||||
}
|
||||
|
||||
void convert_bin(std::istringstream& is)
|
||||
{
|
||||
std::vector<std::string> filenames;
|
||||
|
||||
string base_dir;
|
||||
string target_dir;
|
||||
|
||||
int ply_minimum = 0;
|
||||
int ply_maximum = 114514;
|
||||
bool interpolate_eval = 0;
|
||||
bool check_invalid_fen = false;
|
||||
bool check_illegal_move = false;
|
||||
|
||||
bool pgn_eval_side_to_move = false;
|
||||
bool convert_no_eval_fens_as_score_zero = false;
|
||||
|
||||
double src_score_min_value = 0.0;
|
||||
double src_score_max_value = 1.0;
|
||||
double dest_score_min_value = 0.0;
|
||||
double dest_score_max_value = 1.0;
|
||||
|
||||
string output_file_name = "shuffled_sfen.bin";
|
||||
|
||||
while (true)
|
||||
{
|
||||
string option;
|
||||
is >> option;
|
||||
|
||||
if (option == "")
|
||||
break;
|
||||
|
||||
if (option == "targetdir") is >> target_dir;
|
||||
else if (option == "targetfile")
|
||||
{
|
||||
std::string filename;
|
||||
is >> filename;
|
||||
filenames.push_back(filename);
|
||||
}
|
||||
|
||||
else if (option == "basedir") is >> base_dir;
|
||||
|
||||
else if (option == "ply_minimum") is >> ply_minimum;
|
||||
else if (option == "ply_maximum") is >> ply_maximum;
|
||||
else if (option == "interpolate_eval") is >> interpolate_eval;
|
||||
else if (option == "check_invalid_fen") is >> check_invalid_fen;
|
||||
else if (option == "check_illegal_move") is >> check_illegal_move;
|
||||
else if (option == "pgn_eval_side_to_move") is >> pgn_eval_side_to_move;
|
||||
else if (option == "convert_no_eval_fens_as_score_zero") is >> convert_no_eval_fens_as_score_zero;
|
||||
else if (option == "src_score_min_value") is >> src_score_min_value;
|
||||
else if (option == "src_score_max_value") is >> src_score_max_value;
|
||||
else if (option == "dest_score_min_value") is >> dest_score_min_value;
|
||||
else if (option == "dest_score_max_value") is >> dest_score_max_value;
|
||||
else if (option == "output_file_name") is >> output_file_name;
|
||||
else
|
||||
{
|
||||
cout << "Unknown option: " << option << ". Ignoring.\n";
|
||||
}
|
||||
}
|
||||
|
||||
if (!target_dir.empty())
|
||||
{
|
||||
append_files_from_dir(filenames, base_dir, target_dir);
|
||||
}
|
||||
rebase_files(filenames, base_dir);
|
||||
|
||||
Eval::NNUE::init();
|
||||
|
||||
cout << "convert_bin.." << endl;
|
||||
convert_bin(
|
||||
filenames,
|
||||
output_file_name,
|
||||
ply_minimum,
|
||||
ply_maximum,
|
||||
interpolate_eval,
|
||||
src_score_min_value,
|
||||
src_score_max_value,
|
||||
dest_score_min_value,
|
||||
dest_score_max_value,
|
||||
check_invalid_fen,
|
||||
check_illegal_move
|
||||
);
|
||||
}
|
||||
|
||||
void convert_plain(std::istringstream& is)
|
||||
{
|
||||
std::vector<std::string> filenames;
|
||||
|
||||
string base_dir;
|
||||
string target_dir;
|
||||
|
||||
string output_file_name = "shuffled_sfen.bin";
|
||||
|
||||
while (true)
|
||||
{
|
||||
string option;
|
||||
is >> option;
|
||||
|
||||
if (option == "")
|
||||
break;
|
||||
|
||||
if (option == "targetdir") is >> target_dir;
|
||||
else if (option == "targetfile")
|
||||
{
|
||||
std::string filename;
|
||||
is >> filename;
|
||||
filenames.push_back(filename);
|
||||
}
|
||||
|
||||
else if (option == "basedir") is >> base_dir;
|
||||
|
||||
else if (option == "output_file_name") is >> output_file_name;
|
||||
else
|
||||
{
|
||||
cout << "Unknown option: " << option << ". Ignoring.\n";
|
||||
}
|
||||
}
|
||||
|
||||
if (!target_dir.empty())
|
||||
{
|
||||
append_files_from_dir(filenames, base_dir, target_dir);
|
||||
}
|
||||
rebase_files(filenames, base_dir);
|
||||
|
||||
Eval::NNUE::init();
|
||||
|
||||
cout << "convert_plain.." << endl;
|
||||
convert_plain(filenames, output_file_name);
|
||||
}
|
||||
}
|
||||
18
src/tools/convert.h
Normal file
18
src/tools/convert.h
Normal file
@@ -0,0 +1,18 @@
|
||||
#ifndef _CONVERT_H_
|
||||
#define _CONVERT_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
|
||||
namespace Tools {
|
||||
void convert(std::istringstream& is);
|
||||
|
||||
void convert_bin_from_pgn_extract(std::istringstream& is);
|
||||
|
||||
void convert_bin(std::istringstream& is);
|
||||
|
||||
void convert_plain(std::istringstream& is);
|
||||
}
|
||||
|
||||
#endif
|
||||
974
src/tools/gensfen.cpp
Normal file
974
src/tools/gensfen.cpp
Normal file
@@ -0,0 +1,974 @@
|
||||
#include "gensfen.h"
|
||||
|
||||
#include "sfen_writer.h"
|
||||
#include "packed_sfen.h"
|
||||
#include "opening_book.h"
|
||||
|
||||
#include "misc.h"
|
||||
#include "position.h"
|
||||
#include "thread.h"
|
||||
#include "tt.h"
|
||||
#include "uci.h"
|
||||
|
||||
#include "extra/nnue_data_binpack_format.h"
|
||||
|
||||
#include "nnue/evaluate_nnue.h"
|
||||
|
||||
#include "syzygy/tbprobe.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <climits>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <limits>
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <random>
|
||||
#include <shared_mutex>
|
||||
#include <sstream>
|
||||
#include <unordered_set>
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace Tools
|
||||
{
|
||||
// Class to generate sfen with multiple threads
|
||||
struct Gensfen
|
||||
{
|
||||
struct Params
|
||||
{
|
||||
// Min and max depths for search during gensfen
|
||||
int search_depth_min = 3;
|
||||
int search_depth_max = -1;
|
||||
|
||||
// Number of the nodes to be searched.
|
||||
// 0 represents no limits.
|
||||
uint64_t nodes = 0;
|
||||
|
||||
// Upper limit of evaluation value of generated situation
|
||||
int eval_limit = 3000;
|
||||
|
||||
// minimum ply with random move
|
||||
// maximum ply with random move
|
||||
// Number of random moves in one station
|
||||
int random_move_minply = 1;
|
||||
int random_move_maxply = 24;
|
||||
int random_move_count = 5;
|
||||
|
||||
// Move kings with a probability of 1/N when randomly moving like Apery software.
|
||||
// When you move the king again, there is a 1/N chance that it will randomly moved
|
||||
// once in the opponent's turn.
|
||||
// Apery has N=2. Specifying 0 here disables this function.
|
||||
int random_move_like_apery = 0;
|
||||
|
||||
// For when using multi pv instead of random move.
|
||||
// random_multi_pv is the number of candidates for MultiPV.
|
||||
// When adopting the move of the candidate move, the difference
|
||||
// between the evaluation value of the move of the 1st place
|
||||
// and the evaluation value of the move of the Nth place is.
|
||||
// Must be in the range random_multi_pv_diff.
|
||||
// random_multi_pv_depth is the search depth for MultiPV.
|
||||
int random_multi_pv = 0;
|
||||
int random_multi_pv_diff = 32000;
|
||||
int random_multi_pv_depth = -1;
|
||||
|
||||
// The minimum and maximum ply (number of steps from
|
||||
// the initial phase) of the sfens to write out.
|
||||
int write_minply = 16;
|
||||
int write_maxply = 400;
|
||||
|
||||
uint64_t save_every = std::numeric_limits<uint64_t>::max();
|
||||
|
||||
std::string output_file_name = "generated_kifu";
|
||||
|
||||
SfenOutputType sfen_format = SfenOutputType::Binpack;
|
||||
|
||||
std::string seed;
|
||||
|
||||
bool write_out_draw_game_in_training_data_generation = true;
|
||||
bool detect_draw_by_consecutive_low_score = true;
|
||||
bool detect_draw_by_insufficient_mating_material = true;
|
||||
|
||||
bool ensure_quiet = false;
|
||||
|
||||
uint64_t num_threads;
|
||||
|
||||
std::string book;
|
||||
|
||||
void enforce_constraints()
|
||||
{
|
||||
search_depth_max = std::max(search_depth_min, search_depth_max);
|
||||
|
||||
// Limit the maximum to a one-stop score. (Otherwise you might not end the loop)
|
||||
eval_limit = std::min(eval_limit, (int)mate_in(2));
|
||||
|
||||
save_every = std::max(save_every, REPORT_STATS_EVERY);
|
||||
|
||||
num_threads = Options["Threads"];
|
||||
|
||||
random_multi_pv_depth = std::max(search_depth_max, random_multi_pv_depth);
|
||||
}
|
||||
};
|
||||
|
||||
// Hash to limit the export of identical sfens
|
||||
static constexpr uint64_t GENSFEN_HASH_SIZE = 64 * 1024 * 1024;
|
||||
// It must be 2**N because it will be used as the mask to calculate hash_index.
|
||||
static_assert((GENSFEN_HASH_SIZE& (GENSFEN_HASH_SIZE - 1)) == 0);
|
||||
|
||||
static constexpr uint64_t REPORT_DOT_EVERY = 5000;
|
||||
static constexpr uint64_t REPORT_STATS_EVERY = 200000;
|
||||
static_assert(REPORT_STATS_EVERY % REPORT_DOT_EVERY == 0);
|
||||
|
||||
Gensfen(
|
||||
const Params& prm
|
||||
) :
|
||||
params(prm),
|
||||
sfen_writer(prm.output_file_name, prm.num_threads, prm.save_every, prm.sfen_format)
|
||||
{
|
||||
hash.resize(GENSFEN_HASH_SIZE);
|
||||
prngs.reserve(prm.num_threads);
|
||||
auto seed = prm.seed;
|
||||
for (uint64_t i = 0; i < prm.num_threads; ++i)
|
||||
{
|
||||
prngs.emplace_back(seed);
|
||||
seed = prngs.back().next_random_seed();
|
||||
}
|
||||
|
||||
if (!prm.book.empty())
|
||||
{
|
||||
opening_book = open_opening_book(prm.book, prngs[0]);
|
||||
if (opening_book == nullptr)
|
||||
{
|
||||
std::cout << "WARNING: Failed to open opening book " << prm.book << ". Falling back to startpos.\n";
|
||||
}
|
||||
}
|
||||
|
||||
// Output seed to veryfy by the user if it's not identical by chance.
|
||||
std::cout << prngs[0] << std::endl;
|
||||
}
|
||||
|
||||
void generate(uint64_t limit);
|
||||
|
||||
private:
|
||||
Params params;
|
||||
|
||||
std::vector<PRNG> prngs;
|
||||
|
||||
std::mutex stats_mutex;
|
||||
TimePoint last_stats_report_time;
|
||||
|
||||
// sfen exporter
|
||||
SfenWriter sfen_writer;
|
||||
|
||||
SynchronizedRegionLogger::Region out;
|
||||
|
||||
vector<Key> hash; // 64MB*sizeof(HASH_KEY) = 512MB
|
||||
|
||||
std::unique_ptr<OpeningBook> opening_book;
|
||||
|
||||
static void set_gensfen_search_limits();
|
||||
|
||||
void generate_worker(
|
||||
Thread& th,
|
||||
std::atomic<uint64_t>& counter,
|
||||
uint64_t limit);
|
||||
|
||||
bool was_seen_before(const Position& pos);
|
||||
|
||||
optional<int8_t> get_current_game_result(
|
||||
Position& pos,
|
||||
const vector<int>& move_hist_scores) const;
|
||||
|
||||
vector<uint8_t> generate_random_move_flags(PRNG& prng);
|
||||
|
||||
optional<Move> choose_random_move(
|
||||
PRNG& prng,
|
||||
Position& pos,
|
||||
std::vector<uint8_t>& random_move_flag,
|
||||
int ply,
|
||||
int& random_move_c);
|
||||
|
||||
bool commit_psv(
|
||||
Thread& th,
|
||||
PSVector& sfens,
|
||||
int8_t lastTurnIsWin,
|
||||
std::atomic<uint64_t>& counter,
|
||||
uint64_t limit,
|
||||
Color result_color);
|
||||
|
||||
void report(uint64_t done, uint64_t new_done);
|
||||
|
||||
void maybe_report(uint64_t done);
|
||||
};
|
||||
|
||||
void Gensfen::set_gensfen_search_limits()
|
||||
{
|
||||
// About Search::Limits
|
||||
// Be careful because this member variable is global and affects other threads.
|
||||
auto& limits = Search::Limits;
|
||||
|
||||
// Make the search equivalent to the "go infinite" command. (Because it is troublesome if time management is done)
|
||||
limits.infinite = true;
|
||||
|
||||
// Since PV is an obstacle when displayed, erase it.
|
||||
limits.silent = true;
|
||||
|
||||
// If you use this, it will be compared with the accumulated nodes of each thread. Therefore, do not use it.
|
||||
limits.nodes = 0;
|
||||
|
||||
// depth is also processed by the one passed as an argument of Tools::search().
|
||||
limits.depth = 0;
|
||||
}
|
||||
|
||||
void Gensfen::generate(uint64_t limit)
|
||||
{
|
||||
last_stats_report_time = 0;
|
||||
|
||||
set_gensfen_search_limits();
|
||||
|
||||
std::atomic<uint64_t> counter{0};
|
||||
Threads.execute_with_workers([&counter, limit, this](Thread& th) {
|
||||
generate_worker(th, counter, limit);
|
||||
});
|
||||
Threads.wait_for_workers_finished();
|
||||
|
||||
sfen_writer.flush();
|
||||
|
||||
if (limit % REPORT_STATS_EVERY != 0)
|
||||
{
|
||||
report(limit, limit % REPORT_STATS_EVERY);
|
||||
}
|
||||
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
void Gensfen::generate_worker(
|
||||
Thread& th,
|
||||
std::atomic<uint64_t>& counter,
|
||||
uint64_t limit)
|
||||
{
|
||||
// For the time being, it will be treated as a draw
|
||||
// at the maximum number of steps to write.
|
||||
// Maximum StateInfo + Search PV to advance to leaf buffer
|
||||
std::vector<StateInfo, AlignedAllocator<StateInfo>> states(
|
||||
params.write_maxply + MAX_PLY /* == search_depth_min + α */);
|
||||
|
||||
StateInfo si;
|
||||
|
||||
auto& prng = prngs[th.thread_idx()];
|
||||
|
||||
// end flag
|
||||
bool quit = false;
|
||||
|
||||
// repeat until the specified number of times
|
||||
while (!quit)
|
||||
{
|
||||
// It is necessary to set a dependent thread for Position.
|
||||
// When parallelizing, Threads (since this is a vector<Thread*>,
|
||||
// Do the same for up to Threads[0]...Threads[thread_num-1].
|
||||
auto& pos = th.rootPos;
|
||||
if (opening_book != nullptr)
|
||||
{
|
||||
auto& fen = opening_book->next_fen();
|
||||
pos.set(fen, false, &si, &th);
|
||||
}
|
||||
else
|
||||
{
|
||||
pos.set(StartFEN, false, &si, &th);
|
||||
}
|
||||
|
||||
int resign_counter = 0;
|
||||
bool should_resign = prng.rand(10) > 1;
|
||||
// Vector for holding the sfens in the current simulated game.
|
||||
PSVector packed_sfens;
|
||||
packed_sfens.reserve(params.write_maxply + MAX_PLY);
|
||||
|
||||
// Precomputed flags. Used internally by choose_random_move.
|
||||
vector<uint8_t> random_move_flag = generate_random_move_flags(prng);
|
||||
|
||||
// A counter that keeps track of the number of random moves
|
||||
// When random_move_minply == -1, random moves are
|
||||
// performed continuously, so use it at this time.
|
||||
// Used internally by choose_random_move.
|
||||
int actual_random_move_count = 0;
|
||||
|
||||
// Save history of move scores for adjudication
|
||||
vector<int> move_hist_scores;
|
||||
|
||||
auto flush_psv = [&](int8_t result) {
|
||||
quit = commit_psv(th, packed_sfens, result, counter, limit, pos.side_to_move());
|
||||
};
|
||||
|
||||
for (int ply = 0; ; ++ply)
|
||||
{
|
||||
// Current search depth
|
||||
const int depth = params.search_depth_min + (int)prng.rand(params.search_depth_max - params.search_depth_min + 1);
|
||||
|
||||
// Starting search calls init_for_search
|
||||
auto [search_value, search_pv] = Search::search(pos, depth, 1, params.nodes);
|
||||
|
||||
// This has to be performed after search because it needs to know
|
||||
// rootMoves which are filled in init_for_search.
|
||||
const auto result = get_current_game_result(pos, move_hist_scores);
|
||||
if (result.has_value())
|
||||
{
|
||||
flush_psv(result.value());
|
||||
break;
|
||||
}
|
||||
|
||||
// Always adjudivate by eval limit.
|
||||
// Also because of this we don't have to check for TB/MATE scores
|
||||
if (abs(search_value) >= params.eval_limit)
|
||||
{
|
||||
resign_counter++;
|
||||
if ((should_resign && resign_counter >= 4) || abs(search_value) >= VALUE_KNOWN_WIN) {
|
||||
flush_psv((search_value >= params.eval_limit) ? 1 : -1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
resign_counter = 0;
|
||||
}
|
||||
|
||||
// In case there is no PV and the game was not ended here
|
||||
// there is nothing we can do, we can't continue the game,
|
||||
// we don't know the result, so discard this game.
|
||||
if (search_pv.empty())
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
// Save the move score for adjudication.
|
||||
move_hist_scores.push_back(search_value);
|
||||
|
||||
// Discard stuff before write_minply is reached
|
||||
// because it can harm training due to overfitting.
|
||||
// Initial positions would be too common.
|
||||
if (ply >= params.write_minply)
|
||||
{
|
||||
packed_sfens.emplace_back(PackedSfenValue());
|
||||
|
||||
auto& psv = packed_sfens.back();
|
||||
|
||||
if (params.ensure_quiet)
|
||||
{
|
||||
auto [qsearch_value, qsearch_pv] = Search::qsearch(pos);
|
||||
if (qsearch_pv.empty())
|
||||
{
|
||||
// Here we only write the position data.
|
||||
// Result is added after the whole game is done.
|
||||
pos.sfen_pack(psv.sfen);
|
||||
|
||||
// Already a quiet position
|
||||
psv.score = search_value;
|
||||
psv.move = search_pv[0];
|
||||
psv.gamePly = ply;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Navigate to a quiet
|
||||
int old_ply = ply;
|
||||
for (auto m : qsearch_pv)
|
||||
{
|
||||
pos.do_move(m, states[ply++]);
|
||||
}
|
||||
|
||||
if (was_seen_before(pos))
|
||||
{
|
||||
// Just skip the move.
|
||||
packed_sfens.pop_back();
|
||||
}
|
||||
else
|
||||
{
|
||||
// Reevaluate
|
||||
auto [quiet_search_value, quiet_search_pv] = Search::search(pos, depth, 1, params.nodes);
|
||||
if (quiet_search_pv.empty())
|
||||
{
|
||||
// Just skip the move.
|
||||
packed_sfens.pop_back();
|
||||
}
|
||||
else
|
||||
{
|
||||
// Here we only write the position data.
|
||||
// Result is added after the whole game is done.
|
||||
pos.sfen_pack(psv.sfen);
|
||||
|
||||
psv.score = quiet_search_value;
|
||||
psv.move = quiet_search_pv[0];
|
||||
psv.gamePly = ply;
|
||||
}
|
||||
}
|
||||
|
||||
// Get back to the game
|
||||
for (auto it = qsearch_pv.rbegin(); it != qsearch_pv.rend(); ++it)
|
||||
{
|
||||
pos.undo_move(*it);
|
||||
}
|
||||
ply = old_ply;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (was_seen_before(pos))
|
||||
{
|
||||
packed_sfens.pop_back();
|
||||
}
|
||||
else
|
||||
{
|
||||
// Here we only write the position data.
|
||||
// Result is added after the whole game is done.
|
||||
pos.sfen_pack(psv.sfen);
|
||||
|
||||
psv.score = search_value;
|
||||
psv.move = search_pv[0];
|
||||
psv.gamePly = ply;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update the next move according to best search result or random move.
|
||||
auto random_move = choose_random_move(prng, pos, random_move_flag, ply, actual_random_move_count);
|
||||
const Move next_move = random_move.has_value() ? *random_move : search_pv[0];
|
||||
|
||||
// We don't have the whole game yet, but it ended,
|
||||
// so the writing process ends and the next game starts.
|
||||
// This shouldn't really happen.
|
||||
if (!is_ok(next_move))
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
// Do move.
|
||||
pos.do_move(next_move, states[ply]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool Gensfen::was_seen_before(const Position& pos)
|
||||
{
|
||||
// Look into the position hashtable to see if the same
|
||||
// position was seen before.
|
||||
// This is a good heuristic to exlude already seen
|
||||
// positions without many false positives.
|
||||
auto key = pos.key();
|
||||
auto hash_index = (size_t)(key & (GENSFEN_HASH_SIZE - 1));
|
||||
auto old_key = hash[hash_index];
|
||||
if (key == old_key)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Replace with the current key.
|
||||
hash[hash_index] = key;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
optional<int8_t> Gensfen::get_current_game_result(
|
||||
Position& pos,
|
||||
const vector<int>& move_hist_scores) const
|
||||
{
|
||||
// Variables for draw adjudication.
|
||||
// Todo: Make this as an option.
|
||||
|
||||
// start the adjudication when ply reaches this value
|
||||
constexpr int adj_draw_ply = 80;
|
||||
|
||||
// 4 move scores for each side have to be checked
|
||||
constexpr int adj_draw_cnt = 8;
|
||||
|
||||
// move score in CP
|
||||
constexpr int adj_draw_score = 0;
|
||||
|
||||
// For the time being, it will be treated as a
|
||||
// draw at the maximum number of steps to write.
|
||||
const int ply = move_hist_scores.size();
|
||||
|
||||
// has it reached the max length or is a draw by fifty-move rule
|
||||
// or by 3-fold repetition
|
||||
if (ply >= params.write_maxply
|
||||
|| pos.is_fifty_move_draw()
|
||||
|| pos.is_three_fold_repetition())
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
if(pos.this_thread()->rootMoves.empty())
|
||||
{
|
||||
// If there is no legal move
|
||||
return pos.checkers()
|
||||
? -1 /* mate */
|
||||
: 0 /* stalemate */;
|
||||
}
|
||||
|
||||
// Adjudicate game to a draw if the last 4 scores of each engine is 0.
|
||||
if (params.detect_draw_by_consecutive_low_score)
|
||||
{
|
||||
if (ply >= adj_draw_ply)
|
||||
{
|
||||
int num_cons_plies_within_draw_score = 0;
|
||||
bool is_adj_draw = false;
|
||||
|
||||
for (auto it = move_hist_scores.rbegin();
|
||||
it != move_hist_scores.rend(); ++it)
|
||||
{
|
||||
if (abs(*it) <= adj_draw_score)
|
||||
{
|
||||
num_cons_plies_within_draw_score++;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Draw scores must happen on consecutive plies
|
||||
break;
|
||||
}
|
||||
|
||||
if (num_cons_plies_within_draw_score >= adj_draw_cnt)
|
||||
{
|
||||
is_adj_draw = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_adj_draw)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Draw by insufficient mating material
|
||||
if (params.detect_draw_by_insufficient_mating_material)
|
||||
{
|
||||
if (pos.count<ALL_PIECES>() <= 4)
|
||||
{
|
||||
int num_pieces = pos.count<ALL_PIECES>();
|
||||
|
||||
// (1) KvK
|
||||
if (num_pieces == 2)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
// (2) KvK + 1 minor piece
|
||||
if (num_pieces == 3)
|
||||
{
|
||||
int minor_pc = pos.count<BISHOP>(WHITE) + pos.count<KNIGHT>(WHITE) +
|
||||
pos.count<BISHOP>(BLACK) + pos.count<KNIGHT>(BLACK);
|
||||
if (minor_pc == 1)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
// (3) KBvKB, bishops of the same color
|
||||
else if (num_pieces == 4)
|
||||
{
|
||||
if (pos.count<BISHOP>(WHITE) == 1 && pos.count<BISHOP>(BLACK) == 1)
|
||||
{
|
||||
// Color of bishops is black.
|
||||
if ((pos.pieces(WHITE, BISHOP) & DarkSquares)
|
||||
&& (pos.pieces(BLACK, BISHOP) & DarkSquares))
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
// Color of bishops is white.
|
||||
if ((pos.pieces(WHITE, BISHOP) & ~DarkSquares)
|
||||
&& (pos.pieces(BLACK, BISHOP) & ~DarkSquares))
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nullopt;
|
||||
}
|
||||
|
||||
vector<uint8_t> Gensfen::generate_random_move_flags(PRNG& prng)
|
||||
{
|
||||
vector<uint8_t> random_move_flag;
|
||||
|
||||
// Depending on random move selection parameters setup
|
||||
// the array of flags that indicates whether a random move
|
||||
// be taken at a given ply.
|
||||
|
||||
// Make an array like a[0] = 0 ,a[1] = 1, ...
|
||||
// Fisher-Yates shuffle and take out the first N items.
|
||||
// Actually, I only want N pieces, so I only need
|
||||
// to shuffle the first N pieces with Fisher-Yates.
|
||||
|
||||
vector<int> a;
|
||||
a.reserve((size_t)params.random_move_maxply);
|
||||
|
||||
// random_move_minply ,random_move_maxply is specified by 1 origin,
|
||||
// Note that we are handling 0 origin here.
|
||||
for (int i = std::max(params.random_move_minply - 1, 0); i < params.random_move_maxply; ++i)
|
||||
{
|
||||
a.push_back(i);
|
||||
}
|
||||
|
||||
// In case of Apery random move, insert() may be called random_move_count times.
|
||||
// Reserve only the size considering it.
|
||||
random_move_flag.resize((size_t)params.random_move_maxply + params.random_move_count);
|
||||
|
||||
// A random move that exceeds the size() of a[] cannot be applied, so limit it.
|
||||
for (int i = 0; i < std::min(params.random_move_count, (int)a.size()); ++i)
|
||||
{
|
||||
swap(a[i], a[prng.rand((uint64_t)a.size() - i) + i]);
|
||||
random_move_flag[a[i]] = true;
|
||||
}
|
||||
|
||||
return random_move_flag;
|
||||
}
|
||||
|
||||
optional<Move> Gensfen::choose_random_move(
|
||||
PRNG& prng,
|
||||
Position& pos,
|
||||
std::vector<uint8_t>& random_move_flag,
|
||||
int ply,
|
||||
int& random_move_c)
|
||||
{
|
||||
optional<Move> random_move;
|
||||
|
||||
// Randomly choose one from legal move
|
||||
if (
|
||||
// 1. Random move of random_move_count times from random_move_minply to random_move_maxply
|
||||
(params.random_move_minply != -1 && ply < (int)random_move_flag.size() && random_move_flag[ply]) ||
|
||||
// 2. A mode to perform random move of random_move_count times after leaving the startpos
|
||||
(params.random_move_minply == -1 && random_move_c < params.random_move_count))
|
||||
{
|
||||
++random_move_c;
|
||||
|
||||
// It's not a mate, so there should be one legal move...
|
||||
if (params.random_multi_pv == 0)
|
||||
{
|
||||
// Normal random move
|
||||
MoveList<LEGAL> list(pos);
|
||||
|
||||
// I don't really know the goodness and badness of making this the Apery method.
|
||||
if (params.random_move_like_apery == 0
|
||||
|| prng.rand(params.random_move_like_apery) != 0)
|
||||
{
|
||||
// Normally one move from legal move
|
||||
random_move = list.at((size_t)prng.rand((uint64_t)list.size()));
|
||||
}
|
||||
else
|
||||
{
|
||||
// if you can move the king, move the king
|
||||
Move moves[8]; // Near 8
|
||||
Move* p = &moves[0];
|
||||
for (auto& m : list)
|
||||
{
|
||||
if (type_of(pos.moved_piece(m)) == KING)
|
||||
{
|
||||
*(p++) = m;
|
||||
}
|
||||
}
|
||||
|
||||
size_t n = p - &moves[0];
|
||||
if (n != 0)
|
||||
{
|
||||
// move to move the king
|
||||
random_move = moves[prng.rand(n)];
|
||||
|
||||
// In Apery method, at this time there is a 1/2 chance
|
||||
// that the opponent will also move randomly
|
||||
if (prng.rand(2) == 0)
|
||||
{
|
||||
// Is it a simple hack to add a "1" next to random_move_flag[ply]?
|
||||
random_move_flag.insert(random_move_flag.begin() + ply + 1, 1, true);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Normally one move from legal move
|
||||
random_move = list.at((size_t)prng.rand((uint64_t)list.size()));
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
Search::search(pos, params.random_multi_pv_depth, params.random_multi_pv);
|
||||
|
||||
// Select one from the top N hands of root Moves
|
||||
auto& rm = pos.this_thread()->rootMoves;
|
||||
|
||||
uint64_t s = min((uint64_t)rm.size(), (uint64_t)params.random_multi_pv);
|
||||
for (uint64_t i = 1; i < s; ++i)
|
||||
{
|
||||
// The difference from the evaluation value of rm[0] must
|
||||
// be within the range of random_multi_pv_diff.
|
||||
// It can be assumed that rm[x].score is arranged in descending order.
|
||||
if (rm[0].score > rm[i].score + params.random_multi_pv_diff)
|
||||
{
|
||||
s = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
random_move = rm[prng.rand(s)].pv[0];
|
||||
}
|
||||
}
|
||||
|
||||
return random_move;
|
||||
}
|
||||
|
||||
// Write out the phases loaded in sfens to a file.
|
||||
// result: win/loss in the next phase after the final phase in sfens
|
||||
// 1 when winning. -1 when losing. Pass 0 for a draw.
|
||||
// Return value: true if the specified number of
|
||||
// sfens has already been reached and the process ends.
|
||||
bool Gensfen::commit_psv(
|
||||
Thread& th,
|
||||
PSVector& sfens,
|
||||
int8_t result,
|
||||
std::atomic<uint64_t>& counter,
|
||||
uint64_t limit,
|
||||
Color result_color)
|
||||
{
|
||||
if (!params.write_out_draw_game_in_training_data_generation && result == 0)
|
||||
{
|
||||
// We didn't write anything so why quit.
|
||||
return false;
|
||||
}
|
||||
|
||||
auto side_to_move_from_sfen = [](auto& sfen){
|
||||
return (Color)(sfen.sfen.data[0] & 1);
|
||||
};
|
||||
|
||||
// From the final stage (one step before) to the first stage, give information on the outcome of the game for each stage.
|
||||
// The phases stored in sfens are assumed to be continuous (in order).
|
||||
for (auto it = sfens.rbegin(); it != sfens.rend(); ++it)
|
||||
{
|
||||
// The side to move is packed as the lowest bit of the first byte
|
||||
const Color side_to_move = side_to_move_from_sfen(*it);
|
||||
it->game_result = side_to_move == result_color ? result : -result;
|
||||
}
|
||||
|
||||
// Write sfens in move order to make potential compression easier
|
||||
for (auto& sfen : sfens)
|
||||
{
|
||||
// Return true if there is already enough data generated.
|
||||
const auto iter = counter.fetch_add(1);
|
||||
if (iter >= limit)
|
||||
return true;
|
||||
|
||||
// because `iter` was done, now we do one more
|
||||
maybe_report(iter + 1);
|
||||
|
||||
// Write out one sfen.
|
||||
sfen_writer.write(th.thread_idx(), sfen);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void Gensfen::report(uint64_t done, uint64_t new_done)
|
||||
{
|
||||
const auto now_time = now();
|
||||
const TimePoint elapsed = now_time - last_stats_report_time + 1;
|
||||
|
||||
out
|
||||
<< endl
|
||||
<< done << " sfens, "
|
||||
<< new_done * 1000 / elapsed << " sfens/second, "
|
||||
<< "at " << now_string() << sync_endl;
|
||||
|
||||
last_stats_report_time = now_time;
|
||||
|
||||
out = sync_region_cout.new_region();
|
||||
}
|
||||
|
||||
void Gensfen::maybe_report(uint64_t done)
|
||||
{
|
||||
if (done % REPORT_DOT_EVERY == 0)
|
||||
{
|
||||
std::lock_guard lock(stats_mutex);
|
||||
|
||||
if (last_stats_report_time == 0)
|
||||
{
|
||||
last_stats_report_time = now();
|
||||
out = sync_region_cout.new_region();
|
||||
}
|
||||
|
||||
if (done != 0)
|
||||
{
|
||||
out << '.';
|
||||
|
||||
if (done % REPORT_STATS_EVERY == 0)
|
||||
{
|
||||
report(done, REPORT_STATS_EVERY);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Command to generate a game record
|
||||
void gensfen(istringstream& is)
|
||||
{
|
||||
// Number of generated game records default = 8 billion phases (Ponanza specification)
|
||||
uint64_t loop_max = 8000000000UL;
|
||||
|
||||
Gensfen::Params params;
|
||||
|
||||
// Add a random number to the end of the file name.
|
||||
bool random_file_name = false;
|
||||
std::string sfen_format = "binpack";
|
||||
|
||||
string token;
|
||||
while (true)
|
||||
{
|
||||
token = "";
|
||||
is >> token;
|
||||
if (token == "")
|
||||
break;
|
||||
|
||||
if (token == "depth")
|
||||
is >> params.search_depth_min;
|
||||
else if (token == "depth2")
|
||||
is >> params.search_depth_max;
|
||||
else if (token == "nodes")
|
||||
is >> params.nodes;
|
||||
else if (token == "loop")
|
||||
is >> loop_max;
|
||||
else if (token == "output_file_name")
|
||||
is >> params.output_file_name;
|
||||
else if (token == "eval_limit")
|
||||
is >> params.eval_limit;
|
||||
else if (token == "random_move_minply")
|
||||
is >> params.random_move_minply;
|
||||
else if (token == "random_move_maxply")
|
||||
is >> params.random_move_maxply;
|
||||
else if (token == "random_move_count")
|
||||
is >> params.random_move_count;
|
||||
else if (token == "random_move_like_apery")
|
||||
is >> params.random_move_like_apery;
|
||||
else if (token == "random_multi_pv")
|
||||
is >> params.random_multi_pv;
|
||||
else if (token == "random_multi_pv_diff")
|
||||
is >> params.random_multi_pv_diff;
|
||||
else if (token == "random_multi_pv_depth")
|
||||
is >> params.random_multi_pv_depth;
|
||||
else if (token == "write_minply")
|
||||
is >> params.write_minply;
|
||||
else if (token == "write_maxply")
|
||||
is >> params.write_maxply;
|
||||
else if (token == "save_every")
|
||||
is >> params.save_every;
|
||||
else if (token == "book")
|
||||
is >> params.book;
|
||||
else if (token == "random_file_name")
|
||||
is >> random_file_name;
|
||||
// Accept also the old option name.
|
||||
else if (token == "use_draw_in_training_data_generation" || token == "write_out_draw_game_in_training_data_generation")
|
||||
is >> params.write_out_draw_game_in_training_data_generation;
|
||||
// Accept also the old option name.
|
||||
else if (token == "use_game_draw_adjudication" || token == "detect_draw_by_consecutive_low_score")
|
||||
is >> params.detect_draw_by_consecutive_low_score;
|
||||
else if (token == "detect_draw_by_insufficient_mating_material")
|
||||
is >> params.detect_draw_by_insufficient_mating_material;
|
||||
else if (token == "sfen_format")
|
||||
is >> sfen_format;
|
||||
else if (token == "seed")
|
||||
is >> params.seed;
|
||||
else if (token == "set_recommended_uci_options")
|
||||
{
|
||||
UCI::setoption("Contempt", "0");
|
||||
UCI::setoption("Skill Level", "20");
|
||||
UCI::setoption("UCI_Chess960", "false");
|
||||
UCI::setoption("UCI_AnalyseMode", "false");
|
||||
UCI::setoption("UCI_LimitStrength", "false");
|
||||
UCI::setoption("PruneAtShallowDepth", "false");
|
||||
UCI::setoption("EnableTranspositionTable", "true");
|
||||
}
|
||||
else if (token == "ensure_quiet")
|
||||
{
|
||||
params.ensure_quiet = true;
|
||||
}
|
||||
else
|
||||
cout << "ERROR: Ignoring unknown option " << token << endl;
|
||||
}
|
||||
|
||||
if (!sfen_format.empty())
|
||||
{
|
||||
if (sfen_format == "bin")
|
||||
params.sfen_format = SfenOutputType::Bin;
|
||||
else if (sfen_format == "binpack")
|
||||
params.sfen_format = SfenOutputType::Binpack;
|
||||
else
|
||||
cout << "WARNING: Unknown sfen format `" << sfen_format << "`. Using bin\n";
|
||||
}
|
||||
|
||||
if (params.ensure_quiet)
|
||||
{
|
||||
// Otherwise we can't ensure quiet positions...
|
||||
UCI::setoption("EnableTranspositionTable", "false");
|
||||
}
|
||||
|
||||
if (random_file_name)
|
||||
{
|
||||
// Give a random number to output_file_name at this point.
|
||||
// Do not use std::random_device(). Because it always the same integers on MinGW.
|
||||
PRNG r(params.seed);
|
||||
|
||||
// Just in case, reassign the random numbers.
|
||||
for (int i = 0; i < 10; ++i)
|
||||
r.rand(1);
|
||||
|
||||
auto to_hex = [](uint64_t u) {
|
||||
std::stringstream ss;
|
||||
ss << std::hex << u;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
// I don't want to wear 64bit numbers by accident, so I'next_move going to make a 64bit number 2 just in case.
|
||||
params.output_file_name += "_" + to_hex(r.rand<uint64_t>()) + to_hex(r.rand<uint64_t>());
|
||||
}
|
||||
|
||||
params.enforce_constraints();
|
||||
|
||||
std::cout << "INFO: Executing gensfen command\n";
|
||||
|
||||
std::cout << "INFO: Parameters:\n";
|
||||
std::cout
|
||||
<< " - search_depth_min = " << params.search_depth_min << endl
|
||||
<< " - search_depth_max = " << params.search_depth_max << endl
|
||||
<< " - nodes = " << params.nodes << endl
|
||||
<< " - num sfens to generate = " << loop_max << endl
|
||||
<< " - eval_limit = " << params.eval_limit << endl
|
||||
<< " - num threads (UCI) = " << params.num_threads << endl
|
||||
<< " - random_move_minply = " << params.random_move_minply << endl
|
||||
<< " - random_move_maxply = " << params.random_move_maxply << endl
|
||||
<< " - random_move_count = " << params.random_move_count << endl
|
||||
<< " - random_move_like_apery = " << params.random_move_like_apery << endl
|
||||
<< " - random_multi_pv = " << params.random_multi_pv << endl
|
||||
<< " - random_multi_pv_diff = " << params.random_multi_pv_diff << endl
|
||||
<< " - random_multi_pv_depth = " << params.random_multi_pv_depth << endl
|
||||
<< " - write_minply = " << params.write_minply << endl
|
||||
<< " - write_maxply = " << params.write_maxply << endl
|
||||
<< " - book = " << params.book << endl
|
||||
<< " - output_file_name = " << params.output_file_name << endl
|
||||
<< " - save_every = " << params.save_every << endl
|
||||
<< " - random_file_name = " << random_file_name << endl
|
||||
<< " - write_drawn_games = " << params.write_out_draw_game_in_training_data_generation << endl
|
||||
<< " - draw by low score = " << params.detect_draw_by_consecutive_low_score << endl
|
||||
<< " - draw by insuff. mat. = " << params.detect_draw_by_insufficient_mating_material << endl;
|
||||
|
||||
// Show if the training data generator uses NNUE.
|
||||
Eval::NNUE::verify_eval_file_loaded();
|
||||
|
||||
Threads.main()->ponder = false;
|
||||
|
||||
Gensfen gensfen(params);
|
||||
gensfen.generate(loop_max);
|
||||
|
||||
std::cout << "INFO: Gensfen finished." << endl;
|
||||
}
|
||||
}
|
||||
14
src/tools/gensfen.h
Normal file
14
src/tools/gensfen.h
Normal file
@@ -0,0 +1,14 @@
|
||||
#ifndef _GENSFEN_H_
|
||||
#define _GENSFEN_H_
|
||||
|
||||
#include "position.h"
|
||||
|
||||
#include <sstream>
|
||||
|
||||
namespace Tools {
|
||||
|
||||
// Automatic generation of teacher position
|
||||
void gensfen(std::istringstream& is);
|
||||
}
|
||||
|
||||
#endif
|
||||
488
src/tools/gensfen_nonpv.cpp
Normal file
488
src/tools/gensfen_nonpv.cpp
Normal file
@@ -0,0 +1,488 @@
|
||||
#include "gensfen_nonpv.h"
|
||||
|
||||
#include "sfen_writer.h"
|
||||
#include "packed_sfen.h"
|
||||
#include "opening_book.h"
|
||||
|
||||
#include "misc.h"
|
||||
#include "position.h"
|
||||
#include "thread.h"
|
||||
#include "tt.h"
|
||||
#include "uci.h"
|
||||
|
||||
#include "extra/nnue_data_binpack_format.h"
|
||||
|
||||
#include "nnue/evaluate_nnue.h"
|
||||
|
||||
#include "syzygy/tbprobe.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <climits>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <limits>
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <random>
|
||||
#include <shared_mutex>
|
||||
#include <sstream>
|
||||
#include <unordered_set>
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace Tools
|
||||
{
|
||||
// Class to generate sfen with multiple threads
|
||||
struct GensfenNonPv
|
||||
{
|
||||
struct Params
|
||||
{
|
||||
// The depth for search on the fens gathered during exploration
|
||||
int search_depth = 3;
|
||||
|
||||
// the min/max number of nodes to use for exploration per ply
|
||||
int exploration_min_nodes = 5000;
|
||||
int exploration_max_nodes = 15000;
|
||||
|
||||
// The pct of positions explored that are saved for rescoring
|
||||
float exploration_save_rate = 0.01;
|
||||
|
||||
// Upper limit of evaluation value of generated situation
|
||||
int eval_limit = 4000;
|
||||
|
||||
// the upper limit on evaluation during exploration selfplay
|
||||
int exploration_eval_limit = 4000;
|
||||
|
||||
int exploration_max_ply = 200;
|
||||
|
||||
int exploration_min_pieces = 8;
|
||||
|
||||
std::string output_file_name = "generated_gensfen_nonpv";
|
||||
|
||||
SfenOutputType sfen_format = SfenOutputType::Binpack;
|
||||
|
||||
std::string seed;
|
||||
|
||||
int num_threads;
|
||||
|
||||
std::string book;
|
||||
|
||||
bool smart_fen_skipping = false;
|
||||
|
||||
void enforce_constraints()
|
||||
{
|
||||
// Limit the maximum to a one-stop score. (Otherwise you might not end the loop)
|
||||
eval_limit = std::min(eval_limit, (int)mate_in(2));
|
||||
exploration_eval_limit = std::min(eval_limit, (int)mate_in(2));
|
||||
exploration_min_nodes = std::max(100, exploration_min_nodes);
|
||||
exploration_max_nodes = std::max(exploration_min_nodes, exploration_max_nodes);
|
||||
|
||||
num_threads = Options["Threads"];
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr uint64_t REPORT_DOT_EVERY = 5000;
|
||||
static constexpr uint64_t REPORT_STATS_EVERY = 200000;
|
||||
static_assert(REPORT_STATS_EVERY % REPORT_DOT_EVERY == 0);
|
||||
|
||||
GensfenNonPv(
|
||||
const Params& prm
|
||||
) :
|
||||
params(prm),
|
||||
prng(prm.seed),
|
||||
sfen_writer(prm.output_file_name, prm.num_threads, std::numeric_limits<uint64_t>::max(), prm.sfen_format)
|
||||
{
|
||||
if (!prm.book.empty())
|
||||
{
|
||||
opening_book = open_opening_book(prm.book, prng);
|
||||
if (opening_book == nullptr)
|
||||
{
|
||||
std::cout << "WARNING: Failed to open opening book " << prm.book << ". Falling back to startpos.\n";
|
||||
}
|
||||
}
|
||||
|
||||
// Output seed to veryfy by the user if it's not identical by chance.
|
||||
std::cout << prng << std::endl;
|
||||
}
|
||||
|
||||
void generate(uint64_t limit);
|
||||
|
||||
private:
|
||||
Params params;
|
||||
|
||||
PRNG prng;
|
||||
|
||||
std::mutex stats_mutex;
|
||||
TimePoint last_stats_report_time;
|
||||
|
||||
// sfen exporter
|
||||
SfenWriter sfen_writer;
|
||||
|
||||
SynchronizedRegionLogger::Region out;
|
||||
|
||||
std::unique_ptr<OpeningBook> opening_book;
|
||||
|
||||
static void set_gensfen_search_limits();
|
||||
|
||||
void generate_worker(
|
||||
Thread& th,
|
||||
std::atomic<uint64_t>& counter,
|
||||
uint64_t limit);
|
||||
|
||||
bool commit_psv(
|
||||
Thread& th,
|
||||
PSVector& sfens,
|
||||
std::atomic<uint64_t>& counter,
|
||||
uint64_t limit);
|
||||
|
||||
PSVector do_exploration(
|
||||
Thread& th,
|
||||
int count);
|
||||
|
||||
void report(uint64_t done, uint64_t new_done);
|
||||
|
||||
void maybe_report(uint64_t done);
|
||||
};
|
||||
|
||||
void GensfenNonPv::set_gensfen_search_limits()
|
||||
{
|
||||
// About Search::Limits
|
||||
// Be careful because this member variable is global and affects other threads.
|
||||
auto& limits = Search::Limits;
|
||||
|
||||
// Make the search equivalent to the "go infinite" command. (Because it is troublesome if time management is done)
|
||||
limits.infinite = true;
|
||||
|
||||
// Since PV is an obstacle when displayed, erase it.
|
||||
limits.silent = true;
|
||||
|
||||
// If you use this, it will be compared with the accumulated nodes of each thread. Therefore, do not use it.
|
||||
limits.nodes = 0;
|
||||
|
||||
// depth is also processed by the one passed as an argument of Tools::search().
|
||||
limits.depth = 0;
|
||||
}
|
||||
|
||||
void GensfenNonPv::generate(uint64_t limit)
|
||||
{
|
||||
last_stats_report_time = 0;
|
||||
|
||||
set_gensfen_search_limits();
|
||||
|
||||
std::atomic<uint64_t> counter{0};
|
||||
Threads.execute_with_workers([&counter, limit, this](Thread& th) {
|
||||
generate_worker(th, counter, limit);
|
||||
});
|
||||
Threads.wait_for_workers_finished();
|
||||
|
||||
sfen_writer.flush();
|
||||
|
||||
if (limit % REPORT_STATS_EVERY != 0)
|
||||
{
|
||||
report(limit, limit % REPORT_STATS_EVERY);
|
||||
}
|
||||
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
PSVector GensfenNonPv::do_exploration(
|
||||
Thread& th,
|
||||
int count)
|
||||
{
|
||||
constexpr int max_depth = 30;
|
||||
|
||||
PSVector psv;
|
||||
|
||||
std::vector<StateInfo, AlignedAllocator<StateInfo>> states(
|
||||
max_depth + MAX_PLY /* == search_depth_min + α */);
|
||||
|
||||
th.set_eval_callback([this, &psv](Position& pos) {
|
||||
if ((double)prng.rand<uint64_t>() / std::numeric_limits<uint64_t>::max() < params.exploration_save_rate)
|
||||
{
|
||||
psv.emplace_back();
|
||||
pos.sfen_pack(psv.back().sfen);
|
||||
}
|
||||
});
|
||||
|
||||
auto& pos = th.rootPos;
|
||||
StateInfo si;
|
||||
|
||||
for (int i = 0; i < count; ++i)
|
||||
{
|
||||
if (opening_book != nullptr)
|
||||
{
|
||||
auto& fen = opening_book->next_fen();
|
||||
pos.set(fen, false, &si, &th);
|
||||
}
|
||||
else
|
||||
{
|
||||
pos.set(StartFEN, false, &si, &th);
|
||||
}
|
||||
|
||||
for(int ply = 0; ply < params.exploration_max_ply; ++ply)
|
||||
{
|
||||
auto nodes = prng.rand(params.exploration_max_nodes - params.exploration_min_nodes + 1) + params.exploration_min_nodes;
|
||||
|
||||
auto [search_value, search_pv] = Search::search(pos, max_depth, 1, nodes);
|
||||
|
||||
if (search_pv.empty())
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
if (std::abs(search_value) > params.exploration_eval_limit)
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
pos.do_move(search_pv[0], states[ply]);
|
||||
|
||||
if (popcount(pos.pieces()) < params.exploration_min_pieces)
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
th.clear_eval_callback();
|
||||
|
||||
return psv;
|
||||
}
|
||||
|
||||
void GensfenNonPv::generate_worker(
|
||||
Thread& th,
|
||||
std::atomic<uint64_t>& counter,
|
||||
uint64_t limit)
|
||||
{
|
||||
constexpr int exploration_batch_size = 1;
|
||||
|
||||
StateInfo si;
|
||||
|
||||
PSVector psv;
|
||||
|
||||
// end flag
|
||||
bool quit = false;
|
||||
|
||||
// repeat until the specified number of times
|
||||
while (!quit)
|
||||
{
|
||||
// It is necessary to set a dependent thread for Position.
|
||||
// When parallelizing, Threads (since this is a vector<Thread*>,
|
||||
// Do the same for up to Threads[0]...Threads[thread_num-1].
|
||||
auto& pos = th.rootPos;
|
||||
|
||||
auto packed_sfens = do_exploration(th, exploration_batch_size);
|
||||
psv.clear();
|
||||
|
||||
for (auto& ps : packed_sfens)
|
||||
{
|
||||
pos.set_from_packed_sfen(ps.sfen, &si, &th);
|
||||
pos.state()->rule50 = 0;
|
||||
|
||||
if (params.smart_fen_skipping && pos.checkers())
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
auto [search_value, search_pv] = Search::search(pos, params.search_depth, 1);
|
||||
|
||||
if (search_pv.empty())
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if (std::abs(search_value) > params.eval_limit)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if (params.smart_fen_skipping && pos.capture_or_promotion(search_pv[0]))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
auto& new_ps = psv.emplace_back();
|
||||
pos.sfen_pack(new_ps.sfen);
|
||||
new_ps.score = search_value;
|
||||
new_ps.move = search_pv[0];
|
||||
new_ps.gamePly = 1;
|
||||
new_ps.game_result = 0;
|
||||
new_ps.padding = 0;
|
||||
}
|
||||
|
||||
quit = commit_psv(th, psv, counter, limit);
|
||||
}
|
||||
}
|
||||
|
||||
// Write out the phases loaded in sfens to a file.
|
||||
// result: win/loss in the next phase after the final phase in sfens
|
||||
// 1 when winning. -1 when losing. Pass 0 for a draw.
|
||||
// Return value: true if the specified number of
|
||||
// sfens has already been reached and the process ends.
|
||||
bool GensfenNonPv::commit_psv(
|
||||
Thread& th,
|
||||
PSVector& sfens,
|
||||
std::atomic<uint64_t>& counter,
|
||||
uint64_t limit)
|
||||
{
|
||||
// Write sfens in move order to make potential compression easier
|
||||
for (auto& sfen : sfens)
|
||||
{
|
||||
// Return true if there is already enough data generated.
|
||||
const auto iter = counter.fetch_add(1);
|
||||
if (iter >= limit)
|
||||
return true;
|
||||
|
||||
// because `iter` was done, now we do one more
|
||||
maybe_report(iter + 1);
|
||||
|
||||
// Write out one sfen.
|
||||
sfen_writer.write(th.thread_idx(), sfen);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void GensfenNonPv::report(uint64_t done, uint64_t new_done)
|
||||
{
|
||||
const auto now_time = now();
|
||||
const TimePoint elapsed = now_time - last_stats_report_time + 1;
|
||||
|
||||
out
|
||||
<< endl
|
||||
<< done << " sfens, "
|
||||
<< new_done * 1000 / elapsed << " sfens/second, "
|
||||
<< "at " << now_string() << sync_endl;
|
||||
|
||||
last_stats_report_time = now_time;
|
||||
|
||||
out = sync_region_cout.new_region();
|
||||
}
|
||||
|
||||
void GensfenNonPv::maybe_report(uint64_t done)
|
||||
{
|
||||
if (done % REPORT_DOT_EVERY == 0)
|
||||
{
|
||||
std::lock_guard lock(stats_mutex);
|
||||
|
||||
if (last_stats_report_time == 0)
|
||||
{
|
||||
last_stats_report_time = now();
|
||||
out = sync_region_cout.new_region();
|
||||
}
|
||||
|
||||
if (done != 0)
|
||||
{
|
||||
out << '.';
|
||||
|
||||
if (done % REPORT_STATS_EVERY == 0)
|
||||
{
|
||||
report(done, REPORT_STATS_EVERY);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Command to generate a game record
|
||||
void gensfen_nonpv(istringstream& is)
|
||||
{
|
||||
// Number of generated game records default = 8 billion phases (Ponanza specification)
|
||||
GensfenNonPv::Params params;
|
||||
|
||||
uint64_t count = 1'000'000;
|
||||
|
||||
// Add a random number to the end of the file name.
|
||||
std::string sfen_format = "binpack";
|
||||
|
||||
string token;
|
||||
while (true)
|
||||
{
|
||||
token = "";
|
||||
is >> token;
|
||||
if (token == "")
|
||||
break;
|
||||
|
||||
if (token == "depth")
|
||||
is >> params.search_depth;
|
||||
else if (token == "count")
|
||||
is >> count;
|
||||
else if (token == "output_file")
|
||||
is >> params.output_file_name;
|
||||
else if (token == "exploration_eval_limit")
|
||||
is >> params.exploration_eval_limit;
|
||||
else if (token == "eval_limit")
|
||||
is >> params.eval_limit;
|
||||
else if (token == "exploration_min_nodes")
|
||||
is >> params.exploration_min_nodes;
|
||||
else if (token == "exploration_max_nodes")
|
||||
is >> params.exploration_max_nodes;
|
||||
else if (token == "exploration_min_pieces")
|
||||
is >> params.exploration_min_pieces;
|
||||
else if (token == "exploration_save_rate")
|
||||
is >> params.exploration_save_rate;
|
||||
else if (token == "book")
|
||||
is >> params.book;
|
||||
else if (token == "sfen_format")
|
||||
is >> sfen_format;
|
||||
else if (token == "seed")
|
||||
is >> params.seed;
|
||||
else if (token == "smart_fen_skipping")
|
||||
params.smart_fen_skipping = true;
|
||||
else if (token == "set_recommended_uci_options")
|
||||
{
|
||||
UCI::setoption("Contempt", "0");
|
||||
UCI::setoption("Skill Level", "20");
|
||||
UCI::setoption("UCI_Chess960", "false");
|
||||
UCI::setoption("UCI_AnalyseMode", "false");
|
||||
UCI::setoption("UCI_LimitStrength", "false");
|
||||
UCI::setoption("PruneAtShallowDepth", "false");
|
||||
UCI::setoption("EnableTranspositionTable", "true");
|
||||
}
|
||||
else
|
||||
cout << "ERROR: Ignoring unknown option " << token << endl;
|
||||
}
|
||||
|
||||
if (!sfen_format.empty())
|
||||
{
|
||||
if (sfen_format == "bin")
|
||||
params.sfen_format = SfenOutputType::Bin;
|
||||
else if (sfen_format == "binpack")
|
||||
params.sfen_format = SfenOutputType::Binpack;
|
||||
else
|
||||
cout << "WARNING: Unknown sfen format `" << sfen_format << "`. Using bin\n";
|
||||
}
|
||||
|
||||
params.enforce_constraints();
|
||||
|
||||
std::cout << "INFO: Executing gensfen_nonpv command\n";
|
||||
|
||||
std::cout << "INFO: Parameters:\n";
|
||||
std::cout
|
||||
<< " - search_depth = " << params.search_depth << endl
|
||||
<< " - output_file = " << params.output_file_name << endl
|
||||
<< " - exploration_eval_limit = " << params.exploration_eval_limit << endl
|
||||
<< " - eval_limit = " << params.eval_limit << endl
|
||||
<< " - exploration_min_nodes = " << params.exploration_min_nodes << endl
|
||||
<< " - exploration_max_nodes = " << params.exploration_max_nodes << endl
|
||||
<< " - exploration_min_pieces = " << params.exploration_min_pieces << endl
|
||||
<< " - exploration_save_rate = " << params.exploration_save_rate << endl
|
||||
<< " - book = " << params.book << endl
|
||||
<< " - sfen_format = " << sfen_format << endl
|
||||
<< " - seed = " << params.seed << endl
|
||||
<< " - count = " << count << endl;
|
||||
|
||||
// Show if the training data generator uses NNUE.
|
||||
Eval::NNUE::verify_eval_file_loaded();
|
||||
|
||||
Threads.main()->ponder = false;
|
||||
|
||||
GensfenNonPv gensfen(params);
|
||||
gensfen.generate(count);
|
||||
|
||||
std::cout << "INFO: gensfen_nonpv finished." << endl;
|
||||
}
|
||||
}
|
||||
12
src/tools/gensfen_nonpv.h
Normal file
12
src/tools/gensfen_nonpv.h
Normal file
@@ -0,0 +1,12 @@
|
||||
#ifndef _GENSFEN_NONPV_H_
|
||||
#define _GENSFEN_NONPV_H_
|
||||
|
||||
#include <sstream>
|
||||
|
||||
namespace Tools {
|
||||
|
||||
// Automatic generation of teacher position
|
||||
void gensfen_nonpv(std::istringstream& is);
|
||||
}
|
||||
|
||||
#endif
|
||||
43
src/tools/opening_book.cpp
Normal file
43
src/tools/opening_book.cpp
Normal file
@@ -0,0 +1,43 @@
|
||||
#include "opening_book.h"
|
||||
|
||||
#include <fstream>
|
||||
|
||||
namespace Tools {
|
||||
|
||||
EpdOpeningBook::EpdOpeningBook(const std::string& file, PRNG& prng) :
|
||||
OpeningBook(file)
|
||||
{
|
||||
std::ifstream in(file);
|
||||
if (!in)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
std::string line;
|
||||
while (std::getline(in, line))
|
||||
{
|
||||
if (line.empty())
|
||||
continue;
|
||||
|
||||
fens.emplace_back(line);
|
||||
}
|
||||
|
||||
Algo::shuffle(fens, prng);
|
||||
}
|
||||
|
||||
static bool ends_with(const std::string& lhs, const std::string& end)
|
||||
{
|
||||
if (end.size() > lhs.size()) return false;
|
||||
|
||||
return std::equal(end.rbegin(), end.rend(), lhs.rbegin());
|
||||
}
|
||||
|
||||
std::unique_ptr<OpeningBook> open_opening_book(const std::string& filename, PRNG& prng)
|
||||
{
|
||||
if (ends_with(filename, ".epd"))
|
||||
return std::make_unique<EpdOpeningBook>(filename, prng);
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
}
|
||||
60
src/tools/opening_book.h
Normal file
60
src/tools/opening_book.h
Normal file
@@ -0,0 +1,60 @@
|
||||
#ifndef LEARN_OPENING_BOOK_H
|
||||
#define LEARN_OPENING_BOOK_H
|
||||
|
||||
#include "misc.h"
|
||||
#include "position.h"
|
||||
#include "thread.h"
|
||||
|
||||
#include <vector>
|
||||
#include <random>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
|
||||
namespace Tools {
|
||||
|
||||
struct OpeningBook {
|
||||
|
||||
const std::string& next_fen()
|
||||
{
|
||||
assert(fens.size() > 0);
|
||||
|
||||
std::unique_lock lock(mutex);
|
||||
|
||||
auto& fen = fens[current_index++];
|
||||
if (current_index >= fens.size())
|
||||
current_index = 0;
|
||||
|
||||
return fen;
|
||||
}
|
||||
|
||||
std::size_t size() const { return fens.size(); }
|
||||
|
||||
const std::string& get_filename() const { return filename; }
|
||||
|
||||
protected:
|
||||
OpeningBook(const std::string& file) :
|
||||
filename(file),
|
||||
current_index(0)
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
std::mutex mutex;
|
||||
std::string filename;
|
||||
std::vector<std::string> fens;
|
||||
std::size_t current_index;
|
||||
};
|
||||
|
||||
struct EpdOpeningBook : OpeningBook {
|
||||
|
||||
EpdOpeningBook(const std::string& file, PRNG& prng);
|
||||
};
|
||||
|
||||
std::unique_ptr<OpeningBook> open_opening_book(const std::string& filename, PRNG& prng);
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
46
src/tools/packed_sfen.h
Normal file
46
src/tools/packed_sfen.h
Normal file
@@ -0,0 +1,46 @@
|
||||
#ifndef _PACKED_SFEN_H_
|
||||
#define _PACKED_SFEN_H_
|
||||
|
||||
#include <vector>
|
||||
#include <cstdint>
|
||||
|
||||
namespace Tools {
|
||||
|
||||
// packed sfen
|
||||
struct PackedSfen { std::uint8_t data[32]; };
|
||||
|
||||
// Structure in which PackedSfen and evaluation value are integrated
|
||||
// If you write different contents for each option, it will be a problem when reusing the teacher game
|
||||
// For the time being, write all the following members regardless of the options.
|
||||
struct PackedSfenValue
|
||||
{
|
||||
// phase
|
||||
PackedSfen sfen;
|
||||
|
||||
// Evaluation value returned from Tools::search()
|
||||
std::int16_t score;
|
||||
|
||||
// PV first move
|
||||
// Used when finding the match rate with the teacher
|
||||
std::uint16_t move;
|
||||
|
||||
// Trouble of the phase from the initial phase.
|
||||
std::uint16_t gamePly;
|
||||
|
||||
// 1 if the player on this side ultimately wins the game. -1 if you are losing.
|
||||
// 0 if a draw is reached.
|
||||
// The draw is in the teacher position generation command gensfen,
|
||||
// Only write if LEARN_GENSFEN_DRAW_RESULT is enabled.
|
||||
std::int8_t game_result;
|
||||
|
||||
// When exchanging the file that wrote the teacher aspect with other people
|
||||
//Because this structure size is not fixed, pad it so that it is 40 bytes in any environment.
|
||||
std::uint8_t padding;
|
||||
|
||||
// 32 + 2 + 2 + 2 + 1 + 1 = 40bytes
|
||||
};
|
||||
|
||||
// Phase array: PSVector stands for packed sfen vector.
|
||||
using PSVector = std::vector<PackedSfenValue>;
|
||||
}
|
||||
#endif
|
||||
386
src/tools/sfen_packer.cpp
Normal file
386
src/tools/sfen_packer.cpp
Normal file
@@ -0,0 +1,386 @@
|
||||
#include "sfen_packer.h"
|
||||
|
||||
#include "packed_sfen.h"
|
||||
|
||||
#include "misc.h"
|
||||
#include "position.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
#include <cstring> // std::memset()
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace Tools {
|
||||
|
||||
// Class that handles bitstream
|
||||
// useful when doing aspect encoding
|
||||
struct BitStream
|
||||
{
|
||||
// Set the memory to store the data in advance.
|
||||
// Assume that memory is cleared to 0.
|
||||
void set_data(std::uint8_t* data_) { data = data_; reset(); }
|
||||
|
||||
// Get the pointer passed in set_data().
|
||||
uint8_t* get_data() const { return data; }
|
||||
|
||||
// Get the cursor.
|
||||
int get_cursor() const { return bit_cursor; }
|
||||
|
||||
// reset the cursor
|
||||
void reset() { bit_cursor = 0; }
|
||||
|
||||
// Write 1bit to the stream.
|
||||
// If b is non-zero, write out 1. If 0, write 0.
|
||||
void write_one_bit(int b)
|
||||
{
|
||||
if (b)
|
||||
data[bit_cursor / 8] |= 1 << (bit_cursor & 7);
|
||||
|
||||
++bit_cursor;
|
||||
}
|
||||
|
||||
// Get 1 bit from the stream.
|
||||
int read_one_bit()
|
||||
{
|
||||
int b = (data[bit_cursor / 8] >> (bit_cursor & 7)) & 1;
|
||||
++bit_cursor;
|
||||
|
||||
return b;
|
||||
}
|
||||
|
||||
// write n bits of data
|
||||
// Data shall be written out from the lower order of d.
|
||||
void write_n_bit(int d, int n)
|
||||
{
|
||||
for (int i = 0; i <n; ++i)
|
||||
write_one_bit(d & (1 << i));
|
||||
}
|
||||
|
||||
// read n bits of data
|
||||
// Reverse conversion of write_n_bit().
|
||||
int read_n_bit(int n)
|
||||
{
|
||||
int result = 0;
|
||||
for (int i = 0; i < n; ++i)
|
||||
result |= read_one_bit() ? (1 << i) : 0;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
// Next bit position to read/write.
|
||||
int bit_cursor;
|
||||
|
||||
// data entity
|
||||
std::uint8_t* data;
|
||||
};
|
||||
|
||||
// Class for compressing/decompressing sfen
|
||||
// sfen can be packed to 256bit (32bytes) by Huffman coding.
|
||||
// This is proven by mini. The above is Huffman coding.
|
||||
//
|
||||
// Internal format = 1-bit turn + 7-bit king position *2 + piece on board (Huffman coding) + hand piece (Huffman coding)
|
||||
// Side to move (White = 0, Black = 1) (1bit)
|
||||
// White King Position (6 bits)
|
||||
// Black King Position (6 bits)
|
||||
// Huffman Encoding of the board
|
||||
// Castling availability (1 bit x 4)
|
||||
// En passant square (1 or 1 + 6 bits)
|
||||
// Rule 50 (6 bits)
|
||||
// Game play (8 bits)
|
||||
//
|
||||
// TODO(someone): Rename SFEN to FEN.
|
||||
//
|
||||
struct SfenPacker
|
||||
{
|
||||
void pack(const Position& pos);
|
||||
|
||||
// sfen packed by pack() (256bit = 32bytes)
|
||||
// Or sfen to decode with unpack()
|
||||
uint8_t *data; // uint8_t[32];
|
||||
|
||||
BitStream stream;
|
||||
|
||||
// Output the board pieces to stream.
|
||||
void write_board_piece_to_stream(Piece pc);
|
||||
|
||||
// Read one board piece from stream
|
||||
Piece read_board_piece_from_stream();
|
||||
};
|
||||
|
||||
|
||||
// Huffman coding
|
||||
// * is simplified from mini encoding to make conversion easier.
|
||||
//
|
||||
// Huffman Encoding
|
||||
//
|
||||
// Empty xxxxxxx0
|
||||
// Pawn xxxxx001 + 1 bit (Color)
|
||||
// Knight xxxxx011 + 1 bit (Color)
|
||||
// Bishop xxxxx101 + 1 bit (Color)
|
||||
// Rook xxxxx111 + 1 bit (Color)
|
||||
// Queen xxxx1001 + 1 bit (Color)
|
||||
//
|
||||
// Worst case:
|
||||
// - 32 empty squares 32 bits
|
||||
// - 30 pieces 150 bits
|
||||
// - 2 kings 12 bits
|
||||
// - castling rights 4 bits
|
||||
// - ep square 7 bits
|
||||
// - rule50 7 bits
|
||||
// - game ply 16 bits
|
||||
// - TOTAL 228 bits < 256 bits
|
||||
|
||||
struct HuffmanedPiece
|
||||
{
|
||||
int code; // how it will be coded
|
||||
int bits; // How many bits do you have
|
||||
};
|
||||
|
||||
constexpr HuffmanedPiece huffman_table[] =
|
||||
{
|
||||
{0b0000,1}, // NO_PIECE
|
||||
{0b0001,4}, // PAWN
|
||||
{0b0011,4}, // KNIGHT
|
||||
{0b0101,4}, // BISHOP
|
||||
{0b0111,4}, // ROOK
|
||||
{0b1001,4}, // QUEEN
|
||||
};
|
||||
|
||||
// Pack sfen and store in data[32].
|
||||
void SfenPacker::pack(const Position& pos)
|
||||
{
|
||||
memset(data, 0, 32 /* 256bit */);
|
||||
stream.set_data(data);
|
||||
|
||||
// turn
|
||||
// Side to move.
|
||||
stream.write_one_bit((int)(pos.side_to_move()));
|
||||
|
||||
// 7-bit positions for leading and trailing balls
|
||||
// White king and black king, 6 bits for each.
|
||||
for(auto c: Colors)
|
||||
stream.write_n_bit(pos.king_square(c), 6);
|
||||
|
||||
// Write the pieces on the board other than the kings.
|
||||
for (Rank r = RANK_8; r >= RANK_1; --r)
|
||||
{
|
||||
for (File f = FILE_A; f <= FILE_H; ++f)
|
||||
{
|
||||
Piece pc = pos.piece_on(make_square(f, r));
|
||||
if (type_of(pc) == KING)
|
||||
continue;
|
||||
write_board_piece_to_stream(pc);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(someone): Support chess960.
|
||||
stream.write_one_bit(pos.can_castle(WHITE_OO));
|
||||
stream.write_one_bit(pos.can_castle(WHITE_OOO));
|
||||
stream.write_one_bit(pos.can_castle(BLACK_OO));
|
||||
stream.write_one_bit(pos.can_castle(BLACK_OOO));
|
||||
|
||||
if (pos.ep_square() == SQ_NONE) {
|
||||
stream.write_one_bit(0);
|
||||
}
|
||||
else {
|
||||
stream.write_one_bit(1);
|
||||
stream.write_n_bit(static_cast<int>(pos.ep_square()), 6);
|
||||
}
|
||||
|
||||
stream.write_n_bit(pos.state()->rule50, 6);
|
||||
|
||||
const int fm = 1 + (pos.game_ply()-(pos.side_to_move() == BLACK)) / 2;
|
||||
stream.write_n_bit(fm, 8);
|
||||
|
||||
// Write high bits of half move. This is a fix for the
|
||||
// limited range of half move counter.
|
||||
// This is backwards compatibile.
|
||||
stream.write_n_bit(fm >> 8, 8);
|
||||
|
||||
// Write the highest bit of rule50 at the end. This is a backwards
|
||||
// compatibile fix for rule50 having only 6 bits stored.
|
||||
// This bit is just ignored by the old parsers.
|
||||
stream.write_n_bit(pos.state()->rule50 >> 6, 1);
|
||||
|
||||
assert(stream.get_cursor() <= 256);
|
||||
}
|
||||
|
||||
// Output the board pieces to stream.
|
||||
void SfenPacker::write_board_piece_to_stream(Piece pc)
|
||||
{
|
||||
// piece type
|
||||
PieceType pr = type_of(pc);
|
||||
auto c = huffman_table[pr];
|
||||
stream.write_n_bit(c.code, c.bits);
|
||||
|
||||
if (pc == NO_PIECE)
|
||||
return;
|
||||
|
||||
// first and second flag
|
||||
stream.write_one_bit(color_of(pc));
|
||||
}
|
||||
|
||||
// Read one board piece from stream
|
||||
Piece SfenPacker::read_board_piece_from_stream()
|
||||
{
|
||||
PieceType pr = NO_PIECE_TYPE;
|
||||
int code = 0, bits = 0;
|
||||
while (true)
|
||||
{
|
||||
code |= stream.read_one_bit() << bits;
|
||||
++bits;
|
||||
|
||||
assert(bits <= 6);
|
||||
|
||||
for (pr = NO_PIECE_TYPE; pr <KING; ++pr)
|
||||
if (huffman_table[pr].code == code
|
||||
&& huffman_table[pr].bits == bits)
|
||||
goto Found;
|
||||
}
|
||||
Found:;
|
||||
if (pr == NO_PIECE_TYPE)
|
||||
return NO_PIECE;
|
||||
|
||||
// first and second flag
|
||||
Color c = (Color)stream.read_one_bit();
|
||||
|
||||
return make_piece(c, pr);
|
||||
}
|
||||
|
||||
int set_from_packed_sfen(Position& pos, const PackedSfen& sfen, StateInfo* si, Thread* th)
|
||||
{
|
||||
SfenPacker packer;
|
||||
auto& stream = packer.stream;
|
||||
|
||||
// TODO: separate streams for writing and reading. Here we actually have to
|
||||
// const_cast which is not safe in the long run.
|
||||
stream.set_data(const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(&sfen)));
|
||||
|
||||
pos.clear();
|
||||
std::memset(si, 0, sizeof(StateInfo));
|
||||
std::fill_n(&pos.pieceList[0][0], sizeof(pos.pieceList) / sizeof(Square), SQ_NONE);
|
||||
pos.st = si;
|
||||
|
||||
// Active color
|
||||
pos.sideToMove = (Color)stream.read_one_bit();
|
||||
|
||||
pos.pieceList[W_KING][0] = SQUARE_NB;
|
||||
pos.pieceList[B_KING][0] = SQUARE_NB;
|
||||
|
||||
// First the position of the ball
|
||||
for (auto c : Colors)
|
||||
pos.board[stream.read_n_bit(6)] = make_piece(c, KING);
|
||||
|
||||
// Piece placement
|
||||
for (Rank r = RANK_8; r >= RANK_1; --r)
|
||||
{
|
||||
for (File f = FILE_A; f <= FILE_H; ++f)
|
||||
{
|
||||
auto sq = make_square(f, r);
|
||||
|
||||
// it seems there are already balls
|
||||
Piece pc;
|
||||
if (type_of(pos.board[sq]) != KING)
|
||||
{
|
||||
assert(pos.board[sq] == NO_PIECE);
|
||||
pc = packer.read_board_piece_from_stream();
|
||||
}
|
||||
else
|
||||
{
|
||||
pc = pos.board[sq];
|
||||
// put_piece() will catch ASSERT unless you remove it all.
|
||||
pos.board[sq] = NO_PIECE;
|
||||
}
|
||||
|
||||
// There may be no pieces, so skip in that case.
|
||||
if (pc == NO_PIECE)
|
||||
continue;
|
||||
|
||||
pos.put_piece(Piece(pc), sq);
|
||||
|
||||
if (stream.get_cursor()> 256)
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Castling availability.
|
||||
// TODO(someone): Support chess960.
|
||||
pos.st->castlingRights = 0;
|
||||
if (stream.read_one_bit()) {
|
||||
Square rsq;
|
||||
for (rsq = relative_square(WHITE, SQ_H1); pos.piece_on(rsq) != W_ROOK; --rsq) {}
|
||||
pos.set_castling_right(WHITE, rsq);
|
||||
}
|
||||
if (stream.read_one_bit()) {
|
||||
Square rsq;
|
||||
for (rsq = relative_square(WHITE, SQ_A1); pos.piece_on(rsq) != W_ROOK; ++rsq) {}
|
||||
pos.set_castling_right(WHITE, rsq);
|
||||
}
|
||||
if (stream.read_one_bit()) {
|
||||
Square rsq;
|
||||
for (rsq = relative_square(BLACK, SQ_H1); pos.piece_on(rsq) != B_ROOK; --rsq) {}
|
||||
pos.set_castling_right(BLACK, rsq);
|
||||
}
|
||||
if (stream.read_one_bit()) {
|
||||
Square rsq;
|
||||
for (rsq = relative_square(BLACK, SQ_A1); pos.piece_on(rsq) != B_ROOK; ++rsq) {}
|
||||
pos.set_castling_right(BLACK, rsq);
|
||||
}
|
||||
|
||||
// En passant square. Ignore if no pawn capture is possible
|
||||
if (stream.read_one_bit()) {
|
||||
Square ep_square = static_cast<Square>(stream.read_n_bit(6));
|
||||
pos.st->epSquare = ep_square;
|
||||
|
||||
if (!(pos.attackers_to(pos.st->epSquare) & pos.pieces(pos.sideToMove, PAWN))
|
||||
|| !(pos.pieces(~pos.sideToMove, PAWN) & (pos.st->epSquare + pawn_push(~pos.sideToMove))))
|
||||
pos.st->epSquare = SQ_NONE;
|
||||
}
|
||||
else {
|
||||
pos.st->epSquare = SQ_NONE;
|
||||
}
|
||||
|
||||
// Halfmove clock
|
||||
pos.st->rule50 = stream.read_n_bit(6);
|
||||
|
||||
// Fullmove number
|
||||
pos.gamePly = stream.read_n_bit(8);
|
||||
|
||||
// Read the highest bit of rule50. This was added as a fix for rule50
|
||||
// counter having only 6 bits stored.
|
||||
// In older entries this will just be a zero bit.
|
||||
pos.gamePly |= stream.read_n_bit(8) << 8;
|
||||
|
||||
// Read the highest bit of rule50. This was added as a fix for rule50
|
||||
// counter having only 6 bits stored.
|
||||
// In older entries this will just be a zero bit.
|
||||
pos.st->rule50 |= stream.read_n_bit(1) << 6;
|
||||
|
||||
// Convert from fullmove starting from 1 to gamePly starting from 0,
|
||||
// handle also common incorrect FEN with fullmove = 0.
|
||||
pos.gamePly = std::max(2 * (pos.gamePly - 1), 0) + (pos.sideToMove == BLACK);
|
||||
|
||||
assert(stream.get_cursor() <= 256);
|
||||
|
||||
pos.chess960 = false;
|
||||
pos.thisThread = th;
|
||||
pos.set_state(pos.st);
|
||||
|
||||
assert(pos.pos_is_ok());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
PackedSfen sfen_pack(Position& pos)
|
||||
{
|
||||
PackedSfen sfen;
|
||||
|
||||
SfenPacker sp;
|
||||
sp.data = (uint8_t*)&sfen;
|
||||
sp.pack(pos);
|
||||
|
||||
return sfen;
|
||||
}
|
||||
}
|
||||
20
src/tools/sfen_packer.h
Normal file
20
src/tools/sfen_packer.h
Normal file
@@ -0,0 +1,20 @@
|
||||
#ifndef _SFEN_PACKER_H_
|
||||
#define _SFEN_PACKER_H_
|
||||
|
||||
#include "types.h"
|
||||
|
||||
#include "packed_sfen.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
class Position;
|
||||
struct StateInfo;
|
||||
class Thread;
|
||||
|
||||
namespace Tools {
|
||||
|
||||
int set_from_packed_sfen(Position& pos, const PackedSfen& sfen, StateInfo* si, Thread* th);
|
||||
PackedSfen sfen_pack(Position& pos);
|
||||
}
|
||||
|
||||
#endif
|
||||
352
src/tools/sfen_reader.h
Normal file
352
src/tools/sfen_reader.h
Normal file
@@ -0,0 +1,352 @@
|
||||
#include "sfen_stream.h"
|
||||
|
||||
#include "packed_sfen.h"
|
||||
|
||||
#include "misc.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <deque>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <list>
|
||||
#include <atomic>
|
||||
#include <optional>
|
||||
#include <iostream>
|
||||
#include <cstdint>
|
||||
#include <thread>
|
||||
#include <functional>
|
||||
|
||||
namespace Tools{
|
||||
|
||||
enum struct SfenReaderMode
|
||||
{
|
||||
Sequential,
|
||||
Cyclic
|
||||
};
|
||||
|
||||
// Sfen reader
|
||||
struct SfenReader
|
||||
{
|
||||
// Number of phases buffered by each thread 0.1M phases. 4M phase at 40HT
|
||||
static constexpr size_t DEFAULT_THREAD_BUFFER_SIZE = 10 * 1000;
|
||||
|
||||
// Buffer for reading files (If this is made larger,
|
||||
// the shuffle becomes larger and the phases may vary.
|
||||
// If it is too large, the memory consumption will increase.
|
||||
// SFEN_READ_SIZE is a multiple of THREAD_BUFFER_SIZE.
|
||||
static constexpr const size_t DEFAULT_SFEN_READ_SIZE = 1000 * 1000 * 10;
|
||||
|
||||
// Do not use std::random_device().
|
||||
// Because it always the same integers on MinGW.
|
||||
SfenReader(
|
||||
const std::vector<std::string>& filenames_,
|
||||
bool do_shuffle,
|
||||
SfenReaderMode mode_,
|
||||
int thread_num,
|
||||
const std::string& seed,
|
||||
size_t read_size = DEFAULT_SFEN_READ_SIZE,
|
||||
size_t buffer_size = DEFAULT_THREAD_BUFFER_SIZE
|
||||
) :
|
||||
filenames(filenames_.begin(), filenames_.end()),
|
||||
mode(mode_),
|
||||
// Due to the implementation of waiting for buffer empty a bit
|
||||
// the read size must be at least twice the buffer size.
|
||||
sfen_read_size(std::max(read_size, buffer_size * 2)),
|
||||
thread_buffer_size(buffer_size),
|
||||
prng(seed)
|
||||
{
|
||||
packed_sfens.resize(thread_num);
|
||||
total_read = 0;
|
||||
end_of_files = false;
|
||||
shuffle = do_shuffle;
|
||||
stop_flag = false;
|
||||
num_buffers_in_pool.store(0);
|
||||
|
||||
file_worker_thread = std::thread([&] {
|
||||
this->file_read_worker();
|
||||
});
|
||||
}
|
||||
|
||||
~SfenReader()
|
||||
{
|
||||
stop_flag = true;
|
||||
|
||||
if (file_worker_thread.joinable())
|
||||
file_worker_thread.join();
|
||||
}
|
||||
|
||||
// Load the phase for calculation such as mse.
|
||||
PSVector read_some(uint64_t count, uint64_t count_tries, std::function<bool(const PackedSfenValue&)> do_take)
|
||||
{
|
||||
PSVector psv;
|
||||
psv.reserve(count);
|
||||
|
||||
for (uint64_t i = 0; i < count_tries; ++i)
|
||||
{
|
||||
PackedSfenValue ps;
|
||||
if (!read_to_thread_buffer(0, ps))
|
||||
{
|
||||
std::cout << "ERROR (sfen_reader): Reading failed." << std::endl;
|
||||
return psv;
|
||||
}
|
||||
|
||||
if (do_take(ps))
|
||||
{
|
||||
psv.push_back(ps);
|
||||
|
||||
if (psv.size() >= count)
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return psv;
|
||||
}
|
||||
|
||||
// [ASYNC] Thread returns one aspect. Otherwise returns false.
|
||||
bool read_to_thread_buffer(size_t thread_id, PackedSfenValue& ps)
|
||||
{
|
||||
// If there are any positions left in the thread buffer
|
||||
// then retrieve one and return it.
|
||||
auto& thread_ps = packed_sfens[thread_id];
|
||||
|
||||
// Fill the read buffer if there is no remaining buffer,
|
||||
// but if it doesn't even exist, finish.
|
||||
// If the buffer is empty, fill it.
|
||||
if ((thread_ps == nullptr || thread_ps->empty())
|
||||
&& !read_to_thread_buffer_impl(thread_id))
|
||||
return false;
|
||||
|
||||
// read_to_thread_buffer_impl() returned true,
|
||||
// Since the filling of the thread buffer with the
|
||||
// phase has been completed successfully
|
||||
// thread_ps->rbegin() is alive.
|
||||
|
||||
ps = thread_ps->back();
|
||||
thread_ps->pop_back();
|
||||
|
||||
// If you've run out of buffers, call delete yourself to free this buffer.
|
||||
if (thread_ps->empty())
|
||||
{
|
||||
thread_ps.reset();
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// [ASYNC] Read some aspects into thread buffer.
|
||||
bool read_to_thread_buffer_impl(size_t thread_id)
|
||||
{
|
||||
while (true)
|
||||
{
|
||||
{
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
// If you can fill from the file buffer, that's fine.
|
||||
if (packed_sfens_pool.size() != 0)
|
||||
{
|
||||
// It seems that filling is possible, so fill and finish.
|
||||
|
||||
packed_sfens[thread_id] = std::move(packed_sfens_pool.front());
|
||||
packed_sfens_pool.pop_front();
|
||||
num_buffers_in_pool.fetch_sub(1);
|
||||
|
||||
total_read += thread_buffer_size;
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// The file to read is already gone. No more use.
|
||||
if (end_of_files)
|
||||
return false;
|
||||
|
||||
// Waiting for file worker to fill packed_sfens_pool.
|
||||
// The mutex isn't locked, so it should fill up soon.
|
||||
// Poor man's condition variable.
|
||||
sleep(1);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void file_read_worker()
|
||||
{
|
||||
std::string currentFilename;
|
||||
uint64_t numEntriesReadFromCurrentFile = 0;
|
||||
|
||||
auto open_next_file = [&]() {
|
||||
// no more
|
||||
for(;;)
|
||||
{
|
||||
sfen_input_stream.reset();
|
||||
|
||||
if (filenames.empty())
|
||||
return false;
|
||||
|
||||
// Get the next file name.
|
||||
currentFilename = filenames.front();
|
||||
filenames.pop_front();
|
||||
|
||||
numEntriesReadFromCurrentFile = 0;
|
||||
|
||||
sfen_input_stream = open_sfen_input_file(currentFilename);
|
||||
|
||||
auto out = sync_region_cout.new_region();
|
||||
if (sfen_input_stream == nullptr)
|
||||
{
|
||||
out << "INFO (sfen_reader): File does not exist: " << currentFilename << '\n';
|
||||
}
|
||||
else
|
||||
{
|
||||
out << "INFO (sfen_reader): Opened file for reading: " << currentFilename << '\n';
|
||||
|
||||
// in case the file is empty or was deleted.
|
||||
if (sfen_input_stream->eof())
|
||||
{
|
||||
out << " - File empty, nothing to read.\n";
|
||||
}
|
||||
else
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if (sfen_input_stream == nullptr && !open_next_file())
|
||||
{
|
||||
auto out = sync_region_cout.new_region();
|
||||
out << "INFO (sfen_reader): End of files." << std::endl;
|
||||
end_of_files = true;
|
||||
return;
|
||||
}
|
||||
|
||||
// We want to set the `end_of_files` only after we read everything AND copy to the buffer pool.
|
||||
bool local_end_of_files = false;
|
||||
while (!local_end_of_files)
|
||||
{
|
||||
// Wait for the buffer to run out.
|
||||
// This size() is read only, so you don't need to lock it.
|
||||
while (!stop_flag && num_buffers_in_pool.load() >= sfen_read_size / thread_buffer_size)
|
||||
sleep(100);
|
||||
|
||||
if (stop_flag)
|
||||
return;
|
||||
|
||||
PSVector sfens;
|
||||
sfens.reserve(sfen_read_size);
|
||||
|
||||
// Read from the file into the file buffer.
|
||||
while (sfens.size() < sfen_read_size)
|
||||
{
|
||||
std::optional<PackedSfenValue> p = sfen_input_stream->next();
|
||||
if (p.has_value())
|
||||
{
|
||||
sfens.push_back(*p);
|
||||
++numEntriesReadFromCurrentFile;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (mode == SfenReaderMode::Cyclic
|
||||
&& numEntriesReadFromCurrentFile > 0)
|
||||
{
|
||||
// The file contained data so we add it again to the end of the queue.
|
||||
filenames.emplace_back(currentFilename);
|
||||
}
|
||||
|
||||
if(!open_next_file())
|
||||
{
|
||||
// There was no next file. Abort.
|
||||
auto out = sync_region_cout.new_region();
|
||||
out << "INFO (sfen_reader): End of files." << std::endl;
|
||||
local_end_of_files = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Shuffle the read phase data.
|
||||
if (shuffle)
|
||||
{
|
||||
Algo::shuffle(sfens, prng);
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<PSVector>> buffers;
|
||||
for (size_t offset = 0; offset < sfens.size(); offset += thread_buffer_size)
|
||||
{
|
||||
const size_t count =
|
||||
offset + thread_buffer_size > sfens.size()
|
||||
? sfens.size() - offset
|
||||
: thread_buffer_size;
|
||||
|
||||
// Delete this pointer on the receiving side.
|
||||
auto buf = std::make_unique<PSVector>();
|
||||
buf->resize(count);
|
||||
memcpy(
|
||||
buf->data(),
|
||||
&sfens[offset],
|
||||
sizeof(PackedSfenValue) * count);
|
||||
|
||||
buffers.emplace_back(std::move(buf));
|
||||
}
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
|
||||
// The mutex lock is required because the%
|
||||
// contents of packed_sfens_pool are changed.
|
||||
|
||||
for (auto& buf : buffers)
|
||||
{
|
||||
num_buffers_in_pool.fetch_add(1);
|
||||
packed_sfens_pool.emplace_back(std::move(buf));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
end_of_files = true;
|
||||
}
|
||||
|
||||
protected:
|
||||
|
||||
// worker thread reading file in background
|
||||
std::thread file_worker_thread;
|
||||
|
||||
// sfen files
|
||||
std::deque<std::string> filenames;
|
||||
|
||||
std::atomic<bool> stop_flag;
|
||||
|
||||
// number of phases read (file to memory buffer)
|
||||
std::atomic<uint64_t> total_read;
|
||||
|
||||
// Do not shuffle when reading the phase.
|
||||
bool shuffle;
|
||||
|
||||
SfenReaderMode mode;
|
||||
|
||||
size_t sfen_read_size;
|
||||
size_t thread_buffer_size;
|
||||
|
||||
// Random number to shuffle when reading the phase
|
||||
PRNG prng;
|
||||
|
||||
// Did you read the files and reached the end?
|
||||
std::atomic<bool> end_of_files;
|
||||
|
||||
// handle of sfen file
|
||||
std::unique_ptr<BasicSfenInputStream> sfen_input_stream;
|
||||
|
||||
// sfen for each thread
|
||||
// (When the thread is used up, the thread should call delete to release it.)
|
||||
std::vector<std::unique_ptr<PSVector>> packed_sfens;
|
||||
|
||||
// Mutex when accessing packed_sfens_pool
|
||||
std::mutex mutex;
|
||||
|
||||
// pool of sfen. The worker thread read from the file is added here.
|
||||
// Each worker thread fills its own packed_sfens[thread_id] from here.
|
||||
// * Lock and access the mutex.
|
||||
std::list<std::unique_ptr<PSVector>> packed_sfens_pool;
|
||||
std::atomic<size_t> num_buffers_in_pool;
|
||||
};
|
||||
}
|
||||
222
src/tools/sfen_stream.h
Normal file
222
src/tools/sfen_stream.h
Normal file
@@ -0,0 +1,222 @@
|
||||
#ifndef _SFEN_STREAM_H_
|
||||
#define _SFEN_STREAM_H_
|
||||
|
||||
#include "packed_sfen.h"
|
||||
|
||||
#include "extra/nnue_data_binpack_format.h"
|
||||
|
||||
#include <optional>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
namespace Tools {
|
||||
|
||||
enum struct SfenOutputType
|
||||
{
|
||||
Bin,
|
||||
Binpack
|
||||
};
|
||||
|
||||
static bool ends_with(const std::string& lhs, const std::string& end)
|
||||
{
|
||||
if (end.size() > lhs.size()) return false;
|
||||
|
||||
return std::equal(end.rbegin(), end.rend(), lhs.rbegin());
|
||||
}
|
||||
|
||||
static bool has_extension(const std::string& filename, const std::string& extension)
|
||||
{
|
||||
return ends_with(filename, "." + extension);
|
||||
}
|
||||
|
||||
static std::string filename_with_extension(const std::string& filename, const std::string& ext)
|
||||
{
|
||||
if (ends_with(filename, ext))
|
||||
{
|
||||
return filename;
|
||||
}
|
||||
else
|
||||
{
|
||||
return filename + "." + ext;
|
||||
}
|
||||
}
|
||||
|
||||
struct BasicSfenInputStream
|
||||
{
|
||||
virtual std::optional<PackedSfenValue> next() = 0;
|
||||
virtual bool eof() const = 0;
|
||||
virtual ~BasicSfenInputStream() {}
|
||||
};
|
||||
|
||||
struct BinSfenInputStream : BasicSfenInputStream
|
||||
{
|
||||
static constexpr auto openmode = std::ios::in | std::ios::binary;
|
||||
static inline const std::string extension = "bin";
|
||||
|
||||
BinSfenInputStream(std::string filename) :
|
||||
m_stream(filename, openmode),
|
||||
m_eof(!m_stream)
|
||||
{
|
||||
}
|
||||
|
||||
std::optional<PackedSfenValue> next() override
|
||||
{
|
||||
PackedSfenValue e;
|
||||
if(m_stream.read(reinterpret_cast<char*>(&e), sizeof(PackedSfenValue)))
|
||||
{
|
||||
return e;
|
||||
}
|
||||
else
|
||||
{
|
||||
m_eof = true;
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
bool eof() const override
|
||||
{
|
||||
return m_eof;
|
||||
}
|
||||
|
||||
~BinSfenInputStream() override {}
|
||||
|
||||
private:
|
||||
std::fstream m_stream;
|
||||
bool m_eof;
|
||||
};
|
||||
|
||||
struct BinpackSfenInputStream : BasicSfenInputStream
|
||||
{
|
||||
static constexpr auto openmode = std::ios::in | std::ios::binary;
|
||||
static inline const std::string extension = "binpack";
|
||||
|
||||
BinpackSfenInputStream(std::string filename) :
|
||||
m_stream(filename, openmode),
|
||||
m_eof(!m_stream.hasNext())
|
||||
{
|
||||
}
|
||||
|
||||
std::optional<PackedSfenValue> next() override
|
||||
{
|
||||
static_assert(sizeof(binpack::nodchip::PackedSfenValue) == sizeof(PackedSfenValue));
|
||||
|
||||
if (!m_stream.hasNext())
|
||||
{
|
||||
m_eof = true;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto training_data_entry = m_stream.next();
|
||||
auto v = binpack::trainingDataEntryToPackedSfenValue(training_data_entry);
|
||||
PackedSfenValue psv;
|
||||
// same layout, different types. One is from generic library.
|
||||
std::memcpy(&psv, &v, sizeof(PackedSfenValue));
|
||||
|
||||
return psv;
|
||||
}
|
||||
|
||||
bool eof() const override
|
||||
{
|
||||
return m_eof;
|
||||
}
|
||||
|
||||
~BinpackSfenInputStream() override {}
|
||||
|
||||
private:
|
||||
binpack::CompressedTrainingDataEntryReader m_stream;
|
||||
bool m_eof;
|
||||
};
|
||||
|
||||
struct BasicSfenOutputStream
|
||||
{
|
||||
virtual void write(const PSVector& sfens) = 0;
|
||||
virtual ~BasicSfenOutputStream() {}
|
||||
};
|
||||
|
||||
struct BinSfenOutputStream : BasicSfenOutputStream
|
||||
{
|
||||
static constexpr auto openmode = std::ios::out | std::ios::binary | std::ios::app;
|
||||
static inline const std::string extension = "bin";
|
||||
|
||||
BinSfenOutputStream(std::string filename) :
|
||||
m_stream(filename_with_extension(filename, extension), openmode)
|
||||
{
|
||||
}
|
||||
|
||||
void write(const PSVector& sfens) override
|
||||
{
|
||||
m_stream.write(reinterpret_cast<const char*>(sfens.data()), sizeof(PackedSfenValue) * sfens.size());
|
||||
}
|
||||
|
||||
~BinSfenOutputStream() override {}
|
||||
|
||||
private:
|
||||
std::fstream m_stream;
|
||||
};
|
||||
|
||||
struct BinpackSfenOutputStream : BasicSfenOutputStream
|
||||
{
|
||||
static constexpr auto openmode = std::ios::out | std::ios::binary | std::ios::app;
|
||||
static inline const std::string extension = "binpack";
|
||||
|
||||
BinpackSfenOutputStream(std::string filename) :
|
||||
m_stream(filename_with_extension(filename, extension), openmode)
|
||||
{
|
||||
}
|
||||
|
||||
void write(const PSVector& sfens) override
|
||||
{
|
||||
static_assert(sizeof(binpack::nodchip::PackedSfenValue) == sizeof(PackedSfenValue));
|
||||
|
||||
for(auto& sfen : sfens)
|
||||
{
|
||||
// The library uses a type that's different but layout-compatibile.
|
||||
binpack::nodchip::PackedSfenValue e;
|
||||
std::memcpy(&e, &sfen, sizeof(binpack::nodchip::PackedSfenValue));
|
||||
m_stream.addTrainingDataEntry(binpack::packedSfenValueToTrainingDataEntry(e));
|
||||
}
|
||||
}
|
||||
|
||||
~BinpackSfenOutputStream() override {}
|
||||
|
||||
private:
|
||||
binpack::CompressedTrainingDataEntryWriter m_stream;
|
||||
};
|
||||
|
||||
inline std::unique_ptr<BasicSfenInputStream> open_sfen_input_file(const std::string& filename)
|
||||
{
|
||||
if (has_extension(filename, BinSfenInputStream::extension))
|
||||
return std::make_unique<BinSfenInputStream>(filename);
|
||||
else if (has_extension(filename, BinpackSfenInputStream::extension))
|
||||
return std::make_unique<BinpackSfenInputStream>(filename);
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
inline std::unique_ptr<BasicSfenOutputStream> create_new_sfen_output(const std::string& filename, SfenOutputType sfen_output_type)
|
||||
{
|
||||
switch(sfen_output_type)
|
||||
{
|
||||
case SfenOutputType::Bin:
|
||||
return std::make_unique<BinSfenOutputStream>(filename);
|
||||
case SfenOutputType::Binpack:
|
||||
return std::make_unique<BinpackSfenOutputStream>(filename);
|
||||
}
|
||||
|
||||
assert(false);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
inline std::unique_ptr<BasicSfenOutputStream> create_new_sfen_output(const std::string& filename)
|
||||
{
|
||||
if (has_extension(filename, BinSfenOutputStream::extension))
|
||||
return std::make_unique<BinSfenOutputStream>(filename);
|
||||
else if (has_extension(filename, BinpackSfenOutputStream::extension))
|
||||
return std::make_unique<BinpackSfenOutputStream>(filename);
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
205
src/tools/sfen_writer.h
Normal file
205
src/tools/sfen_writer.h
Normal file
@@ -0,0 +1,205 @@
|
||||
#include "packed_sfen.h"
|
||||
#include "sfen_stream.h"
|
||||
|
||||
#include "misc.h"
|
||||
|
||||
#include "extra/nnue_data_binpack_format.h"
|
||||
|
||||
#include "syzygy/tbprobe.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <limits>
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <shared_mutex>
|
||||
#include <thread>
|
||||
#include <atomic>
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace Tools {
|
||||
|
||||
// Helper class for exporting Sfen
|
||||
struct SfenWriter
|
||||
{
|
||||
// Amount of sfens required to flush the buffer.
|
||||
static constexpr size_t SFEN_WRITE_SIZE = 5000;
|
||||
|
||||
// File name to write and number of threads to create
|
||||
SfenWriter(string filename_, int thread_num, uint64_t save_count, SfenOutputType sfen_output_type)
|
||||
{
|
||||
sfen_buffers_pool.reserve((size_t)thread_num * 10);
|
||||
sfen_buffers.resize(thread_num);
|
||||
|
||||
auto out = sync_region_cout.new_region();
|
||||
out << "INFO (sfen_writer): Creating new data file at " << filename_ << endl;
|
||||
|
||||
sfen_format = sfen_output_type;
|
||||
output_file_stream = create_new_sfen_output(filename_, sfen_format);
|
||||
filename = filename_;
|
||||
save_every = save_count;
|
||||
|
||||
finished = false;
|
||||
|
||||
file_worker_thread = std::thread([&] { this->file_write_worker(); });
|
||||
}
|
||||
|
||||
~SfenWriter()
|
||||
{
|
||||
flush();
|
||||
|
||||
finished = true;
|
||||
file_worker_thread.join();
|
||||
output_file_stream.reset();
|
||||
|
||||
#if !defined(NDEBUG)
|
||||
{
|
||||
// All buffers should be empty since file_worker_thread
|
||||
// should have written everything before exiting.
|
||||
for (const auto& p : sfen_buffers) { assert(p == nullptr); (void)p ; }
|
||||
assert(sfen_buffers_pool.empty());
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void write(size_t thread_id, const PackedSfenValue& psv)
|
||||
{
|
||||
// We have a buffer for each thread and add it there.
|
||||
// If the buffer overflows, write it to a file.
|
||||
|
||||
// This buffer is prepared for each thread.
|
||||
auto& buf = sfen_buffers[thread_id];
|
||||
|
||||
// Secure since there is no buf at the first time
|
||||
// and immediately after writing the thread buffer.
|
||||
if (!buf)
|
||||
{
|
||||
buf = std::make_unique<PSVector>();
|
||||
buf->reserve(SFEN_WRITE_SIZE);
|
||||
}
|
||||
|
||||
// Buffer is exclusive to this thread.
|
||||
// There is no need for a critical section.
|
||||
buf->push_back(psv);
|
||||
|
||||
if (buf->size() >= SFEN_WRITE_SIZE)
|
||||
{
|
||||
// If you load it in sfen_buffers_pool, the worker will do the rest.
|
||||
|
||||
// Critical section since sfen_buffers_pool is shared among threads.
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
sfen_buffers_pool.emplace_back(std::move(buf));
|
||||
}
|
||||
}
|
||||
|
||||
void flush()
|
||||
{
|
||||
for (size_t i = 0; i < sfen_buffers.size(); ++i)
|
||||
{
|
||||
flush(i);
|
||||
}
|
||||
}
|
||||
|
||||
// Move what remains in the buffer for your thread to a buffer for writing to a file.
|
||||
void flush(size_t thread_id)
|
||||
{
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
|
||||
auto& buf = sfen_buffers[thread_id];
|
||||
|
||||
// There is a case that buf==nullptr, so that check is necessary.
|
||||
if (buf && buf->size() != 0)
|
||||
{
|
||||
sfen_buffers_pool.emplace_back(std::move(buf));
|
||||
}
|
||||
}
|
||||
|
||||
// Dedicated thread to write to file
|
||||
void file_write_worker()
|
||||
{
|
||||
while (!finished || sfen_buffers_pool.size())
|
||||
{
|
||||
vector<std::unique_ptr<PSVector>> buffers;
|
||||
{
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
|
||||
// Atomically swap take the filled buffers and
|
||||
// create a new buffer pool for threads to fill.
|
||||
buffers = std::move(sfen_buffers_pool);
|
||||
sfen_buffers_pool = std::vector<std::unique_ptr<PSVector>>();
|
||||
}
|
||||
|
||||
if (!buffers.size())
|
||||
{
|
||||
// Poor man's condition variable.
|
||||
sleep(100);
|
||||
}
|
||||
else
|
||||
{
|
||||
for (auto& buf : buffers)
|
||||
{
|
||||
output_file_stream->write(*buf);
|
||||
|
||||
sfen_write_count += buf->size();
|
||||
|
||||
// Add the processed number here, and if it exceeds save_every,
|
||||
// change the file name and reset this counter.
|
||||
sfen_write_count_current_file += buf->size();
|
||||
if (sfen_write_count_current_file >= save_every)
|
||||
{
|
||||
sfen_write_count_current_file = 0;
|
||||
|
||||
// Sequential number attached to the file
|
||||
int n = (int)(sfen_write_count / save_every);
|
||||
|
||||
// Rename the file and open it again.
|
||||
// Add ios::app in consideration of overwriting.
|
||||
// (Depending on the operation, it may not be necessary.)
|
||||
string new_filename = filename + "_" + std::to_string(n);
|
||||
output_file_stream = create_new_sfen_output(new_filename, sfen_format);
|
||||
|
||||
auto out = sync_region_cout.new_region();
|
||||
out << "INFO (sfen_writer): Creating new data file at " << new_filename << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
std::unique_ptr<BasicSfenOutputStream> output_file_stream;
|
||||
|
||||
// A new net is saved after every save_every sfens are processed.
|
||||
uint64_t save_every = std::numeric_limits<uint64_t>::max();
|
||||
|
||||
// File name passed in the constructor
|
||||
std::string filename;
|
||||
|
||||
// Thread to write to the file
|
||||
std::thread file_worker_thread;
|
||||
|
||||
// Flag that all threads have finished
|
||||
atomic<bool> finished;
|
||||
|
||||
SfenOutputType sfen_format;
|
||||
|
||||
// buffer before writing to file
|
||||
// sfen_buffers is the buffer for each thread
|
||||
// sfen_buffers_pool is a buffer for writing.
|
||||
// After loading the phase in the former buffer by SFEN_WRITE_SIZE,
|
||||
// transfer it to the latter.
|
||||
std::vector<std::unique_ptr<PSVector>> sfen_buffers;
|
||||
std::vector<std::unique_ptr<PSVector>> sfen_buffers_pool;
|
||||
|
||||
// Mutex required to access sfen_buffers_pool
|
||||
std::mutex mutex;
|
||||
|
||||
// Number of sfens written in total, and the
|
||||
// number of sfens written in the current file.
|
||||
uint64_t sfen_write_count = 0;
|
||||
uint64_t sfen_write_count_current_file = 0;
|
||||
};
|
||||
}
|
||||
645
src/tools/stats.cpp
Normal file
645
src/tools/stats.cpp
Normal file
@@ -0,0 +1,645 @@
|
||||
#include "stats.h"
|
||||
|
||||
#include "sfen_stream.h"
|
||||
#include "packed_sfen.h"
|
||||
#include "sfen_writer.h"
|
||||
|
||||
#include "thread.h"
|
||||
#include "position.h"
|
||||
#include "evaluate.h"
|
||||
#include "search.h"
|
||||
|
||||
#include "nnue/evaluate_nnue.h"
|
||||
|
||||
#include <array>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <iostream>
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <sstream>
|
||||
#include <iomanip>
|
||||
#include <limits>
|
||||
#include <mutex>
|
||||
#include <optional>
|
||||
|
||||
namespace Tools::Stats
|
||||
{
|
||||
struct StatisticGathererBase
|
||||
{
|
||||
virtual void on_position(const Position&) {}
|
||||
virtual void on_move(const Position&, const Move&) {}
|
||||
virtual void reset() = 0;
|
||||
[[nodiscard]] virtual const std::string& get_name() const = 0;
|
||||
[[nodiscard]] virtual std::vector<std::pair<std::string, std::string>> get_formatted_stats() const = 0;
|
||||
};
|
||||
|
||||
struct StatisticGathererFactoryBase
|
||||
{
|
||||
[[nodiscard]] virtual std::unique_ptr<StatisticGathererBase> create() const = 0;
|
||||
[[nodiscard]] virtual const std::string& get_name() const = 0;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct StatisticGathererFactory : StatisticGathererFactoryBase
|
||||
{
|
||||
static inline std::string name = T::name;
|
||||
|
||||
[[nodiscard]] std::unique_ptr<StatisticGathererBase> create() const override
|
||||
{
|
||||
return std::make_unique<T>();
|
||||
}
|
||||
|
||||
[[nodiscard]] const std::string& get_name() const override
|
||||
{
|
||||
return name;
|
||||
}
|
||||
};
|
||||
|
||||
struct StatisticGathererSet : StatisticGathererBase
|
||||
{
|
||||
void add(const StatisticGathererFactoryBase& factory)
|
||||
{
|
||||
const std::string name = factory.get_name();
|
||||
if (m_gatherers_names.count(name) == 0)
|
||||
{
|
||||
m_gatherers_names.insert(name);
|
||||
m_gatherers.emplace_back(factory.create());
|
||||
}
|
||||
}
|
||||
|
||||
void add(std::unique_ptr<StatisticGathererBase>&& gatherer)
|
||||
{
|
||||
const std::string name = gatherer->get_name();
|
||||
if (m_gatherers_names.count(name) == 0)
|
||||
{
|
||||
m_gatherers_names.insert(name);
|
||||
m_gatherers.emplace_back(std::move(gatherer));
|
||||
}
|
||||
}
|
||||
|
||||
void on_position(const Position& position) override
|
||||
{
|
||||
for (auto& g : m_gatherers)
|
||||
{
|
||||
g->on_position(position);
|
||||
}
|
||||
}
|
||||
|
||||
void on_move(const Position& pos, const Move& move) override
|
||||
{
|
||||
for (auto& g : m_gatherers)
|
||||
{
|
||||
g->on_move(pos, move);
|
||||
}
|
||||
}
|
||||
|
||||
void reset() override
|
||||
{
|
||||
for (auto& g : m_gatherers)
|
||||
{
|
||||
g->reset();
|
||||
}
|
||||
}
|
||||
|
||||
[[nodiscard]] virtual const std::string& get_name() const override
|
||||
{
|
||||
static std::string name = "SET";
|
||||
return name;
|
||||
}
|
||||
|
||||
[[nodiscard]] virtual std::vector<std::pair<std::string, std::string>> get_formatted_stats() const override
|
||||
{
|
||||
std::vector<std::pair<std::string, std::string>> parts;
|
||||
for (auto&& s : m_gatherers)
|
||||
{
|
||||
auto part = s->get_formatted_stats();
|
||||
parts.insert(parts.end(), part.begin(), part.end());
|
||||
}
|
||||
return parts;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::unique_ptr<StatisticGathererBase>> m_gatherers;
|
||||
std::set<std::string> m_gatherers_names;
|
||||
};
|
||||
|
||||
struct StatisticGathererRegistry
|
||||
{
|
||||
void add_statistic_gatherers_by_group(
|
||||
StatisticGathererSet& gatherers,
|
||||
const std::string& group) const
|
||||
{
|
||||
auto it = m_gatherers_by_group.find(group);
|
||||
if (it != m_gatherers_by_group.end())
|
||||
{
|
||||
for (auto& factory : it->second)
|
||||
{
|
||||
gatherers.add(*factory);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename... ArgsTs>
|
||||
void add(const ArgsTs&... group)
|
||||
{
|
||||
auto dummy = {(add_single<T>(group), 0)...};
|
||||
(void)dummy;
|
||||
add_single<T>("all");
|
||||
}
|
||||
|
||||
private:
|
||||
std::map<std::string, std::vector<std::unique_ptr<StatisticGathererFactoryBase>>> m_gatherers_by_group;
|
||||
std::map<std::string, std::set<std::string>> m_gatherers_names_by_group;
|
||||
|
||||
template <typename T, typename ArgT>
|
||||
void add_single(const ArgT& group)
|
||||
{
|
||||
using FactoryT = StatisticGathererFactory<T>;
|
||||
|
||||
if (m_gatherers_names_by_group[group].count(FactoryT::name) == 0)
|
||||
{
|
||||
m_gatherers_by_group[group].emplace_back(std::make_unique<FactoryT>());
|
||||
m_gatherers_names_by_group[group].insert(FactoryT::name);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
Statistic gatherer helpers
|
||||
*/
|
||||
|
||||
template <typename T>
|
||||
struct StatPerSquare
|
||||
{
|
||||
StatPerSquare()
|
||||
{
|
||||
for (int i = 0; i < SQUARE_NB; ++i)
|
||||
m_squares[i] = 0;
|
||||
}
|
||||
|
||||
[[nodiscard]] T& operator[](Square sq)
|
||||
{
|
||||
return m_squares[sq];
|
||||
}
|
||||
|
||||
[[nodiscard]] const T& operator[](Square sq) const
|
||||
{
|
||||
return m_squares[sq];
|
||||
}
|
||||
|
||||
[[nodiscard]] std::string get_formatted_stats() const
|
||||
{
|
||||
std::stringstream ss;
|
||||
for (int i = 0; i < SQUARE_NB; ++i)
|
||||
{
|
||||
ss << std::setw(8) << m_squares[i ^ (int)SQ_A8] << ' ';
|
||||
if ((i + 1) % 8 == 0)
|
||||
ss << '\n';
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
private:
|
||||
std::array<T, SQUARE_NB> m_squares;
|
||||
};
|
||||
|
||||
/*
|
||||
Definitions for specific statistic gatherers follow:
|
||||
*/
|
||||
|
||||
struct PositionCounter : StatisticGathererBase
|
||||
{
|
||||
static inline std::string name = "PositionCounter";
|
||||
|
||||
PositionCounter() :
|
||||
m_num_positions(0)
|
||||
{
|
||||
}
|
||||
|
||||
void on_position(const Position&) override
|
||||
{
|
||||
m_num_positions += 1;
|
||||
}
|
||||
|
||||
void reset() override
|
||||
{
|
||||
m_num_positions = 0;
|
||||
}
|
||||
|
||||
[[nodiscard]] const std::string& get_name() const override
|
||||
{
|
||||
return name;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::vector<std::pair<std::string, std::string>> get_formatted_stats() const override
|
||||
{
|
||||
return {
|
||||
{ "Number of positions", std::to_string(m_num_positions) }
|
||||
};
|
||||
}
|
||||
|
||||
private:
|
||||
std::uint64_t m_num_positions;
|
||||
};
|
||||
|
||||
struct KingSquareCounter : StatisticGathererBase
|
||||
{
|
||||
static inline std::string name = "KingSquareCounter";
|
||||
|
||||
KingSquareCounter() :
|
||||
m_white{},
|
||||
m_black{}
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
void on_position(const Position& pos) override
|
||||
{
|
||||
m_white[pos.square<KING>(WHITE)] += 1;
|
||||
m_black[pos.square<KING>(BLACK)] += 1;
|
||||
}
|
||||
|
||||
void reset() override
|
||||
{
|
||||
m_white = StatPerSquare<std::uint64_t>{};
|
||||
m_black = StatPerSquare<std::uint64_t>{};
|
||||
}
|
||||
|
||||
[[nodiscard]] const std::string& get_name() const override
|
||||
{
|
||||
return name;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::vector<std::pair<std::string, std::string>> get_formatted_stats() const override
|
||||
{
|
||||
return {
|
||||
{ "White king squares", '\n' + m_white.get_formatted_stats() },
|
||||
{ "Black king squares", '\n' + m_black.get_formatted_stats() }
|
||||
};
|
||||
}
|
||||
|
||||
private:
|
||||
StatPerSquare<std::uint64_t> m_white;
|
||||
StatPerSquare<std::uint64_t> m_black;
|
||||
};
|
||||
|
||||
struct MoveFromCounter : StatisticGathererBase
|
||||
{
|
||||
static inline std::string name = "MoveFromCounter";
|
||||
|
||||
MoveFromCounter() :
|
||||
m_white{},
|
||||
m_black{}
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
void on_move(const Position& pos, const Move& move) override
|
||||
{
|
||||
if (pos.side_to_move() == WHITE)
|
||||
m_white[from_sq(move)] += 1;
|
||||
else
|
||||
m_black[from_sq(move)] += 1;
|
||||
}
|
||||
|
||||
void reset() override
|
||||
{
|
||||
m_white = StatPerSquare<std::uint64_t>{};
|
||||
m_black = StatPerSquare<std::uint64_t>{};
|
||||
}
|
||||
|
||||
[[nodiscard]] const std::string& get_name() const override
|
||||
{
|
||||
return name;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::vector<std::pair<std::string, std::string>> get_formatted_stats() const override
|
||||
{
|
||||
return {
|
||||
{ "White move from squares", '\n' + m_white.get_formatted_stats() },
|
||||
{ "Black move from squares", '\n' + m_black.get_formatted_stats() }
|
||||
};
|
||||
}
|
||||
|
||||
private:
|
||||
StatPerSquare<std::uint64_t> m_white;
|
||||
StatPerSquare<std::uint64_t> m_black;
|
||||
};
|
||||
|
||||
struct MoveToCounter : StatisticGathererBase
|
||||
{
|
||||
static inline std::string name = "MoveToCounter";
|
||||
|
||||
MoveToCounter() :
|
||||
m_white{},
|
||||
m_black{}
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
void on_move(const Position& pos, const Move& move) override
|
||||
{
|
||||
if (pos.side_to_move() == WHITE)
|
||||
m_white[to_sq(move)] += 1;
|
||||
else
|
||||
m_black[to_sq(move)] += 1;
|
||||
}
|
||||
|
||||
void reset() override
|
||||
{
|
||||
m_white = StatPerSquare<std::uint64_t>{};
|
||||
m_black = StatPerSquare<std::uint64_t>{};
|
||||
}
|
||||
|
||||
[[nodiscard]] const std::string& get_name() const override
|
||||
{
|
||||
return name;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::vector<std::pair<std::string, std::string>> get_formatted_stats() const override
|
||||
{
|
||||
return {
|
||||
{ "White move to squares", '\n' + m_white.get_formatted_stats() },
|
||||
{ "Black move to squares", '\n' + m_black.get_formatted_stats() }
|
||||
};
|
||||
}
|
||||
|
||||
private:
|
||||
StatPerSquare<std::uint64_t> m_white;
|
||||
StatPerSquare<std::uint64_t> m_black;
|
||||
};
|
||||
|
||||
struct MoveTypeCounter : StatisticGathererBase
|
||||
{
|
||||
static inline std::string name = "MoveTypeCounter";
|
||||
|
||||
MoveTypeCounter() :
|
||||
m_total(0),
|
||||
m_normal(0),
|
||||
m_capture(0),
|
||||
m_promotion(0),
|
||||
m_castling(0),
|
||||
m_enpassant(0)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
void on_move(const Position& pos, const Move& move) override
|
||||
{
|
||||
m_total += 1;
|
||||
|
||||
if (!pos.empty(to_sq(move)))
|
||||
m_capture += 1;
|
||||
|
||||
if (type_of(move) == CASTLING)
|
||||
m_castling += 1;
|
||||
else if (type_of(move) == PROMOTION)
|
||||
m_promotion += 1;
|
||||
else if (type_of(move) == ENPASSANT)
|
||||
m_enpassant += 1;
|
||||
else if (type_of(move) == NORMAL)
|
||||
m_normal += 1;
|
||||
}
|
||||
|
||||
void reset() override
|
||||
{
|
||||
m_total = 0;
|
||||
m_normal = 0;
|
||||
m_capture = 0;
|
||||
m_promotion = 0;
|
||||
m_castling = 0;
|
||||
m_enpassant = 0;
|
||||
}
|
||||
|
||||
[[nodiscard]] const std::string& get_name() const override
|
||||
{
|
||||
return name;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::vector<std::pair<std::string, std::string>> get_formatted_stats() const override
|
||||
{
|
||||
return {
|
||||
{ "Total moves", std::to_string(m_total) },
|
||||
{ "Normal moves", std::to_string(m_normal) },
|
||||
{ "Capture moves", std::to_string(m_capture) },
|
||||
{ "Promotion moves", std::to_string(m_promotion) },
|
||||
{ "Castling moves", std::to_string(m_castling) },
|
||||
{ "En-passant moves", std::to_string(m_enpassant) }
|
||||
};
|
||||
}
|
||||
|
||||
private:
|
||||
std::uint64_t m_total;
|
||||
std::uint64_t m_normal;
|
||||
std::uint64_t m_capture;
|
||||
std::uint64_t m_promotion;
|
||||
std::uint64_t m_castling;
|
||||
std::uint64_t m_enpassant;
|
||||
};
|
||||
|
||||
struct PieceCountCounter : StatisticGathererBase
|
||||
{
|
||||
static inline std::string name = "PieceCountCounter";
|
||||
|
||||
PieceCountCounter()
|
||||
{
|
||||
reset();
|
||||
}
|
||||
|
||||
void on_position(const Position& pos) override
|
||||
{
|
||||
m_piece_count_hist[popcount(pos.pieces())] += 1;
|
||||
}
|
||||
|
||||
void reset() override
|
||||
{
|
||||
for (int i = 0; i < SQUARE_NB; ++i)
|
||||
m_piece_count_hist[i] = 0;
|
||||
}
|
||||
|
||||
[[nodiscard]] const std::string& get_name() const override
|
||||
{
|
||||
return name;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::vector<std::pair<std::string, std::string>> get_formatted_stats() const override
|
||||
{
|
||||
std::vector<std::pair<std::string, std::string>> result;
|
||||
bool do_write = false;
|
||||
for (int i = SQUARE_NB - 1; i >= 0; --i)
|
||||
{
|
||||
if (m_piece_count_hist[i] != 0)
|
||||
do_write = true;
|
||||
|
||||
// Start writing when the first non-zero number pops up.
|
||||
if (do_write)
|
||||
{
|
||||
result.emplace_back(
|
||||
std::string("Number of positions with ") + std::to_string(i) + " pieces",
|
||||
std::to_string(m_piece_count_hist[i])
|
||||
);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
std::uint64_t m_piece_count_hist[SQUARE_NB];
|
||||
};
|
||||
|
||||
struct MovedPieceTypeCounter : StatisticGathererBase
|
||||
{
|
||||
static inline std::string name = "MovedPieceTypeCounter";
|
||||
|
||||
MovedPieceTypeCounter()
|
||||
{
|
||||
reset();
|
||||
}
|
||||
|
||||
void on_move(const Position& pos, const Move& move) override
|
||||
{
|
||||
m_moved_piece_type_hist[type_of(pos.piece_on(from_sq(move)))] += 1;
|
||||
}
|
||||
|
||||
void reset() override
|
||||
{
|
||||
for (int i = 0; i < PIECE_TYPE_NB; ++i)
|
||||
m_moved_piece_type_hist[i] = 0;
|
||||
}
|
||||
|
||||
[[nodiscard]] const std::string& get_name() const override
|
||||
{
|
||||
return name;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::vector<std::pair<std::string, std::string>> get_formatted_stats() const override
|
||||
{
|
||||
return {
|
||||
{ "Pawn moves", std::to_string(m_moved_piece_type_hist[PAWN]) },
|
||||
{ "Knight moves", std::to_string(m_moved_piece_type_hist[KNIGHT]) },
|
||||
{ "Bishop moves", std::to_string(m_moved_piece_type_hist[BISHOP]) },
|
||||
{ "Rook moves", std::to_string(m_moved_piece_type_hist[ROOK]) },
|
||||
{ "Queen moves", std::to_string(m_moved_piece_type_hist[QUEEN]) },
|
||||
{ "King moves", std::to_string(m_moved_piece_type_hist[KING]) }
|
||||
};
|
||||
}
|
||||
|
||||
private:
|
||||
std::uint64_t m_moved_piece_type_hist[PIECE_TYPE_NB];
|
||||
};
|
||||
|
||||
/*
|
||||
This function provides factories for all possible statistic gatherers.
|
||||
Each new statistic gatherer needs to be added there.
|
||||
*/
|
||||
const auto& get_statistics_gatherers_registry()
|
||||
{
|
||||
static StatisticGathererRegistry s_reg = [](){
|
||||
StatisticGathererRegistry reg;
|
||||
|
||||
reg.add<PositionCounter>("position_count");
|
||||
|
||||
reg.add<KingSquareCounter>("king", "king_square_count");
|
||||
|
||||
reg.add<MoveFromCounter>("move", "move_from_count");
|
||||
reg.add<MoveToCounter>("move", "move_to_count");
|
||||
reg.add<MoveTypeCounter>("move", "move_type");
|
||||
reg.add<MovedPieceTypeCounter>("move", "moved_piece_type");
|
||||
|
||||
reg.add<PieceCountCounter>("piece_count");
|
||||
|
||||
return reg;
|
||||
}();
|
||||
|
||||
return s_reg;
|
||||
}
|
||||
|
||||
void do_gather_statistics(
|
||||
const std::string& filename,
|
||||
StatisticGathererSet& statistic_gatherers,
|
||||
std::uint64_t max_count)
|
||||
{
|
||||
Thread* th = Threads.main();
|
||||
Position& pos = th->rootPos;
|
||||
StateInfo si;
|
||||
|
||||
auto in = Tools::open_sfen_input_file(filename);
|
||||
|
||||
auto on_move = [&](const Position& position, const Move& move) {
|
||||
statistic_gatherers.on_move(position, move);
|
||||
};
|
||||
|
||||
auto on_position = [&](const Position& position) {
|
||||
statistic_gatherers.on_position(position);
|
||||
};
|
||||
|
||||
if (in == nullptr)
|
||||
{
|
||||
std::cerr << "Invalid input file type.\n";
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t num_processed = 0;
|
||||
while (num_processed < max_count)
|
||||
{
|
||||
auto v = in->next();
|
||||
if (!v.has_value())
|
||||
break;
|
||||
|
||||
auto& ps = v.value();
|
||||
|
||||
pos.set_from_packed_sfen(ps.sfen, &si, th);
|
||||
|
||||
on_position(pos);
|
||||
on_move(pos, (Move)ps.move);
|
||||
|
||||
num_processed += 1;
|
||||
if (num_processed % 1'000'000 == 0)
|
||||
{
|
||||
std::cout << "Processed " << num_processed << " positions.\n";
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Finished gathering statistics.\n\n";
|
||||
std::cout << "Results:\n\n";
|
||||
|
||||
for (auto&& [name, value] : statistic_gatherers.get_formatted_stats())
|
||||
{
|
||||
std::cout << name << ": " << value << '\n';
|
||||
}
|
||||
}
|
||||
|
||||
void gather_statistics(std::istringstream& is)
|
||||
{
|
||||
Eval::NNUE::init();
|
||||
|
||||
auto& registry = get_statistics_gatherers_registry();
|
||||
|
||||
StatisticGathererSet statistic_gatherers;
|
||||
|
||||
std::string input_file;
|
||||
std::uint64_t max_count = std::numeric_limits<std::uint64_t>::max();
|
||||
|
||||
while(true)
|
||||
{
|
||||
std::string token;
|
||||
is >> token;
|
||||
|
||||
if (token == "")
|
||||
break;
|
||||
|
||||
if (token == "input_file")
|
||||
is >> input_file;
|
||||
else if (token == "max_count")
|
||||
is >> max_count;
|
||||
else
|
||||
registry.add_statistic_gatherers_by_group(statistic_gatherers, token);
|
||||
}
|
||||
|
||||
do_gather_statistics(input_file, statistic_gatherers, max_count);
|
||||
}
|
||||
|
||||
}
|
||||
12
src/tools/stats.h
Normal file
12
src/tools/stats.h
Normal file
@@ -0,0 +1,12 @@
|
||||
#ifndef _STATS_H_
|
||||
#define _STATS_H_
|
||||
|
||||
#include <sstream>
|
||||
|
||||
namespace Tools::Stats {
|
||||
|
||||
void gather_statistics(std::istringstream& is);
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
514
src/tools/transform.cpp
Normal file
514
src/tools/transform.cpp
Normal file
@@ -0,0 +1,514 @@
|
||||
#include "transform.h"
|
||||
|
||||
#include "sfen_stream.h"
|
||||
#include "packed_sfen.h"
|
||||
#include "sfen_writer.h"
|
||||
|
||||
#include "thread.h"
|
||||
#include "position.h"
|
||||
#include "evaluate.h"
|
||||
#include "search.h"
|
||||
|
||||
#include "nnue/evaluate_nnue.h"
|
||||
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <iostream>
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <mutex>
|
||||
#include <optional>
|
||||
|
||||
namespace Tools
|
||||
{
|
||||
using CommandFunc = void(*)(std::istringstream&);
|
||||
|
||||
enum struct NudgedStaticMode
|
||||
{
|
||||
Absolute,
|
||||
Relative,
|
||||
Interpolate
|
||||
};
|
||||
|
||||
struct NudgedStaticParams
|
||||
{
|
||||
std::string input_filename = "in.binpack";
|
||||
std::string output_filename = "out.binpack";
|
||||
NudgedStaticMode mode = NudgedStaticMode::Absolute;
|
||||
int absolute_nudge = 5;
|
||||
float relative_nudge = 0.1;
|
||||
float interpolate_nudge = 0.1;
|
||||
|
||||
void enforce_constraints()
|
||||
{
|
||||
relative_nudge = std::max(relative_nudge, 0.0f);
|
||||
absolute_nudge = std::max(absolute_nudge, 0);
|
||||
}
|
||||
};
|
||||
|
||||
struct RescoreParams
|
||||
{
|
||||
std::string input_filename = "in.epd";
|
||||
std::string output_filename = "out.binpack";
|
||||
int depth = 3;
|
||||
int research_count = 0;
|
||||
bool keep_moves = true;
|
||||
|
||||
void enforce_constraints()
|
||||
{
|
||||
depth = std::max(1, depth);
|
||||
research_count = std::max(0, research_count);
|
||||
}
|
||||
};
|
||||
|
||||
[[nodiscard]] std::int16_t nudge(NudgedStaticParams& params, std::int16_t static_eval_i16, std::int16_t deep_eval_i16)
|
||||
{
|
||||
auto saturate_i32_to_i16 = [](int v) {
|
||||
return static_cast<std::int16_t>(
|
||||
std::clamp(
|
||||
v,
|
||||
(int)std::numeric_limits<std::int16_t>::min(),
|
||||
(int)std::numeric_limits<std::int16_t>::max()
|
||||
)
|
||||
);
|
||||
};
|
||||
|
||||
auto saturate_f32_to_i16 = [saturate_i32_to_i16](float v) {
|
||||
return saturate_i32_to_i16((int)v);
|
||||
};
|
||||
|
||||
int static_eval = static_eval_i16;
|
||||
int deep_eval = deep_eval_i16;
|
||||
|
||||
switch(params.mode)
|
||||
{
|
||||
case NudgedStaticMode::Absolute:
|
||||
return saturate_i32_to_i16(
|
||||
static_eval + std::clamp(
|
||||
deep_eval - static_eval,
|
||||
-params.absolute_nudge,
|
||||
params.absolute_nudge
|
||||
)
|
||||
);
|
||||
|
||||
case NudgedStaticMode::Relative:
|
||||
return saturate_f32_to_i16(
|
||||
(float)static_eval * std::clamp(
|
||||
(float)deep_eval / (float)static_eval,
|
||||
(1.0f - params.relative_nudge),
|
||||
(1.0f + params.relative_nudge)
|
||||
)
|
||||
);
|
||||
|
||||
case NudgedStaticMode::Interpolate:
|
||||
return saturate_f32_to_i16(
|
||||
(float)static_eval * (1.0f - params.interpolate_nudge)
|
||||
+ (float)deep_eval * params.interpolate_nudge
|
||||
);
|
||||
|
||||
default:
|
||||
assert(false);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
void do_nudged_static(NudgedStaticParams& params)
|
||||
{
|
||||
Thread* th = Threads.main();
|
||||
Position& pos = th->rootPos;
|
||||
StateInfo si;
|
||||
|
||||
auto in = Tools::open_sfen_input_file(params.input_filename);
|
||||
auto out = Tools::create_new_sfen_output(params.output_filename);
|
||||
|
||||
if (in == nullptr)
|
||||
{
|
||||
std::cerr << "Invalid input file type.\n";
|
||||
return;
|
||||
}
|
||||
|
||||
if (out == nullptr)
|
||||
{
|
||||
std::cerr << "Invalid output file type.\n";
|
||||
return;
|
||||
}
|
||||
|
||||
PSVector buffer;
|
||||
uint64_t batch_size = 1'000'000;
|
||||
|
||||
buffer.reserve(batch_size);
|
||||
|
||||
uint64_t num_processed = 0;
|
||||
for (;;)
|
||||
{
|
||||
auto v = in->next();
|
||||
if (!v.has_value())
|
||||
break;
|
||||
|
||||
auto& ps = v.value();
|
||||
|
||||
pos.set_from_packed_sfen(ps.sfen, &si, th);
|
||||
auto static_eval = Eval::evaluate(pos);
|
||||
auto deep_eval = ps.score;
|
||||
ps.score = nudge(params, static_eval, deep_eval);
|
||||
|
||||
buffer.emplace_back(ps);
|
||||
if (buffer.size() >= batch_size)
|
||||
{
|
||||
num_processed += buffer.size();
|
||||
|
||||
out->write(buffer);
|
||||
buffer.clear();
|
||||
|
||||
std::cout << "Processed " << num_processed << " positions.\n";
|
||||
}
|
||||
}
|
||||
|
||||
if (!buffer.empty())
|
||||
{
|
||||
num_processed += buffer.size();
|
||||
|
||||
out->write(buffer);
|
||||
buffer.clear();
|
||||
|
||||
std::cout << "Processed " << num_processed << " positions.\n";
|
||||
}
|
||||
|
||||
std::cout << "Finished.\n";
|
||||
}
|
||||
|
||||
void nudged_static(std::istringstream& is)
|
||||
{
|
||||
NudgedStaticParams params{};
|
||||
|
||||
while(true)
|
||||
{
|
||||
std::string token;
|
||||
is >> token;
|
||||
|
||||
if (token == "")
|
||||
break;
|
||||
|
||||
if (token == "absolute")
|
||||
{
|
||||
params.mode = NudgedStaticMode::Absolute;
|
||||
is >> params.absolute_nudge;
|
||||
}
|
||||
else if (token == "relative")
|
||||
{
|
||||
params.mode = NudgedStaticMode::Relative;
|
||||
is >> params.relative_nudge;
|
||||
}
|
||||
else if (token == "interpolate")
|
||||
{
|
||||
params.mode = NudgedStaticMode::Interpolate;
|
||||
is >> params.interpolate_nudge;
|
||||
}
|
||||
else if (token == "input_file")
|
||||
is >> params.input_filename;
|
||||
else if (token == "output_file")
|
||||
is >> params.output_filename;
|
||||
}
|
||||
|
||||
std::cout << "Performing transform nudged_static with parameters:\n";
|
||||
std::cout << "input_file : " << params.input_filename << '\n';
|
||||
std::cout << "output_file : " << params.output_filename << '\n';
|
||||
std::cout << "\n";
|
||||
if (params.mode == NudgedStaticMode::Absolute)
|
||||
{
|
||||
std::cout << "mode : absolute\n";
|
||||
std::cout << "absolute_nudge : " << params.absolute_nudge << '\n';
|
||||
}
|
||||
else if (params.mode == NudgedStaticMode::Relative)
|
||||
{
|
||||
std::cout << "mode : relative\n";
|
||||
std::cout << "relative_nudge : " << params.relative_nudge << '\n';
|
||||
}
|
||||
else if (params.mode == NudgedStaticMode::Interpolate)
|
||||
{
|
||||
std::cout << "mode : interpolate\n";
|
||||
std::cout << "interpolate_nudge : " << params.interpolate_nudge << '\n';
|
||||
}
|
||||
std::cout << '\n';
|
||||
|
||||
params.enforce_constraints();
|
||||
do_nudged_static(params);
|
||||
}
|
||||
|
||||
void do_rescore_epd(RescoreParams& params)
|
||||
{
|
||||
std::ifstream fens_file(params.input_filename);
|
||||
|
||||
auto next_fen = [&fens_file, mutex = std::mutex{}]() mutable -> std::optional<std::string>{
|
||||
std::string fen;
|
||||
|
||||
std::unique_lock lock(mutex);
|
||||
|
||||
if (std::getline(fens_file, fen) && fen.size() >= 10)
|
||||
{
|
||||
return fen;
|
||||
}
|
||||
else
|
||||
{
|
||||
return std::nullopt;
|
||||
}
|
||||
};
|
||||
|
||||
PSVector buffer;
|
||||
uint64_t batch_size = 10'000;
|
||||
|
||||
buffer.reserve(batch_size);
|
||||
|
||||
auto out = Tools::create_new_sfen_output(params.output_filename);
|
||||
|
||||
std::mutex mutex;
|
||||
uint64_t num_processed = 0;
|
||||
|
||||
// About Search::Limits
|
||||
// Be careful because this member variable is global and affects other threads.
|
||||
auto& limits = Search::Limits;
|
||||
|
||||
// Make the search equivalent to the "go infinite" command. (Because it is troublesome if time management is done)
|
||||
limits.infinite = true;
|
||||
|
||||
// Since PV is an obstacle when displayed, erase it.
|
||||
limits.silent = true;
|
||||
|
||||
// If you use this, it will be compared with the accumulated nodes of each thread. Therefore, do not use it.
|
||||
limits.nodes = 0;
|
||||
|
||||
// depth is also processed by the one passed as an argument of Tools::search().
|
||||
limits.depth = 0;
|
||||
|
||||
Threads.execute_with_workers([&](auto& th){
|
||||
Position& pos = th.rootPos;
|
||||
StateInfo si;
|
||||
|
||||
for(;;)
|
||||
{
|
||||
auto fen = next_fen();
|
||||
if (!fen.has_value())
|
||||
return;
|
||||
|
||||
pos.set(*fen, false, &si, &th);
|
||||
pos.state()->rule50 = 0;
|
||||
|
||||
|
||||
for (int cnt = 0; cnt < params.research_count; ++cnt)
|
||||
Search::search(pos, params.depth, 1);
|
||||
|
||||
auto [search_value, search_pv] = Search::search(pos, params.depth, 1);
|
||||
|
||||
if (search_pv.empty())
|
||||
continue;
|
||||
|
||||
PackedSfenValue ps;
|
||||
pos.sfen_pack(ps.sfen);
|
||||
ps.score = search_value;
|
||||
ps.move = search_pv[0];
|
||||
ps.gamePly = 1;
|
||||
ps.game_result = 0;
|
||||
ps.padding = 0;
|
||||
|
||||
std::unique_lock lock(mutex);
|
||||
buffer.emplace_back(ps);
|
||||
if (buffer.size() >= batch_size)
|
||||
{
|
||||
num_processed += buffer.size();
|
||||
|
||||
out->write(buffer);
|
||||
buffer.clear();
|
||||
|
||||
std::cout << "Processed " << num_processed << " positions.\n";
|
||||
}
|
||||
}
|
||||
});
|
||||
Threads.wait_for_workers_finished();
|
||||
|
||||
if (!buffer.empty())
|
||||
{
|
||||
num_processed += buffer.size();
|
||||
|
||||
out->write(buffer);
|
||||
buffer.clear();
|
||||
|
||||
std::cout << "Processed " << num_processed << " positions.\n";
|
||||
}
|
||||
|
||||
std::cout << "Finished.\n";
|
||||
}
|
||||
|
||||
void do_rescore_data(RescoreParams& params)
|
||||
{
|
||||
// TODO: Use SfenReader once it works correctly in sequential mode. See issue #271
|
||||
auto in = Tools::open_sfen_input_file(params.input_filename);
|
||||
auto readsome = [&in, mutex = std::mutex{}](int n) mutable -> PSVector {
|
||||
|
||||
PSVector psv;
|
||||
psv.reserve(n);
|
||||
|
||||
std::unique_lock lock(mutex);
|
||||
|
||||
for (int i = 0; i < n; ++i)
|
||||
{
|
||||
auto ps_opt = in->next();
|
||||
if (ps_opt.has_value())
|
||||
{
|
||||
psv.emplace_back(*ps_opt);
|
||||
}
|
||||
else
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return psv;
|
||||
};
|
||||
|
||||
auto sfen_format = ends_with(params.output_filename, ".binpack") ? SfenOutputType::Binpack : SfenOutputType::Bin;
|
||||
|
||||
auto out = SfenWriter(
|
||||
params.output_filename,
|
||||
Threads.size(),
|
||||
std::numeric_limits<std::uint64_t>::max(),
|
||||
sfen_format);
|
||||
|
||||
// About Search::Limits
|
||||
// Be careful because this member variable is global and affects other threads.
|
||||
auto& limits = Search::Limits;
|
||||
|
||||
// Make the search equivalent to the "go infinite" command. (Because it is troublesome if time management is done)
|
||||
limits.infinite = true;
|
||||
|
||||
// Since PV is an obstacle when displayed, erase it.
|
||||
limits.silent = true;
|
||||
|
||||
// If you use this, it will be compared with the accumulated nodes of each thread. Therefore, do not use it.
|
||||
limits.nodes = 0;
|
||||
|
||||
// depth is also processed by the one passed as an argument of Tools::search().
|
||||
limits.depth = 0;
|
||||
|
||||
std::atomic<std::uint64_t> num_processed = 0;
|
||||
|
||||
Threads.execute_with_workers([&](auto& th){
|
||||
Position& pos = th.rootPos;
|
||||
StateInfo si;
|
||||
|
||||
for (;;)
|
||||
{
|
||||
PSVector psv = readsome(5000);
|
||||
if (psv.empty())
|
||||
break;
|
||||
|
||||
for(auto& ps : psv)
|
||||
{
|
||||
pos.set_from_packed_sfen(ps.sfen, &si, &th);
|
||||
|
||||
for (int cnt = 0; cnt < params.research_count; ++cnt)
|
||||
Search::search(pos, params.depth, 1);
|
||||
|
||||
auto [search_value, search_pv] = Search::search(pos, params.depth, 1);
|
||||
|
||||
if (search_pv.empty())
|
||||
continue;
|
||||
|
||||
pos.sfen_pack(ps.sfen);
|
||||
ps.score = search_value;
|
||||
if (!params.keep_moves)
|
||||
ps.move = search_pv[0];
|
||||
ps.padding = 0;
|
||||
|
||||
out.write(th.thread_idx(), ps);
|
||||
|
||||
auto p = num_processed.fetch_add(1) + 1;
|
||||
if (p % 10000 == 0)
|
||||
{
|
||||
std::cout << "Processed " << p << " positions.\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
Threads.wait_for_workers_finished();
|
||||
|
||||
std::cout << "Finished.\n";
|
||||
}
|
||||
|
||||
void do_rescore(RescoreParams& params)
|
||||
{
|
||||
if (ends_with(params.input_filename, ".epd"))
|
||||
{
|
||||
do_rescore_epd(params);
|
||||
}
|
||||
else if (ends_with(params.input_filename, ".bin") || ends_with(params.input_filename, ".binpack"))
|
||||
{
|
||||
do_rescore_data(params);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Invalid input file type.\n";
|
||||
}
|
||||
}
|
||||
|
||||
void rescore(std::istringstream& is)
|
||||
{
|
||||
RescoreParams params{};
|
||||
|
||||
while(true)
|
||||
{
|
||||
std::string token;
|
||||
is >> token;
|
||||
|
||||
if (token == "")
|
||||
break;
|
||||
|
||||
if (token == "depth")
|
||||
is >> params.depth;
|
||||
else if (token == "input_file")
|
||||
is >> params.input_filename;
|
||||
else if (token == "output_file")
|
||||
is >> params.output_filename;
|
||||
else if (token == "keep_moves")
|
||||
is >> params.keep_moves;
|
||||
else if (token == "research_count")
|
||||
is >> params.research_count;
|
||||
}
|
||||
|
||||
params.enforce_constraints();
|
||||
|
||||
std::cout << "Performing transform rescore with parameters:\n";
|
||||
std::cout << "depth : " << params.depth << '\n';
|
||||
std::cout << "input_file : " << params.input_filename << '\n';
|
||||
std::cout << "output_file : " << params.output_filename << '\n';
|
||||
std::cout << "keep_moves : " << params.keep_moves << '\n';
|
||||
std::cout << "research_count : " << params.research_count << '\n';
|
||||
std::cout << '\n';
|
||||
|
||||
do_rescore(params);
|
||||
}
|
||||
|
||||
void transform(std::istringstream& is)
|
||||
{
|
||||
const std::map<std::string, CommandFunc> subcommands = {
|
||||
{ "nudged_static", &nudged_static },
|
||||
{ "rescore", &rescore }
|
||||
};
|
||||
|
||||
Eval::NNUE::init();
|
||||
|
||||
std::string subcommand;
|
||||
is >> subcommand;
|
||||
|
||||
auto func = subcommands.find(subcommand);
|
||||
if (func == subcommands.end())
|
||||
{
|
||||
std::cout << "Invalid subcommand " << subcommand << ". Exiting...\n";
|
||||
return;
|
||||
}
|
||||
|
||||
func->second(is);
|
||||
}
|
||||
|
||||
}
|
||||
12
src/tools/transform.h
Normal file
12
src/tools/transform.h
Normal file
@@ -0,0 +1,12 @@
|
||||
#ifndef _TRANSFORM_H_
|
||||
#define _TRANSFORM_H_
|
||||
|
||||
#include <sstream>
|
||||
|
||||
namespace Tools {
|
||||
|
||||
void transform(std::istringstream& is);
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
Reference in New Issue
Block a user