Merge remote-tracking branch 'remotes/origin/master' into trainer

This commit is contained in:
noobpwnftw
2020-09-23 18:43:54 +08:00
19 changed files with 304 additions and 107 deletions

View File

@@ -30,7 +30,7 @@
namespace Eval::NNUE {
uint32_t kpp_board_index[PIECE_NB][COLOR_NB] = {
const uint32_t kpp_board_index[PIECE_NB][COLOR_NB] = {
// convention: W - us, B - them
// viewed from other side, W and B are reversed
{ PS_NONE, PS_NONE },
@@ -52,7 +52,7 @@ namespace Eval::NNUE {
};
// Input feature converter
AlignedPtr<FeatureTransformer> feature_transformer;
LargePagePtr<FeatureTransformer> feature_transformer;
// Evaluation function
AlignedPtr<Network> network;
@@ -79,14 +79,22 @@ namespace Eval::NNUE {
std::memset(pointer.get(), 0, sizeof(T));
}
template <typename T>
void Initialize(LargePagePtr<T>& pointer) {
static_assert(alignof(T) <= 4096, "aligned_large_pages_alloc() may fail for such a big alignment requirement of T");
pointer.reset(reinterpret_cast<T*>(aligned_large_pages_alloc(sizeof(T))));
std::memset(pointer.get(), 0, sizeof(T));
}
// Read evaluation function parameters
template <typename T>
bool ReadParameters(std::istream& stream, const AlignedPtr<T>& pointer) {
bool ReadParameters(std::istream& stream, T& reference) {
std::uint32_t header;
header = read_little_endian<std::uint32_t>(stream);
if (!stream || header != T::GetHashValue()) return false;
return pointer->ReadParameters(stream);
return reference.ReadParameters(stream);
}
// write evaluation function parameters
@@ -97,6 +105,13 @@ namespace Eval::NNUE {
return pointer->WriteParameters(stream);
}
template <typename T>
bool WriteParameters(std::ostream& stream, const LargePagePtr<T>& pointer) {
constexpr std::uint32_t header = T::GetHashValue();
stream.write(reinterpret_cast<const char*>(&header), sizeof(header));
return pointer->WriteParameters(stream);
}
} // namespace Detail
// Initialize the evaluation function parameters
@@ -138,8 +153,8 @@ namespace Eval::NNUE {
std::string architecture;
if (!ReadHeader(stream, &hash_value, &architecture)) return false;
if (hash_value != kHashValue) return false;
if (!Detail::ReadParameters(stream, feature_transformer)) return false;
if (!Detail::ReadParameters(stream, network)) return false;
if (!Detail::ReadParameters(stream, *feature_transformer)) return false;
if (!Detail::ReadParameters(stream, *network)) return false;
return stream && stream.peek() == std::ios::traits_type::eof();
}
// write evaluation function parameters
@@ -162,7 +177,7 @@ namespace Eval::NNUE {
}
// Load eval, from a file stream or a memory stream
bool load_eval(std::string streamName, std::istream& stream) {
bool load_eval(std::string name, std::istream& stream) {
Initialize();
@@ -171,7 +186,7 @@ namespace Eval::NNUE {
std::cout << "info string SkipLoadingEval set to true, Net not loaded!" << std::endl;
return true;
}
fileName = streamName;
fileName = name;
return ReadParameters(stream);
}

View File

@@ -40,11 +40,22 @@ namespace Eval::NNUE {
}
};
template <typename T>
struct LargePageDeleter {
void operator()(T* ptr) const {
ptr->~T();
aligned_large_pages_free(ptr);
}
};
template <typename T>
using AlignedPtr = std::unique_ptr<T, AlignedDeleter<T>>;
template <typename T>
using LargePagePtr = std::unique_ptr<T, LargePageDeleter<T>>;
// Input feature converter
extern AlignedPtr<FeatureTransformer> feature_transformer;
extern LargePagePtr<FeatureTransformer> feature_transformer;
// Evaluation function
extern AlignedPtr<Network> network;

View File

@@ -113,7 +113,7 @@ namespace Eval::NNUE {
PS_END2 = 12 * SQUARE_NB + 1
};
extern uint32_t kpp_board_index[PIECE_NB][COLOR_NB];
extern const uint32_t kpp_board_index[PIECE_NB][COLOR_NB];
// Type of input feature after conversion
using TransformedFeatureType = std::uint8_t;