diff --git a/src/nnue/trainer/trainer_input_slice.h b/src/nnue/trainer/trainer_input_slice.h index a94cae93..62a761a7 100644 --- a/src/nnue/trainer/trainer_input_slice.h +++ b/src/nnue/trainer/trainer_input_slice.h @@ -15,6 +15,19 @@ namespace Eval::NNUE { // Learning: Input layer + // This is tricky. It exists because when there's more than one trainer + // on top of a single feature transformer we want to only call propagate/backpropagate + // on the feature transformer once. This is straightforward in the old + // multithreading case, because propagate/backpropagate is called just once from the + // main thread. But with the current implementation of coarser multithreading + // we end up calling each method from each thread. Therefore we have to keep + // the num_calls and current_operation per thread basis, each thread must work + // on its designated batch slice, and the only synchronization points are + // step_start and step_end - for which we use state of the first thread. + // Each thread requires their own bookkeeping because it's possible that + // one thread is still in propagate of some batch slice while the other thread + // is doing backpropagate of some other slice. We also ensure the thread state + // isn't suspectible to false sharing by using a full cache line for the state. class SharedInputTrainer { public: // factory function @@ -34,32 +47,36 @@ namespace Eval::NNUE { // Set options such as hyperparameters void send_message(Message* message) { - if (num_calls_[0] == 0) { - current_operation_ = Operation::kSendMessage; + auto& thread_state = thread_states_[0]; + + if (thread_state.num_calls == 0) { + thread_state.current_operation = Operation::kSendMessage; feature_transformer_trainer_->send_message(message); } - assert(current_operation_ == Operation::kSendMessage); + assert(thread_state.current_operation == Operation::kSendMessage); - if (++num_calls_[0] == num_referrers_) { - num_calls_[0] = 0; - current_operation_ = Operation::kNone; + if (++thread_state.num_calls == num_referrers_) { + thread_state.num_calls = 0; + thread_state.current_operation = Operation::kNone; } } // Initialize the parameters with random numbers template void initialize(RNG& rng) { - if (num_calls_[0] == 0) { - current_operation_ = Operation::kInitialize; + auto& thread_state = thread_states_[0]; + + if (thread_state.num_calls == 0) { + thread_state.current_operation = Operation::kInitialize; feature_transformer_trainer_->initialize(rng); } - assert(current_operation_ == Operation::kInitialize); + assert(thread_state.current_operation == Operation::kInitialize); - if (++num_calls_[0] == num_referrers_) { - num_calls_[0] = 0; - current_operation_ = Operation::kNone; + if (++thread_state.num_calls == num_referrers_) { + thread_state.num_calls = 0; + thread_state.current_operation = Operation::kNone; } } @@ -71,23 +88,25 @@ namespace Eval::NNUE { gradients_.resize(kInputDimensions * size); } - if (num_calls_.size() < thread_pool.size()) + if (thread_states_.size() < thread_pool.size()) { - num_calls_.resize(thread_pool.size(), 0); + thread_states_.resize(thread_pool.size()); } batch_size_ = size; - if (num_calls_[0] == 0) { - current_operation_ = Operation::kStepStart; + auto& thread_state = thread_states_[0]; + + if (thread_state.num_calls == 0) { + thread_state.current_operation = Operation::kStepStart; output_ = feature_transformer_trainer_->step_start(thread_pool, batch_begin, batch_end); } - assert(current_operation_ == Operation::kStepStart); + assert(thread_state.current_operation == Operation::kStepStart); - if (++num_calls_[0] == num_referrers_) { - num_calls_[0] = 0; - current_operation_ = Operation::kNone; + if (++thread_state.num_calls == num_referrers_) { + thread_state.num_calls = 0; + thread_state.current_operation = Operation::kNone; } return output_; @@ -97,16 +116,18 @@ namespace Eval::NNUE { void propagate(Thread& th, uint64_t offset, uint64_t count) { const auto thread_id = th.thread_idx(); - if (num_calls_[thread_id] == 0) { - current_operation_ = Operation::kPropagate; + auto& thread_state = thread_states_[thread_id]; + + if (thread_state.num_calls == 0) { + thread_state.current_operation = Operation::kPropagate; feature_transformer_trainer_->propagate(th, offset, count); } - assert(current_operation_ == Operation::kPropagate); + assert(thread_state.current_operation == Operation::kPropagate); - if (++num_calls_[thread_id] == num_referrers_) { - num_calls_[thread_id] = 0; - current_operation_ = Operation::kNone; + if (++thread_state.num_calls == num_referrers_) { + thread_state.num_calls = 0; + thread_state.current_operation = Operation::kNone; } } @@ -118,13 +139,15 @@ namespace Eval::NNUE { const auto thread_id = th.thread_idx(); + auto& thread_state = thread_states_[thread_id]; + if (num_referrers_ == 1) { feature_transformer_trainer_->backpropagate(th, gradients, offset, count); return; } - if (num_calls_[thread_id] == 0) { - current_operation_ = Operation::kBackPropagate; + if (thread_state.num_calls == 0) { + thread_state.current_operation = Operation::kBackPropagate; for (IndexType b = offset; b < offset + count; ++b) { const IndexType batch_offset = kInputDimensions * b; for (IndexType i = 0; i < kInputDimensions; ++i) { @@ -133,7 +156,7 @@ namespace Eval::NNUE { } } - assert(current_operation_ == Operation::kBackPropagate); + assert(thread_state.current_operation == Operation::kBackPropagate); for (IndexType b = offset; b < offset + count; ++b) { const IndexType batch_offset = kInputDimensions * b; @@ -142,25 +165,27 @@ namespace Eval::NNUE { } } - if (++num_calls_[thread_id] == num_referrers_) { + if (++thread_state.num_calls == num_referrers_) { feature_transformer_trainer_->backpropagate( th, gradients_.data(), offset, count); - num_calls_[thread_id] = 0; - current_operation_ = Operation::kNone; + thread_state.num_calls = 0; + thread_state.current_operation = Operation::kNone; } } void step_end(ThreadPool& thread_pool, LearnFloatType learning_rate) { - if (num_calls_[0] == 0) { - current_operation_ = Operation::kStepEnd; + auto& thread_state = thread_states_[0]; + + if (thread_state.num_calls == 0) { + thread_state.current_operation = Operation::kStepEnd; feature_transformer_trainer_->step_end(thread_pool, learning_rate); } - assert(current_operation_ == Operation::kStepEnd); + assert(thread_state.current_operation == Operation::kStepEnd); - if (++num_calls_[0] == num_referrers_) { - num_calls_[0] = 0; - current_operation_ = Operation::kNone; + if (++thread_state.num_calls == num_referrers_) { + thread_state.num_calls = 0; + thread_state.current_operation = Operation::kNone; } } @@ -169,8 +194,7 @@ namespace Eval::NNUE { SharedInputTrainer(FeatureTransformer* ft) : batch_size_(0), num_referrers_(0), - num_calls_(1, 0), - current_operation_(Operation::kNone), + thread_states_(1), feature_transformer_trainer_(Trainer::create( ft)), output_(nullptr) { @@ -197,11 +221,16 @@ namespace Eval::NNUE { // number of layers sharing this layer as input std::uint32_t num_referrers_; - // Number of times the current process has been called - std::vector num_calls_; + struct alignas(kCacheLineSize) ThreadState + { + std::uint32_t num_calls{0}; - // current processing type - Operation current_operation_; + // current processing type + Operation current_operation = Operation::kNone; + }; + + // Number of times the current process has been called + std::vector> thread_states_; // Trainer of input feature converter const std::shared_ptr>