diff --git a/applications/ColossalChat/coati/distributed/loss.py b/applications/ColossalChat/coati/distributed/loss.py index 222540b92..af5776731 100644 --- a/applications/ColossalChat/coati/distributed/loss.py +++ b/applications/ColossalChat/coati/distributed/loss.py @@ -26,15 +26,10 @@ class PolicyLoss(nn.Module): ) -> torch.Tensor: skip = False if action_mask is None: - ratio_ = (log_probs - old_log_probs).exp() + ratio = (log_probs - log_probs.detach()).exp() else: - ratio_ = ((log_probs - old_log_probs) * action_mask).exp() + ratio = ((log_probs - log_probs.detach()) * action_mask).exp() - # note that if dropout is disabled (recommanded), ratio will always be 1. - if ratio_.mean() > self.skip_threshold: - skip = True - - ratio = ratio_.clamp(0.0, 10.0) surr1 = ratio * advantages surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages loss = -torch.min(surr1, surr2) + self.beta * per_token_kl @@ -44,4 +39,4 @@ class PolicyLoss(nn.Module): else: loss = loss.mean(dim=1) loss = loss.mean() - return loss, skip, ratio_.max() + return loss, skip, ratio.max()