diff --git a/src/Makefile b/src/Makefile index c12a3eb6..c718ba6d 100644 --- a/src/Makefile +++ b/src/Makefile @@ -548,16 +548,16 @@ icc-profile-use: all nnue: config-sanity - $(MAKE) CXXFLAGS='$(CXXFLAGS) -DEVAL_NNUE -DUSE_EVAL_HASH -DUSE_AVX2 -DUSE_SSE2 -fopenmp' LDFLAGS='$(LDFLAGS) -fopenmp' build + $(MAKE) CXXFLAGS='$(CXXFLAGS) -DEVAL_NNUE -DUSE_EVAL_HASH -DUSE_AVX2 -DUSE_SSE2 -DENABLE_TEST_CMD -fopenmp' LDFLAGS='$(LDFLAGS) -fopenmp' build nnue-gen-sfen-from-original-eval: config-sanity - $(MAKE) CXXFLAGS='$(CXXFLAGS) -DEVAL_LEARN -DUSE_EVAL_HASH -DUSE_AVX2 -DUSE_SSE2 -fopenmp' LDFLAGS='$(LDFLAGS) -fopenmp' build + $(MAKE) CXXFLAGS='$(CXXFLAGS) -DEVAL_LEARN -DUSE_EVAL_HASH -DUSE_AVX2 -DUSE_SSE2 -DENABLE_TEST_CMD -fopenmp' LDFLAGS='$(LDFLAGS) -fopenmp' build nnue-learn: config-sanity - $(MAKE) CXXFLAGS='$(CXXFLAGS) -DEVAL_LEARN -DEVAL_NNUE -DUSE_EVAL_HASH -DUSE_AVX2 -DUSE_SSE2 -fopenmp' LDFLAGS='$(LDFLAGS) -fopenmp' build + $(MAKE) CXXFLAGS='$(CXXFLAGS) -DEVAL_LEARN -DEVAL_NNUE -DUSE_EVAL_HASH -DUSE_AVX2 -DUSE_SSE2 -DENABLE_TEST_CMD -fopenmp' LDFLAGS='$(LDFLAGS) -fopenmp' build nnue-learn-use-blas: config-sanity - $(MAKE) CXXFLAGS='$(CXXFLAGS) -DEVAL_LEARN -DEVAL_NNUE -DUSE_EVAL_HASH -DUSE_AVX2 -DUSE_SSE2 -DUSE_BLAS -I/mingw64/include/OpenBLAS -fopenmp' LDFLAGS='$(LDFLAGS) -lopenblas -fopenmp' build + $(MAKE) CXXFLAGS='$(CXXFLAGS) -DEVAL_LEARN -DEVAL_NNUE -DUSE_EVAL_HASH -DUSE_AVX2 -DUSE_SSE2 -DENABLE_TEST_CMD -DUSE_BLAS -I/mingw64/include/OpenBLAS -fopenmp' LDFLAGS='$(LDFLAGS) -lopenblas -fopenmp' build .depend: -@$(CXX) $(DEPENDFLAGS) -MM $(OBJS:.o=.cpp) > $@ 2> /dev/null diff --git a/src/eval/nnue/nnue_test_command.cpp b/src/eval/nnue/nnue_test_command.cpp index a2618b3b..28e44273 100644 --- a/src/eval/nnue/nnue_test_command.cpp +++ b/src/eval/nnue/nnue_test_command.cpp @@ -2,11 +2,16 @@ #if defined(ENABLE_TEST_CMD) && defined(EVAL_NNUE) -#include "../../extra/all.h" +#include "../../thread.h" +#include "../../uci.h" #include "evaluate_nnue.h" #include "nnue_test_command.h" #include +#include + +#define ASSERT(X) { if (!(X)) { std::cout << "\nError : ASSERT(" << #X << "), " << __FILE__ << "(" << __LINE__ << "): " << __func__ << std::endl; \ + std::this_thread::sleep_for(std::chrono::microseconds(3000)); *(int*)1 =0;} } namespace Eval { @@ -18,7 +23,7 @@ namespace { void TestFeatures(Position& pos) { const std::uint64_t num_games = 1000; StateInfo si; - pos.set_hirate(&si,Threads.main()); + pos.set(StartFEN, false, &si, Threads.main()); const int MAX_PLY = 256; // 256手までテスト StateInfo state[MAX_PLY]; // StateInfoを最大手数分だけ @@ -38,7 +43,7 @@ void TestFeatures(Position& pos) { Features::IndexList active_indices[2]; RawFeatures::AppendActiveIndices(pos, kRefreshTriggers[i], active_indices); - for (const auto perspective : COLOR) { + for (const auto perspective : Colors) { for (const auto index : active_indices[perspective]) { ASSERT(index < RawFeatures::kDimensions); ASSERT(index_sets[i][perspective].count(index) == 0); @@ -56,7 +61,7 @@ void TestFeatures(Position& pos) { bool reset[2]; RawFeatures::AppendChangedIndices(pos, kRefreshTriggers[i], removed_indices, added_indices, reset); - for (const auto perspective : COLOR) { + for (const auto perspective : Colors) { if (reset[perspective]) { (*index_sets)[i][perspective].clear(); ++num_resets[i]; @@ -91,7 +96,7 @@ void TestFeatures(Position& pos) { for (std::uint64_t i = 0; i < num_games; ++i) { auto index_sets = make_index_sets(pos); for (ply = 0; ply < MAX_PLY; ++ply) { - MoveList mg(pos); // 全合法手の生成 + MoveList mg(pos); // 全合法手の生成 // 合法な指し手がなかった == 詰み if (mg.size() == 0) @@ -106,7 +111,7 @@ void TestFeatures(Position& pos) { ASSERT(index_sets == make_index_sets(pos)); } - pos.set_hirate(&si,Threads.main()); + pos.set(StartFEN, false, &si, Threads.main()); // 100回に1回ごとに'.'を出力(進んでいることがわかるように) if ((i % 100) == 0) @@ -184,8 +189,8 @@ void TestCommand(Position& pos, std::istream& stream) { PrintInfo(stream); } else { std::cout << "usage:" << std::endl; - std::cout << " test nn test_features" << std::endl; - std::cout << " test nn info [path/to/" << kFileName << "...]" << std::endl; + std::cout << " test nnue test_features" << std::endl; + std::cout << " test nnue info [path/to/" << kFileName << "...]" << std::endl; } } diff --git a/src/eval/nnue/nnue_test_command.h b/src/eval/nnue/nnue_test_command.h index bf5894c9..10f57f6c 100644 --- a/src/eval/nnue/nnue_test_command.h +++ b/src/eval/nnue/nnue_test_command.h @@ -3,8 +3,6 @@ #ifndef _NNUE_TEST_COMMAND_H_ #define _NNUE_TEST_COMMAND_H_ -#include "../../config.h" - #if defined(ENABLE_TEST_CMD) && defined(EVAL_NNUE) namespace Eval { diff --git a/src/uci.cpp b/src/uci.cpp index 19af09a0..d4178879 100644 --- a/src/uci.cpp +++ b/src/uci.cpp @@ -33,6 +33,10 @@ #include "uci.h" #include "syzygy/tbprobe.h" +#if defined(EVAL_NNUE) && defined(ENABLE_TEST_CMD) +#include "eval/nnue/nnue_test_command.h" +#endif + using namespace std; extern vector setup_bench(const Position&, istream&); @@ -64,6 +68,19 @@ namespace Learner } #endif +#if defined(EVAL_NNUE) && defined(ENABLE_TEST_CMD) +void test_cmd(Position& pos, istringstream& is) +{ + // T邩mȂ̂ŏĂB + is_ready(); + + std::string param; + is >> param; + + if (param == "nnue") Eval::NNUE::TestCommand(pos, is); +} +#endif + namespace { // position() is called when engine receives the "position" UCI command. // The function sets up the position described in the given FEN string ("fen") @@ -376,6 +393,12 @@ void UCI::loop(int argc, char* argv[]) { else if (token == "search") search_cmd(pos, is); #endif + +#if defined(EVAL_NNUE) && defined(ENABLE_TEST_CMD) + // eXgR}h + else if (token == "test") test_cmd(pos, is); +#endif + else sync_cout << "Unknown command: " << cmd << sync_endl;