This commit is contained in:
Tong Li 2025-03-11 16:17:02 +08:00
parent 47d6493778
commit 704866a240

View File

@ -26,15 +26,10 @@ class PolicyLoss(nn.Module):
) -> torch.Tensor:
skip = False
if action_mask is None:
ratio_ = (log_probs - old_log_probs).exp()
ratio = (log_probs - log_probs.detach()).exp()
else:
ratio_ = ((log_probs - old_log_probs) * action_mask).exp()
ratio = ((log_probs - log_probs.detach()) * action_mask).exp()
# note that if dropout is disabled (recommanded), ratio will always be 1.
if ratio_.mean() > self.skip_threshold:
skip = True
ratio = ratio_.clamp(0.0, 10.0)
surr1 = ratio * advantages
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
loss = -torch.min(surr1, surr2) + self.beta * per_token_kl
@ -44,4 +39,4 @@ class PolicyLoss(nn.Module):
else:
loss = loss.mean(dim=1)
loss = loss.mean()
return loss, skip, ratio_.max()
return loss, skip, ratio.max()