Refactor Stats Array

* Limit use of `StatsEntry` wrapper to arithmetic types
* Generalize `Stats` to `MultiArray` by discarding the template parameter `D`
* Allow `MultiArray::fill` to take any type assignable to element type
* Remove now-unused operator overloads on `StatsEntry`

closes https://github.com/official-stockfish/Stockfish/pull/5750

No functional change
This commit is contained in:
Shawn Xu
2025-01-05 17:11:22 -08:00
committed by Disservin
parent c47e6fcf84
commit 5370c3035e
4 changed files with 117 additions and 72 deletions

View File

@@ -28,6 +28,7 @@
#include <limits>
#include <type_traits> // IWYU pragma: keep
#include "misc.h"
#include "position.h"
namespace Stockfish {
@@ -66,13 +67,17 @@ inline int non_pawn_index(const Position& pos) {
return pos.non_pawn_key(c) & (CORRECTION_HISTORY_SIZE - 1);
}
// StatsEntry stores the stat table value. It is usually a number but could
// be a move or even a nested history. We use a class instead of a naked value
// to directly call history update operator<<() on the entry so to use stats
// tables at caller sites as simple multi-dim arrays.
// StatsEntry is the container of various numerical statistics. We use a class
// instead of a naked value to directly call history update operator<<() on
// the entry. The first template parameter T is the base type of the array,
// and the second template parameter D limits the range of updates in [-D, D]
// when we update values with the << operator
template<typename T, int D>
class StatsEntry {
static_assert(std::is_arithmetic<T>::value, "Not an arithmetic type");
static_assert(D <= std::numeric_limits<T>::max(), "D overflows T");
T entry;
public:
@@ -80,13 +85,9 @@ class StatsEntry {
entry = v;
return *this;
}
T* operator&() { return &entry; }
T* operator->() { return &entry; }
operator const T&() const { return entry; }
void operator<<(int bonus) {
static_assert(D <= std::numeric_limits<T>::max(), "D overflows T");
// Make sure that bonus is in range [-D, D]
int clampedBonus = std::clamp(bonus, -D, D);
entry += clampedBonus - entry * std::abs(clampedBonus) / D;
@@ -95,87 +96,39 @@ class StatsEntry {
}
};
template<typename T, int D, std::size_t Size, std::size_t... Sizes>
struct StatsHelper;
// Stats is a generic N-dimensional array used to store various statistics.
// The first template parameter T is the base type of the array, and the second
// template parameter D limits the range of updates in [-D, D] when we update
// values with the << operator, while the last parameters (Size and Sizes)
// encode the dimensions of the array.
template<typename T, int D, std::size_t Size, std::size_t... Sizes>
class Stats {
using child_type = typename StatsHelper<T, D, Size, Sizes...>::child_type;
using array_type = std::array<child_type, Size>;
array_type data;
public:
using size_type = typename array_type::size_type;
auto& operator[](size_type index) { return data[index]; }
const auto& operator[](size_type index) const { return data[index]; }
auto begin() { return data.begin(); }
auto end() { return data.end(); }
auto begin() const { return data.cbegin(); }
auto end() const { return data.cend(); }
auto cbegin() const { return data.cbegin(); }
auto cend() const { return data.cend(); }
void fill(const T& v) {
for (auto& ele : data)
{
if constexpr (sizeof...(Sizes) == 0)
ele = v;
else
ele.fill(v);
}
}
};
template<typename T, int D, std::size_t Size, std::size_t... Sizes>
struct StatsHelper {
using child_type = Stats<T, D, Sizes...>;
};
template<typename T, int D, std::size_t Size>
struct StatsHelper<T, D, Size> {
using child_type = StatsEntry<T, D>;
};
// In stats table, D=0 means that the template parameter is not used
enum StatsParams {
NOT_USED = 0
};
enum StatsType {
NoCaptures,
Captures
};
template<typename T, int D, std::size_t... Sizes>
using Stats = MultiArray<StatsEntry<T, D>, Sizes...>;
// ButterflyHistory records how often quiet moves have been successful or unsuccessful
// during the current search, and is used for reduction and move ordering decisions.
// It uses 2 tables (one for each color) indexed by the move's from and to squares,
// see https://www.chessprogramming.org/Butterfly_Boards (~11 elo)
using ButterflyHistory = Stats<int16_t, 7183, COLOR_NB, int(SQUARE_NB) * int(SQUARE_NB)>;
using ButterflyHistory = Stats<std::int16_t, 7183, COLOR_NB, int(SQUARE_NB) * int(SQUARE_NB)>;
// LowPlyHistory is adressed by play and move's from and to squares, used
// to improve move ordering near the root
using LowPlyHistory = Stats<int16_t, 7183, LOW_PLY_HISTORY_SIZE, int(SQUARE_NB) * int(SQUARE_NB)>;
using LowPlyHistory =
Stats<std::int16_t, 7183, LOW_PLY_HISTORY_SIZE, int(SQUARE_NB) * int(SQUARE_NB)>;
// CapturePieceToHistory is addressed by a move's [piece][to][captured piece type]
using CapturePieceToHistory = Stats<int16_t, 10692, PIECE_NB, SQUARE_NB, PIECE_TYPE_NB>;
using CapturePieceToHistory = Stats<std::int16_t, 10692, PIECE_NB, SQUARE_NB, PIECE_TYPE_NB>;
// PieceToHistory is like ButterflyHistory but is addressed by a move's [piece][to]
using PieceToHistory = Stats<int16_t, 30000, PIECE_NB, SQUARE_NB>;
using PieceToHistory = Stats<std::int16_t, 30000, PIECE_NB, SQUARE_NB>;
// ContinuationHistory is the combined history of a given pair of moves, usually
// the current one given a previous one. The nested history table is based on
// PieceToHistory instead of ButterflyBoards.
// (~63 elo)
using ContinuationHistory = Stats<PieceToHistory, NOT_USED, PIECE_NB, SQUARE_NB>;
using ContinuationHistory = MultiArray<PieceToHistory, PIECE_NB, SQUARE_NB>;
// PawnHistory is addressed by the pawn structure and a move's [piece][to]
using PawnHistory = Stats<int16_t, 8192, PAWN_HISTORY_SIZE, PIECE_NB, SQUARE_NB>;
using PawnHistory = Stats<std::int16_t, 8192, PAWN_HISTORY_SIZE, PIECE_NB, SQUARE_NB>;
// Correction histories record differences between the static evaluation of
// positions and their search score. It is used to improve the static evaluation
@@ -190,23 +143,27 @@ enum CorrHistType {
Continuation, // Combined history of move pairs
};
namespace Detail {
template<CorrHistType _>
struct CorrHistTypedef {
using type = Stats<int16_t, CORRECTION_HISTORY_LIMIT, COLOR_NB, CORRECTION_HISTORY_SIZE>;
using type = Stats<std::int16_t, CORRECTION_HISTORY_LIMIT, COLOR_NB, CORRECTION_HISTORY_SIZE>;
};
template<>
struct CorrHistTypedef<PieceTo> {
using type = Stats<int16_t, CORRECTION_HISTORY_LIMIT, PIECE_NB, SQUARE_NB>;
using type = Stats<std::int16_t, CORRECTION_HISTORY_LIMIT, PIECE_NB, SQUARE_NB>;
};
template<>
struct CorrHistTypedef<Continuation> {
using type = Stats<CorrHistTypedef<PieceTo>::type, NOT_USED, PIECE_NB, SQUARE_NB>;
using type = MultiArray<CorrHistTypedef<PieceTo>::type, PIECE_NB, SQUARE_NB>;
};
}
template<CorrHistType T>
using CorrectionHistory = typename CorrHistTypedef<T>::type;
using CorrectionHistory = typename Detail::CorrHistTypedef<T>::type;
} // namespace Stockfish

View File

@@ -20,6 +20,7 @@
#define MISC_H_INCLUDED
#include <algorithm>
#include <array>
#include <cassert>
#include <chrono>
#include <cstddef>
@@ -142,6 +143,92 @@ class ValueList {
};
template<typename T, std::size_t Size, std::size_t... Sizes>
class MultiArray;
namespace Detail {
template<typename T, std::size_t Size, std::size_t... Sizes>
struct MultiArrayHelper {
using ChildType = MultiArray<T, Sizes...>;
};
template<typename T, std::size_t Size>
struct MultiArrayHelper<T, Size> {
using ChildType = T;
};
}
// MultiArray is a generic N-dimensional array.
// The template parameters (Size and Sizes) encode the dimensions of the array.
template<typename T, std::size_t Size, std::size_t... Sizes>
class MultiArray {
using ChildType = typename Detail::MultiArrayHelper<T, Size, Sizes...>::ChildType;
using ArrayType = std::array<ChildType, Size>;
ArrayType data_;
public:
using value_type = typename ArrayType::value_type;
using size_type = typename ArrayType::size_type;
using difference_type = typename ArrayType::difference_type;
using reference = typename ArrayType::reference;
using const_reference = typename ArrayType::const_reference;
using pointer = typename ArrayType::pointer;
using const_pointer = typename ArrayType::const_pointer;
using iterator = typename ArrayType::iterator;
using const_iterator = typename ArrayType::const_iterator;
using reverse_iterator = typename ArrayType::reverse_iterator;
using const_reverse_iterator = typename ArrayType::const_reverse_iterator;
constexpr auto& at(size_type index) noexcept { return data_.at(index); }
constexpr const auto& at(size_type index) const noexcept { return data_.at(index); }
constexpr auto& operator[](size_type index) noexcept { return data_[index]; }
constexpr const auto& operator[](size_type index) const noexcept { return data_[index]; }
constexpr auto& front() noexcept { return data_.front(); }
constexpr const auto& front() const noexcept { return data_.front(); }
constexpr auto& back() noexcept { return data_.back(); }
constexpr const auto& back() const noexcept { return data_.back(); }
auto* data() { return data_.data(); }
const auto* data() const { return data_.data(); }
constexpr auto begin() noexcept { return data_.begin(); }
constexpr auto end() noexcept { return data_.end(); }
constexpr auto begin() const noexcept { return data_.begin(); }
constexpr auto end() const noexcept { return data_.end(); }
constexpr auto cbegin() const noexcept { return data_.cbegin(); }
constexpr auto cend() const noexcept { return data_.cend(); }
constexpr auto rbegin() noexcept { return data_.rbegin(); }
constexpr auto rend() noexcept { return data_.rend(); }
constexpr auto rbegin() const noexcept { return data_.rbegin(); }
constexpr auto rend() const noexcept { return data_.rend(); }
constexpr auto crbegin() const noexcept { return data_.crbegin(); }
constexpr auto crend() const noexcept { return data_.crend(); }
constexpr bool empty() const noexcept { return data_.empty(); }
constexpr size_type size() const noexcept { return data_.size(); }
constexpr size_type max_size() const noexcept { return data_.max_size(); }
template<typename U>
void fill(const U& v) {
static_assert(std::is_assignable_v<T, U>, "Cannot assign fill value to entry type");
for (auto& ele : data_)
{
if constexpr (sizeof...(Sizes) == 0)
ele = v;
else
ele.fill(v);
}
}
constexpr void swap(MultiArray<T, Size, Sizes...>& other) noexcept { data_.swap(other.data_); }
};
// xorshift64star Pseudo-Random Number Generator
// This class is based on original code written and dedicated
// to the public domain by Sebastiano Vigna (2014).

View File

@@ -22,6 +22,7 @@
#include <limits>
#include "bitboard.h"
#include "misc.h"
#include "position.h"
namespace Stockfish {

View File

@@ -510,13 +510,13 @@ void Search::Worker::clear() {
for (auto& to : continuationCorrectionHistory)
for (auto& h : to)
h->fill(0);
h.fill(0);
for (bool inCheck : {false, true})
for (StatsType c : {NoCaptures, Captures})
for (auto& to : continuationHistory[inCheck][c])
for (auto& h : to)
h->fill(-427);
h.fill(-427);
for (size_t i = 1; i < reductions.size(); ++i)
reductions[i] = int(19.43 * std::log(i));