diff --git a/src/learn/autograd.h b/src/learn/autograd.h index 2b0eee3a..afbcc41b 100644 --- a/src/learn/autograd.h +++ b/src/learn/autograd.h @@ -7,6 +7,7 @@ #include #include #include +#include namespace Learner { @@ -48,6 +49,11 @@ namespace Learner { return { std::abs(value), std::abs(grad) }; } + + ValueWithGrad clamp_grad(T max) const + { + return { value, std::clamp(grad, -max, max) }; + } }; } diff --git a/src/learn/learn.cpp b/src/learn/learn.cpp index 83229c61..0b04d034 100644 --- a/src/learn/learn.cpp +++ b/src/learn/learn.cpp @@ -230,7 +230,7 @@ namespace Learner auto loss_ = result_ - entropy_; auto args = std::tuple((double)shallow, (double)teacher_signal, (double)result, calculate_lambda(teacher_signal)); - return loss_.eval(args); + return loss_.eval(args).clamp_grad(max_grad); } static auto get_loss(