mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-30 23:56:42 +00:00
[chat] fix compute_approx_kl (#4338)
This commit is contained in:
parent
03654c0ce2
commit
75c5389037
@ -19,7 +19,7 @@ def compute_approx_kl(log_probs: torch.Tensor,
|
||||
action_mask: Mask for actions.
|
||||
"""
|
||||
|
||||
log_ratio = log_probs - log_probs_base
|
||||
log_ratio = log_probs_base - log_probs
|
||||
approx_kl = (log_ratio.exp() - 1) - log_ratio
|
||||
if action_mask is not None:
|
||||
approx_kl = masked_mean(approx_kl, action_mask, dim=1)
|
||||
|
Loading…
Reference in New Issue
Block a user