mirror of
https://github.com/HChaZZY/Stockfish.git
synced 2025-12-24 19:16:49 +08:00
Parallelize input slice trainer backprop.
This commit is contained in:
@@ -236,17 +236,29 @@ namespace Eval::NNUE {
|
||||
const LearnFloatType* gradients,
|
||||
LearnFloatType learning_rate) {
|
||||
|
||||
for (IndexType b = 0; b < batch_size_; ++b) {
|
||||
const IndexType input_offset = kInputDimensions * b;
|
||||
const IndexType output_offset = kOutputDimensions * b;
|
||||
for (IndexType i = 0; i < kInputDimensions; ++i) {
|
||||
if ((int)i < (int)Offset || i >= Offset + kOutputDimensions) {
|
||||
thread_pool.for_each_index_with_workers(
|
||||
0, batch_size_,
|
||||
[&](Thread&, int b) {
|
||||
const IndexType input_offset = kInputDimensions * b;
|
||||
const IndexType output_offset = kOutputDimensions * b;
|
||||
|
||||
IndexType i = 0;
|
||||
for (; i < Offset; ++i) {
|
||||
gradients_[input_offset + i] = static_cast<LearnFloatType>(0.0);
|
||||
} else {
|
||||
}
|
||||
|
||||
for (; i < Offset + kOutputDimensions; ++i) {
|
||||
gradients_[input_offset + i] = gradients[output_offset + i - Offset];
|
||||
}
|
||||
|
||||
for (; i < kInputDimensions; ++i)
|
||||
{
|
||||
gradients_[input_offset + i] = static_cast<LearnFloatType>(0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
);
|
||||
thread_pool.wait_for_workers_finished();
|
||||
|
||||
shared_input_trainer_->backpropagate(thread_pool, gradients_.data(), learning_rate);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user