mirror of
https://github.com/HChaZZY/Stockfish.git
synced 2025-12-25 19:46:55 +08:00
Merge remote-tracking branch 'remotes/nodchip/master' into trainer
This commit is contained in:
@@ -1,17 +1,23 @@
|
||||
#if defined(EVAL_LEARN)
|
||||
#include "gensfen.h"
|
||||
|
||||
#include "../eval/evaluate_common.h"
|
||||
#include "../misc.h"
|
||||
#include "../nnue/evaluate_nnue_learner.h"
|
||||
#include "../position.h"
|
||||
#include "../syzygy/tbprobe.h"
|
||||
#include "../thread.h"
|
||||
#include "../tt.h"
|
||||
#include "../uci.h"
|
||||
#include "learn.h"
|
||||
#include "packed_sfen.h"
|
||||
#include "multi_think.h"
|
||||
#include "../syzygy/tbprobe.h"
|
||||
|
||||
#include "misc.h"
|
||||
#include "position.h"
|
||||
#include "thread.h"
|
||||
#include "tt.h"
|
||||
#include "uci.h"
|
||||
|
||||
#include "eval/evaluate_common.h"
|
||||
|
||||
#include "extra/nnue_data_binpack_format.h"
|
||||
|
||||
#include "nnue/evaluate_nnue_learner.h"
|
||||
|
||||
#include "syzygy/tbprobe.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <climits>
|
||||
#include <cmath>
|
||||
@@ -33,11 +39,107 @@ using namespace std;
|
||||
|
||||
namespace Learner
|
||||
{
|
||||
enum struct SfenOutputType
|
||||
{
|
||||
Bin,
|
||||
Binpack
|
||||
};
|
||||
|
||||
static bool write_out_draw_game_in_training_data_generation = false;
|
||||
static bool detect_draw_by_consecutive_low_score = false;
|
||||
static bool detect_draw_by_insufficient_mating_material = false;
|
||||
|
||||
static std::vector<std::string> bookStart;
|
||||
static SfenOutputType sfen_output_type = SfenOutputType::Bin;
|
||||
|
||||
static bool ends_with(const std::string& lhs, const std::string& end)
|
||||
{
|
||||
if (end.size() > lhs.size()) return false;
|
||||
|
||||
return std::equal(end.rbegin(), end.rend(), lhs.rbegin());
|
||||
}
|
||||
|
||||
static std::string filename_with_extension(const std::string& filename, const std::string& ext)
|
||||
{
|
||||
if (ends_with(filename, ext))
|
||||
{
|
||||
return filename;
|
||||
}
|
||||
else
|
||||
{
|
||||
return filename + "." + ext;
|
||||
}
|
||||
}
|
||||
|
||||
struct BasicSfenOutputStream
|
||||
{
|
||||
virtual void write(const PSVector& sfens) = 0;
|
||||
virtual ~BasicSfenOutputStream() {}
|
||||
};
|
||||
|
||||
struct BinSfenOutputStream : BasicSfenOutputStream
|
||||
{
|
||||
static constexpr auto openmode = ios::out | ios::binary | ios::app;
|
||||
static inline const std::string extension = "bin";
|
||||
|
||||
BinSfenOutputStream(std::string filename) :
|
||||
m_stream(filename_with_extension(filename, extension), openmode)
|
||||
{
|
||||
}
|
||||
|
||||
void write(const PSVector& sfens) override
|
||||
{
|
||||
m_stream.write(reinterpret_cast<const char*>(sfens.data()), sizeof(PackedSfenValue) * sfens.size());
|
||||
}
|
||||
|
||||
~BinSfenOutputStream() override {}
|
||||
|
||||
private:
|
||||
fstream m_stream;
|
||||
};
|
||||
|
||||
struct BinpackSfenOutputStream : BasicSfenOutputStream
|
||||
{
|
||||
static constexpr auto openmode = ios::out | ios::binary | ios::app;
|
||||
static inline const std::string extension = "binpack";
|
||||
|
||||
BinpackSfenOutputStream(std::string filename) :
|
||||
m_stream(filename_with_extension(filename, extension), openmode)
|
||||
{
|
||||
}
|
||||
|
||||
void write(const PSVector& sfens) override
|
||||
{
|
||||
static_assert(sizeof(binpack::nodchip::PackedSfenValue) == sizeof(PackedSfenValue));
|
||||
|
||||
for(auto& sfen : sfens)
|
||||
{
|
||||
// The library uses a type that's different but layout-compatibile.
|
||||
binpack::nodchip::PackedSfenValue e;
|
||||
std::memcpy(&e, &sfen, sizeof(binpack::nodchip::PackedSfenValue));
|
||||
m_stream.addTrainingDataEntry(binpack::packedSfenValueToTrainingDataEntry(e));
|
||||
}
|
||||
}
|
||||
|
||||
~BinpackSfenOutputStream() override {}
|
||||
|
||||
private:
|
||||
binpack::CompressedTrainingDataEntryWriter m_stream;
|
||||
};
|
||||
|
||||
static std::unique_ptr<BasicSfenOutputStream> create_new_sfen_output(const std::string& filename)
|
||||
{
|
||||
switch(sfen_output_type)
|
||||
{
|
||||
case SfenOutputType::Bin:
|
||||
return std::make_unique<BinSfenOutputStream>(filename);
|
||||
case SfenOutputType::Binpack:
|
||||
return std::make_unique<BinpackSfenOutputStream>(filename);
|
||||
}
|
||||
|
||||
assert(false);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Helper class for exporting Sfen
|
||||
struct SfenWriter
|
||||
@@ -55,7 +157,7 @@ namespace Learner
|
||||
sfen_buffers_pool.reserve((size_t)thread_num * 10);
|
||||
sfen_buffers.resize(thread_num);
|
||||
|
||||
output_file_stream.open(filename_, ios::out | ios::binary | ios::app);
|
||||
output_file_stream = create_new_sfen_output(filename_);
|
||||
filename = filename_;
|
||||
|
||||
finished = false;
|
||||
@@ -65,7 +167,7 @@ namespace Learner
|
||||
{
|
||||
finished = true;
|
||||
file_worker_thread.join();
|
||||
output_file_stream.close();
|
||||
output_file_stream.reset();
|
||||
|
||||
#if defined(_DEBUG)
|
||||
{
|
||||
@@ -134,9 +236,6 @@ namespace Learner
|
||||
{
|
||||
// Also output the current time to console.
|
||||
sync_cout << endl << sfen_write_count << " sfens , at " << now_string() << sync_endl;
|
||||
|
||||
// This is enough for flush().
|
||||
output_file_stream.flush();
|
||||
};
|
||||
|
||||
while (!finished || sfen_buffers_pool.size())
|
||||
@@ -160,7 +259,7 @@ namespace Learner
|
||||
{
|
||||
for (auto& buf : buffers)
|
||||
{
|
||||
output_file_stream.write(reinterpret_cast<const char*>(buf->data()), sizeof(PackedSfenValue) * buf->size());
|
||||
output_file_stream->write(*buf);
|
||||
|
||||
sfen_write_count += buf->size();
|
||||
|
||||
@@ -171,8 +270,6 @@ namespace Learner
|
||||
{
|
||||
sfen_write_count_current_file = 0;
|
||||
|
||||
output_file_stream.close();
|
||||
|
||||
// Sequential number attached to the file
|
||||
int n = (int)(sfen_write_count / save_every);
|
||||
|
||||
@@ -180,7 +277,7 @@ namespace Learner
|
||||
// Add ios::app in consideration of overwriting.
|
||||
// (Depending on the operation, it may not be necessary.)
|
||||
string new_filename = filename + "_" + std::to_string(n);
|
||||
output_file_stream.open(new_filename, ios::out | ios::binary | ios::app);
|
||||
output_file_stream = create_new_sfen_output(new_filename);
|
||||
cout << endl << "output sfen file = " << new_filename << endl;
|
||||
}
|
||||
|
||||
@@ -214,7 +311,7 @@ namespace Learner
|
||||
|
||||
private:
|
||||
|
||||
fstream output_file_stream;
|
||||
std::unique_ptr<BasicSfenOutputStream> output_file_stream;
|
||||
|
||||
// A new net is saved after every save_every sfens are processed.
|
||||
uint64_t save_every = std::numeric_limits<uint64_t>::max();
|
||||
@@ -260,7 +357,8 @@ namespace Learner
|
||||
// It must be 2**N because it will be used as the mask to calculate hash_index.
|
||||
static_assert((GENSFEN_HASH_SIZE& (GENSFEN_HASH_SIZE - 1)) == 0);
|
||||
|
||||
MultiThinkGenSfen(int search_depth_min_, int search_depth_max_, SfenWriter& sw_) :
|
||||
MultiThinkGenSfen(int search_depth_min_, int search_depth_max_, SfenWriter& sw_, const std::string& seed) :
|
||||
MultiThink(seed),
|
||||
search_depth_min(search_depth_min_),
|
||||
search_depth_max(search_depth_max_),
|
||||
sfen_writer(sw_)
|
||||
@@ -759,20 +857,6 @@ namespace Learner
|
||||
break;
|
||||
}
|
||||
|
||||
if (pos.count<ALL_PIECES>() <= 6) {
|
||||
Tablebases::ProbeState probe_state;
|
||||
Tablebases::WDLScore wdl = Tablebases::probe_wdl(pos, &probe_state);
|
||||
assert(wdl != Tablebases::WDLScore::WDLScoreNone);
|
||||
if (wdl == Tablebases::WDLScore::WDLWin) {
|
||||
flush_psv(1);
|
||||
} else if (wdl == Tablebases::WDLScore::WDLLoss) {
|
||||
flush_psv(-1);
|
||||
} else {
|
||||
flush_psv(0);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
{
|
||||
auto [search_value, search_pv] = search(pos, depth, 1, nodes);
|
||||
|
||||
@@ -819,6 +903,25 @@ namespace Learner
|
||||
goto SKIP_SAVE;
|
||||
}
|
||||
|
||||
// Look into the position hashtable to see if the same
|
||||
// position was seen before.
|
||||
// This is a good heuristic to exlude already seen
|
||||
// positions without many false positives.
|
||||
{
|
||||
auto key = pos.key();
|
||||
auto hash_index = (size_t)(key & (GENSFEN_HASH_SIZE - 1));
|
||||
auto old_key = hash[hash_index];
|
||||
if (key == old_key)
|
||||
{
|
||||
goto SKIP_SAVE;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Replace with the current key.
|
||||
hash[hash_index] = key;
|
||||
}
|
||||
}
|
||||
|
||||
// Pack the current position into a packed sfen and save it into the buffer.
|
||||
{
|
||||
a_psv.emplace_back(PackedSfenValue());
|
||||
@@ -916,7 +1019,7 @@ namespace Learner
|
||||
int write_maxply = 400;
|
||||
|
||||
// File name to write
|
||||
string output_file_name = "generated_kifu.bin";
|
||||
string output_file_name = "generated_kifu";
|
||||
|
||||
string token;
|
||||
|
||||
@@ -927,6 +1030,9 @@ namespace Learner
|
||||
// Add a random number to the end of the file name.
|
||||
bool random_file_name = false;
|
||||
|
||||
std::string sfen_format;
|
||||
std::string seed;
|
||||
|
||||
while (true)
|
||||
{
|
||||
token = "";
|
||||
@@ -980,10 +1086,26 @@ namespace Learner
|
||||
is >> detect_draw_by_consecutive_low_score;
|
||||
else if (token == "detect_draw_by_insufficient_mating_material")
|
||||
is >> detect_draw_by_insufficient_mating_material;
|
||||
else if (token == "sfen_format")
|
||||
is >> sfen_format;
|
||||
else if (token == "seed")
|
||||
is >> seed;
|
||||
else
|
||||
cout << "Error! : Illegal token " << token << endl;
|
||||
}
|
||||
|
||||
if (!sfen_format.empty())
|
||||
{
|
||||
if (sfen_format == "bin")
|
||||
sfen_output_type = SfenOutputType::Bin;
|
||||
else if (sfen_format == "binpack")
|
||||
sfen_output_type = SfenOutputType::Binpack;
|
||||
else
|
||||
{
|
||||
cout << "Unknown sfen format `" << sfen_format << "`. Using bin\n";
|
||||
}
|
||||
}
|
||||
|
||||
// If search depth2 is not set, leave it the same as search depth.
|
||||
if (search_depth_max == INT_MIN)
|
||||
search_depth_max = search_depth_min;
|
||||
@@ -994,7 +1116,7 @@ namespace Learner
|
||||
{
|
||||
// Give a random number to output_file_name at this point.
|
||||
// Do not use std::random_device(). Because it always the same integers on MinGW.
|
||||
PRNG r(std::chrono::system_clock::now().time_since_epoch().count());
|
||||
PRNG r(seed);
|
||||
// Just in case, reassign the random numbers.
|
||||
for (int i = 0; i < 10; ++i)
|
||||
r.rand(1);
|
||||
@@ -1018,6 +1140,8 @@ namespace Learner
|
||||
bookStart.push_back(line);
|
||||
}
|
||||
myfile.close();
|
||||
} else {
|
||||
bookStart.push_back(StartFEN);
|
||||
}
|
||||
}
|
||||
std::cout << "gensfen : " << endl
|
||||
@@ -1048,12 +1172,30 @@ namespace Learner
|
||||
|
||||
Threads.main()->ponder = false;
|
||||
|
||||
// About Search::Limits
|
||||
// Be careful because this member variable is global and affects other threads.
|
||||
{
|
||||
auto& limits = Search::Limits;
|
||||
|
||||
// Make the search equivalent to the "go infinite" command. (Because it is troublesome if time management is done)
|
||||
limits.infinite = true;
|
||||
|
||||
// Since PV is an obstacle when displayed, erase it.
|
||||
limits.silent = true;
|
||||
|
||||
// If you use this, it will be compared with the accumulated nodes of each thread. Therefore, do not use it.
|
||||
limits.nodes = 0;
|
||||
|
||||
// depth is also processed by the one passed as an argument of Learner::search().
|
||||
limits.depth = 0;
|
||||
}
|
||||
|
||||
// Create and execute threads as many as Options["Threads"].
|
||||
{
|
||||
SfenWriter sfen_writer(output_file_name, thread_num);
|
||||
sfen_writer.set_save_interval(save_every);
|
||||
|
||||
MultiThinkGenSfen multi_think(search_depth_min, search_depth_max, sfen_writer);
|
||||
MultiThinkGenSfen multi_think(search_depth_min, search_depth_max, sfen_writer, seed);
|
||||
multi_think.nodes = nodes;
|
||||
multi_think.set_loop_max(loop_max);
|
||||
multi_think.eval_limit = eval_limit;
|
||||
@@ -1074,7 +1216,5 @@ namespace Learner
|
||||
}
|
||||
|
||||
std::cout << "gensfen finished." << endl;
|
||||
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user