mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-21 05:04:47 +00:00
detach
This commit is contained in:
parent
47d6493778
commit
704866a240
@ -26,15 +26,10 @@ class PolicyLoss(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
skip = False
|
||||
if action_mask is None:
|
||||
ratio_ = (log_probs - old_log_probs).exp()
|
||||
ratio = (log_probs - log_probs.detach()).exp()
|
||||
else:
|
||||
ratio_ = ((log_probs - old_log_probs) * action_mask).exp()
|
||||
ratio = ((log_probs - log_probs.detach()) * action_mask).exp()
|
||||
|
||||
# note that if dropout is disabled (recommanded), ratio will always be 1.
|
||||
if ratio_.mean() > self.skip_threshold:
|
||||
skip = True
|
||||
|
||||
ratio = ratio_.clamp(0.0, 10.0)
|
||||
surr1 = ratio * advantages
|
||||
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
||||
loss = -torch.min(surr1, surr2) + self.beta * per_token_kl
|
||||
@ -44,4 +39,4 @@ class PolicyLoss(nn.Module):
|
||||
else:
|
||||
loss = loss.mean(dim=1)
|
||||
loss = loss.mean()
|
||||
return loss, skip, ratio_.max()
|
||||
return loss, skip, ratio.max()
|
||||
|
Loading…
Reference in New Issue
Block a user