add SimPO

This commit is contained in:
YeAnbang
2024-06-24 02:12:20 +00:00
parent 84eab13078
commit 82aecd6374
14 changed files with 128 additions and 70 deletions

View File

@@ -88,11 +88,22 @@ class DpoLoss(nn.Module):
"""
Dpo loss
Details: https://arxiv.org/pdf/2305.18290.pdf
SimPO loss:
Details: https://arxiv.org/pdf/2405.14734.pdf
"""
def __init__(self, beta: float = 0.1):
def __init__(self, beta: float = 0.1, gamma: float = 0.0):
"""
Args:
beta: The temperature parameter in the DPO paper.
gamma: The margin parameter in the SimPO paper.
length_normalization: Whether to normalize the loss by the length of chosen and rejected responses.
Refer to the length normalization in the SimPO paper
"""
super().__init__()
self.beta = beta
self.gamma = gamma
def forward(
self,
@@ -103,7 +114,7 @@ class DpoLoss(nn.Module):
chosen_mask: torch.Tensor,
reject_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute the DPO loss for a batch of policy and reference model log probabilities.
"""Compute the DPO/SimPO loss for a batch of policy and reference model log probabilities.
# adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py#L328
@@ -112,6 +123,8 @@ class DpoLoss(nn.Module):
logprob_actor_reject: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
logprob_ref_chosen: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
logprob_ref_reject: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
chosen_mask: Mask tensor indicating which responses were chosen. Shape: (batch_size,)
reject_mask: Mask tensor indicating which responses were rejected. Shape: (batch_size,)
Returns:
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
@@ -126,13 +139,12 @@ class DpoLoss(nn.Module):
if len(logprob_ref_chosen.shape) == 2:
ref_logratios = logprob_ref_chosen.sum(-1) - logprob_ref_reject.sum(-1)
else:
ref_logratios = logprob_ref_chosen.squeeze() - logprob_ref_reject.squeeze()
ref_logratios = logprob_ref_chosen - logprob_ref_reject
else:
# If no reference model is provided
ref_logratios = 0.0
pi_logratios = logprob_actor_chosen.sum(-1) - logprob_actor_reject.sum(-1)
logits = pi_logratios - ref_logratios
logits = pi_logratios - ref_logratios - self.gamma / self.beta
losses = -torch.nn.functional.logsigmoid(self.beta * logits)
# Calculate rewards for logging