Pass ThreadPool to update_parameters, propagate, and backpropagate.

This commit is contained in:
Tomasz Sobczyk
2020-10-26 15:06:15 +01:00
committed by nodchip
parent f1e96cab55
commit ee0917a345
8 changed files with 53 additions and 28 deletions

View File

@@ -18,6 +18,7 @@
#include "uci.h"
#include "misc.h"
#include "thread_win32_osx.h"
#include "thread.h"
// Code for learning NNUE evaluation function
namespace Eval::NNUE {
@@ -180,6 +181,7 @@ namespace Eval::NNUE {
// update the evaluation function parameters
void update_parameters(
ThreadPool& thread_pool,
uint64_t epoch,
bool verbose,
double learning_rate,
@@ -202,7 +204,7 @@ namespace Eval::NNUE {
std::vector<Example> batch(examples.end() - batch_size, examples.end());
examples.resize(examples.size() - batch_size);
const auto network_output = trainer->propagate(batch);
const auto network_output = trainer->propagate(thread_pool, batch);
std::vector<LearnFloatType> gradients(batch.size());
for (std::size_t b = 0; b < batch.size(); ++b) {
@@ -226,7 +228,7 @@ namespace Eval::NNUE {
}
}
trainer->backpropagate(gradients.data(), learning_rate);
trainer->backpropagate(thread_pool, gradients.data(), learning_rate);
collect_stats = false;
}