Apply gradient clipping.

This commit is contained in:
Tomasz Sobczyk
2020-11-29 17:57:06 +01:00
committed by nodchip
parent d103867558
commit a5c20bee5b
2 changed files with 7 additions and 1 deletions

View File

@@ -7,6 +7,7 @@
#include <memory>
#include <tuple>
#include <optional>
#include <algorithm>
namespace Learner
{
@@ -48,6 +49,11 @@ namespace Learner
{
return { std::abs(value), std::abs(grad) };
}
ValueWithGrad clamp_grad(T max) const
{
return { value, std::clamp(grad, -max, max) };
}
};
}

View File

@@ -230,7 +230,7 @@ namespace Learner
auto loss_ = result_ - entropy_;
auto args = std::tuple((double)shallow, (double)teacher_signal, (double)result, calculate_lambda(teacher_signal));
return loss_.eval(args);
return loss_.eval(args).clamp_grad(max_grad);
}
static auto get_loss(