diff --git a/src/nnue/trainer/trainer_affine_transform.h b/src/nnue/trainer/trainer_affine_transform.h index 610805ca..f66f1a65 100644 --- a/src/nnue/trainer/trainer_affine_transform.h +++ b/src/nnue/trainer/trainer_affine_transform.h @@ -91,19 +91,52 @@ namespace Eval::NNUE { quantize_parameters(); } - // forward propagation - const LearnFloatType* propagate(ThreadPool& thread_pool, const std::vector& batch) { - if (output_.size() < kOutputDimensions * batch.size()) { - output_.resize(kOutputDimensions * batch.size()); - gradients_.resize(kInputDimensions * batch.size()); + const LearnFloatType* step_start(ThreadPool& thread_pool, const std::vector& combined_batch) + { + if (output_.size() < kOutputDimensions * combined_batch.size()) { + output_.resize(kOutputDimensions * combined_batch.size()); + gradients_.resize(kInputDimensions * combined_batch.size()); } - batch_size_ = static_cast(batch.size()); - batch_input_ = previous_layer_trainer_->propagate(thread_pool, batch); + if (thread_states_.size() < thread_pool.size()) + { + thread_states_.resize(thread_pool.size()); + } + + combined_batch_size_ = static_cast(combined_batch.size()); + combined_batch_input_ = previous_layer_trainer_->step_start(thread_pool, combined_batch); + + auto& main_thread_state = thread_states_[0]; #if defined(USE_BLAS) - for (IndexType b = 0; b < batch_size_; ++b) { + // update + cblas_sscal( + kOutputDimensions, momentum_, main_thread_state.biases_diff_, 1 + ); + +#else + + Blas::sscal( + kOutputDimensions, momentum_, main_thread_state.biases_diff_, 1 + ); + +#endif + + for (IndexType i = 1; i < thread_states_.size(); ++i) + thread_states_[i].reset_biases(); + + return output_.data(); + } + + // forward propagation + void propagate(Thread& th, const uint64_t offset, const uint64_t count) { + + previous_layer_trainer_->propagate(th, offset, count); + +#if defined(USE_BLAS) + + for (IndexType b = offset; b < offset + count; ++b) { const IndexType batch_offset = kOutputDimensions * b; cblas_scopy( kOutputDimensions, biases_, 1, &output_[batch_offset], 1 @@ -112,149 +145,151 @@ namespace Eval::NNUE { cblas_sgemm( CblasColMajor, CblasTrans, CblasNoTrans, - kOutputDimensions, batch_size_, kInputDimensions, + kOutputDimensions, count, kInputDimensions, 1.0, weights_, kInputDimensions, - batch_input_, kInputDimensions, + combined_batch_input_ + offset * kInputDimensions, kInputDimensions, 1.0, - &output_[0], kOutputDimensions + &output_[offset * kOutputDimensions], kOutputDimensions ); #else - for (IndexType b = 0; b < batch_size_; ++b) { + for (IndexType b = offset; b < offset + count; ++b) { const IndexType batch_offset = kOutputDimensions * b; Blas::scopy( - thread_pool, kOutputDimensions, biases_, 1, &output_[batch_offset], 1 ); } Blas::sgemm( - thread_pool, Blas::MatrixLayout::ColMajor, Blas::MatrixTranspose::Trans, Blas::MatrixTranspose::NoTrans, - kOutputDimensions, batch_size_, kInputDimensions, + kOutputDimensions, count, kInputDimensions, 1.0, weights_, kInputDimensions, - batch_input_, kInputDimensions, + combined_batch_input_ + offset * kInputDimensions, kInputDimensions, 1.0, - &output_[0], kOutputDimensions + &output_[offset * kOutputDimensions], kOutputDimensions ); #endif - return output_.data(); } // backpropagation - void backpropagate(ThreadPool& thread_pool, + void backpropagate(Thread& th, const LearnFloatType* gradients, - LearnFloatType learning_rate) { - - const LearnFloatType local_learning_rate = - learning_rate * learning_rate_scale_; + uint64_t offset, + uint64_t count) { + auto& thread_state = thread_states_[th.thread_idx()]; + const auto momentum = th.thread_idx() == 0 ? momentum_ : 0.0f; #if defined(USE_BLAS) cblas_sgemm( CblasColMajor, CblasNoTrans, CblasNoTrans, - kInputDimensions, batch_size_, kOutputDimensions, + kInputDimensions, count, kOutputDimensions, 1.0, weights_, kInputDimensions, - gradients, kOutputDimensions, + gradients + offset * kOutputDimensions, kOutputDimensions, 0.0, - &gradients_[0], kInputDimensions + &gradients_[offset * kInputDimensions], kInputDimensions ); - // update - cblas_sscal( - kOutputDimensions, momentum_, biases_diff_, 1 - ); - - for (IndexType b = 0; b < batch_size_; ++b) { + for (IndexType b = offset; b < offset + count; ++b) { const IndexType batch_offset = kOutputDimensions * b; cblas_saxpy( kOutputDimensions, 1.0, - &gradients[batch_offset], 1, biases_diff_, 1 + &gradients[batch_offset], 1, thread_state.biases_diff_, 1 ); } cblas_sgemm( CblasRowMajor, CblasTrans, CblasNoTrans, - kOutputDimensions, kInputDimensions, batch_size_, + kOutputDimensions, kInputDimensions, count, 1.0, - gradients, kOutputDimensions, - batch_input_, kInputDimensions, - momentum_, - weights_diff_, kInputDimensions + gradients + offset * kOutputDimensions, kOutputDimensions, + combined_batch_input_ + offset * kInputDimensions, kInputDimensions, + momentum, + thread_state.weights_diff_, kInputDimensions ); #else // backpropagate Blas::sgemm( - thread_pool, Blas::MatrixLayout::ColMajor, Blas::MatrixTranspose::NoTrans, Blas::MatrixTranspose::NoTrans, - kInputDimensions, batch_size_, kOutputDimensions, + kInputDimensions, count, kOutputDimensions, 1.0, weights_, kInputDimensions, - gradients, kOutputDimensions, + gradients + offset * kOutputDimensions, kOutputDimensions, 0.0, - &gradients_[0], kInputDimensions + &gradients_[offset * kInputDimensions], kInputDimensions ); - - Blas::sscal( - thread_pool, - kOutputDimensions, momentum_, biases_diff_, 1 - ); - - for (IndexType b = 0; b < batch_size_; ++b) { + for (IndexType b = offset; b < offset + count; ++b) { const IndexType batch_offset = kOutputDimensions * b; - Blas::saxpy(thread_pool, kOutputDimensions, 1.0, - &gradients[batch_offset], 1, biases_diff_, 1); + Blas::saxpy(kOutputDimensions, 1.0, + &gradients[batch_offset], 1, thread_state.biases_diff_, 1); } Blas::sgemm( - thread_pool, Blas::MatrixLayout::RowMajor, Blas::MatrixTranspose::Trans, Blas::MatrixTranspose::NoTrans, - kOutputDimensions, kInputDimensions, batch_size_, + kOutputDimensions, kInputDimensions, count, 1.0, - gradients, kOutputDimensions, - batch_input_, kInputDimensions, - momentum_, - weights_diff_, kInputDimensions + gradients + offset * kOutputDimensions, kOutputDimensions, + combined_batch_input_ + offset * kInputDimensions, kInputDimensions, + momentum, + thread_state.weights_diff_, kInputDimensions ); #endif + previous_layer_trainer_->backpropagate(th, gradients_.data(), offset, count); + } + + void reduce_thread_state() + { + for (IndexType i = 1; i < thread_states_.size(); ++i) + { + thread_states_[0] += thread_states_[i]; + } + } + + void step_end(ThreadPool& thread_pool, LearnFloatType learning_rate) + { + const LearnFloatType local_learning_rate = + learning_rate * learning_rate_scale_; + + reduce_thread_state(); + + auto& main_thread_state = thread_states_[0]; + for (IndexType i = 0; i < kOutputDimensions; ++i) { - const double d = local_learning_rate * biases_diff_[i]; + const double d = local_learning_rate * main_thread_state.biases_diff_[i]; biases_[i] -= d; abs_biases_diff_sum_ += std::abs(d); } num_biases_diffs_ += kOutputDimensions; for (IndexType i = 0; i < kOutputDimensions * kInputDimensions; ++i) { - const double d = local_learning_rate * weights_diff_[i]; + const double d = local_learning_rate * main_thread_state.weights_diff_[i]; weights_[i] -= d; abs_weights_diff_sum_ += std::abs(d); } num_weights_diffs_ += kOutputDimensions * kInputDimensions; - previous_layer_trainer_->backpropagate(thread_pool, gradients_.data(), learning_rate); + previous_layer_trainer_->step_end(thread_pool, learning_rate); } private: // constructor Trainer(LayerType* target_layer, FeatureTransformer* ft) : - batch_size_(0), - batch_input_(nullptr), + combined_batch_size_(0), + combined_batch_input_(nullptr), previous_layer_trainer_(Trainer::create( &target_layer->previous_layer_, ft)), target_layer_(target_layer), biases_(), weights_(), - biases_diff_(), - weights_diff_(), momentum_(0.2), learning_rate_scale_(1.0) { @@ -335,10 +370,12 @@ namespace Eval::NNUE { } } - std::fill(std::begin(biases_diff_), std::end(biases_diff_), - static_cast(0.0)); - std::fill(std::begin(weights_diff_), std::end(weights_diff_), - static_cast(0.0)); + for (auto& state : thread_states_) + { + state.reset_weights(); + state.reset_biases(); + } + reset_stats(); } @@ -365,7 +402,7 @@ namespace Eval::NNUE { std::numeric_limits::max() / kWeightScale; // number of samples in mini-batch - IndexType batch_size_; + IndexType combined_batch_size_; double abs_biases_diff_sum_; double abs_weights_diff_sum_; @@ -373,7 +410,7 @@ namespace Eval::NNUE { uint64_t num_weights_diffs_; // Input mini batch - const LearnFloatType* batch_input_; + const LearnFloatType* combined_batch_input_; // Trainer of the previous layer const std::shared_ptr> previous_layer_trainer_; @@ -382,12 +419,44 @@ namespace Eval::NNUE { LayerType* const target_layer_; // parameter + struct alignas(kCacheLineSize) ThreadState + { + // Buffer used for updating parameters + alignas(kCacheLineSize) LearnFloatType biases_diff_[kOutputDimensions]; + alignas(kCacheLineSize) LearnFloatType weights_diff_[kOutputDimensions * kInputDimensions]; + + ThreadState() { reset_weights(); reset_biases(); } + + ThreadState& operator+=(const ThreadState& other) + { + for (IndexType i = 0; i < kOutputDimensions; ++i) + { + biases_diff_[i] += other.biases_diff_[i]; + } + + for (IndexType i = 0; i < kOutputDimensions * kInputDimensions; ++i) + { + weights_diff_[i] += other.weights_diff_[i]; + } + + return *this; + } + + void reset_weights() + { + std::fill(std::begin(weights_diff_), std::end(weights_diff_), 0.0f); + } + + void reset_biases() + { + std::fill(std::begin(biases_diff_), std::end(biases_diff_), 0.0f); + } + }; + alignas(kCacheLineSize) LearnFloatType biases_[kOutputDimensions]; alignas(kCacheLineSize) LearnFloatType weights_[kOutputDimensions * kInputDimensions]; - // Buffer used for updating parameters - alignas(kCacheLineSize) LearnFloatType biases_diff_[kOutputDimensions]; - alignas(kCacheLineSize) LearnFloatType weights_diff_[kOutputDimensions * kInputDimensions]; + std::vector> thread_states_; // Forward propagation buffer std::vector> output_;