diff --git a/src/thread.cpp b/src/thread.cpp index c81ac43d..e4226769 100644 --- a/src/thread.cpp +++ b/src/thread.cpp @@ -80,6 +80,13 @@ void Thread::start_searching() { cv.notify_one(); // Wake up the thread in idle_loop() } +void Thread::execute_task(std::function t) +{ + std::lock_guard lk(mutex); + task = std::move(t); + cv.notify_one(); // Wake up the thread in idle_loop() +} + /// Thread::wait_for_search_finished() blocks on the condition variable /// until the thread has finished searching. @@ -109,14 +116,22 @@ void Thread::idle_loop() { std::unique_lock lk(mutex); searching = false; cv.notify_one(); // Wake up anyone waiting for search finished - cv.wait(lk, [&]{ return searching; }); + cv.wait(lk, [&]{ return searching || task; }); if (exit) return; lk.unlock(); - search(); + if (task) + { + task(*this); + task = nullptr; + } + else + { + search(); + } } } @@ -162,6 +177,14 @@ void ThreadPool::clear() { } +void ThreadPool::execute_parallel(std::function task) +{ + for(Thread* th : *this) + { + th->execute_task(task); + } +} + /// ThreadPool::start_thinking() wakes up main thread waiting in idle_loop() and /// returns immediately. Main thread will wake up other threads and start the search. diff --git a/src/thread.h b/src/thread.h index 501a6042..8e9e6fba 100644 --- a/src/thread.h +++ b/src/thread.h @@ -24,6 +24,7 @@ #include #include #include +#include #include "material.h" #include "movepick.h" @@ -50,10 +51,12 @@ public: explicit Thread(size_t); virtual ~Thread(); virtual void search(); + virtual void execute_task(std::function t); void clear(); void idle_loop(); void start_searching(); void wait_for_search_finished(); + size_t thread_idx() const { return idx; } Pawns::Table pawnsTable; Material::Table materialTable; @@ -78,6 +81,7 @@ public: bool UseRule50; Depth ProbeDepth; + std::function task; }; @@ -105,6 +109,8 @@ struct MainThread : public Thread { struct ThreadPool : public std::vector { + void execute_parallel(std::function task); + void start_thinking(Position&, StateListPtr&, const Search::LimitsType&, bool = false); void clear(); void set(size_t); diff --git a/src/uci.cpp b/src/uci.cpp index b5a0524c..1aa9f95e 100644 --- a/src/uci.cpp +++ b/src/uci.cpp @@ -345,6 +345,12 @@ void UCI::loop(int argc, char* argv[]) { // Command to call qsearch(),search() directly for testing else if (token == "qsearch") qsearch_cmd(pos); else if (token == "search") search_cmd(pos, is); + else if (token == "tasktest") + { + Threads.execute_parallel([](auto& th) { + std::cout << th.thread_idx() << '\n'; + }); + } // test command else if (token == "test") test_cmd(pos, is);