mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
add SimPO
This commit is contained in:
@@ -88,11 +88,22 @@ class DpoLoss(nn.Module):
|
||||
"""
|
||||
Dpo loss
|
||||
Details: https://arxiv.org/pdf/2305.18290.pdf
|
||||
|
||||
SimPO loss:
|
||||
Details: https://arxiv.org/pdf/2405.14734.pdf
|
||||
"""
|
||||
|
||||
def __init__(self, beta: float = 0.1):
|
||||
def __init__(self, beta: float = 0.1, gamma: float = 0.0):
|
||||
"""
|
||||
Args:
|
||||
beta: The temperature parameter in the DPO paper.
|
||||
gamma: The margin parameter in the SimPO paper.
|
||||
length_normalization: Whether to normalize the loss by the length of chosen and rejected responses.
|
||||
Refer to the length normalization in the SimPO paper
|
||||
"""
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
self.gamma = gamma
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -103,7 +114,7 @@ class DpoLoss(nn.Module):
|
||||
chosen_mask: torch.Tensor,
|
||||
reject_mask: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Compute the DPO loss for a batch of policy and reference model log probabilities.
|
||||
"""Compute the DPO/SimPO loss for a batch of policy and reference model log probabilities.
|
||||
|
||||
# adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py#L328
|
||||
|
||||
@@ -112,6 +123,8 @@ class DpoLoss(nn.Module):
|
||||
logprob_actor_reject: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
|
||||
logprob_ref_chosen: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
|
||||
logprob_ref_reject: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
|
||||
chosen_mask: Mask tensor indicating which responses were chosen. Shape: (batch_size,)
|
||||
reject_mask: Mask tensor indicating which responses were rejected. Shape: (batch_size,)
|
||||
|
||||
Returns:
|
||||
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
|
||||
@@ -126,13 +139,12 @@ class DpoLoss(nn.Module):
|
||||
if len(logprob_ref_chosen.shape) == 2:
|
||||
ref_logratios = logprob_ref_chosen.sum(-1) - logprob_ref_reject.sum(-1)
|
||||
else:
|
||||
ref_logratios = logprob_ref_chosen.squeeze() - logprob_ref_reject.squeeze()
|
||||
ref_logratios = logprob_ref_chosen - logprob_ref_reject
|
||||
else:
|
||||
# If no reference model is provided
|
||||
ref_logratios = 0.0
|
||||
|
||||
pi_logratios = logprob_actor_chosen.sum(-1) - logprob_actor_reject.sum(-1)
|
||||
logits = pi_logratios - ref_logratios
|
||||
logits = pi_logratios - ref_logratios - self.gamma / self.beta
|
||||
losses = -torch.nn.functional.logsigmoid(self.beta * logits)
|
||||
|
||||
# Calculate rewards for logging
|
||||
|
Reference in New Issue
Block a user