Prepare trainer affine transform.

This commit is contained in:
Tomasz Sobczyk
2020-11-22 19:23:30 +01:00
committed by nodchip
parent 4ea8572b6d
commit 0d4b803b08

View File

@@ -91,19 +91,52 @@ namespace Eval::NNUE {
quantize_parameters();
}
// forward propagation
const LearnFloatType* propagate(ThreadPool& thread_pool, const std::vector<Example>& batch) {
if (output_.size() < kOutputDimensions * batch.size()) {
output_.resize(kOutputDimensions * batch.size());
gradients_.resize(kInputDimensions * batch.size());
const LearnFloatType* step_start(ThreadPool& thread_pool, const std::vector<Example>& combined_batch)
{
if (output_.size() < kOutputDimensions * combined_batch.size()) {
output_.resize(kOutputDimensions * combined_batch.size());
gradients_.resize(kInputDimensions * combined_batch.size());
}
batch_size_ = static_cast<IndexType>(batch.size());
batch_input_ = previous_layer_trainer_->propagate(thread_pool, batch);
if (thread_states_.size() < thread_pool.size())
{
thread_states_.resize(thread_pool.size());
}
combined_batch_size_ = static_cast<IndexType>(combined_batch.size());
combined_batch_input_ = previous_layer_trainer_->step_start(thread_pool, combined_batch);
auto& main_thread_state = thread_states_[0];
#if defined(USE_BLAS)
for (IndexType b = 0; b < batch_size_; ++b) {
// update
cblas_sscal(
kOutputDimensions, momentum_, main_thread_state.biases_diff_, 1
);
#else
Blas::sscal(
kOutputDimensions, momentum_, main_thread_state.biases_diff_, 1
);
#endif
for (IndexType i = 1; i < thread_states_.size(); ++i)
thread_states_[i].reset_biases();
return output_.data();
}
// forward propagation
void propagate(Thread& th, const uint64_t offset, const uint64_t count) {
previous_layer_trainer_->propagate(th, offset, count);
#if defined(USE_BLAS)
for (IndexType b = offset; b < offset + count; ++b) {
const IndexType batch_offset = kOutputDimensions * b;
cblas_scopy(
kOutputDimensions, biases_, 1, &output_[batch_offset], 1
@@ -112,149 +145,151 @@ namespace Eval::NNUE {
cblas_sgemm(
CblasColMajor, CblasTrans, CblasNoTrans,
kOutputDimensions, batch_size_, kInputDimensions,
kOutputDimensions, count, kInputDimensions,
1.0,
weights_, kInputDimensions,
batch_input_, kInputDimensions,
combined_batch_input_ + offset * kInputDimensions, kInputDimensions,
1.0,
&output_[0], kOutputDimensions
&output_[offset * kOutputDimensions], kOutputDimensions
);
#else
for (IndexType b = 0; b < batch_size_; ++b) {
for (IndexType b = offset; b < offset + count; ++b) {
const IndexType batch_offset = kOutputDimensions * b;
Blas::scopy(
thread_pool,
kOutputDimensions, biases_, 1, &output_[batch_offset], 1
);
}
Blas::sgemm(
thread_pool,
Blas::MatrixLayout::ColMajor, Blas::MatrixTranspose::Trans, Blas::MatrixTranspose::NoTrans,
kOutputDimensions, batch_size_, kInputDimensions,
kOutputDimensions, count, kInputDimensions,
1.0,
weights_, kInputDimensions,
batch_input_, kInputDimensions,
combined_batch_input_ + offset * kInputDimensions, kInputDimensions,
1.0,
&output_[0], kOutputDimensions
&output_[offset * kOutputDimensions], kOutputDimensions
);
#endif
return output_.data();
}
// backpropagation
void backpropagate(ThreadPool& thread_pool,
void backpropagate(Thread& th,
const LearnFloatType* gradients,
LearnFloatType learning_rate) {
const LearnFloatType local_learning_rate =
learning_rate * learning_rate_scale_;
uint64_t offset,
uint64_t count) {
auto& thread_state = thread_states_[th.thread_idx()];
const auto momentum = th.thread_idx() == 0 ? momentum_ : 0.0f;
#if defined(USE_BLAS)
cblas_sgemm(
CblasColMajor, CblasNoTrans, CblasNoTrans,
kInputDimensions, batch_size_, kOutputDimensions,
kInputDimensions, count, kOutputDimensions,
1.0,
weights_, kInputDimensions,
gradients, kOutputDimensions,
gradients + offset * kOutputDimensions, kOutputDimensions,
0.0,
&gradients_[0], kInputDimensions
&gradients_[offset * kInputDimensions], kInputDimensions
);
// update
cblas_sscal(
kOutputDimensions, momentum_, biases_diff_, 1
);
for (IndexType b = 0; b < batch_size_; ++b) {
for (IndexType b = offset; b < offset + count; ++b) {
const IndexType batch_offset = kOutputDimensions * b;
cblas_saxpy(
kOutputDimensions, 1.0,
&gradients[batch_offset], 1, biases_diff_, 1
&gradients[batch_offset], 1, thread_state.biases_diff_, 1
);
}
cblas_sgemm(
CblasRowMajor, CblasTrans, CblasNoTrans,
kOutputDimensions, kInputDimensions, batch_size_,
kOutputDimensions, kInputDimensions, count,
1.0,
gradients, kOutputDimensions,
batch_input_, kInputDimensions,
momentum_,
weights_diff_, kInputDimensions
gradients + offset * kOutputDimensions, kOutputDimensions,
combined_batch_input_ + offset * kInputDimensions, kInputDimensions,
momentum,
thread_state.weights_diff_, kInputDimensions
);
#else
// backpropagate
Blas::sgemm(
thread_pool,
Blas::MatrixLayout::ColMajor, Blas::MatrixTranspose::NoTrans, Blas::MatrixTranspose::NoTrans,
kInputDimensions, batch_size_, kOutputDimensions,
kInputDimensions, count, kOutputDimensions,
1.0,
weights_, kInputDimensions,
gradients, kOutputDimensions,
gradients + offset * kOutputDimensions, kOutputDimensions,
0.0,
&gradients_[0], kInputDimensions
&gradients_[offset * kInputDimensions], kInputDimensions
);
Blas::sscal(
thread_pool,
kOutputDimensions, momentum_, biases_diff_, 1
);
for (IndexType b = 0; b < batch_size_; ++b) {
for (IndexType b = offset; b < offset + count; ++b) {
const IndexType batch_offset = kOutputDimensions * b;
Blas::saxpy(thread_pool, kOutputDimensions, 1.0,
&gradients[batch_offset], 1, biases_diff_, 1);
Blas::saxpy(kOutputDimensions, 1.0,
&gradients[batch_offset], 1, thread_state.biases_diff_, 1);
}
Blas::sgemm(
thread_pool,
Blas::MatrixLayout::RowMajor, Blas::MatrixTranspose::Trans, Blas::MatrixTranspose::NoTrans,
kOutputDimensions, kInputDimensions, batch_size_,
kOutputDimensions, kInputDimensions, count,
1.0,
gradients, kOutputDimensions,
batch_input_, kInputDimensions,
momentum_,
weights_diff_, kInputDimensions
gradients + offset * kOutputDimensions, kOutputDimensions,
combined_batch_input_ + offset * kInputDimensions, kInputDimensions,
momentum,
thread_state.weights_diff_, kInputDimensions
);
#endif
previous_layer_trainer_->backpropagate(th, gradients_.data(), offset, count);
}
void reduce_thread_state()
{
for (IndexType i = 1; i < thread_states_.size(); ++i)
{
thread_states_[0] += thread_states_[i];
}
}
void step_end(ThreadPool& thread_pool, LearnFloatType learning_rate)
{
const LearnFloatType local_learning_rate =
learning_rate * learning_rate_scale_;
reduce_thread_state();
auto& main_thread_state = thread_states_[0];
for (IndexType i = 0; i < kOutputDimensions; ++i) {
const double d = local_learning_rate * biases_diff_[i];
const double d = local_learning_rate * main_thread_state.biases_diff_[i];
biases_[i] -= d;
abs_biases_diff_sum_ += std::abs(d);
}
num_biases_diffs_ += kOutputDimensions;
for (IndexType i = 0; i < kOutputDimensions * kInputDimensions; ++i) {
const double d = local_learning_rate * weights_diff_[i];
const double d = local_learning_rate * main_thread_state.weights_diff_[i];
weights_[i] -= d;
abs_weights_diff_sum_ += std::abs(d);
}
num_weights_diffs_ += kOutputDimensions * kInputDimensions;
previous_layer_trainer_->backpropagate(thread_pool, gradients_.data(), learning_rate);
previous_layer_trainer_->step_end(thread_pool, learning_rate);
}
private:
// constructor
Trainer(LayerType* target_layer, FeatureTransformer* ft) :
batch_size_(0),
batch_input_(nullptr),
combined_batch_size_(0),
combined_batch_input_(nullptr),
previous_layer_trainer_(Trainer<PreviousLayer>::create(
&target_layer->previous_layer_, ft)),
target_layer_(target_layer),
biases_(),
weights_(),
biases_diff_(),
weights_diff_(),
momentum_(0.2),
learning_rate_scale_(1.0) {
@@ -335,10 +370,12 @@ namespace Eval::NNUE {
}
}
std::fill(std::begin(biases_diff_), std::end(biases_diff_),
static_cast<LearnFloatType>(0.0));
std::fill(std::begin(weights_diff_), std::end(weights_diff_),
static_cast<LearnFloatType>(0.0));
for (auto& state : thread_states_)
{
state.reset_weights();
state.reset_biases();
}
reset_stats();
}
@@ -365,7 +402,7 @@ namespace Eval::NNUE {
std::numeric_limits<typename LayerType::WeightType>::max() / kWeightScale;
// number of samples in mini-batch
IndexType batch_size_;
IndexType combined_batch_size_;
double abs_biases_diff_sum_;
double abs_weights_diff_sum_;
@@ -373,7 +410,7 @@ namespace Eval::NNUE {
uint64_t num_weights_diffs_;
// Input mini batch
const LearnFloatType* batch_input_;
const LearnFloatType* combined_batch_input_;
// Trainer of the previous layer
const std::shared_ptr<Trainer<PreviousLayer>> previous_layer_trainer_;
@@ -382,12 +419,44 @@ namespace Eval::NNUE {
LayerType* const target_layer_;
// parameter
struct alignas(kCacheLineSize) ThreadState
{
// Buffer used for updating parameters
alignas(kCacheLineSize) LearnFloatType biases_diff_[kOutputDimensions];
alignas(kCacheLineSize) LearnFloatType weights_diff_[kOutputDimensions * kInputDimensions];
ThreadState() { reset_weights(); reset_biases(); }
ThreadState& operator+=(const ThreadState& other)
{
for (IndexType i = 0; i < kOutputDimensions; ++i)
{
biases_diff_[i] += other.biases_diff_[i];
}
for (IndexType i = 0; i < kOutputDimensions * kInputDimensions; ++i)
{
weights_diff_[i] += other.weights_diff_[i];
}
return *this;
}
void reset_weights()
{
std::fill(std::begin(weights_diff_), std::end(weights_diff_), 0.0f);
}
void reset_biases()
{
std::fill(std::begin(biases_diff_), std::end(biases_diff_), 0.0f);
}
};
alignas(kCacheLineSize) LearnFloatType biases_[kOutputDimensions];
alignas(kCacheLineSize) LearnFloatType weights_[kOutputDimensions * kInputDimensions];
// Buffer used for updating parameters
alignas(kCacheLineSize) LearnFloatType biases_diff_[kOutputDimensions];
alignas(kCacheLineSize) LearnFloatType weights_diff_[kOutputDimensions * kInputDimensions];
std::vector<ThreadState, CacheLineAlignedAllocator<ThreadState>> thread_states_;
// Forward propagation buffer
std::vector<LearnFloatType, CacheLineAlignedAllocator<LearnFloatType>> output_;