mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
Support overall loss, update KTO logging
This commit is contained in:
@@ -46,7 +46,10 @@ class PolicyLoss(nn.Module):
|
||||
action_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
skip = False
|
||||
ratio_ = ((log_probs - old_log_probs) * action_mask).exp()
|
||||
if action_mask is None:
|
||||
ratio_ = (log_probs - old_log_probs).exp()
|
||||
else:
|
||||
ratio_ = ((log_probs - old_log_probs) * action_mask).exp()
|
||||
|
||||
# note that if dropout is disabled (recommanded), ratio will always be 1.
|
||||
if ratio_.mean() > self.skip_threshold:
|
||||
@@ -56,7 +59,10 @@ class PolicyLoss(nn.Module):
|
||||
surr1 = ratio * advantages
|
||||
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
||||
loss = -torch.min(surr1, surr2)
|
||||
loss = masked_mean(loss, action_mask)
|
||||
if action_mask is not None:
|
||||
loss = masked_mean(loss, action_mask)
|
||||
else:
|
||||
loss = loss.mean(dim=1)
|
||||
loss = loss.mean()
|
||||
return loss, skip, ratio_.max()
|
||||
|
||||
@@ -81,8 +87,10 @@ class ValueLoss(nn.Module):
|
||||
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
|
||||
surr1 = (values_clipped - returns) ** 2
|
||||
surr2 = (values - returns) ** 2
|
||||
loss = torch.max(surr1, surr2) / torch.sum(action_mask)
|
||||
loss = torch.sum(loss * action_mask)
|
||||
if action_mask is not None:
|
||||
loss = torch.sum(torch.max(surr1, surr2) / torch.sum(action_mask) * action_mask)
|
||||
else:
|
||||
loss = torch.mean(torch.max(surr1, surr2))
|
||||
return 0.5 * loss
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user