diff --git a/src/learn/transform.cpp b/src/learn/transform.cpp index 5687b48b..3e302e21 100644 --- a/src/learn/transform.cpp +++ b/src/learn/transform.cpp @@ -6,6 +6,7 @@ #include "thread.h" #include "position.h" #include "evaluate.h" +#include "search.h" #include "nnue/evaluate_nnue.h" @@ -16,6 +17,8 @@ #include #include #include +#include +#include namespace Learner { @@ -44,6 +47,18 @@ namespace Learner } }; + struct RescoreFenParams + { + std::string input_filename = "in.epd"; + std::string output_filename = "out.binpack"; + int depth = 3; + + void enforce_constraints() + { + depth = std::max(1, depth); + } + }; + [[nodiscard]] std::int16_t nudge(NudgedStaticParams& params, std::int16_t static_eval_i16, std::int16_t deep_eval_i16) { auto saturate_i32_to_i16 = [](int v) { @@ -218,10 +233,141 @@ namespace Learner do_nudged_static(params); } + void do_rescore_fen(RescoreFenParams& params) + { + std::ifstream fens_file(params.input_filename); + + auto next_fen = [&fens_file]() -> std::optional{ + static std::mutex mutex; + + std::string fen; + + std::unique_lock lock(mutex); + + if (std::getline(fens_file, fen) && fen.size() >= 10) + { + return fen; + } + else + { + return std::nullopt; + } + }; + + PSVector buffer; + uint64_t batch_size = 10'000; + + buffer.reserve(batch_size); + + auto out = Learner::create_new_sfen_output(params.output_filename); + + std::mutex mutex; + uint64_t num_processed = 0; + + // About Search::Limits + // Be careful because this member variable is global and affects other threads. + auto& limits = Search::Limits; + + // Make the search equivalent to the "go infinite" command. (Because it is troublesome if time management is done) + limits.infinite = true; + + // Since PV is an obstacle when displayed, erase it. + limits.silent = true; + + // If you use this, it will be compared with the accumulated nodes of each thread. Therefore, do not use it. + limits.nodes = 0; + + // depth is also processed by the one passed as an argument of Learner::search(). + limits.depth = 0; + + Threads.execute_with_workers([&](auto& th){ + Position& pos = th.rootPos; + StateInfo si; + + for(;;) + { + auto fen = next_fen(); + if (!fen.has_value()) + return; + + pos.set(*fen, false, &si, &th); + pos.state()->rule50 = 0; + + auto [search_value, search_pv] = Search::search(pos, params.depth, 1); + if (search_pv.empty()) + continue; + + PackedSfenValue ps; + pos.sfen_pack(ps.sfen); + ps.score = search_value; + ps.move = search_pv[0]; + ps.gamePly = 1; + ps.game_result = 0; + ps.padding = 0; + + std::unique_lock lock(mutex); + buffer.emplace_back(ps); + if (buffer.size() >= batch_size) + { + num_processed += buffer.size(); + + out->write(buffer); + buffer.clear(); + + std::cout << "Processed " << num_processed << " positions.\n"; + } + } + }); + Threads.wait_for_workers_finished(); + + if (!buffer.empty()) + { + num_processed += buffer.size(); + + out->write(buffer); + buffer.clear(); + + std::cout << "Processed " << num_processed << " positions.\n"; + } + + std::cout << "Finished.\n"; + } + + void rescore_fen(std::istringstream& is) + { + RescoreFenParams params{}; + + while(true) + { + std::string token; + is >> token; + + if (token == "") + break; + + if (token == "depth") + is >> params.depth; + else if (token == "input_file") + is >> params.input_filename; + else if (token == "output_file") + is >> params.output_filename; + } + + std::cout << "Performing transform rescore_fen with parameters:\n"; + std::cout << "depth : " << params.depth << '\n'; + std::cout << "input_file : " << params.input_filename << '\n'; + std::cout << "output_file : " << params.output_filename << '\n'; + std::cout << '\n'; + + params.enforce_constraints(); + do_rescore_fen(params); + } + void transform(std::istringstream& is) { const std::map subcommands = { - { "nudged_static", &nudged_static } + { "nudged_static", &nudged_static }, + { "rescore_fen", &rescore_fen } }; Eval::NNUE::init();