Replicate network weights only to used NUMA nodes

On a system with multiple NUMA nodes, this patch avoids unneeded replicated
(e.g. 8x for a single threaded run), reducting memory use in that case.

Lazy initialization forced before search.

Passed STC:
https://tests.stockfishchess.org/tests/view/66a28c524ff211be9d4ecdd4
LLR: 2.96 (-2.94,2.94) <-1.75,0.25>
Total: 691776 W: 179429 L: 179927 D: 332420
Ptnml(0-2): 2573, 79370, 182547, 78778, 2620

closes https://github.com/official-stockfish/Stockfish/pull/5515

No functional change
This commit is contained in:
Tomasz Sobczyk
2024-07-25 14:37:08 +02:00
committed by Joost VandeVondele
parent 2343f71f3f
commit 8e560c4fd3
7 changed files with 152 additions and 16 deletions

View File

@@ -27,6 +27,7 @@
#include <limits>
#include <map>
#include <memory>
#include <mutex>
#include <set>
#include <sstream>
#include <string>
@@ -1136,6 +1137,117 @@ class NumaReplicated: public NumaReplicatedBase {
}
};
// We force boxing with a unique_ptr. If this becomes an issue due to added
// indirection we may need to add an option for a custom boxing type.
template<typename T>
class LazyNumaReplicated: public NumaReplicatedBase {
public:
using ReplicatorFuncType = std::function<T(const T&)>;
LazyNumaReplicated(NumaReplicationContext& ctx) :
NumaReplicatedBase(ctx) {
prepare_replicate_from(T{});
}
LazyNumaReplicated(NumaReplicationContext& ctx, T&& source) :
NumaReplicatedBase(ctx) {
prepare_replicate_from(std::move(source));
}
LazyNumaReplicated(const LazyNumaReplicated&) = delete;
LazyNumaReplicated(LazyNumaReplicated&& other) noexcept :
NumaReplicatedBase(std::move(other)),
instances(std::exchange(other.instances, {})) {}
LazyNumaReplicated& operator=(const LazyNumaReplicated&) = delete;
LazyNumaReplicated& operator=(LazyNumaReplicated&& other) noexcept {
NumaReplicatedBase::operator=(*this, std::move(other));
instances = std::exchange(other.instances, {});
return *this;
}
LazyNumaReplicated& operator=(T&& source) {
prepare_replicate_from(std::move(source));
return *this;
}
~LazyNumaReplicated() override = default;
const T& operator[](NumaReplicatedAccessToken token) const {
assert(token.get_numa_index() < instances.size());
ensure_present(token.get_numa_index());
return *(instances[token.get_numa_index()]);
}
const T& operator*() const { return *(instances[0]); }
const T* operator->() const { return instances[0].get(); }
template<typename FuncT>
void modify_and_replicate(FuncT&& f) {
auto source = std::move(instances[0]);
std::forward<FuncT>(f)(*source);
prepare_replicate_from(std::move(*source));
}
void on_numa_config_changed() override {
// Use the first one as the source. It doesn't matter which one we use,
// because they all must be identical, but the first one is guaranteed to exist.
auto source = std::move(instances[0]);
prepare_replicate_from(std::move(*source));
}
private:
mutable std::vector<std::unique_ptr<T>> instances;
mutable std::mutex mutex;
void ensure_present(NumaIndex idx) const {
assert(idx < instances.size());
if (instances[idx] != nullptr)
return;
assert(idx != 0);
std::unique_lock<std::mutex> lock(mutex);
// Check again for races.
if (instances[idx] != nullptr)
return;
const NumaConfig& cfg = get_numa_config();
cfg.execute_on_numa_node(
idx, [this, idx]() { instances[idx] = std::make_unique<T>(*instances[0]); });
}
void prepare_replicate_from(T&& source) {
instances.clear();
const NumaConfig& cfg = get_numa_config();
if (cfg.requires_memory_replication())
{
assert(cfg.num_numa_nodes() > 0);
// We just need to make sure the first instance is there.
// Note that we cannot move here as we need to reallocate the data
// on the correct NUMA node.
cfg.execute_on_numa_node(
0, [this, &source]() { instances.emplace_back(std::make_unique<T>(source)); });
// Prepare others for lazy init.
instances.resize(cfg.num_numa_nodes());
}
else
{
assert(cfg.num_numa_nodes() == 1);
// We take advantage of the fact that replication is not required
// and reuse the source value, avoiding one copy operation.
instances.emplace_back(std::make_unique<T>(std::move(source)));
}
}
};
class NumaReplicationContext {
public:
NumaReplicationContext(NumaConfig&& cfg) :