Don't unnecessarily copy the batch part.

This commit is contained in:
Tomasz Sobczyk
2020-11-25 22:59:34 +01:00
committed by nodchip
parent e954b14196
commit 0bee8fef64
5 changed files with 61 additions and 43 deletions

View File

@@ -215,27 +215,28 @@ namespace Eval::NNUE {
std::vector<double> gradient_norm_local(thread_pool.size(), 0.0);
while (examples.size() >= batch_size) {
std::vector<Example> batch(examples.end() - batch_size, examples.end());
examples.resize(examples.size() - batch_size);
const auto network_output = trainer->step_start(thread_pool, batch);
std::vector<LearnFloatType> gradients(batch.size());
auto batch_begin = examples.end() - batch_size;
auto batch_end = examples.end();
auto size = batch_end - batch_begin;
const auto network_output = trainer->step_start(thread_pool, batch_begin, batch_end);
std::vector<LearnFloatType> gradients(size);
thread_pool.for_each_index_chunk_with_workers(
std::size_t(0), batch.size(),
std::size_t(0), size,
[&](Thread& th, std::size_t offset, std::size_t count) {
const auto thread_id = th.thread_idx();
trainer->propagate(th, offset, count);
for (std::size_t b = offset; b < offset + count; ++b) {
const auto& e = *(batch_begin + b);
const auto shallow = static_cast<Value>(round<std::int32_t>(
batch[b].sign * network_output[b] * kPonanzaConstant));
const auto discrete = batch[b].sign * batch[b].discrete_nn_eval;
const auto& psv = batch[b].psv;
e.sign * network_output[b] * kPonanzaConstant));
const auto discrete = e.sign * e.discrete_nn_eval;
const auto& psv = e.psv;
const double gradient =
batch[b].sign * calc_grad(shallow, (Value)psv.score, psv.game_result, psv.gamePly);
gradients[b] = static_cast<LearnFloatType>(gradient * batch[b].weight);
e.sign * calc_grad(shallow, (Value)psv.score, psv.game_result, psv.gamePly);
gradients[b] = static_cast<LearnFloatType>(gradient * e.weight);
// The discrete eval will only be valid before first backpropagation,
@@ -256,6 +257,8 @@ namespace Eval::NNUE {
trainer->step_end(thread_pool, learning_rate);
examples.resize(examples.size() - size);
collect_stats = false;
}