mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
fix vllm
This commit is contained in:
@@ -23,6 +23,7 @@ class PolicyLoss(nn.Module):
|
||||
advantages: torch.Tensor,
|
||||
per_token_kl: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None,
|
||||
loss_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
skip = False
|
||||
if action_mask is None:
|
||||
@@ -38,5 +39,7 @@ class PolicyLoss(nn.Module):
|
||||
loss = masked_mean(loss, action_mask)
|
||||
else:
|
||||
loss = loss.mean(dim=1)
|
||||
if loss_mask is not None:
|
||||
loss = loss * loss_mask
|
||||
loss = loss.mean()
|
||||
return loss, skip, ratio.max()
|
||||
|
Reference in New Issue
Block a user