mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
add orpo
This commit is contained in:
@@ -179,3 +179,28 @@ class LogExpLoss(nn.Module):
|
||||
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
|
||||
loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean()
|
||||
return loss
|
||||
|
||||
|
||||
class OddsRatioLoss(nn.Module):
|
||||
"""
|
||||
Odds Ratio Loss in ORPO
|
||||
Details: https://arxiv.org/pdf/2403.07691
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
chosen_logp: torch.Tensor,
|
||||
reject_logp: torch.Tensor,
|
||||
chosen_loss_mask: torch.Tensor,
|
||||
reject_loss_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
chosen_logp = chosen_logp.to(dtype=torch.float32)
|
||||
reject_logp = reject_logp.to(dtype=torch.float32)
|
||||
chosen_odds = chosen_logp - torch.log(-torch.exp(chosen_logp) + 1.0001)
|
||||
chosen_odds_masked = torch.sum(chosen_odds * chosen_loss_mask.float()) / torch.sum(chosen_loss_mask)
|
||||
reject_odds = reject_logp - torch.log(-torch.exp(reject_logp) + 1.0001)
|
||||
reject_odds_masked = torch.sum(reject_odds * reject_loss_mask.float()) / torch.sum(reject_loss_mask)
|
||||
# print("chosen_odds_masked", chosen_odds_masked[0], "reject_odds_masked", reject_odds_masked[0])
|
||||
log_odds_ratio = chosen_odds_masked - reject_odds_masked
|
||||
ratio = torch.log(torch.nn.functional.sigmoid(log_odds_ratio))
|
||||
return ratio.to(dtype=torch.bfloat16), log_odds_ratio
|
||||
|
Reference in New Issue
Block a user