This commit is contained in:
YeAnbang
2025-03-19 17:07:20 +08:00
parent 7795d4c50d
commit 7ee4452f8c
5 changed files with 172 additions and 24 deletions

View File

@@ -23,6 +23,7 @@ class PolicyLoss(nn.Module):
advantages: torch.Tensor,
per_token_kl: torch.Tensor,
action_mask: Optional[torch.Tensor] = None,
loss_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
skip = False
if action_mask is None:
@@ -38,5 +39,7 @@ class PolicyLoss(nn.Module):
loss = masked_mean(loss, action_mask)
else:
loss = loss.mean(dim=1)
if loss_mask is not None:
loss = loss * loss_mask
loss = loss.mean()
return loss, skip, ratio.max()