Merged the training data generator and the machine learning logic from YaneuraOu.

This commit is contained in:
Hisayori Noda
2019-06-18 08:48:05 +09:00
parent 87445881ec
commit bcd6985871
37 changed files with 6306 additions and 139 deletions

View File

@@ -42,6 +42,7 @@ typedef bool(*fun3_t)(HANDLE, CONST GROUP_AFFINITY*, PGROUP_AFFINITY);
#endif
#include <fstream>
#include <functional>
#include <iomanip>
#include <iostream>
#include <sstream>
@@ -316,6 +317,27 @@ void bindThisThread(size_t idx) {
} // namespace WinProcGroup
// 現在時刻を文字列化したもを返す。(評価関数の学習時などに用いる)
std::string now_string()
{
// std::ctime(), localtime()を使うと、MSVCでセキュアでないという警告が出る。
// C++標準的にはそんなことないはずなのだが…。
#if defined(_MSC_VER)
// C4996 : 'ctime' : This function or variable may be unsafe.Consider using ctime_s instead.
#pragma warning(disable : 4996)
#endif
auto now = std::chrono::system_clock::now();
auto tp = std::chrono::system_clock::to_time_t(now);
auto result = string(std::ctime(&tp));
// 末尾に改行コードが含まれているならこれを除去する
while (*result.rbegin() == '\n' || (*result.rbegin() == '\r'))
result.pop_back();
return result;
}
void sleep(int ms)
{
std::this_thread::sleep_for(std::chrono::milliseconds(ms));
@@ -331,3 +353,127 @@ void* aligned_malloc(size_t size, size_t align)
}
return p;
}
int read_file_to_memory(std::string filename, std::function<void* (uint64_t)> callback_func)
{
fstream fs(filename, ios::in | ios::binary);
if (fs.fail())
return 1;
fs.seekg(0, fstream::end);
uint64_t eofPos = (uint64_t)fs.tellg();
fs.clear(); // これをしないと次のseekに失敗することがある。
fs.seekg(0, fstream::beg);
uint64_t begPos = (uint64_t)fs.tellg();
uint64_t file_size = eofPos - begPos;
//std::cout << "filename = " << filename << " , file_size = " << file_size << endl;
// ファイルサイズがわかったのでcallback_funcを呼び出してこの分のバッファを確保してもらい、
// そのポインターをもらう。
void* ptr = callback_func(file_size);
// バッファが確保できなかった場合や、想定していたファイルサイズと異なった場合は、
// nullptrを返すことになっている。このとき、読み込みを中断し、エラーリターンする。
if (ptr == nullptr)
return 2;
// 細切れに読み込む
const uint64_t block_size = 1024 * 1024 * 1024; // 1回のreadで読み込む要素の数(1GB)
for (uint64_t pos = 0; pos < file_size; pos += block_size)
{
// 今回読み込むサイズ
uint64_t read_size = (pos + block_size < file_size) ? block_size : (file_size - pos);
fs.read((char*)ptr + pos, read_size);
// ファイルの途中で読み込みエラーに至った。
if (fs.fail())
return 2;
//cout << ".";
}
fs.close();
return 0;
}
int write_memory_to_file(std::string filename, void* ptr, uint64_t size)
{
fstream fs(filename, ios::out | ios::binary);
if (fs.fail())
return 1;
const uint64_t block_size = 1024 * 1024 * 1024; // 1回のwriteで書き出す要素の数(1GB)
for (uint64_t pos = 0; pos < size; pos += block_size)
{
// 今回書き出すメモリサイズ
uint64_t write_size = (pos + block_size < size) ? block_size : (size - pos);
fs.write((char*)ptr + pos, write_size);
//cout << ".";
}
fs.close();
return 0;
}
// ----------------------------
// mkdir wrapper
// ----------------------------
// カレントフォルダ相対で指定する。成功すれば0、失敗すれば非0が返る。
// フォルダを作成する。日本語は使っていないものとする。
// どうもmsys2環境下のgccだと_wmkdir()だとフォルダの作成に失敗する。原因不明。
// 仕方ないので_mkdir()を用いる。
#if defined(_WIN32)
// Windows用
#if defined(_MSC_VER)
#include <codecvt> // mkdirするのにwstringが欲しいのでこれが必要
#include <locale> // wstring_convertにこれが必要。
namespace Dependency {
int mkdir(std::string dir_name)
{
std::wstring_convert<std::codecvt_utf8<wchar_t>, wchar_t> cv;
return _wmkdir(cv.from_bytes(dir_name).c_str());
// ::CreateDirectory(cv.from_bytes(dir_name).c_str(),NULL);
}
}
#elif defined(__GNUC__)
#include <direct.h>
namespace Dependency {
int mkdir(std::string dir_name)
{
return _mkdir(dir_name.c_str());
}
}
#endif
#elif defined(_LINUX)
// linux環境において、この_LINUXというシンボルはmakefileにて定義されるものとする。
// Linux用のmkdir実装。
#include "sys/stat.h"
namespace Dependency {
int mkdir(std::string dir_name)
{
return ::mkdir(dir_name.c_str(), 0777);
}
}
#else
// Linux環境かどうかを判定するためにはmakefileを分けないといけなくなってくるな..
// linuxでフォルダ掘る機能は、とりあえずナシでいいや..。評価関数ファイルの保存にしか使ってないし…。
namespace Dependency {
int mkdir(std::string dir_name)
{
return 0;
}
}
#endif