mirror of
https://github.com/HChaZZY/Stockfish.git
synced 2025-12-25 11:36:51 +08:00
Avoid computing gradient for validation loss.
This commit is contained in:
@@ -310,7 +310,8 @@ namespace Learner
|
||||
return perf;
|
||||
}
|
||||
|
||||
[[maybe_unused]] static ValueWithGrad<double> get_loss_noob(Value shallow, Value teacher_signal, int result, int /* ply */)
|
||||
[[maybe_unused]] static ValueWithGrad<double> get_loss_noob(
|
||||
Value shallow, Value teacher_signal, int result, int /* ply */)
|
||||
{
|
||||
using namespace Learner::Autograd::UnivariateStatic;
|
||||
|
||||
@@ -328,7 +329,7 @@ namespace Learner
|
||||
return loss_.eval(args);
|
||||
}
|
||||
|
||||
static ValueWithGrad<double> get_loss_cross_entropy(Value shallow, Value teacher_signal, int result, int /* ply */)
|
||||
static auto& get_loss_cross_entropy_()
|
||||
{
|
||||
using namespace Learner::Autograd::UnivariateStatic;
|
||||
|
||||
@@ -338,18 +339,45 @@ namespace Learner
|
||||
static thread_local auto lambda_ = ConstantParameter<double, 3>{};
|
||||
static thread_local auto& loss_ = cross_entropy_(q_, p_, t_, lambda_);
|
||||
|
||||
auto args = std::tuple(
|
||||
return loss_;
|
||||
}
|
||||
|
||||
static auto get_loss_cross_entropy_args(
|
||||
Value shallow, Value teacher_signal, int result)
|
||||
{
|
||||
return std::tuple(
|
||||
(double)shallow,
|
||||
(double)teacher_signal,
|
||||
(double)result,
|
||||
calculate_lambda(teacher_signal)
|
||||
);
|
||||
}
|
||||
|
||||
static ValueWithGrad<double> get_loss_cross_entropy(
|
||||
Value shallow, Value teacher_signal, int result, int /* ply */)
|
||||
{
|
||||
using namespace Learner::Autograd::UnivariateStatic;
|
||||
|
||||
static thread_local auto& loss_ = get_loss_cross_entropy_();
|
||||
|
||||
auto args = get_loss_cross_entropy_args(shallow, teacher_signal, result);
|
||||
|
||||
return loss_.eval(args);
|
||||
}
|
||||
|
||||
static ValueWithGrad<double> get_loss_cross_entropy_use_wdl(
|
||||
Value shallow, Value teacher_signal, int result, int ply)
|
||||
static ValueWithGrad<double> get_loss_cross_entropy_no_grad(
|
||||
Value shallow, Value teacher_signal, int result, int /* ply */)
|
||||
{
|
||||
using namespace Learner::Autograd::UnivariateStatic;
|
||||
|
||||
static thread_local auto& loss_ = get_loss_cross_entropy_();
|
||||
|
||||
auto args = get_loss_cross_entropy_args(shallow, teacher_signal, result);
|
||||
|
||||
return { loss_.value(args), 0.0 };
|
||||
}
|
||||
|
||||
static auto& get_loss_cross_entropy_use_wdl_()
|
||||
{
|
||||
using namespace Learner::Autograd::UnivariateStatic;
|
||||
|
||||
@@ -364,7 +392,13 @@ namespace Learner
|
||||
static thread_local auto lambda_ = ConstantParameter<double, 3>{};
|
||||
static thread_local auto& loss_ = cross_entropy_(q_, p_, t_, lambda_);
|
||||
|
||||
auto args = std::tuple(
|
||||
return loss_;
|
||||
}
|
||||
|
||||
static auto get_loss_cross_entropy_use_wdl_args(
|
||||
Value shallow, Value teacher_signal, int result, int ply)
|
||||
{
|
||||
return std::tuple(
|
||||
(double)shallow,
|
||||
// This is required because otherwise MSVC crashes :(
|
||||
expected_perf_use_wdl(scale_score(teacher_signal), ply),
|
||||
@@ -372,10 +406,32 @@ namespace Learner
|
||||
calculate_lambda(teacher_signal),
|
||||
(double)std::min(240, ply)
|
||||
);
|
||||
}
|
||||
|
||||
static ValueWithGrad<double> get_loss_cross_entropy_use_wdl(
|
||||
Value shallow, Value teacher_signal, int result, int ply)
|
||||
{
|
||||
using namespace Learner::Autograd::UnivariateStatic;
|
||||
|
||||
static thread_local auto& loss_ = get_loss_cross_entropy_use_wdl_();
|
||||
|
||||
auto args = get_loss_cross_entropy_use_wdl_args(shallow, teacher_signal, result, ply);
|
||||
|
||||
return loss_.eval(args);
|
||||
}
|
||||
|
||||
static ValueWithGrad<double> get_loss_cross_entropy_use_wdl_no_grad(
|
||||
Value shallow, Value teacher_signal, int result, int ply)
|
||||
{
|
||||
using namespace Learner::Autograd::UnivariateStatic;
|
||||
|
||||
static thread_local auto& loss_ = get_loss_cross_entropy_use_wdl_();
|
||||
|
||||
auto args = get_loss_cross_entropy_use_wdl_args(shallow, teacher_signal, result, ply);
|
||||
|
||||
return { loss_.value(args), 0.0 };
|
||||
}
|
||||
|
||||
static auto get_loss(Value shallow, Value teacher_signal, int result, int ply)
|
||||
{
|
||||
using namespace Learner::Autograd::UnivariateStatic;
|
||||
@@ -390,7 +446,21 @@ namespace Learner
|
||||
}
|
||||
}
|
||||
|
||||
static auto get_loss(
|
||||
static auto get_loss_no_grad(Value shallow, Value teacher_signal, int result, int ply)
|
||||
{
|
||||
using namespace Learner::Autograd::UnivariateStatic;
|
||||
|
||||
if (use_wdl)
|
||||
{
|
||||
return get_loss_cross_entropy_use_wdl_no_grad(shallow, teacher_signal, result, ply);
|
||||
}
|
||||
else
|
||||
{
|
||||
return get_loss_cross_entropy_no_grad(shallow, teacher_signal, result, ply);
|
||||
}
|
||||
}
|
||||
|
||||
[[maybe_unused]] static auto get_loss(
|
||||
Value teacher_signal,
|
||||
Value shallow,
|
||||
const PackedSfenValue& psv)
|
||||
@@ -398,6 +468,14 @@ namespace Learner
|
||||
return get_loss(shallow, teacher_signal, psv.game_result, psv.gamePly);
|
||||
}
|
||||
|
||||
static auto get_loss_no_grad(
|
||||
Value teacher_signal,
|
||||
Value shallow,
|
||||
const PackedSfenValue& psv)
|
||||
{
|
||||
return get_loss_no_grad(shallow, teacher_signal, psv.game_result, psv.gamePly);
|
||||
}
|
||||
|
||||
// Class to generate sfen with multiple threads
|
||||
struct LearnerThink
|
||||
{
|
||||
@@ -828,11 +906,11 @@ namespace Learner
|
||||
|
||||
if (psv.size() && test_loss_sum.count() > 0)
|
||||
{
|
||||
test_loss_sum.print("val", out);
|
||||
test_loss_sum.print_only_loss("val", out);
|
||||
|
||||
if (learn_loss_sum.count() > 0)
|
||||
{
|
||||
learn_loss_sum.print("train", out);
|
||||
learn_loss_sum.print_with_grad("train", out);
|
||||
}
|
||||
|
||||
out << " - norm = " << sum_norm << endl;
|
||||
@@ -880,7 +958,7 @@ namespace Learner
|
||||
// Evaluation value of deep search
|
||||
const auto deep_value = (Value)ps.score;
|
||||
|
||||
const auto loss = get_loss(
|
||||
const auto loss = get_loss_no_grad(
|
||||
deep_value,
|
||||
shallow_value,
|
||||
ps);
|
||||
|
||||
Reference in New Issue
Block a user