From 89294e2e4f44fdf3b4c3e38609c6c9b4c2a3c982 Mon Sep 17 00:00:00 2001 From: Tomasz Sobczyk Date: Thu, 26 Nov 2020 17:28:09 +0100 Subject: [PATCH] Add transform command. Add transform nudged_static subcommand. --- src/Makefile | 3 +- src/learn/sfen_stream.h | 10 ++ src/learn/transform.cpp | 242 ++++++++++++++++++++++++++++++++++++++++ src/learn/transform.h | 12 ++ src/uci.cpp | 2 + 5 files changed, 268 insertions(+), 1 deletion(-) create mode 100644 src/learn/transform.cpp create mode 100644 src/learn/transform.h diff --git a/src/Makefile b/src/Makefile index a5f5f06f..7f00bfff 100644 --- a/src/Makefile +++ b/src/Makefile @@ -64,7 +64,8 @@ SRCS = benchmark.cpp bitbase.cpp bitboard.cpp endgame.cpp evaluate.cpp main.cpp learn/learn.cpp \ learn/gensfen.cpp \ learn/opening_book.cpp \ - learn/convert.cpp + learn/convert.cpp \ + learn/transform.cpp OBJS = $(notdir $(SRCS:.cpp=.o)) diff --git a/src/learn/sfen_stream.h b/src/learn/sfen_stream.h index d25dd41d..da411346 100644 --- a/src/learn/sfen_stream.h +++ b/src/learn/sfen_stream.h @@ -207,6 +207,16 @@ namespace Learner { assert(false); return nullptr; } + + inline std::unique_ptr create_new_sfen_output(const std::string& filename) + { + if (has_extension(filename, BinSfenOutputStream::extension)) + return std::make_unique(filename); + else if (has_extension(filename, BinpackSfenOutputStream::extension)) + return std::make_unique(filename); + + return nullptr; + } } #endif \ No newline at end of file diff --git a/src/learn/transform.cpp b/src/learn/transform.cpp new file mode 100644 index 00000000..5687b48b --- /dev/null +++ b/src/learn/transform.cpp @@ -0,0 +1,242 @@ +#include "transform.h" + +#include "sfen_stream.h" +#include "packed_sfen.h" + +#include "thread.h" +#include "position.h" +#include "evaluate.h" + +#include "nnue/evaluate_nnue.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace Learner +{ + using CommandFunc = void(*)(std::istringstream&); + + enum struct NudgedStaticMode + { + Absolute, + Relative, + Interpolate + }; + + struct NudgedStaticParams + { + std::string input_filename = "in.binpack"; + std::string output_filename = "out.binpack"; + NudgedStaticMode mode = NudgedStaticMode::Absolute; + int absolute_nudge = 5; + float relative_nudge = 0.1; + float interpolate_nudge = 0.1; + + void enforce_constraints() + { + relative_nudge = std::max(relative_nudge, 0.0f); + absolute_nudge = std::max(absolute_nudge, 0); + } + }; + + [[nodiscard]] std::int16_t nudge(NudgedStaticParams& params, std::int16_t static_eval_i16, std::int16_t deep_eval_i16) + { + auto saturate_i32_to_i16 = [](int v) { + return static_cast( + std::clamp( + v, + (int)std::numeric_limits::min(), + (int)std::numeric_limits::max() + ) + ); + }; + + auto saturate_f32_to_i16 = [saturate_i32_to_i16](float v) { + return saturate_i32_to_i16((int)v); + }; + + int static_eval = static_eval_i16; + int deep_eval = deep_eval_i16; + + switch(params.mode) + { + case NudgedStaticMode::Absolute: + return saturate_i32_to_i16( + static_eval + std::clamp( + deep_eval - static_eval, + -params.absolute_nudge, + params.absolute_nudge + ) + ); + + case NudgedStaticMode::Relative: + return saturate_f32_to_i16( + (float)static_eval * std::clamp( + (float)deep_eval / (float)static_eval, + (1.0f - params.relative_nudge), + (1.0f + params.relative_nudge) + ) + ); + + case NudgedStaticMode::Interpolate: + return saturate_f32_to_i16( + (float)static_eval * (1.0f - params.interpolate_nudge) + + (float)deep_eval * params.interpolate_nudge + ); + + default: + assert(false); + return 0; + } + } + + void do_nudged_static(NudgedStaticParams& params) + { + Thread* th = Threads.main(); + Position& pos = th->rootPos; + StateInfo si; + + auto in = Learner::open_sfen_input_file(params.input_filename); + auto out = Learner::create_new_sfen_output(params.output_filename); + + if (in == nullptr) + { + std::cerr << "Invalid input file type.\n"; + return; + } + + if (out == nullptr) + { + std::cerr << "Invalid output file type.\n"; + return; + } + + PSVector buffer; + uint64_t batch_size = 1'000'000; + + buffer.reserve(batch_size); + + uint64_t num_processed = 0; + for (;;) + { + auto v = in->next(); + if (!v.has_value()) + break; + + auto& ps = v.value(); + + pos.set_from_packed_sfen(ps.sfen, &si, th); + auto static_eval = Eval::evaluate(pos); + auto deep_eval = ps.score; + ps.score = nudge(params, static_eval, deep_eval); + + buffer.emplace_back(ps); + if (buffer.size() >= batch_size) + { + num_processed += buffer.size(); + + out->write(buffer); + buffer.clear(); + + std::cout << "Processed " << num_processed << " positions.\n"; + } + } + + if (!buffer.empty()) + { + num_processed += buffer.size(); + + out->write(buffer); + buffer.clear(); + + std::cout << "Processed " << num_processed << " positions.\n"; + } + + std::cout << "Finished.\n"; + } + + void nudged_static(std::istringstream& is) + { + NudgedStaticParams params{}; + + while(true) + { + std::string token; + is >> token; + + if (token == "") + break; + + if (token == "absolute") + { + params.mode = NudgedStaticMode::Absolute; + is >> params.absolute_nudge; + } + else if (token == "relative") + { + params.mode = NudgedStaticMode::Relative; + is >> params.relative_nudge; + } + else if (token == "interpolate") + { + params.mode = NudgedStaticMode::Interpolate; + is >> params.interpolate_nudge; + } + else if (token == "input_file") + is >> params.input_filename; + else if (token == "output_file") + is >> params.output_filename; + } + + std::cout << "Performing transform nudged_static with parameters:\n"; + std::cout << "input_file : " << params.input_filename << '\n'; + std::cout << "output_file : " << params.output_filename << '\n'; + std::cout << "\n"; + if (params.mode == NudgedStaticMode::Absolute) + { + std::cout << "mode : absolute\n"; + std::cout << "absolute_nudge : " << params.absolute_nudge << '\n'; + } + else if (params.mode == NudgedStaticMode::Relative) + { + std::cout << "mode : relative\n"; + std::cout << "relative_nudge : " << params.relative_nudge << '\n'; + } + else if (params.mode == NudgedStaticMode::Interpolate) + { + std::cout << "mode : interpolate\n"; + std::cout << "interpolate_nudge : " << params.interpolate_nudge << '\n'; + } + std::cout << '\n'; + + params.enforce_constraints(); + do_nudged_static(params); + } + + void transform(std::istringstream& is) + { + const std::map subcommands = { + { "nudged_static", &nudged_static } + }; + + Eval::NNUE::init(); + + std::string subcommand; + is >> subcommand; + + auto func = subcommands.find(subcommand); + if (func == subcommands.end()) + { + std::cout << "Invalid subcommand " << subcommand << ". Exiting...\n"; + return; + } + + func->second(is); + } + +} diff --git a/src/learn/transform.h b/src/learn/transform.h new file mode 100644 index 00000000..8a6921a0 --- /dev/null +++ b/src/learn/transform.h @@ -0,0 +1,12 @@ +#ifndef _TRANSFORM_H_ +#define _TRANSFORM_H_ + +#include + +namespace Learner { + + void transform(std::istringstream& is); + +} + +#endif diff --git a/src/uci.cpp b/src/uci.cpp index ae21a3ae..8e64da6b 100644 --- a/src/uci.cpp +++ b/src/uci.cpp @@ -38,6 +38,7 @@ #include "learn/gensfen.h" #include "learn/learn.h" #include "learn/convert.h" +#include "learn/transform.h" using namespace std; @@ -345,6 +346,7 @@ void UCI::loop(int argc, char* argv[]) { else if (token == "convert_bin") Learner::convert_bin(is); else if (token == "convert_plain") Learner::convert_plain(is); else if (token == "convert_bin_from_pgn_extract") Learner::convert_bin_from_pgn_extract(is); + else if (token == "transform") Learner::transform(is); // Command to call qsearch(),search() directly for testing else if (token == "qsearch") qsearch_cmd(pos);