mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-21 13:11:27 +00:00
detach
This commit is contained in:
parent
47d6493778
commit
704866a240
@ -26,15 +26,10 @@ class PolicyLoss(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
skip = False
|
skip = False
|
||||||
if action_mask is None:
|
if action_mask is None:
|
||||||
ratio_ = (log_probs - old_log_probs).exp()
|
ratio = (log_probs - log_probs.detach()).exp()
|
||||||
else:
|
else:
|
||||||
ratio_ = ((log_probs - old_log_probs) * action_mask).exp()
|
ratio = ((log_probs - log_probs.detach()) * action_mask).exp()
|
||||||
|
|
||||||
# note that if dropout is disabled (recommanded), ratio will always be 1.
|
|
||||||
if ratio_.mean() > self.skip_threshold:
|
|
||||||
skip = True
|
|
||||||
|
|
||||||
ratio = ratio_.clamp(0.0, 10.0)
|
|
||||||
surr1 = ratio * advantages
|
surr1 = ratio * advantages
|
||||||
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
||||||
loss = -torch.min(surr1, surr2) + self.beta * per_token_kl
|
loss = -torch.min(surr1, surr2) + self.beta * per_token_kl
|
||||||
@ -44,4 +39,4 @@ class PolicyLoss(nn.Module):
|
|||||||
else:
|
else:
|
||||||
loss = loss.mean(dim=1)
|
loss = loss.mean(dim=1)
|
||||||
loss = loss.mean()
|
loss = loss.mean()
|
||||||
return loss, skip, ratio_.max()
|
return loss, skip, ratio.max()
|
||||||
|
Loading…
Reference in New Issue
Block a user