mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
add simple grpo
This commit is contained in:
44
applications/ColossalChat/coati/distributed/loss.py
Normal file
44
applications/ColossalChat/coati/distributed/loss.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from coati.distributed.utils import masked_mean
|
||||
|
||||
|
||||
class PolicyLoss(nn.Module):
|
||||
"""
|
||||
Policy Loss for PPO
|
||||
"""
|
||||
|
||||
def __init__(self, clip_eps: float = 0.2, skip_threshold: float = 20.0) -> None:
|
||||
super().__init__()
|
||||
self.clip_eps = clip_eps
|
||||
self.skip_threshold = skip_threshold
|
||||
|
||||
def forward(
|
||||
self,
|
||||
log_probs: torch.Tensor,
|
||||
old_log_probs: torch.Tensor,
|
||||
advantages: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
skip = False
|
||||
if action_mask is None:
|
||||
ratio_ = (log_probs - old_log_probs).exp()
|
||||
else:
|
||||
ratio_ = ((log_probs - old_log_probs) * action_mask).exp()
|
||||
|
||||
# note that if dropout is disabled (recommanded), ratio will always be 1.
|
||||
if ratio_.mean() > self.skip_threshold:
|
||||
skip = True
|
||||
|
||||
ratio = ratio_.clamp(0.0, 10.0)
|
||||
surr1 = ratio * advantages
|
||||
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
||||
loss = -torch.min(surr1, surr2)
|
||||
if action_mask is not None:
|
||||
loss = masked_mean(loss, action_mask)
|
||||
else:
|
||||
loss = loss.mean(dim=1)
|
||||
loss = loss.mean()
|
||||
return loss, skip, ratio_.max()
|
Reference in New Issue
Block a user