From 880d23af1c551e9122e95cd52c9aa155bfe11a38 Mon Sep 17 00:00:00 2001 From: Tomasz Sobczyk Date: Wed, 14 Oct 2020 19:44:15 +0200 Subject: [PATCH] Move sfen input/output streams to sfen_stream.h --- src/learn/gensfen.cpp | 100 +------------------ src/learn/learn.cpp | 112 +-------------------- src/learn/sfen_stream.h | 213 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 217 insertions(+), 208 deletions(-) create mode 100644 src/learn/sfen_stream.h diff --git a/src/learn/gensfen.cpp b/src/learn/gensfen.cpp index 7b135b81..4a6f26dc 100644 --- a/src/learn/gensfen.cpp +++ b/src/learn/gensfen.cpp @@ -2,6 +2,7 @@ #include "packed_sfen.h" #include "multi_think.h" +#include "sfen_stream.h" #include "../syzygy/tbprobe.h" #include "misc.h" @@ -38,107 +39,12 @@ using namespace std; namespace Learner { - enum struct SfenOutputType - { - Bin, - Binpack - }; - static bool write_out_draw_game_in_training_data_generation = true; static bool detect_draw_by_consecutive_low_score = true; static bool detect_draw_by_insufficient_mating_material = true; static SfenOutputType sfen_output_type = SfenOutputType::Bin; - static bool ends_with(const std::string& lhs, const std::string& end) - { - if (end.size() > lhs.size()) return false; - - return std::equal(end.rbegin(), end.rend(), lhs.rbegin()); - } - - static std::string filename_with_extension(const std::string& filename, const std::string& ext) - { - if (ends_with(filename, ext)) - { - return filename; - } - else - { - return filename + "." + ext; - } - } - - struct BasicSfenOutputStream - { - virtual void write(const PSVector& sfens) = 0; - virtual ~BasicSfenOutputStream() {} - }; - - struct BinSfenOutputStream : BasicSfenOutputStream - { - static constexpr auto openmode = ios::out | ios::binary | ios::app; - static inline const std::string extension = "bin"; - - BinSfenOutputStream(std::string filename) : - m_stream(filename_with_extension(filename, extension), openmode) - { - } - - void write(const PSVector& sfens) override - { - m_stream.write(reinterpret_cast(sfens.data()), sizeof(PackedSfenValue) * sfens.size()); - } - - ~BinSfenOutputStream() override {} - - private: - fstream m_stream; - }; - - struct BinpackSfenOutputStream : BasicSfenOutputStream - { - static constexpr auto openmode = ios::out | ios::binary | ios::app; - static inline const std::string extension = "binpack"; - - BinpackSfenOutputStream(std::string filename) : - m_stream(filename_with_extension(filename, extension), openmode) - { - } - - void write(const PSVector& sfens) override - { - static_assert(sizeof(binpack::nodchip::PackedSfenValue) == sizeof(PackedSfenValue)); - - for(auto& sfen : sfens) - { - // The library uses a type that's different but layout-compatibile. - binpack::nodchip::PackedSfenValue e; - std::memcpy(&e, &sfen, sizeof(binpack::nodchip::PackedSfenValue)); - m_stream.addTrainingDataEntry(binpack::packedSfenValueToTrainingDataEntry(e)); - } - } - - ~BinpackSfenOutputStream() override {} - - private: - binpack::CompressedTrainingDataEntryWriter m_stream; - }; - - static std::unique_ptr create_new_sfen_output(const std::string& filename) - { - switch(sfen_output_type) - { - case SfenOutputType::Bin: - return std::make_unique(filename); - case SfenOutputType::Binpack: - return std::make_unique(filename); - } - - assert(false); - return nullptr; - } - // Helper class for exporting Sfen struct SfenWriter { @@ -155,7 +61,7 @@ namespace Learner sfen_buffers_pool.reserve((size_t)thread_num * 10); sfen_buffers.resize(thread_num); - output_file_stream = create_new_sfen_output(filename_); + output_file_stream = create_new_sfen_output(filename_, sfen_output_type); filename = filename_; finished = false; @@ -283,7 +189,7 @@ namespace Learner // Add ios::app in consideration of overwriting. // (Depending on the operation, it may not be necessary.) string new_filename = filename + "_" + std::to_string(n); - output_file_stream = create_new_sfen_output(new_filename); + output_file_stream = create_new_sfen_output(new_filename, sfen_output_type); cout << endl << "output sfen file = " << new_filename << endl; } diff --git a/src/learn/learn.cpp b/src/learn/learn.cpp index 452bd15f..6c865d98 100644 --- a/src/learn/learn.cpp +++ b/src/learn/learn.cpp @@ -21,6 +21,7 @@ #include "convert.h" #include "multi_think.h" +#include "sfen_stream.h" #include "misc.h" #include "position.h" @@ -30,8 +31,6 @@ #include "search.h" #include "timeman.h" -#include "extra/nnue_data_binpack_format.h" - #include "nnue/evaluate_nnue.h" #include "nnue/evaluate_nnue_learner.h" @@ -286,115 +285,6 @@ namespace Learner return calc_grad((Value)psv.score, shallow, psv); } - struct BasicSfenInputStream - { - virtual std::optional next() = 0; - virtual bool eof() const = 0; - virtual ~BasicSfenInputStream() {} - }; - - struct BinSfenInputStream : BasicSfenInputStream - { - static constexpr auto openmode = ios::in | ios::binary; - static inline const std::string extension = "bin"; - - BinSfenInputStream(std::string filename) : - m_stream(filename, openmode), - m_eof(!m_stream) - { - } - - std::optional next() override - { - PackedSfenValue e; - if(m_stream.read(reinterpret_cast(&e), sizeof(PackedSfenValue))) - { - return e; - } - else - { - m_eof = true; - return std::nullopt; - } - } - - bool eof() const override - { - return m_eof; - } - - ~BinSfenInputStream() override {} - - private: - fstream m_stream; - bool m_eof; - }; - - struct BinpackSfenInputStream : BasicSfenInputStream - { - static constexpr auto openmode = ios::in | ios::binary; - static inline const std::string extension = "binpack"; - - BinpackSfenInputStream(std::string filename) : - m_stream(filename, openmode), - m_eof(!m_stream.hasNext()) - { - } - - std::optional next() override - { - static_assert(sizeof(binpack::nodchip::PackedSfenValue) == sizeof(PackedSfenValue)); - - if (!m_stream.hasNext()) - { - m_eof = true; - return std::nullopt; - } - - auto training_data_entry = m_stream.next(); - auto v = binpack::trainingDataEntryToPackedSfenValue(training_data_entry); - PackedSfenValue psv; - // same layout, different types. One is from generic library. - std::memcpy(&psv, &v, sizeof(PackedSfenValue)); - - return psv; - } - - bool eof() const override - { - return m_eof; - } - - ~BinpackSfenInputStream() override {} - - private: - binpack::CompressedTrainingDataEntryReader m_stream; - bool m_eof; - }; - - static bool ends_with(const std::string& lhs, const std::string& end) - { - if (end.size() > lhs.size()) return false; - - return std::equal(end.rbegin(), end.rend(), lhs.rbegin()); - } - - static bool has_extension(const std::string& filename, const std::string& extension) - { - return ends_with(filename, "." + extension); - } - - static std::unique_ptr open_sfen_input_file(const std::string& filename) - { - if (has_extension(filename, BinSfenInputStream::extension)) - return std::make_unique(filename); - else if (has_extension(filename, BinpackSfenInputStream::extension)) - return std::make_unique(filename); - - assert(false); - return nullptr; - } - // Sfen reader struct SfenReader { diff --git a/src/learn/sfen_stream.h b/src/learn/sfen_stream.h new file mode 100644 index 00000000..4d44901b --- /dev/null +++ b/src/learn/sfen_stream.h @@ -0,0 +1,213 @@ +#ifndef _SFEN_STREAM_H_ +#define _SFEN_STREAM_H_ + +#include "packed_sfen.h" + +#include "extra/nnue_data_binpack_format.h" + +#include +#include +#include +#include + +namespace Learner { + + enum struct SfenOutputType + { + Bin, + Binpack + }; + + static bool ends_with(const std::string& lhs, const std::string& end) + { + if (end.size() > lhs.size()) return false; + + return std::equal(end.rbegin(), end.rend(), lhs.rbegin()); + } + + static bool has_extension(const std::string& filename, const std::string& extension) + { + return ends_with(filename, "." + extension); + } + + static std::string filename_with_extension(const std::string& filename, const std::string& ext) + { + if (ends_with(filename, ext)) + { + return filename; + } + else + { + return filename + "." + ext; + } + } + + struct BasicSfenInputStream + { + virtual std::optional next() = 0; + virtual bool eof() const = 0; + virtual ~BasicSfenInputStream() {} + }; + + struct BinSfenInputStream : BasicSfenInputStream + { + static constexpr auto openmode = std::ios::in | std::ios::binary; + static inline const std::string extension = "bin"; + + BinSfenInputStream(std::string filename) : + m_stream(filename, openmode), + m_eof(!m_stream) + { + } + + std::optional next() override + { + PackedSfenValue e; + if(m_stream.read(reinterpret_cast(&e), sizeof(PackedSfenValue))) + { + return e; + } + else + { + m_eof = true; + return std::nullopt; + } + } + + bool eof() const override + { + return m_eof; + } + + ~BinSfenInputStream() override {} + + private: + std::fstream m_stream; + bool m_eof; + }; + + struct BinpackSfenInputStream : BasicSfenInputStream + { + static constexpr auto openmode = std::ios::in | std::ios::binary; + static inline const std::string extension = "binpack"; + + BinpackSfenInputStream(std::string filename) : + m_stream(filename, openmode), + m_eof(!m_stream.hasNext()) + { + } + + std::optional next() override + { + static_assert(sizeof(binpack::nodchip::PackedSfenValue) == sizeof(PackedSfenValue)); + + if (!m_stream.hasNext()) + { + m_eof = true; + return std::nullopt; + } + + auto training_data_entry = m_stream.next(); + auto v = binpack::trainingDataEntryToPackedSfenValue(training_data_entry); + PackedSfenValue psv; + // same layout, different types. One is from generic library. + std::memcpy(&psv, &v, sizeof(PackedSfenValue)); + + return psv; + } + + bool eof() const override + { + return m_eof; + } + + ~BinpackSfenInputStream() override {} + + private: + binpack::CompressedTrainingDataEntryReader m_stream; + bool m_eof; + }; + + struct BasicSfenOutputStream + { + virtual void write(const PSVector& sfens) = 0; + virtual ~BasicSfenOutputStream() {} + }; + + struct BinSfenOutputStream : BasicSfenOutputStream + { + static constexpr auto openmode = std::ios::out | std::ios::binary | std::ios::app; + static inline const std::string extension = "bin"; + + BinSfenOutputStream(std::string filename) : + m_stream(filename_with_extension(filename, extension), openmode) + { + } + + void write(const PSVector& sfens) override + { + m_stream.write(reinterpret_cast(sfens.data()), sizeof(PackedSfenValue) * sfens.size()); + } + + ~BinSfenOutputStream() override {} + + private: + std::fstream m_stream; + }; + + struct BinpackSfenOutputStream : BasicSfenOutputStream + { + static constexpr auto openmode = std::ios::out | std::ios::binary | std::ios::app; + static inline const std::string extension = "binpack"; + + BinpackSfenOutputStream(std::string filename) : + m_stream(filename_with_extension(filename, extension), openmode) + { + } + + void write(const PSVector& sfens) override + { + static_assert(sizeof(binpack::nodchip::PackedSfenValue) == sizeof(PackedSfenValue)); + + for(auto& sfen : sfens) + { + // The library uses a type that's different but layout-compatibile. + binpack::nodchip::PackedSfenValue e; + std::memcpy(&e, &sfen, sizeof(binpack::nodchip::PackedSfenValue)); + m_stream.addTrainingDataEntry(binpack::packedSfenValueToTrainingDataEntry(e)); + } + } + + ~BinpackSfenOutputStream() override {} + + private: + binpack::CompressedTrainingDataEntryWriter m_stream; + }; + + inline std::unique_ptr open_sfen_input_file(const std::string& filename) + { + if (has_extension(filename, BinSfenInputStream::extension)) + return std::make_unique(filename); + else if (has_extension(filename, BinpackSfenInputStream::extension)) + return std::make_unique(filename); + + assert(false); + return nullptr; + } + + inline std::unique_ptr create_new_sfen_output(const std::string& filename, SfenOutputType sfen_output_type) + { + switch(sfen_output_type) + { + case SfenOutputType::Bin: + return std::make_unique(filename); + case SfenOutputType::Binpack: + return std::make_unique(filename); + } + + assert(false); + return nullptr; + } +} + +#endif \ No newline at end of file