diff --git a/applications/ColossalChat/coati/distributed/loss.py b/applications/ColossalChat/coati/distributed/loss.py index c08acba51..222540b92 100644 --- a/applications/ColossalChat/coati/distributed/loss.py +++ b/applications/ColossalChat/coati/distributed/loss.py @@ -10,16 +10,18 @@ class PolicyLoss(nn.Module): Policy Loss for PPO """ - def __init__(self, clip_eps: float = 0.2, skip_threshold: float = 20.0) -> None: + def __init__(self, clip_eps: float = 0.2, skip_threshold: float = 20.0, beta: float = 0.01) -> None: super().__init__() self.clip_eps = clip_eps self.skip_threshold = skip_threshold + self.beta = beta def forward( self, log_probs: torch.Tensor, old_log_probs: torch.Tensor, advantages: torch.Tensor, + per_token_kl: torch.Tensor, action_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: skip = False @@ -35,7 +37,8 @@ class PolicyLoss(nn.Module): ratio = ratio_.clamp(0.0, 10.0) surr1 = ratio * advantages surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages - loss = -torch.min(surr1, surr2) + loss = -torch.min(surr1, surr2) + self.beta * per_token_kl + if action_mask is not None: loss = masked_mean(loss, action_mask) else: