mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-28 03:43:01 +00:00
* update help information * update style * fix * minor fix * support PP training * add pp support * remove unused code * address conversation * fix memory leakage support tp+pp * move empty cache * move empty cache * add DAPO support * remove format reward * fix filtering, still buggy * small fix * add DAPO support * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tested multi-node training; fix bind_batch bug * fix conversation; support sleep mode * support reusing excessive samples * add dynamic batching control flag * add dynamic batching control flag * refactored * fix logging --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
69 lines
2.3 KiB
Python
69 lines
2.3 KiB
Python
from typing import Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from coati.distributed.utils import masked_mean, masked_sum
|
|
|
|
|
|
class PolicyLoss(nn.Module):
|
|
"""
|
|
Policy Loss for PPO
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
clip_eps_low: float = 0.2,
|
|
clip_eps_high: float = 0.2,
|
|
beta: float = 0.01,
|
|
loss_variation: str = "sample_level",
|
|
) -> None:
|
|
super().__init__()
|
|
self.clip_eps_low = clip_eps_low
|
|
self.clip_eps_high = clip_eps_high
|
|
self.beta = beta
|
|
self.loss_variation = loss_variation
|
|
assert loss_variation in ["sample_level", "token_level"], f"Unsupported loss variation: {loss_variation}"
|
|
|
|
def forward(
|
|
self,
|
|
log_probs: torch.Tensor,
|
|
old_log_probs: torch.Tensor,
|
|
advantages: torch.Tensor,
|
|
per_token_kl: torch.Tensor,
|
|
action_mask: Optional[torch.Tensor] = None,
|
|
loss_mask: Optional[torch.Tensor] = None,
|
|
total_effective_tokens_in_batch: torch.Tensor = None,
|
|
) -> torch.Tensor:
|
|
if action_mask is None:
|
|
ratio = (log_probs - log_probs.detach()).exp()
|
|
else:
|
|
ratio = ((log_probs - log_probs.detach()) * action_mask).exp()
|
|
|
|
surr1 = ratio * advantages
|
|
surr2 = ratio.clamp(1 - self.clip_eps_low, 1 + self.clip_eps_high) * advantages
|
|
if self.beta == 0:
|
|
# skip kl term if kl coefficient is zero
|
|
per_token_kl = 0.0
|
|
loss = -torch.min(surr1, surr2) + self.beta * per_token_kl
|
|
|
|
if self.loss_variation == "sample_level":
|
|
if action_mask is not None:
|
|
loss = masked_mean(loss, action_mask)
|
|
else:
|
|
loss = loss.mean(dim=1)
|
|
if loss_mask is not None:
|
|
loss = loss * loss_mask
|
|
loss = loss.mean()
|
|
elif self.loss_variation == "token_level":
|
|
if action_mask is not None:
|
|
loss = masked_sum(loss, action_mask)
|
|
else:
|
|
loss = loss.sum(dim=1)
|
|
if loss_mask is not None:
|
|
loss = loss * loss_mask
|
|
loss = loss.sum() / (total_effective_tokens_in_batch + 1e-8)
|
|
else:
|
|
raise ValueError(f"Unsupported loss variation: {self.loss_variation}")
|
|
|
|
return loss, ratio.max()
|