mirror of
https://github.com/HChaZZY/Stockfish.git
synced 2025-12-24 19:16:49 +08:00
Add transform command. Add transform nudged_static subcommand.
This commit is contained in:
@@ -64,7 +64,8 @@ SRCS = benchmark.cpp bitbase.cpp bitboard.cpp endgame.cpp evaluate.cpp main.cpp
|
||||
learn/learn.cpp \
|
||||
learn/gensfen.cpp \
|
||||
learn/opening_book.cpp \
|
||||
learn/convert.cpp
|
||||
learn/convert.cpp \
|
||||
learn/transform.cpp
|
||||
|
||||
OBJS = $(notdir $(SRCS:.cpp=.o))
|
||||
|
||||
|
||||
@@ -207,6 +207,16 @@ namespace Learner {
|
||||
assert(false);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
inline std::unique_ptr<BasicSfenOutputStream> create_new_sfen_output(const std::string& filename)
|
||||
{
|
||||
if (has_extension(filename, BinSfenOutputStream::extension))
|
||||
return std::make_unique<BinSfenOutputStream>(filename);
|
||||
else if (has_extension(filename, BinpackSfenOutputStream::extension))
|
||||
return std::make_unique<BinpackSfenOutputStream>(filename);
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
242
src/learn/transform.cpp
Normal file
242
src/learn/transform.cpp
Normal file
@@ -0,0 +1,242 @@
|
||||
#include "transform.h"
|
||||
|
||||
#include "sfen_stream.h"
|
||||
#include "packed_sfen.h"
|
||||
|
||||
#include "thread.h"
|
||||
#include "position.h"
|
||||
#include "evaluate.h"
|
||||
|
||||
#include "nnue/evaluate_nnue.h"
|
||||
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <iostream>
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
|
||||
namespace Learner
|
||||
{
|
||||
using CommandFunc = void(*)(std::istringstream&);
|
||||
|
||||
enum struct NudgedStaticMode
|
||||
{
|
||||
Absolute,
|
||||
Relative,
|
||||
Interpolate
|
||||
};
|
||||
|
||||
struct NudgedStaticParams
|
||||
{
|
||||
std::string input_filename = "in.binpack";
|
||||
std::string output_filename = "out.binpack";
|
||||
NudgedStaticMode mode = NudgedStaticMode::Absolute;
|
||||
int absolute_nudge = 5;
|
||||
float relative_nudge = 0.1;
|
||||
float interpolate_nudge = 0.1;
|
||||
|
||||
void enforce_constraints()
|
||||
{
|
||||
relative_nudge = std::max(relative_nudge, 0.0f);
|
||||
absolute_nudge = std::max(absolute_nudge, 0);
|
||||
}
|
||||
};
|
||||
|
||||
[[nodiscard]] std::int16_t nudge(NudgedStaticParams& params, std::int16_t static_eval_i16, std::int16_t deep_eval_i16)
|
||||
{
|
||||
auto saturate_i32_to_i16 = [](int v) {
|
||||
return static_cast<std::int16_t>(
|
||||
std::clamp(
|
||||
v,
|
||||
(int)std::numeric_limits<std::int16_t>::min(),
|
||||
(int)std::numeric_limits<std::int16_t>::max()
|
||||
)
|
||||
);
|
||||
};
|
||||
|
||||
auto saturate_f32_to_i16 = [saturate_i32_to_i16](float v) {
|
||||
return saturate_i32_to_i16((int)v);
|
||||
};
|
||||
|
||||
int static_eval = static_eval_i16;
|
||||
int deep_eval = deep_eval_i16;
|
||||
|
||||
switch(params.mode)
|
||||
{
|
||||
case NudgedStaticMode::Absolute:
|
||||
return saturate_i32_to_i16(
|
||||
static_eval + std::clamp(
|
||||
deep_eval - static_eval,
|
||||
-params.absolute_nudge,
|
||||
params.absolute_nudge
|
||||
)
|
||||
);
|
||||
|
||||
case NudgedStaticMode::Relative:
|
||||
return saturate_f32_to_i16(
|
||||
(float)static_eval * std::clamp(
|
||||
(float)deep_eval / (float)static_eval,
|
||||
(1.0f - params.relative_nudge),
|
||||
(1.0f + params.relative_nudge)
|
||||
)
|
||||
);
|
||||
|
||||
case NudgedStaticMode::Interpolate:
|
||||
return saturate_f32_to_i16(
|
||||
(float)static_eval * (1.0f - params.interpolate_nudge)
|
||||
+ (float)deep_eval * params.interpolate_nudge
|
||||
);
|
||||
|
||||
default:
|
||||
assert(false);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
void do_nudged_static(NudgedStaticParams& params)
|
||||
{
|
||||
Thread* th = Threads.main();
|
||||
Position& pos = th->rootPos;
|
||||
StateInfo si;
|
||||
|
||||
auto in = Learner::open_sfen_input_file(params.input_filename);
|
||||
auto out = Learner::create_new_sfen_output(params.output_filename);
|
||||
|
||||
if (in == nullptr)
|
||||
{
|
||||
std::cerr << "Invalid input file type.\n";
|
||||
return;
|
||||
}
|
||||
|
||||
if (out == nullptr)
|
||||
{
|
||||
std::cerr << "Invalid output file type.\n";
|
||||
return;
|
||||
}
|
||||
|
||||
PSVector buffer;
|
||||
uint64_t batch_size = 1'000'000;
|
||||
|
||||
buffer.reserve(batch_size);
|
||||
|
||||
uint64_t num_processed = 0;
|
||||
for (;;)
|
||||
{
|
||||
auto v = in->next();
|
||||
if (!v.has_value())
|
||||
break;
|
||||
|
||||
auto& ps = v.value();
|
||||
|
||||
pos.set_from_packed_sfen(ps.sfen, &si, th);
|
||||
auto static_eval = Eval::evaluate(pos);
|
||||
auto deep_eval = ps.score;
|
||||
ps.score = nudge(params, static_eval, deep_eval);
|
||||
|
||||
buffer.emplace_back(ps);
|
||||
if (buffer.size() >= batch_size)
|
||||
{
|
||||
num_processed += buffer.size();
|
||||
|
||||
out->write(buffer);
|
||||
buffer.clear();
|
||||
|
||||
std::cout << "Processed " << num_processed << " positions.\n";
|
||||
}
|
||||
}
|
||||
|
||||
if (!buffer.empty())
|
||||
{
|
||||
num_processed += buffer.size();
|
||||
|
||||
out->write(buffer);
|
||||
buffer.clear();
|
||||
|
||||
std::cout << "Processed " << num_processed << " positions.\n";
|
||||
}
|
||||
|
||||
std::cout << "Finished.\n";
|
||||
}
|
||||
|
||||
void nudged_static(std::istringstream& is)
|
||||
{
|
||||
NudgedStaticParams params{};
|
||||
|
||||
while(true)
|
||||
{
|
||||
std::string token;
|
||||
is >> token;
|
||||
|
||||
if (token == "")
|
||||
break;
|
||||
|
||||
if (token == "absolute")
|
||||
{
|
||||
params.mode = NudgedStaticMode::Absolute;
|
||||
is >> params.absolute_nudge;
|
||||
}
|
||||
else if (token == "relative")
|
||||
{
|
||||
params.mode = NudgedStaticMode::Relative;
|
||||
is >> params.relative_nudge;
|
||||
}
|
||||
else if (token == "interpolate")
|
||||
{
|
||||
params.mode = NudgedStaticMode::Interpolate;
|
||||
is >> params.interpolate_nudge;
|
||||
}
|
||||
else if (token == "input_file")
|
||||
is >> params.input_filename;
|
||||
else if (token == "output_file")
|
||||
is >> params.output_filename;
|
||||
}
|
||||
|
||||
std::cout << "Performing transform nudged_static with parameters:\n";
|
||||
std::cout << "input_file : " << params.input_filename << '\n';
|
||||
std::cout << "output_file : " << params.output_filename << '\n';
|
||||
std::cout << "\n";
|
||||
if (params.mode == NudgedStaticMode::Absolute)
|
||||
{
|
||||
std::cout << "mode : absolute\n";
|
||||
std::cout << "absolute_nudge : " << params.absolute_nudge << '\n';
|
||||
}
|
||||
else if (params.mode == NudgedStaticMode::Relative)
|
||||
{
|
||||
std::cout << "mode : relative\n";
|
||||
std::cout << "relative_nudge : " << params.relative_nudge << '\n';
|
||||
}
|
||||
else if (params.mode == NudgedStaticMode::Interpolate)
|
||||
{
|
||||
std::cout << "mode : interpolate\n";
|
||||
std::cout << "interpolate_nudge : " << params.interpolate_nudge << '\n';
|
||||
}
|
||||
std::cout << '\n';
|
||||
|
||||
params.enforce_constraints();
|
||||
do_nudged_static(params);
|
||||
}
|
||||
|
||||
void transform(std::istringstream& is)
|
||||
{
|
||||
const std::map<std::string, CommandFunc> subcommands = {
|
||||
{ "nudged_static", &nudged_static }
|
||||
};
|
||||
|
||||
Eval::NNUE::init();
|
||||
|
||||
std::string subcommand;
|
||||
is >> subcommand;
|
||||
|
||||
auto func = subcommands.find(subcommand);
|
||||
if (func == subcommands.end())
|
||||
{
|
||||
std::cout << "Invalid subcommand " << subcommand << ". Exiting...\n";
|
||||
return;
|
||||
}
|
||||
|
||||
func->second(is);
|
||||
}
|
||||
|
||||
}
|
||||
12
src/learn/transform.h
Normal file
12
src/learn/transform.h
Normal file
@@ -0,0 +1,12 @@
|
||||
#ifndef _TRANSFORM_H_
|
||||
#define _TRANSFORM_H_
|
||||
|
||||
#include <sstream>
|
||||
|
||||
namespace Learner {
|
||||
|
||||
void transform(std::istringstream& is);
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -38,6 +38,7 @@
|
||||
#include "learn/gensfen.h"
|
||||
#include "learn/learn.h"
|
||||
#include "learn/convert.h"
|
||||
#include "learn/transform.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
@@ -345,6 +346,7 @@ void UCI::loop(int argc, char* argv[]) {
|
||||
else if (token == "convert_bin") Learner::convert_bin(is);
|
||||
else if (token == "convert_plain") Learner::convert_plain(is);
|
||||
else if (token == "convert_bin_from_pgn_extract") Learner::convert_bin_from_pgn_extract(is);
|
||||
else if (token == "transform") Learner::transform(is);
|
||||
|
||||
// Command to call qsearch(),search() directly for testing
|
||||
else if (token == "qsearch") qsearch_cmd(pos);
|
||||
|
||||
Reference in New Issue
Block a user