From 75c53890378d3e72b4700a264f52524d4185168a Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Tue, 1 Aug 2023 10:21:45 +0800 Subject: [PATCH] [chat] fix compute_approx_kl (#4338) --- applications/Chat/coati/models/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/Chat/coati/models/utils.py b/applications/Chat/coati/models/utils.py index 772bfc329..8769fb7a8 100644 --- a/applications/Chat/coati/models/utils.py +++ b/applications/Chat/coati/models/utils.py @@ -19,7 +19,7 @@ def compute_approx_kl(log_probs: torch.Tensor, action_mask: Mask for actions. """ - log_ratio = log_probs - log_probs_base + log_ratio = log_probs_base - log_probs approx_kl = (log_ratio.exp() - 1) - log_ratio if action_mask is not None: approx_kl = masked_mean(approx_kl, action_mask, dim=1)