From 15c528ca7b6beefa64ba2c0192c7dc3efacc665e Mon Sep 17 00:00:00 2001 From: Tomasz Sobczyk Date: Sun, 22 Nov 2020 21:38:11 +0100 Subject: [PATCH] Prepare feature transformer learner. --- .../trainer/trainer_feature_transformer.h | 486 +++++++++++------- 1 file changed, 298 insertions(+), 188 deletions(-) diff --git a/src/nnue/trainer/trainer_feature_transformer.h b/src/nnue/trainer/trainer_feature_transformer.h index 80f914f2..9686002f 100644 --- a/src/nnue/trainer/trainer_feature_transformer.h +++ b/src/nnue/trainer/trainer_feature_transformer.h @@ -89,56 +89,88 @@ namespace Eval::NNUE { quantize_parameters(); } - // forward propagation - const LearnFloatType* propagate(ThreadPool& thread_pool, const std::vector& batch) { - if (output_.size() < kOutputDimensions * batch.size()) { - output_.resize(kOutputDimensions * batch.size()); - gradients_.resize(kOutputDimensions * batch.size()); + const LearnFloatType* step_start(ThreadPool& thread_pool, const std::vector& combined_batch) + { + if (output_.size() < kOutputDimensions * combined_batch.size()) { + output_.resize(kOutputDimensions * combined_batch.size()); + gradients_.resize(kOutputDimensions * combined_batch.size()); } - (void)thread_pool; + if (thread_stat_states_.size() < thread_pool.size()) + { + thread_stat_states_.resize(thread_pool.size()); + } - batch_ = &batch; - // affine transform - thread_pool.for_each_index_with_workers( - 0, batch.size(), - [&](Thread&, int b) { - const IndexType batch_offset = kOutputDimensions * b; - for (IndexType c = 0; c < 2; ++c) { - const IndexType output_offset = batch_offset + kHalfDimensions * c; + if (thread_bias_states_.size() < thread_pool.size()) + { + thread_bias_states_.resize(thread_pool.size()); + } + + batch_ = &combined_batch; + + auto& main_thread_bias_state = thread_bias_states_[0]; #if defined(USE_BLAS) - cblas_scopy( - kHalfDimensions, biases_, 1, &output_[output_offset], 1 - ); - - for (const auto& feature : batch[b].training_features[c]) { - const IndexType weights_offset = kHalfDimensions * feature.get_index(); - cblas_saxpy( - kHalfDimensions, (float)feature.get_count(), - &weights_[weights_offset], 1, &output_[output_offset], 1 - ); - } + cblas_sscal( + kHalfDimensions, momentum_, main_thread_bias_state.biases_diff_, 1 + ); #else - Blas::scopy( - kHalfDimensions, biases_, 1, &output_[output_offset], 1 - ); - for (const auto& feature : batch[b].training_features[c]) { - const IndexType weights_offset = kHalfDimensions * feature.get_index(); - Blas::saxpy( - kHalfDimensions, (float)feature.get_count(), - &weights_[weights_offset], 1, &output_[output_offset], 1 - ); - } + Blas::sscal( + kHalfDimensions, momentum_, main_thread_bias_state.biases_diff_, 1 + ); #endif + + for (IndexType i = 1; i < thread_bias_states_.size(); ++i) + thread_bias_states_[i].reset(); + + return output_.data(); + } + + // forward propagation + void propagate(Thread& th, uint64_t offset, uint64_t count) { + + auto& thread_stat_state = thread_stat_states_[th.thread_idx()]; + + for (IndexType b = offset; b < offset + count; ++b) + { + const IndexType batch_offset = kOutputDimensions * b; + for (IndexType c = 0; c < 2; ++c) { + const IndexType output_offset = batch_offset + kHalfDimensions * c; + +#if defined(USE_BLAS) + + cblas_scopy( + kHalfDimensions, biases_, 1, &output_[output_offset], 1 + ); + + for (const auto& feature : (*batch_)[b].training_features[c]) { + const IndexType weights_offset = kHalfDimensions * feature.get_index(); + cblas_saxpy( + kHalfDimensions, (float)feature.get_count(), + &weights_[weights_offset], 1, &output_[output_offset], 1 + ); } + +#else + + Blas::scopy( + kHalfDimensions, biases_, 1, &output_[output_offset], 1 + ); + for (const auto& feature : (*batch_)[b].training_features[c]) { + const IndexType weights_offset = kHalfDimensions * feature.get_index(); + Blas::saxpy( + kHalfDimensions, (float)feature.get_count(), + &weights_[weights_offset], 1, &output_[output_offset], 1 + ); + } + +#endif } - ); - thread_pool.wait_for_workers_finished(); + } #if defined (USE_SSE2) @@ -161,49 +193,51 @@ namespace Eval::NNUE { return _mm_cvtss_f32(_mm_max_ps(max_x_x_13_20, max_x_x_20_13)); }; - const int total_size = batch.size() * kOutputDimensions; - const __m128 kZero4 = _mm_set1_ps(+kZero); const __m128 kOne4 = _mm_set1_ps(+kOne); - __m128 min_pre_activation0 = _mm_set1_ps(min_pre_activation_); - __m128 min_pre_activation1 = _mm_set1_ps(min_pre_activation_); - __m128 max_pre_activation0 = _mm_set1_ps(max_pre_activation_); - __m128 max_pre_activation1 = _mm_set1_ps(max_pre_activation_); + __m128 min_pre_activation0 = _mm_set1_ps(thread_stat_state.min_pre_activation_); + __m128 min_pre_activation1 = _mm_set1_ps(thread_stat_state.min_pre_activation_); + __m128 max_pre_activation0 = _mm_set1_ps(thread_stat_state.max_pre_activation_); + __m128 max_pre_activation1 = _mm_set1_ps(thread_stat_state.max_pre_activation_); - for (int i = 0; i < total_size; i += 16) + for (IndexType b = offset; b < offset + count; ++b) { - __m128 out0 = _mm_loadu_ps(&output_[i + 0]); - __m128 out1 = _mm_loadu_ps(&output_[i + 4]); - __m128 out2 = _mm_loadu_ps(&output_[i + 8]); - __m128 out3 = _mm_loadu_ps(&output_[i + 12]); + const IndexType batch_offset = kOutputDimensions * b; + for (IndexType i = 0; i < kOutputDimensions; i += 16) + { + __m128 out0 = _mm_loadu_ps(&output_[batch_offset + i + 0]); + __m128 out1 = _mm_loadu_ps(&output_[batch_offset + i + 4]); + __m128 out2 = _mm_loadu_ps(&output_[batch_offset + i + 8]); + __m128 out3 = _mm_loadu_ps(&output_[batch_offset + i + 12]); - __m128 min01 = _mm_min_ps(out0, out1); - __m128 min23 = _mm_min_ps(out2, out3); + __m128 min01 = _mm_min_ps(out0, out1); + __m128 min23 = _mm_min_ps(out2, out3); - __m128 max01 = _mm_max_ps(out0, out1); - __m128 max23 = _mm_max_ps(out2, out3); + __m128 max01 = _mm_max_ps(out0, out1); + __m128 max23 = _mm_max_ps(out2, out3); - min_pre_activation0 = _mm_min_ps(min_pre_activation0, min01); - min_pre_activation1 = _mm_min_ps(min_pre_activation1, min23); - max_pre_activation0 = _mm_max_ps(max_pre_activation0, max01); - max_pre_activation1 = _mm_max_ps(max_pre_activation1, max23); + min_pre_activation0 = _mm_min_ps(min_pre_activation0, min01); + min_pre_activation1 = _mm_min_ps(min_pre_activation1, min23); + max_pre_activation0 = _mm_max_ps(max_pre_activation0, max01); + max_pre_activation1 = _mm_max_ps(max_pre_activation1, max23); - out0 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out0)); - out1 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out1)); - out2 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out2)); - out3 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out3)); + out0 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out0)); + out1 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out1)); + out2 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out2)); + out3 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out3)); - _mm_storeu_ps(&output_[i + 0], out0); - _mm_storeu_ps(&output_[i + 4], out1); - _mm_storeu_ps(&output_[i + 8], out2); - _mm_storeu_ps(&output_[i + 12], out3); + _mm_storeu_ps(&output_[batch_offset + i + 0], out0); + _mm_storeu_ps(&output_[batch_offset + i + 4], out1); + _mm_storeu_ps(&output_[batch_offset + i + 8], out2); + _mm_storeu_ps(&output_[batch_offset + i + 12], out3); + } } - min_pre_activation_ = m128_hmin_ps(_mm_min_ps(min_pre_activation0, min_pre_activation1)); - max_pre_activation_ = m128_hmax_ps(_mm_max_ps(max_pre_activation0, max_pre_activation1)); + thread_stat_state.min_pre_activation_ = m128_hmin_ps(_mm_min_ps(min_pre_activation0, min_pre_activation1)); + thread_stat_state.max_pre_activation_ = m128_hmax_ps(_mm_max_ps(max_pre_activation0, max_pre_activation1)); - for (IndexType b = 0; b < batch.size(); ++b) + for (IndexType b = offset; b < offset + count; ++b) { const IndexType batch_offset = kOutputDimensions * b; @@ -217,15 +251,15 @@ namespace Eval::NNUE { const __m128 out2 = _mm_loadu_ps(&output_[i + 8 + half_offset]); const __m128 out3 = _mm_loadu_ps(&output_[i + 12 + half_offset]); - __m128 minact0 = _mm_loadu_ps(&min_activations_[i + 0]); - __m128 minact1 = _mm_loadu_ps(&min_activations_[i + 4]); - __m128 minact2 = _mm_loadu_ps(&min_activations_[i + 8]); - __m128 minact3 = _mm_loadu_ps(&min_activations_[i + 12]); + __m128 minact0 = _mm_loadu_ps(&thread_stat_state.min_activations_[i + 0]); + __m128 minact1 = _mm_loadu_ps(&thread_stat_state.min_activations_[i + 4]); + __m128 minact2 = _mm_loadu_ps(&thread_stat_state.min_activations_[i + 8]); + __m128 minact3 = _mm_loadu_ps(&thread_stat_state.min_activations_[i + 12]); - __m128 maxact0 = _mm_loadu_ps(&max_activations_[i + 0]); - __m128 maxact1 = _mm_loadu_ps(&max_activations_[i + 4]); - __m128 maxact2 = _mm_loadu_ps(&max_activations_[i + 8]); - __m128 maxact3 = _mm_loadu_ps(&max_activations_[i + 12]); + __m128 maxact0 = _mm_loadu_ps(&thread_stat_state.max_activations_[i + 0]); + __m128 maxact1 = _mm_loadu_ps(&thread_stat_state.max_activations_[i + 4]); + __m128 maxact2 = _mm_loadu_ps(&thread_stat_state.max_activations_[i + 8]); + __m128 maxact3 = _mm_loadu_ps(&thread_stat_state.max_activations_[i + 12]); minact0 = _mm_min_ps(out0, minact0); minact1 = _mm_min_ps(out1, minact1); @@ -237,15 +271,15 @@ namespace Eval::NNUE { maxact2 = _mm_max_ps(out2, maxact2); maxact3 = _mm_max_ps(out3, maxact3); - _mm_storeu_ps(&min_activations_[i + 0], minact0); - _mm_storeu_ps(&min_activations_[i + 4], minact1); - _mm_storeu_ps(&min_activations_[i + 8], minact2); - _mm_storeu_ps(&min_activations_[i + 12], minact3); + _mm_storeu_ps(&thread_stat_state.min_activations_[i + 0], minact0); + _mm_storeu_ps(&thread_stat_state.min_activations_[i + 4], minact1); + _mm_storeu_ps(&thread_stat_state.min_activations_[i + 8], minact2); + _mm_storeu_ps(&thread_stat_state.min_activations_[i + 12], minact3); - _mm_storeu_ps(&max_activations_[i + 0], maxact0); - _mm_storeu_ps(&max_activations_[i + 4], maxact1); - _mm_storeu_ps(&max_activations_[i + 8], maxact2); - _mm_storeu_ps(&max_activations_[i + 12], maxact3); + _mm_storeu_ps(&thread_stat_state.max_activations_[i + 0], maxact0); + _mm_storeu_ps(&thread_stat_state.max_activations_[i + 4], maxact1); + _mm_storeu_ps(&thread_stat_state.max_activations_[i + 8], maxact2); + _mm_storeu_ps(&thread_stat_state.max_activations_[i + 12], maxact3); } } } @@ -254,33 +288,30 @@ namespace Eval::NNUE { #else // clipped ReLU - for (IndexType b = 0; b < batch.size(); ++b) { + for (IndexType b = offset; b < offset + count; ++b) { const IndexType batch_offset = kOutputDimensions * b; for (IndexType i = 0; i < kOutputDimensions; ++i) { const IndexType index = batch_offset + i; - min_pre_activation_ = std::min(min_pre_activation_, output_[index]); - max_pre_activation_ = std::max(max_pre_activation_, output_[index]); + thread_stat_state.min_pre_activation_ = std::min(thread_stat_state.min_pre_activation_, output_[index]); + thread_stat_state.max_pre_activation_ = std::max(thread_stat_state.max_pre_activation_, output_[index]); output_[index] = std::max(+kZero, std::min(+kOne, output_[index])); const IndexType t = i % kHalfDimensions; - min_activations_[t] = std::min(min_activations_[t], output_[index]); - max_activations_[t] = std::max(max_activations_[t], output_[index]); + thread_stat_state.min_activations_[t] = std::min(thread_stat_state.min_activations_[t], output_[index]); + thread_stat_state.max_activations_[t] = std::max(thread_stat_state.max_activations_[t], output_[index]); } } #endif - - return output_.data(); } // backpropagation - void backpropagate(ThreadPool& thread_pool, + void backpropagate(Thread& th, const LearnFloatType* gradients, - LearnFloatType learning_rate) { + uint64_t offset, + uint64_t count) { - (void)thread_pool; - - const LearnFloatType local_learning_rate = - learning_rate * learning_rate_scale_; + auto& thread_stat_state = thread_stat_states_[th.thread_idx()]; + auto& thread_bias_state = thread_bias_states_[th.thread_idx()]; #if defined (USE_SSE2) @@ -290,111 +321,134 @@ namespace Eval::NNUE { const __m128 kZero4 = _mm_set1_ps(+kZero); const __m128 kOne4 = _mm_set1_ps(+kOne); - const IndexType total_size = batch_->size() * kOutputDimensions; - - for (IndexType i = 0; i < total_size; i += 16) + for (IndexType b = offset; b < offset + count; ++b) { - __m128 out0 = _mm_loadu_ps(&output_[i + 0]); - __m128 out1 = _mm_loadu_ps(&output_[i + 4]); - __m128 out2 = _mm_loadu_ps(&output_[i + 8]); - __m128 out3 = _mm_loadu_ps(&output_[i + 12]); + const IndexType batch_offset = kOutputDimensions * b; + for (IndexType i = 0; i < kOutputDimensions; i += 16) + { + __m128 out0 = _mm_loadu_ps(&output_[batch_offset + i + 0]); + __m128 out1 = _mm_loadu_ps(&output_[batch_offset + i + 4]); + __m128 out2 = _mm_loadu_ps(&output_[batch_offset + i + 8]); + __m128 out3 = _mm_loadu_ps(&output_[batch_offset + i + 12]); - __m128 clipped0 = _mm_or_ps(_mm_cmple_ps(out0, kZero4), _mm_cmpge_ps(out0, kOne4)); - __m128 clipped1 = _mm_or_ps(_mm_cmple_ps(out1, kZero4), _mm_cmpge_ps(out1, kOne4)); - __m128 clipped2 = _mm_or_ps(_mm_cmple_ps(out2, kZero4), _mm_cmpge_ps(out2, kOne4)); - __m128 clipped3 = _mm_or_ps(_mm_cmple_ps(out3, kZero4), _mm_cmpge_ps(out3, kOne4)); + __m128 clipped0 = _mm_or_ps(_mm_cmple_ps(out0, kZero4), _mm_cmpge_ps(out0, kOne4)); + __m128 clipped1 = _mm_or_ps(_mm_cmple_ps(out1, kZero4), _mm_cmpge_ps(out1, kOne4)); + __m128 clipped2 = _mm_or_ps(_mm_cmple_ps(out2, kZero4), _mm_cmpge_ps(out2, kOne4)); + __m128 clipped3 = _mm_or_ps(_mm_cmple_ps(out3, kZero4), _mm_cmpge_ps(out3, kOne4)); - __m128 grad0 = _mm_loadu_ps(&gradients[i + 0]); - __m128 grad1 = _mm_loadu_ps(&gradients[i + 4]); - __m128 grad2 = _mm_loadu_ps(&gradients[i + 8]); - __m128 grad3 = _mm_loadu_ps(&gradients[i + 12]); + __m128 grad0 = _mm_loadu_ps(&gradients[batch_offset + i + 0]); + __m128 grad1 = _mm_loadu_ps(&gradients[batch_offset + i + 4]); + __m128 grad2 = _mm_loadu_ps(&gradients[batch_offset + i + 8]); + __m128 grad3 = _mm_loadu_ps(&gradients[batch_offset + i + 12]); - grad0 = _mm_andnot_ps(clipped0, grad0); - grad1 = _mm_andnot_ps(clipped1, grad1); - grad2 = _mm_andnot_ps(clipped2, grad2); - grad3 = _mm_andnot_ps(clipped3, grad3); + grad0 = _mm_andnot_ps(clipped0, grad0); + grad1 = _mm_andnot_ps(clipped1, grad1); + grad2 = _mm_andnot_ps(clipped2, grad2); + grad3 = _mm_andnot_ps(clipped3, grad3); - _mm_storeu_ps(&gradients_[i + 0], grad0); - _mm_storeu_ps(&gradients_[i + 4], grad1); - _mm_storeu_ps(&gradients_[i + 8], grad2); - _mm_storeu_ps(&gradients_[i + 12], grad3); + _mm_storeu_ps(&gradients_[batch_offset + i + 0], grad0); + _mm_storeu_ps(&gradients_[batch_offset + i + 4], grad1); + _mm_storeu_ps(&gradients_[batch_offset + i + 8], grad2); + _mm_storeu_ps(&gradients_[batch_offset + i + 12], grad3); - const int clipped_mask = - (_mm_movemask_ps(clipped0) << 0) - | (_mm_movemask_ps(clipped1) << 4) - | (_mm_movemask_ps(clipped2) << 8) - | (_mm_movemask_ps(clipped3) << 12); + const int clipped_mask = + (_mm_movemask_ps(clipped0) << 0) + | (_mm_movemask_ps(clipped1) << 4) + | (_mm_movemask_ps(clipped2) << 8) + | (_mm_movemask_ps(clipped3) << 12); - num_clipped_ += popcount(clipped_mask); + thread_stat_state.num_clipped_ += popcount(clipped_mask); + } } } #else - for (IndexType b = 0; b < batch_->size(); ++b) { + for (IndexType b = offset; b < offset + count; ++b) { const IndexType batch_offset = kOutputDimensions * b; for (IndexType i = 0; i < kOutputDimensions; ++i) { const IndexType index = batch_offset + i; const bool clipped = (output_[index] <= kZero) | (output_[index] >= kOne); gradients_[index] = gradients[index] * !clipped; - num_clipped_ += clipped; + thread_stat_state.num_clipped_ += clipped; } } #endif - num_total_ += batch_->size() * kOutputDimensions; + thread_stat_state.num_total_ += count * kOutputDimensions; + +#if defined(USE_BLAS) + + for (IndexType b = offset; b < offset + count; ++b) { + const IndexType batch_offset = kOutputDimensions * b; + for (IndexType c = 0; c < 2; ++c) { + const IndexType output_offset = batch_offset + kHalfDimensions * c; + cblas_saxpy( + kHalfDimensions, 1.0, + &gradients_[output_offset], 1, thread_bias_state.biases_diff_, 1 + ); + } + } + +#else + + for (IndexType b = offset; b < offset + count; ++b) { + const IndexType batch_offset = kOutputDimensions * b; + for (IndexType c = 0; c < 2; ++c) { + const IndexType output_offset = batch_offset + kHalfDimensions * c; + Blas::saxpy( + kHalfDimensions, 1.0, + &gradients_[output_offset], 1, thread_bias_state.biases_diff_, 1 + ); + } + } + +#endif + } + + void reduce_thread_stat_state() + { + for (IndexType i = 1; i < thread_stat_states_.size(); ++i) + { + thread_stat_states_[0] += thread_stat_states_[i]; + } + } + + void reduce_thread_bias_state() + { + for (IndexType i = 1; i < thread_bias_states_.size(); ++i) + { + thread_bias_states_[0] += thread_bias_states_[i]; + } + } + + void step_end(ThreadPool& thread_pool, LearnFloatType learning_rate) { + + const LearnFloatType local_learning_rate = + learning_rate * learning_rate_scale_; // Since the weight matrix updates only the columns corresponding to the features that appeared in the input, // Correct the learning rate and adjust the scale without using momentum const LearnFloatType effective_learning_rate = static_cast(local_learning_rate / (1.0 - momentum_)); + reduce_thread_bias_state(); + + auto& main_thread_state = thread_bias_states_[0]; + #if defined(USE_BLAS) - cblas_sscal( - kHalfDimensions, momentum_, biases_diff_, 1 - ); - - for (IndexType b = 0; b < batch_->size(); ++b) { - const IndexType batch_offset = kOutputDimensions * b; - for (IndexType c = 0; c < 2; ++c) { - const IndexType output_offset = batch_offset + kHalfDimensions * c; - cblas_saxpy( - kHalfDimensions, 1.0, - &gradients_[output_offset], 1, biases_diff_, 1 - ); - } - } - cblas_saxpy( kHalfDimensions, -local_learning_rate, - biases_diff_, 1, biases_, 1 + main_thread_state.biases_diff_, 1, biases_, 1 ); #else - Blas::sscal( - thread_pool, - kHalfDimensions, momentum_, biases_diff_, 1 - ); - - for (IndexType b = 0; b < batch_->size(); ++b) { - const IndexType batch_offset = kOutputDimensions * b; - for (IndexType c = 0; c < 2; ++c) { - const IndexType output_offset = batch_offset + kHalfDimensions * c; - Blas::saxpy( - thread_pool, - kHalfDimensions, 1.0, - &gradients_[output_offset], 1, biases_diff_, 1 - ); - } - } - Blas::saxpy( - thread_pool, kHalfDimensions, -local_learning_rate, - biases_diff_, 1, biases_, 1 + main_thread_state.biases_diff_, 1, biases_, 1 ); #endif @@ -464,7 +518,6 @@ namespace Eval::NNUE { target_layer_(target_layer), biases_(), weights_(), - biases_diff_(), momentum_(0.2), learning_rate_scale_(1.0) { @@ -502,16 +555,8 @@ namespace Eval::NNUE { } void reset_stats() { - min_pre_activation_ = std::numeric_limits::max(); - max_pre_activation_ = std::numeric_limits::lowest(); - - std::fill(std::begin(min_activations_), std::end(min_activations_), - std::numeric_limits::max()); - std::fill(std::begin(max_activations_), std::end(max_activations_), - std::numeric_limits::lowest()); - - num_clipped_ = 0; - num_total_ = 0; + for (auto& state : thread_stat_states_) + state.reset(); } // read parameterized integer @@ -528,9 +573,10 @@ namespace Eval::NNUE { target_layer_->weights_[i] / kWeightScale); } - std::fill(std::begin(biases_diff_), std::end(biases_diff_), +kZero); - reset_stats(); + + for (auto& state : thread_bias_states_) + state.reset(); } // Set the weight corresponding to the feature that does not appear in the learning data to 0 @@ -552,10 +598,14 @@ namespace Eval::NNUE { std::numeric_limits::max() / kWeightScale; + reduce_thread_stat_state(); + + auto& main_thread_state = thread_stat_states_[0]; + const auto largest_min_activation = *std::max_element( - std::begin(min_activations_), std::end(min_activations_)); + std::begin(main_thread_state.min_activations_), std::end(main_thread_state.min_activations_)); const auto smallest_max_activation = *std::min_element( - std::begin(max_activations_), std::end(max_activations_)); + std::begin(main_thread_state.max_activations_), std::end(main_thread_state.max_activations_)); double abs_bias_sum = 0.0; double abs_weight_sum = 0.0; @@ -578,8 +628,8 @@ namespace Eval::NNUE { << std::endl; out << " - (min, max) of pre-activations = " - << min_pre_activation_ << ", " - << max_pre_activation_ << " (limit = " + << main_thread_state.min_pre_activation_ << ", " + << main_thread_state.max_pre_activation_ << " (limit = " << kPreActivationLimit << ")" << std::endl; @@ -590,7 +640,7 @@ namespace Eval::NNUE { out << " - avg_abs_bias = " << abs_bias_sum / std::size(biases_) << std::endl; out << " - avg_abs_weight = " << abs_weight_sum / std::size(weights_) << std::endl; - out << " - clipped " << static_cast(num_clipped_) / num_total_ * 100.0 << "% of outputs" + out << " - clipped " << static_cast(main_thread_state.num_clipped_) / main_thread_state.num_total_ * 100.0 << "% of outputs" << std::endl; out.unlock(); @@ -620,7 +670,6 @@ namespace Eval::NNUE { // layer to learn LayerType* const target_layer_; - IndexType num_clipped_; IndexType num_total_; // parameter @@ -629,7 +678,6 @@ namespace Eval::NNUE { LearnFloatType weights_[kHalfDimensions * kInputDimensions]; // Buffer used for updating parameters - alignas(kCacheLineSize) LearnFloatType biases_diff_[kHalfDimensions]; std::vector> gradients_; // Forward propagation buffer @@ -643,11 +691,73 @@ namespace Eval::NNUE { LearnFloatType momentum_; LearnFloatType learning_rate_scale_; - // Health check statistics - LearnFloatType min_pre_activation_; - LearnFloatType max_pre_activation_; - alignas(kCacheLineSize) LearnFloatType min_activations_[kHalfDimensions]; - alignas(kCacheLineSize) LearnFloatType max_activations_[kHalfDimensions]; + struct alignas(kCacheLineSize) ThreadStatState + { + alignas(kCacheLineSize) LearnFloatType min_activations_[kHalfDimensions]; + alignas(kCacheLineSize) LearnFloatType max_activations_[kHalfDimensions]; + LearnFloatType min_pre_activation_; + LearnFloatType max_pre_activation_; + IndexType num_clipped_; + IndexType num_total_; + + ThreadStatState() { reset(); } + + ThreadStatState& operator+=(const ThreadStatState& other) + { + for (IndexType i = 0; i < kHalfDimensions; ++i) + { + min_activations_[i] = std::min(min_activations_[i], other.min_activations_[i]); + } + + for (IndexType i = 0; i < kHalfDimensions; ++i) + { + max_activations_[i] = std::max(max_activations_[i], other.max_activations_[i]); + } + + min_pre_activation_ = std::min(min_pre_activation_, other.min_pre_activation_); + max_pre_activation_ = std::max(max_pre_activation_, other.max_pre_activation_); + + num_clipped_ += other.num_clipped_; + num_total_ += other.num_total_; + + return *this; + } + + void reset() + { + std::fill(std::begin(min_activations_), std::end(min_activations_), std::numeric_limits::max()); + std::fill(std::begin(max_activations_), std::end(max_activations_), std::numeric_limits::lowest()); + min_pre_activation_ = std::numeric_limits::max(); + max_pre_activation_ = std::numeric_limits::lowest(); + num_clipped_ = 0; + num_total_ = 0; + } + }; + + struct alignas(kCacheLineSize) ThreadBiasState + { + alignas(kCacheLineSize) LearnFloatType biases_diff_[kHalfDimensions]; + + ThreadBiasState() { reset(); } + + ThreadBiasState& operator+=(const ThreadBiasState& other) + { + for (IndexType i = 0; i < kHalfDimensions; ++i) + { + biases_diff_[i] += other.biases_diff_[i]; + } + + return *this; + } + + void reset() + { + std::fill(std::begin(biases_diff_), std::end(biases_diff_), 0.0f); + } + }; + + std::vector> thread_stat_states_; + std::vector> thread_bias_states_; }; } // namespace Eval::NNUE