mirror of
https://github.com/HChaZZY/Stockfish.git
synced 2025-12-24 19:16:49 +08:00
Pass ThreadPool to update_parameters, propagate, and backpropagate.
This commit is contained in:
@@ -18,6 +18,7 @@
|
||||
#include "uci.h"
|
||||
#include "misc.h"
|
||||
#include "thread_win32_osx.h"
|
||||
#include "thread.h"
|
||||
|
||||
// Code for learning NNUE evaluation function
|
||||
namespace Eval::NNUE {
|
||||
@@ -180,6 +181,7 @@ namespace Eval::NNUE {
|
||||
|
||||
// update the evaluation function parameters
|
||||
void update_parameters(
|
||||
ThreadPool& thread_pool,
|
||||
uint64_t epoch,
|
||||
bool verbose,
|
||||
double learning_rate,
|
||||
@@ -202,7 +204,7 @@ namespace Eval::NNUE {
|
||||
std::vector<Example> batch(examples.end() - batch_size, examples.end());
|
||||
examples.resize(examples.size() - batch_size);
|
||||
|
||||
const auto network_output = trainer->propagate(batch);
|
||||
const auto network_output = trainer->propagate(thread_pool, batch);
|
||||
|
||||
std::vector<LearnFloatType> gradients(batch.size());
|
||||
for (std::size_t b = 0; b < batch.size(); ++b) {
|
||||
@@ -226,7 +228,7 @@ namespace Eval::NNUE {
|
||||
}
|
||||
}
|
||||
|
||||
trainer->backpropagate(gradients.data(), learning_rate);
|
||||
trainer->backpropagate(thread_pool, gradients.data(), learning_rate);
|
||||
|
||||
collect_stats = false;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user