mirror of
https://github.com/HChaZZY/Stockfish.git
synced 2025-12-24 11:06:58 +08:00
Prevent false sharing of num_calls counter in the shared input trainer. Fix current_operation not being local to the executing thread.
This commit is contained in:
@@ -15,6 +15,19 @@
|
||||
namespace Eval::NNUE {
|
||||
|
||||
// Learning: Input layer
|
||||
// This is tricky. It exists because when there's more than one trainer
|
||||
// on top of a single feature transformer we want to only call propagate/backpropagate
|
||||
// on the feature transformer once. This is straightforward in the old
|
||||
// multithreading case, because propagate/backpropagate is called just once from the
|
||||
// main thread. But with the current implementation of coarser multithreading
|
||||
// we end up calling each method from each thread. Therefore we have to keep
|
||||
// the num_calls and current_operation per thread basis, each thread must work
|
||||
// on its designated batch slice, and the only synchronization points are
|
||||
// step_start and step_end - for which we use state of the first thread.
|
||||
// Each thread requires their own bookkeeping because it's possible that
|
||||
// one thread is still in propagate of some batch slice while the other thread
|
||||
// is doing backpropagate of some other slice. We also ensure the thread state
|
||||
// isn't suspectible to false sharing by using a full cache line for the state.
|
||||
class SharedInputTrainer {
|
||||
public:
|
||||
// factory function
|
||||
@@ -34,32 +47,36 @@ namespace Eval::NNUE {
|
||||
|
||||
// Set options such as hyperparameters
|
||||
void send_message(Message* message) {
|
||||
if (num_calls_[0] == 0) {
|
||||
current_operation_ = Operation::kSendMessage;
|
||||
auto& thread_state = thread_states_[0];
|
||||
|
||||
if (thread_state.num_calls == 0) {
|
||||
thread_state.current_operation = Operation::kSendMessage;
|
||||
feature_transformer_trainer_->send_message(message);
|
||||
}
|
||||
|
||||
assert(current_operation_ == Operation::kSendMessage);
|
||||
assert(thread_state.current_operation == Operation::kSendMessage);
|
||||
|
||||
if (++num_calls_[0] == num_referrers_) {
|
||||
num_calls_[0] = 0;
|
||||
current_operation_ = Operation::kNone;
|
||||
if (++thread_state.num_calls == num_referrers_) {
|
||||
thread_state.num_calls = 0;
|
||||
thread_state.current_operation = Operation::kNone;
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize the parameters with random numbers
|
||||
template <typename RNG>
|
||||
void initialize(RNG& rng) {
|
||||
if (num_calls_[0] == 0) {
|
||||
current_operation_ = Operation::kInitialize;
|
||||
auto& thread_state = thread_states_[0];
|
||||
|
||||
if (thread_state.num_calls == 0) {
|
||||
thread_state.current_operation = Operation::kInitialize;
|
||||
feature_transformer_trainer_->initialize(rng);
|
||||
}
|
||||
|
||||
assert(current_operation_ == Operation::kInitialize);
|
||||
assert(thread_state.current_operation == Operation::kInitialize);
|
||||
|
||||
if (++num_calls_[0] == num_referrers_) {
|
||||
num_calls_[0] = 0;
|
||||
current_operation_ = Operation::kNone;
|
||||
if (++thread_state.num_calls == num_referrers_) {
|
||||
thread_state.num_calls = 0;
|
||||
thread_state.current_operation = Operation::kNone;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,23 +88,25 @@ namespace Eval::NNUE {
|
||||
gradients_.resize(kInputDimensions * size);
|
||||
}
|
||||
|
||||
if (num_calls_.size() < thread_pool.size())
|
||||
if (thread_states_.size() < thread_pool.size())
|
||||
{
|
||||
num_calls_.resize(thread_pool.size(), 0);
|
||||
thread_states_.resize(thread_pool.size());
|
||||
}
|
||||
|
||||
batch_size_ = size;
|
||||
|
||||
if (num_calls_[0] == 0) {
|
||||
current_operation_ = Operation::kStepStart;
|
||||
auto& thread_state = thread_states_[0];
|
||||
|
||||
if (thread_state.num_calls == 0) {
|
||||
thread_state.current_operation = Operation::kStepStart;
|
||||
output_ = feature_transformer_trainer_->step_start(thread_pool, batch_begin, batch_end);
|
||||
}
|
||||
|
||||
assert(current_operation_ == Operation::kStepStart);
|
||||
assert(thread_state.current_operation == Operation::kStepStart);
|
||||
|
||||
if (++num_calls_[0] == num_referrers_) {
|
||||
num_calls_[0] = 0;
|
||||
current_operation_ = Operation::kNone;
|
||||
if (++thread_state.num_calls == num_referrers_) {
|
||||
thread_state.num_calls = 0;
|
||||
thread_state.current_operation = Operation::kNone;
|
||||
}
|
||||
|
||||
return output_;
|
||||
@@ -97,16 +116,18 @@ namespace Eval::NNUE {
|
||||
void propagate(Thread& th, uint64_t offset, uint64_t count) {
|
||||
const auto thread_id = th.thread_idx();
|
||||
|
||||
if (num_calls_[thread_id] == 0) {
|
||||
current_operation_ = Operation::kPropagate;
|
||||
auto& thread_state = thread_states_[thread_id];
|
||||
|
||||
if (thread_state.num_calls == 0) {
|
||||
thread_state.current_operation = Operation::kPropagate;
|
||||
feature_transformer_trainer_->propagate(th, offset, count);
|
||||
}
|
||||
|
||||
assert(current_operation_ == Operation::kPropagate);
|
||||
assert(thread_state.current_operation == Operation::kPropagate);
|
||||
|
||||
if (++num_calls_[thread_id] == num_referrers_) {
|
||||
num_calls_[thread_id] = 0;
|
||||
current_operation_ = Operation::kNone;
|
||||
if (++thread_state.num_calls == num_referrers_) {
|
||||
thread_state.num_calls = 0;
|
||||
thread_state.current_operation = Operation::kNone;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -118,13 +139,15 @@ namespace Eval::NNUE {
|
||||
|
||||
const auto thread_id = th.thread_idx();
|
||||
|
||||
auto& thread_state = thread_states_[thread_id];
|
||||
|
||||
if (num_referrers_ == 1) {
|
||||
feature_transformer_trainer_->backpropagate(th, gradients, offset, count);
|
||||
return;
|
||||
}
|
||||
|
||||
if (num_calls_[thread_id] == 0) {
|
||||
current_operation_ = Operation::kBackPropagate;
|
||||
if (thread_state.num_calls == 0) {
|
||||
thread_state.current_operation = Operation::kBackPropagate;
|
||||
for (IndexType b = offset; b < offset + count; ++b) {
|
||||
const IndexType batch_offset = kInputDimensions * b;
|
||||
for (IndexType i = 0; i < kInputDimensions; ++i) {
|
||||
@@ -133,7 +156,7 @@ namespace Eval::NNUE {
|
||||
}
|
||||
}
|
||||
|
||||
assert(current_operation_ == Operation::kBackPropagate);
|
||||
assert(thread_state.current_operation == Operation::kBackPropagate);
|
||||
|
||||
for (IndexType b = offset; b < offset + count; ++b) {
|
||||
const IndexType batch_offset = kInputDimensions * b;
|
||||
@@ -142,25 +165,27 @@ namespace Eval::NNUE {
|
||||
}
|
||||
}
|
||||
|
||||
if (++num_calls_[thread_id] == num_referrers_) {
|
||||
if (++thread_state.num_calls == num_referrers_) {
|
||||
feature_transformer_trainer_->backpropagate(
|
||||
th, gradients_.data(), offset, count);
|
||||
num_calls_[thread_id] = 0;
|
||||
current_operation_ = Operation::kNone;
|
||||
thread_state.num_calls = 0;
|
||||
thread_state.current_operation = Operation::kNone;
|
||||
}
|
||||
}
|
||||
|
||||
void step_end(ThreadPool& thread_pool, LearnFloatType learning_rate) {
|
||||
if (num_calls_[0] == 0) {
|
||||
current_operation_ = Operation::kStepEnd;
|
||||
auto& thread_state = thread_states_[0];
|
||||
|
||||
if (thread_state.num_calls == 0) {
|
||||
thread_state.current_operation = Operation::kStepEnd;
|
||||
feature_transformer_trainer_->step_end(thread_pool, learning_rate);
|
||||
}
|
||||
|
||||
assert(current_operation_ == Operation::kStepEnd);
|
||||
assert(thread_state.current_operation == Operation::kStepEnd);
|
||||
|
||||
if (++num_calls_[0] == num_referrers_) {
|
||||
num_calls_[0] = 0;
|
||||
current_operation_ = Operation::kNone;
|
||||
if (++thread_state.num_calls == num_referrers_) {
|
||||
thread_state.num_calls = 0;
|
||||
thread_state.current_operation = Operation::kNone;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -169,8 +194,7 @@ namespace Eval::NNUE {
|
||||
SharedInputTrainer(FeatureTransformer* ft) :
|
||||
batch_size_(0),
|
||||
num_referrers_(0),
|
||||
num_calls_(1, 0),
|
||||
current_operation_(Operation::kNone),
|
||||
thread_states_(1),
|
||||
feature_transformer_trainer_(Trainer<FeatureTransformer>::create(
|
||||
ft)),
|
||||
output_(nullptr) {
|
||||
@@ -197,11 +221,16 @@ namespace Eval::NNUE {
|
||||
// number of layers sharing this layer as input
|
||||
std::uint32_t num_referrers_;
|
||||
|
||||
// Number of times the current process has been called
|
||||
std::vector<std::uint32_t> num_calls_;
|
||||
struct alignas(kCacheLineSize) ThreadState
|
||||
{
|
||||
std::uint32_t num_calls{0};
|
||||
|
||||
// current processing type
|
||||
Operation current_operation_;
|
||||
// current processing type
|
||||
Operation current_operation = Operation::kNone;
|
||||
};
|
||||
|
||||
// Number of times the current process has been called
|
||||
std::vector<ThreadState, CacheLineAlignedAllocator<ThreadState>> thread_states_;
|
||||
|
||||
// Trainer of input feature converter
|
||||
const std::shared_ptr<Trainer<FeatureTransformer>>
|
||||
|
||||
Reference in New Issue
Block a user