From c1e69f450e3446ea75c22101553bd751554cf4c3 Mon Sep 17 00:00:00 2001 From: Tomasz Sobczyk Date: Thu, 24 Dec 2020 13:25:48 +0100 Subject: [PATCH] Prevent q_ in loss calculation from reaching values that would produce NaN --- src/learn/learn.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/learn/learn.cpp b/src/learn/learn.cpp index 4e70f61c..22578ff3 100644 --- a/src/learn/learn.cpp +++ b/src/learn/learn.cpp @@ -191,8 +191,8 @@ namespace Learner static thread_local auto teacher_entropy_ = -(p_ * log(p_ + epsilon) + (1.0 - p_) * log(1.0 - p_ + epsilon)); static thread_local auto outcome_entropy_ = -(t_ * log(t_ + epsilon) + (1.0 - t_) * log(1.0 - t_ + epsilon)); - static thread_local auto teacher_loss_ = -(p_ * log(q_) + (1.0 - p_) * log(1.0 - q_)); - static thread_local auto outcome_loss_ = -(t_ * log(q_) + (1.0 - t_) * log(1.0 - q_)); + static thread_local auto teacher_loss_ = -(p_ * log(q_ + epsilon) + (1.0 - p_) * log(1.0 - q_ + epsilon)); + static thread_local auto outcome_loss_ = -(t_ * log(q_ + epsilon) + (1.0 - t_) * log(1.0 - q_ + epsilon)); static thread_local auto result_ = lambda_ * teacher_loss_ + (1.0 - lambda_) * outcome_loss_; static thread_local auto entropy_ = lambda_ * teacher_entropy_ + (1.0 - lambda_) * outcome_entropy_; static thread_local auto cross_entropy_ = result_ - entropy_;